In [70]:
import numpy as np
from functools import partial
import matplotlib.pyplot as plt

In [177]:
# (m_max=3, max_no_orbitals_per_m=4, max_split_per_m=5)
ORBITAL_PARAMS = (3, 4, 5)

PYSCF_ORBITAL_INDICES_STRDICT = {
    "H-Ar": [
        [0, 0, 0], 
        [1, 0, 0], 
        [1, 1, 0], 
        [1, 1, 1], 
        [1, 1, 2]
    ],
    "Li-Ne": [
        [0, 0, 0], 
        [1, 0, 0],
        [2, 0, 0],
        [1, 1, 0], 
        [1, 1, 1], 
        [1, 1, 2],
        [2, 1, 0],
        [2, 1, 1],
        [2, 1, 2],
        [2, 2, 0],
        [2, 2, 1],
        [2, 2, 2],
        [2, 2, 3],
        [2, 2, 4]
    ]
}

PYSCF_ORBITAL_INDICES = {i: PYSCF_ORBITAL_INDICES_STRDICT["H-Ar"] if i < 3 else PYSCF_ORBITAL_INDICES_STRDICT["Li-Ne"]for i in range(1, 15)}

In [178]:
def rdm_transform(rdm, Z, orbital_parameters):
    (a, b) = rdm.shape
    x = np.repeat(np.arange(a)[:, None], b, axis=1)
    y = np.repeat(np.arange(b)[None, :], a, axis=0)
    ind_mat = np.dstack((x, y))
    
    def ind_transform(id, Z, orbital_parameters):
        [i, j] = id
        new_ij = []
        orb_numbers = np.array([5 if i < 3 else 14 for i in Z])
        atom_ids = np.repeat(np.arange(1, len(Z) + 1), orb_numbers)
        for k, x in enumerate([i, j]):
            atom_id = atom_ids[x]
            num_orb_within_mol = (atom_ids == atom_id)[:x].sum()
            index = PYSCF_ORBITAL_INDICES[atom_id][num_orb_within_mol]
            flattened_index_offset = np.ravel_multi_index(index, dims=orbital_parameters)
            max_index = tuple([x - 1 for x in orbital_parameters])
            flattened_index = flattened_index_offset + (atom_id - 1) * np.ravel_multi_index(max_index, dims=orbital_parameters)
            new_ij.append(flattened_index)
        return new_ij
    
    new_ind_matrix = np.apply_along_axis(partial(ind_transform, Z=Z, orbital_parameters=orbital_parameters), axis=2, arr=ind_mat)
    new_ids = np.reshape(new_ind_matrix, (a ** 2, 2))
    return new_ids

In [187]:
def add_Z_and_N(h2_dict):
    n_h2_mols = len(h2_dict["R"])
    Z = np.ones((2 * n_h2_mols,), dtype=np.int32)
    N = 2 * np.ones((n_h2_mols,), dtype=np.int32)
    h2_dict["Z"] = Z
    h2_dict["N"] = N
    return h2_dict

In [211]:
def add_rdm_indices(data_dict, orbital_parameters):
    N_rdm = data_dict["N_rdm"]
    N = data_dict["N"]
    rdm_values = data_dict["rdm_hf"]
    Z = data_dict["Z"]
    rdm_ids = []
    n_rdm_end = 0
    n_end = 0
    for (n, n_rdm) in zip(N, N_rdm):
        n_rdm_start = n_rdm_end
        n_start = n_end

        n_rdm_end = n_rdm_start + n_rdm
        n_end = n_start + n

        z = Z[n_start:n_end]
        rdm = rdm_values[n_rdm_start:n_rdm_end]
        
        rdm = np.reshape(rdm, (int(np.sqrt(n_rdm)), int(np.sqrt(n_rdm))))
        rdm_new_ids = rdm_transform(rdm, z, orbital_parameters)
        rdm_ids.append(rdm_new_ids)
    data_dict["rdm_id"] = np.array(rdm_ids)
    return data_dict

In [212]:
data = np.load("../../data/md_h2.npz", allow_pickle=True)
#data = add_Z_and_N(data)
#data = add_rdm_indices(data, ORBITAL_PARAMS)

In [214]:
list(data.keys())

['R', 'densities', 'corrs', 'coords', 'rdm_hf', 'N_rdm', 'Z', 'N', 'rdm_id']

In [215]:
data["N_rdm"]

array([100, 100, 100, ..., 100, 100, 100])