In [3]:
from pyscf import gto, scf, md
import re
import pickle
import numpy as np
import os 



In [4]:
def run_md_simulation():
    h2 = gto.Mole()
    h2.atom = [['H', (0.7, 0, 0)], ['H', (-0.7, 0, 0)]]
    h2.basis = 'ccpvdz'
    h2.unit = 'B'
    h2.build() 
    h2._atom
    mf = scf.RHF(h2)
    mf.kernel()
    mycas = mf.CASSCF(2, 2)
    myscanner = mycas.nuc_grad_method().as_scanner()

    # Generate the integrator
    # sets the time step to 5 a.u. and will run for 100 steps
    # or for 50 a.u.
    myintegrator = md.NVE(myscanner,
                                dt=1,
                                steps=1200,
                                energy_output="BOMD.md.energies",
                                trajectory_output="BOMD.md.xyz",
                                verbose=0).run()

    # Note that we can also just pass the CASSCF object directly to
    # generate the integrator and it will automatically convert it to a scanner
    # myintegrator = pyscf.md.NVE(mycas, dt=5, steps=100)

    # Close the file streams for the energy and trajectory.
    myintegrator.energy_output.close()
    myintegrator.trajectory_output.close()


In [5]:
def create_grid_coordinates(ticks=30):
    raw_data = open("../../data/BOMD.md.xyz", "r").read() + "2\n"
    raw_data = raw_data.split("MD Time")[1:]
    R = []
    for n in range(len(raw_data)):
        data_point = raw_data[n]
        data_point = data_point.split("\n")[1:-2]
        data_point = [re.sub(r'\s+', ',', coords).split(",") for coords in data_point]
        data_point = [[x, y, z] for [atom, x, y, z] in data_point]
        R.append(data_point)
    R = np.array(R, dtype=np.float32)
    R_reshaped = R.reshape((len(raw_data) * 2, 3))
    max_abs_value = np.sqrt((R_reshaped * R_reshaped).sum(axis=-1)).max()
    # since the hydrogen molecule is always centered around the origin, grid axis can be chosen to be symmetrical wrt origin
    array = np.linspace(-max_abs_value / 2, max_abs_value, ticks)
    coords = np.array([[x, y, 0] for x in array for y in array])
    data_dict = {"R": R, "coords": coords}
    pickle.dump(data_dict, open(f"../../data/md_h2_R_coords.npz", "wb"))

In [52]:
def create_densities_and_corrs(length, overwrite=False, with_corrs=False):
    if overwrite:
        data_dict = {key: None for key in ["R", "densities", "corrs", "coords"]}
        with open("../../data/md_h2.npz", "wb") as f:
            pickle.dump(data_dict, f)
    R_and_coords = np.load("../../data/md_h2_R_coords.npz", allow_pickle=True)
    R = R_and_coords["R"]
    coords = R_and_coords["coords"]
    for [R1, R2] in R[:length]:
        atom_data = [[1, R1.tolist()], [1, R2.tolist()]]
        mol = gto.M()
        mol.atom = atom_data
        mol.basis = "ccpvdz"
        mol.build()
        hf = scf.RHF(mol).run(verbose=0)
        rdm1 = hf.make_rdm1()
        rdm2 = hf.make_rdm2()
        ao_vals = np.array(mol.eval_ao("GTOval_sph", coords))
        densities_for_molecule = np.einsum("ij,ni,nj->n", rdm1, ao_vals, ao_vals)
        atom_data = [[x, y, z] for [i, [x, y, z]] in atom_data]

        temp_dict = {
            "R": np.array(atom_data, dtype=np.float32)[np.newaxis, :, :], 
            "densities": densities_for_molecule[np.newaxis, :]
        }
        
        if with_corrs:
            corrs_for_molecule = 0.5 * np.einsum("ijkl,ni,nj,mk,ml->", rdm2, ao_vals, ao_vals, ao_vals, ao_vals)
            temp_dict["corrs"] = corrs_for_molecule

        file_name = "../../data/md_h2.npz"
        current_data = np.load(file_name, allow_pickle=True)
        for key in ["R", "densities", "corrs"]:
            if not with_corrs and key == "corrs":
                continue
            if current_data[key] is None:
                current_data[key] = temp_dict[key]
            else:
                current_data[key] = np.concatenate((current_data[key], temp_dict[key]))
        if current_data["coords"] is None:
            current_data["coords"] = coords
        with open("../../data/md_h2.npz", "wb") as f:
            pickle.dump(current_data, f)
        

In [53]:
create_densities_and_corrs()