# Example of orbkit

## TREX-IO documentation: https://trex-coe.github.io/trexio/trex.html

## PySCF-forge: https://github.com/pyscf/pyscf-forge
TREX-IO is implemented in PySCF-forge. Please install it **from the GitHub repo.** [pip install git+https://github.com/pyscf/pyscf-forge]

In [44]:
import os, sys
import numpy as np

## pyscf-forge

In [None]:
from pyscf import gto, scf
from pyscf.tools import trexio

filename = 'water_ccecp_ccpvqz.h5'

mol = gto.Mole()
mol.verbose  = 5
mol.atom     = '''
               O    5.00000000   7.14707700   7.65097100
               H    4.06806600   6.94297500   7.56376100
               H    5.38023700   6.89696300   6.80798400
               '''
mol.basis    = 'ccecp-ccpvqz'
mol.unit     = 'A'
mol.ecp      = 'ccecp'
mol.charge   = 0
mol.spin     = 0
mol.symmetry = False
mol.cart = True
mol.output   = 'water.out'
mol.build()

mf = scf.HF(mol)
mf.max_cycle=200
mf_scf = mf.kernel()

trexio.to_trexio(mf, filename)

## orbkit part

In [27]:
from orbkit.trexio_wrapper import read_trexio_file
from orbkit.atomic_orbital import compute_AOs_jax
from orbkit.molecular_orbital import compute_MOs_jax

In [28]:
structure_data, aos_data, mos_data_up, mos_data_dn, coulomb_potential_data = read_trexio_file('water_ccecp_ccpvqz.h5')

In [29]:
structure_data.get_info()

['**Structure_data',
 '  PBC flag = False',
 '  --------------------------------------------------',
 '  element, label, Z, x, y, z in cartesian (Bohr)',
 '  --------------------------------------------------',
 '  O, O, 8.0, -1.32695823, -0.10593853, 0.01878815',
 '  H, H, 1.0, -1.93166524, 1.60017432, -0.02171052',
 '  H, H, 1.0, 0.48664428, 0.07959809, 0.00986248',
 '  --------------------------------------------------']

In [30]:
aos_data.get_info()

['**AOs_sphe_data',
 '  Number of AOs = 114',
 '  Number of primitive AOs = 160',
 '  Angular part is the real spherical (solid) Harmonics.',
 '  ------------------------------------',
 '  **basis set for atom index 1: O**',
 '  ------------------------------------',
 '  O s',
 '    54.775216 -0.0012444',
 '    25.616801 0.0107330',
 '    11.980245 0.0018889',
 '    6.992317 -0.1742537',
 '    2.620277 0.0017622',
 '    1.225429 0.3161846',
 '    0.577797 0.4512023',
 '    0.268022 0.3121534',
 '    0.125346 0.0511167',
 '  O s',
 '    1.351771 1.0000000',
 '  O s',
 '    0.843157 1.0000000',
 '  O s',
 '    0.224380 1.0000000',
 '  O p',
 '    22.217266 0.0104866',
 '    10.747550 0.0366435',
 '    5.315785 0.0803674',
 '    2.660761 0.1627010',
 '    1.331816 0.2377791',
 '    0.678626 0.2811422',
 '    0.333673 0.2643189',
 '    0.167017 0.1466014',
 '    0.083598 0.0458145',
 '  O p',
 '    1.106737 1.0000000',
 '  O p',
 '    0.452364 1.0000000',
 '  O p',
 '    0.148562 1.0000000

In [31]:
mos_data_up.get_info()

['**MOs_data',
 '  Number of MOs = 4',
 '  dim. of MOs coeff = (4, 114)',
 '**AOs_sphe_data',
 '  Number of AOs = 114',
 '  Number of primitive AOs = 160',
 '  Angular part is the real spherical (solid) Harmonics.',
 '  ------------------------------------',
 '  **basis set for atom index 1: O**',
 '  ------------------------------------',
 '  O s',
 '    54.775216 -0.0012444',
 '    25.616801 0.0107330',
 '    11.980245 0.0018889',
 '    6.992317 -0.1742537',
 '    2.620277 0.0017622',
 '    1.225429 0.3161846',
 '    0.577797 0.4512023',
 '    0.268022 0.3121534',
 '    0.125346 0.0511167',
 '  O s',
 '    1.351771 1.0000000',
 '  O s',
 '    0.843157 1.0000000',
 '  O s',
 '    0.224380 1.0000000',
 '  O p',
 '    22.217266 0.0104866',
 '    10.747550 0.0366435',
 '    5.315785 0.0803674',
 '    2.660761 0.1627010',
 '    1.331816 0.2377791',
 '    0.678626 0.2811422',
 '    0.333673 0.2643189',
 '    0.167017 0.1466014',
 '    0.083598 0.0458145',
 '  O p',
 '    1.106737 1.0000000

In [32]:
mos_data_dn.get_info()

['**MOs_data',
 '  Number of MOs = 4',
 '  dim. of MOs coeff = (4, 114)',
 '**AOs_sphe_data',
 '  Number of AOs = 114',
 '  Number of primitive AOs = 160',
 '  Angular part is the real spherical (solid) Harmonics.',
 '  ------------------------------------',
 '  **basis set for atom index 1: O**',
 '  ------------------------------------',
 '  O s',
 '    54.775216 -0.0012444',
 '    25.616801 0.0107330',
 '    11.980245 0.0018889',
 '    6.992317 -0.1742537',
 '    2.620277 0.0017622',
 '    1.225429 0.3161846',
 '    0.577797 0.4512023',
 '    0.268022 0.3121534',
 '    0.125346 0.0511167',
 '  O s',
 '    1.351771 1.0000000',
 '  O s',
 '    0.843157 1.0000000',
 '  O s',
 '    0.224380 1.0000000',
 '  O p',
 '    22.217266 0.0104866',
 '    10.747550 0.0366435',
 '    5.315785 0.0803674',
 '    2.660761 0.1627010',
 '    1.331816 0.2377791',
 '    0.678626 0.2811422',
 '    0.333673 0.2643189',
 '    0.167017 0.1466014',
 '    0.083598 0.0458145',
 '  O p',
 '    1.106737 1.0000000

In [33]:
coulomb_potential_data.get_info()

['**Coulomb_potential_data', '  ecp_flag = True']

In [34]:
num_ele_up = 5
num_ele_dn = 3
r_up_carts = np.random.rand(num_ele_up, 3)
r_dn_carts = np.random.rand(num_ele_dn, 3)

# compute_AOs_jax
def compute_AOs_jax(aos_data: AOs_sphe_data | AOs_cart_data, r_carts: jnpt.ArrayLike) -> jax.Array:
    """Compute AO values at the given r_carts.

    The method is for computing the value of the given atomic orbital at r_carts

    Args:
        ao_datas (AOs_data): an instance of AOs_data
        r_carts (jnpt.ArrayLike): Cartesian coordinates of electrons (dim: N_e, 3)

    Returns:
        jax.Array: Arrays containing values of the AOs at r_carts. (dim: num_ao, N_e)
    """

In [35]:
ao_up_values = compute_AOs_jax(aos_data=aos_data, r_carts = r_up_carts)
ao_up_values.shape

(114, 5)

In [36]:
ao_dn_values = compute_AOs_jax(aos_data=aos_data, r_carts = r_dn_carts)
ao_dn_values.shape

(114, 3)

# compute_MOs_jax

def compute_MOs_jax(mos_data: MOs_data, r_carts: jnpt.ArrayLike) -> jax.Array:
    """The class contains information for computing molecular orbitals at r_carts simlunateously.

    Args:
        mos_data (MOs_data): an instance of MOs_data
        r_carts (jnpt.ArrayLike): Cartesian coordinates of electrons (dim: N_e, 3)

    Returns:
        Arrays containing values of the MOs at r_carts. (dim: num_mo, N_e)
    """

In [42]:
mo_up_values = compute_MOs_jax(mos_data=mos_data_up, r_carts = r_up_carts)
mo_up_values.shape

(4, 5)

In [43]:
mo_dn_values = compute_MOs_jax(mos_data=mos_data_dn, r_carts = r_dn_carts)
mo_dn_values.shape

(4, 3)