In [None]:
import ase
import networkx as nx
import os
from monty.serialization import loadfn
from glob import glob
import time
from tqdm import tqdm
import collections
import numpy as np
import matplotlib.pyplot as plt
import ast
import h5py
import itertools


from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.analysis.graphs import MoleculeGraph
from pymatgen.analysis.local_env import OpenBabelNN, CovalentBondNN
from pymatgen.util.graph_hashing import weisfeiler_lehman_graph_hash

from radqm9_pipeline.elements import read_elements
from radqm9_pipeline.modules import merge_data

# Read Data

In [1]:
data = ...

Ellipsis

In [None]:
elements_dict = read_elements('/pscratch/sd/m/mavaylon/sam_ldrd/radqm9_pipeline/src/radqm9_pipeline/modules/elements.pkl')

# Process

In [None]:
def charge_spin_tag(data: list):
    for item in tqdm(data):
        item['charge_spin'] = str(item['charge'])+'_'+str(item['spin'])

In [None]:
charge_spin_tag(merged_data)

In [None]:
def solvent_convert(data:list):
    unresolved = []
    for item in tqdm(data):
        solv  = item['solvent']
        if solv == 'NONE':
            item['solvent'] = 'vacuum'
        elif solv == 'SOLVENT=WATER':
            item['solvent'] = 'SMD'
        else:
            unresolved.append(item)
    return unresolved

In [None]:
solvent_convert(merged_data)

In [None]:
def type_tagger(data: list):
    for item in tqdm(data):
        if item['charge_spin'] == item['optimized_parent_charge_spin']:
            item['sp_config_type'] = 'optimized'
        else:
            item['sp_config_type'] = 'vertical'

In [None]:
type_tagger(merged_data)


In [None]:
def resolve_parent_charge_spin(data: list):
    for item in tqdm(data):
        item['optimized_parent_charge_spin']= item['optimized_parent_charge_spin'].split('_')

In [None]:
resolve_parent_charge_spin(merged_data)

In [None]:
bucket_mol_id2={}
for data in tqdm(merged_data):
    try:
        bucket_mol_id2[data['mol_id']].append(data)
    except KeyError:
        bucket_mol_id2[data['mol_id']] = [data]

In [None]:
mol_id_present_config = {}
for data in tqdm(merged_data):
    opt_parent = data['optimized_parent_charge_spin'][0]+data['optimized_parent_charge_spin'][1]
    try:
        data['mol_id']
        mol_id_present_config[data['mol_id']].append(data['charge_spin']+'_'+data['sp_config_type']+'_'+opt_parent+'_'+data['solvent'])
    except KeyError:
        mol_id_present_config[data['mol_id']] = [data['charge_spin']+'_'+data['sp_config_type']+'_'+opt_parent+'_'+data['solvent']]

In [None]:
def resolve_duplicate_data(data: list):
    filtered_data = []
    
    bucket_mol_id={}
    for item in tqdm(data):
        try:
            bucket_mol_id[item['mol_id']].append(item)
        except KeyError:
            bucket_mol_id[item['mol_id']] = [item]
    
    mol_id_present_config = {}
    for item in tqdm(data):
        opt_parent = item['optimized_parent_charge_spin'][0]+item['optimized_parent_charge_spin'][1]

        item['dup_identifier'] = item['charge_spin']+'_'+item['sp_config_type']+'_'+opt_parent+'_'+item['solvent']
        try:
            mol_id_present_config[item['mol_id']].append(item['charge_spin']+'_'+item['sp_config_type']+'_'+opt_parent+'_'+item['solvent'])
        except KeyError:
            mol_id_present_config[item['mol_id']] = [item['charge_spin']+'_'+item['sp_config_type']+'_'+opt_parent+'_'+item['solvent']]
    
    # get unique set of configs for each key in mol_id_present_config to use as keys to sample from bucket_mol_id
    for mol_id in tqdm(bucket_mol_id):
        pool = list(set(mol_id_present_config[mol_id]))
        for item in pool:
            for point in bucket_mol_id[mol_id]:
                if point['dup_identifier'] == item:
                    filtered_data.append(point)
                    break
                
                
                
    return filtered_data

In [None]:
filtered_data = resolve_duplicate_data(merged_data)

# Split to SMD and VAC

In [None]:
read_removed_graph_h5=...

In [None]:
vacuum_data = []
smd_data = []
solvents = []
for item in read_removed_graph_h5:
    solv = item['solvent']
    solvents.append(solv)
    if solv == 'vacuum':
        vacuum_data.append(item)
    elif solv == 'SMD':
        smd_data.append(item)

# Split Data

In [None]:
def get_molecule_weight(data: list):
    dict_dist = {}
    data_dict = {}
    for item in tqdm(data):
        species_num = []
        species=''.join((sorted(item['species'])))
        
        for element in item['species']:
            species_num.append(elements_dict[element])

        species_sum = sum(species_num)
        try:
            dict_dist[species].append(species_sum)
            # python does a weird thing floats e.g., {126.15499999999993, 126.15499999999994}
            dict_dist[species] = [dict_dist[species][0]]*len(dict_dist[species])
        except KeyError:
            dict_dist[species] = [species_sum]
        
    return dict_dist

def molecule_weight(data: list, weight_dict):
    for item in tqdm(data):
        weight = weight_dict[''.join((sorted(item['species'])))][0]
        item['molecule_mass'] = weight
        
def get_molecule_weight_ase(data: list):
    dict_dist = {}
    data_dict = {}
    for item in tqdm(data):
        species_num = []
        species=''.join((sorted(item.get_chemical_symbols())))
        
        for element in item.get_chemical_symbols():
            species_num.append(elements_dict[element])

        species_sum = sum(species_num)
        try:
            dict_dist[species].append(species_sum)
            # python does a weird thing floats e.g., {126.15499999999993, 126.15499999999994}
            dict_dist[species] = [dict_dist[species][0]]*len(dict_dist[species])
        except KeyError:
            dict_dist[species] = [species_sum]
        
    return dict_dist

def molecule_weight_ase(data: list, weight_dict):
    for item in tqdm(data):
        species=''.join((sorted(item.get_chemical_symbols())))
        weight = weight_dict[species][0]
        item.info['weight'] = weight
        
def weight_to_data_ase(data: list):
    dict_data = {}
    for item in tqdm(data):
        try:
            dict_data[item.info['weight']].append(item)
        except KeyError:
            dict_data[item.info['weight']] = [item]
    return dict_data

#### Vac

In [None]:
merged_dist = get_molecule_weight(vacuum_data)
molecule_weight(vacuum_data, merged_dist)

In [None]:
wtd = weight_to_data(vacuum_data)

In [None]:
length_dict = {key: len(value) for key, value in wtd.items()}
sorted_length_dict = {k: length_dict[k] for k in sorted(length_dict, reverse=True)}


In [None]:
vac_train_mass = [152.037]
vac_test_mass = [144.09200000000007]
vac_val_mass = [143.1080000000001]

vac_train = sorted_length_dict[152.037] # trackers for dataset sizes
vac_test = sorted_length_dict[144.09200000000007]
vac_val = sorted_length_dict[143.1080000000001]

sorted_length_dict.pop(152.037)
sorted_length_dict.pop(144.09200000000007)
sorted_length_dict.pop(143.1080000000001)


# data is a dict: mass-># of trajs
for mass in sorted_length_dict:
    temp_total = vac_train+vac_val+vac_test
    train_ratio = .65-(vac_train/temp_total)
    test_ratio = .25-(vac_test/temp_total)
    val_ratio = .1-(vac_val/temp_total)
    
    if train_ratio > val_ratio and train_ratio>test_ratio:
        vac_train_mass.append(mass)
        vac_train += sorted_length_dict[mass]
    if val_ratio > train_ratio and val_ratio>test_ratio:
        vac_val_mass.append(mass)
        vac_val += sorted_length_dict[mass]
    if test_ratio > val_ratio and test_ratio>train_ratio:
        vac_test_mass.append(mass)
        vac_test += sorted_length_dict[mass]

In [None]:
vac_train/(vac_train+vac_val+vac_test)

In [None]:
sorted_length_dict = {k: length_dict[k] for k in sorted(length_dict, reverse=True)}

In [None]:
vac_train_subset={key: sorted_length_dict[key] for key in vac_train_mass if key in sorted_length_dict}
vac_test_subset={key: sorted_length_dict[key] for key in vac_test_mass if key in sorted_length_dict}
vac_val_subset={key: sorted_length_dict[key] for key in vac_val_mass if key in sorted_length_dict}

vac_train_foo=[[x]*vac_train_subset[x] for x in vac_train_subset]
vac_test_foo=[[x]*vac_test_subset[x] for x in vac_test_subset]
vac_val_foo=[[x]*vac_val_subset[x] for x in vac_val_subset]

from itertools import chain

vac_train_subset_merged = list(chain.from_iterable(vac_train_foo))
vac_test_subset_merged = list(chain.from_iterable(vac_test_foo))
vac_val_subset_merged = list(chain.from_iterable(vac_val_foo))


In [2]:
##### Do manual switches if needed#######

In [None]:
plt.hist(vac_train_subset_merged, bins=50)
plt.ylabel('Frequency (log)')
plt.yscale('log')
plt.xlabel('Molecule Mass')
plt.title('Train')

In [None]:
plt.hist(vac_test_subset_merged, bins=50)
plt.ylabel('Frequency (log)')
plt.yscale('log')
plt.xlabel('Molecule Mass')
plt.title('Test')

In [None]:
plt.hist(vac_val_subset_merged, bins=50)
plt.ylabel('Frequency (log)')
plt.yscale('log')
plt.xlabel('Molecule Mass')
plt.title('Val')

In [None]:
vac_switch=[ 117.039,
116.20399999999991,
116.15999999999993,
115.09599999999999,
115.095,
112.054,
112.05,
111.14799999999995,
 102.08899999999997,
101.06499999999997,
101.06099999999998,
100.20499999999991,
 99.053,
99.04899999999999,
98.18899999999992,
 95.02300000000001,
94.11699999999996,
85.10599999999997,
84.07799999999997,
83.046,]

In [None]:
for mass in vac_switch:
    vac_val_mass.append(mass)
    vac_val += sorted_length_dict[mass]
    
    vac_test_mass.remove(mass)
    vac_test -= sorted_length_dict[mass]

In [None]:
vac_train_data = [wtd[x] for x in vac_train_mass]
vac_train_data = list(chain.from_iterable(vac_train_data))

vac_val_data = [wtd[x] for x in vac_val_mass]
vac_val_data = list(chain.from_iterable(vac_val_data))

vac_test_data = [wtd[x] for x in vac_test_mass]
vac_test_data = list(chain.from_iterable(vac_test_data))


In [None]:
vac_data = {'train':vac_train_data,
        'val': vac_val_data,
        'test': vac_test_data}

#### Vac singlet

In [None]:
vac_train_data_singlet = []
for item in vac_train_data:
    if item['spin']==1:
        vac_train_data_singlet.append(item)

In [None]:
vac_val_data_singlet = []
for item in vac_val_data:
    if item['spin']==1:
        vac_val_data_singlet.append(item)

In [None]:
vac_test_data_singlet = []
for item in vac_test_data:
    if item['spin']==1:
        vac_test_data_singlet.append(item)

In [None]:
vac_singlet_data = {'train':vac_train_data_singlet,
        'val': vac_val_data_singlet,
        'test': vac_test_data_singlet}

#### vac_doublet

In [None]:
vac_train_data_doublet = []
for item in vac_train_data:
    if item['spin']==2:
        vac_train_data_doublet.append(item)

In [None]:
vac_val_data_doublet = []
for item in vac_val_data:
    if item['spin']==2:
        vac_val_data_doublet.append(item)

In [None]:
vac_test_data_doublet = []
for item in vac_test_data:
    if item['spin']==2:
        vac_test_data_doublet.append(item)

In [None]:
vac_doublet_data = {'train':vac_train_data_doublet,
        'val': vac_val_data_doublet,
        'test': vac_test_data_doublet}

# SMD

In [None]:
merged_dist = get_molecule_weight(smd_data)
molecule_weight(smd_data, merged_dist)

In [None]:
wtd = weight_to_data(smd_data)

In [None]:
length_dict = {key: len(value) for key, value in wtd.items()}
sorted_length_dict = {k: length_dict[k] for k in sorted(length_dict, reverse=True)}


In [None]:
smd_train_mass = [152.037]
smd_test_mass = [144.09200000000007]
smd_val_mass = [143.1080000000001]

smd_train = sorted_length_dict[152.037] # trackers for dataset sizes
smd_test = sorted_length_dict[144.09200000000007]
smd_val = sorted_length_dict[143.1080000000001]

sorted_length_dict.pop(152.037)
sorted_length_dict.pop(144.09200000000007)
sorted_length_dict.pop(143.1080000000001)


# data is a dict: mass-># of trajs
for mass in sorted_length_dict:
    temp_total = smd_train+smd_val+smd_test
    train_ratio = .65-(smd_train/temp_total)
    test_ratio = .25-(smd_test/temp_total)
    val_ratio = .1-(smd_val/temp_total)
    
    if train_ratio > val_ratio and train_ratio>test_ratio:
        smd_train_mass.append(mass)
        smd_train += sorted_length_dict[mass]
    if val_ratio > train_ratio and val_ratio>test_ratio:
        smd_val_mass.append(mass)
        smd_val += sorted_length_dict[mass]
    if test_ratio > val_ratio and test_ratio>train_ratio:
        smd_test_mass.append(mass)
        smd_test += sorted_length_dict[mass]

In [None]:
smd_train/(smd_train+smd_val+smd_test)

In [None]:
smd_train_subset={key: sorted_length_dict[key] for key in smd_train_mass if key in sorted_length_dict}
smd_test_subset={key: sorted_length_dict[key] for key in smd_test_mass if key in sorted_length_dict}
smd_val_subset={key: sorted_length_dict[key] for key in smd_val_mass if key in sorted_length_dict}

smd_train_foo=[[x]*smd_train_subset[x] for x in smd_train_subset]
smd_test_foo=[[x]*smd_test_subset[x] for x in smd_test_subset]
smd_val_foo=[[x]*smd_val_subset[x] for x in smd_val_subset]

from itertools import chain

smd_train_subset_merged = list(chain.from_iterable(smd_train_foo))
smd_test_subset_merged = list(chain.from_iterable(smd_test_foo))
smd_val_subset_merged = list(chain.from_iterable(smd_val_foo))


In [None]:
plt.hist(smd_train_subset_merged, bins=50)
plt.ylabel('Frequency (log)')
plt.yscale('log')
plt.xlabel('Molecule Mass')
plt.title('Train')

In [None]:
plt.hist(smd_test_subset_merged, bins=50)
plt.ylabel('Frequency (log)')
plt.yscale('log')
plt.xlabel('Molecule Mass')
plt.title('Test')

In [None]:
plt.hist(smd_val_subset_merged, bins=50)
plt.ylabel('Frequency (log)')
plt.yscale('log')
plt.xlabel('Molecule Mass')
plt.title('Val')

In [3]:
####### do manual switches ##########

In [None]:
smd_switch=[117.039,
116.20399999999991,
116.15999999999993,
113.07199999999997,
112.21599999999991,
110.11599999999997,
109.13199999999996,
109.09199999999998,
102.09199999999997,
102.08899999999997,
101.06499999999997,
101.06099999999998,
100.20499999999991,
97.11699999999996,
95.10499999999999,
95.06099999999999,
93.08899999999998,
92.14099999999995,
88.14999999999993,
86.04999999999998,
86.04599999999999,
85.10999999999997,         ]

In [None]:
for mass in smd_switch:
    smd_val_mass.append(mass)
    smd_val += sorted_length_dict[mass]
    
    smd_test_mass.remove(mass)
    smd_test -= sorted_length_dict[mass]

In [None]:
smd_train_data = [wtd[x] for x in smd_train_mass]
smd_train_data = list(chain.from_iterable(smd_train_data))

smd_val_data = [wtd[x] for x in smd_val_mass]
smd_val_data = list(chain.from_iterable(smd_val_data))

smd_test_data = [wtd[x] for x in smd_test_mass]
smd_test_data = list(chain.from_iterable(smd_test_data))


In [None]:
smd_data = {'train':smd_train_data,
        'val': smd_val_data,
        'test': smd_test_data}

#### singlet

In [None]:
smd_train_data_singlet = []
for item in smd_train_data:
    if item['spin']==1:
        smd_train_data_singlet.append(item)

smd_val_data_singlet = []
for item in smd_val_data:
    if item['spin']==1:
        smd_val_data_singlet.append(item)
        
smd_test_data_singlet = []
for item in smd_test_data:
    if item['spin']==1:
        smd_test_data_singlet.append(item)
        
smd_singlet_data = {'train':smd_train_data_singlet,
        'val': smd_val_data_singlet,
        'test': smd_test_data_singlet}

#### doublet

In [None]:
smd_train_data_doublet = []
for item in smd_train_data:
    if item['spin']==2:
        smd_train_data_doublet.append(item)

smd_val_data_doublet = []
for item in smd_val_data:
    if item['spin']==2:
        smd_val_data_doublet.append(item)
        
smd_test_data_doublet = []
for item in smd_test_data:
    if item['spin']==2:
        smd_test_data_doublet.append(item)
        
smd_doublet_data = {'train':smd_train_data_doublet,
        'val': smd_val_data_doublet,
        'test': smd_test_data_doublet}

# Convert

In [None]:
def sp_convert_energy(data: list):
    for item in tqdm(data):
        energy = item['energy']
        item['energy'] = energy*27.2114

In [None]:
def sp_convert_forces(data: list):
    for item in tqdm(data):
        forces = item['gradient']
        atom_arr = []
        for atom in forces:
            comp_arr = []
            for component in atom:
                new_component = component * 51.42208619083232
                comp_arr.append(new_component)
            atom_arr.append(comp_arr)
        item['gradient'] = atom_arr

In [None]:
##### convert all the data above

# Resp DM

In [None]:
def sp_generate_resp_dipole(data: list): #THIS IS GOOD
    for item in tqdm(data):
        resp_dipole = []
        resp_dipole_conv = []
        
        resp_partial_charges = np.array(item['resp_partial_charges'])
        geometries = np.array(item['geometry'])

        # Calculate dipole moment components
        dipole_components = resp_partial_charges[:, np.newaxis] * geometries

        # Sum the dipole moment components along axis 0 to get the total dipole moment vector
        dipole_moment_conv = np.sum(dipole_components, axis=0)*(1/0.2081943)

        # Append dipole moment to resp_dipole list
        resp_dipole_conv.append(dipole_moment_conv.tolist())  # Convert numpy array to list
        
        item['calc_resp_dipole_moments'] = resp_dipole_conv

In [None]:
sp_generate_resp_dipole(...) # do this for all the data

# Write

In [None]:
def build_sp_atoms(data: dict,
                energy: str = None,
                forces: str = None,
                charge:str = None,
                spin:str = None,
                train = False) -> ase.Atoms:
    """ 
    Populate Atoms class with atoms in molecule.
        atoms.info : global variables
        atoms.array : variables for individual atoms
        
    Both "energy" and "forces" are the dict strings in data.
    """
    atom_list = []
    for i in range(len(data['geometries'])):
        atoms = ase.atoms.Atoms(
            symbols=data['species'],
            positions=data['geometry']
        )
        atoms.arrays['mulliken_partial_charges']=np.array(data['mulliken_partial_charges'])
        atoms.arrays['mulliken_partial_spins']=np.array(data['mulliken_partial_spins'])
        atoms.arrays['resp_partial_charges']=np.array(data['resp_partial_charges'])
        atoms.info['calc_resp_dipole_moments']=np.array(data['calc_resp_dipole_moments'])
        
        atoms.info['optimized_parent_charge']= data['optimized_parent_charge_spin'] [0]
        atoms.info['optimized_parent_spin']= data['optimized_parent_charge_spin'] [0]
        atoms.info['solvent'] =  item.info['solvent'] 

        if energy is not None:
            atoms.info['energy'] = data[energy]
        if forces is not None:
            atoms.arrays['forces'] = np.array(data[forces])
        if charge is not None:
             atoms.info['charge'] = data[charge]
        if spin is not None:
            atoms.info['spin'] = data[spin]
        atoms.info['mol_id'] = data['mol_id']
        
        atom_list.append(atoms)
    return atom_list


In [None]:
train_smd = build_sp_atoms(train_smd) ### repeat this for all datasets

In [None]:
# Write

In [None]:
file = '...' ### repeat for all datasets
ase.io.write(file, train_smd,format="extxyz")

# Chunking

In [None]:
def chunk_data(data: dict, chunks: list):
    return_data = {}
    foo_data = data
    total=0
    for pair in tqdm(data):
        total+=len(data[pair])
    
    sizes = []
    for item in chunks:
        temp_size = round(total*item)
        sizes.append(temp_size)
    
    for i in range(len(chunks)):
        chunk_data = []
        if i==0:
            for key in tqdm(data):
                if len(foo_data[key]) != 0:
                    # print(len(foo_data[key]))
                    sample_size = math.floor(chunks[i] * len(foo_data[key]))
                    chunk_data += foo_data[key][:sample_size]
                    foo_data[key] = foo_data[key][sample_size:]
            return_data[i] = chunk_data
        else:
            counter = 0
            for j in range(50):
                if counter < sizes[i]-sizes[i-1]:
                    for key in data:
                        if len(foo_data[key]) != 0:
                            sample_size = math.floor((chunks[i]-chunks[i-1]) * len(foo_data[key]))
                            # print(sample_size)
                            # print(len(foo_data[key]))
                            add_on = foo_data[key][:sample_size]
                            chunk_data += add_on
                            # print(len(foo_data[key][:sample_size])/len(foo_data[key]))
                            foo_data[key] = foo_data[key][sample_size:]
                            counter += len(add_on)
                            if counter >= sizes[i]-sizes[i-1]:
                                break
                else:
                    # print(counter)
                    # print(sizes[i])
                    # print(sizes[i]/total)
                    # print(len(chunk_data)/total)
                    break
                # if counter > total:
                #     print('bad')
                #     break

            return_data[i] = chunk_data + return_data[i-1]
    return return_data    

In [None]:
#read_xyz train
atoms_list = ...

In [None]:
mm=get_molecule_weight_ase(atoms_list)

In [None]:
molecule_weight_ase(atoms_list, mm)

In [None]:
ase_data = weight_to_data_ase(atoms_list)

In [None]:
cd = chunk_data(ase_data, [.05, .1, .25, .5, .75])

In [None]:
for key in cd:
    print(len(cd[key])/len(atoms_list))

In [None]:
chunk_file = os.path.join('/pscratch/sd/m/mavaylon/chem_final_data/SP/Zip_FINAL/Full_fields_Singlet_chunks/smd','rad_qm9_smd'+'_train05.xyz')
ase.io.write(chunk_file, cd[0],format="extxyz")

In [None]:
chunk_file = os.path.join('/pscratch/sd/m/mavaylon/chem_final_data/SP/Zip_FINAL/Full_fields_Singlet_chunks/smd','rad_qm9_smd'+'_train10.xyz')
ase.io.write(chunk_file, cd[1],format="extxyz")

In [None]:
chunk_file = os.path.join('/pscratch/sd/m/mavaylon/chem_final_data/SP/Zip_FINAL/Full_fields_Singlet_chunks/smd','rad_qm9_smd'+'_train25.xyz')
ase.io.write(chunk_file, cd[2],format="extxyz")

In [None]:
chunk_file = os.path.join('/pscratch/sd/m/mavaylon/chem_final_data/SP/Zip_FINAL/Full_fields_Singlet_chunks/smd','rad_qm9_smd'+'_train50.xyz')
ase.io.write(chunk_file, cd[3],format="extxyz")

In [None]:
chunk_file = os.path.join('/pscratch/sd/m/mavaylon/chem_final_data/SP/Zip_FINAL/Full_fields_Singlet_chunks/smd','rad_qm9_smd'+'_train75.xyz')
ase.io.write(chunk_file, cd[4],format="extxyz")

# Relative Energies

In [None]:
def relative_energies(data: list, stats: dict):
    for item in tqdm(data):
        key = str(item.info['charge'])+str(item.info['spin'])
        lookup_sum = 0
        for num in item.arrays['numbers']:
            lookup_sum += eval(stats[key]['atomic_energies'])[num]
        
        rel = item.info['energy'] - lookup_sum
        item.info['relative_energy'] = rel

# SMD

In [None]:
import json

with open('/pscratch/sd/m/mavaylon/chem_final_data/SP/Full_Fields_ChargeSPin/smd/01/h5/statistics.json', 'r') as f:
    data_smd_01 = json.load(f)
    
with open('/pscratch/sd/m/mavaylon/chem_final_data/SP/Full_Fields_ChargeSPin/smd/03/h5/statistics.json', 'r') as f:
    data_smd_03 = json.load(f)
    
with open('/pscratch/sd/m/mavaylon/chem_final_data/SP/Full_Fields_ChargeSPin/smd/12/h5/statistics.json', 'r') as f:
    data_smd_12 = json.load(f)
    
with open('/pscratch/sd/m/mavaylon/chem_final_data/SP/Full_Fields_ChargeSPin/smd/21/h5/statistics.json', 'r') as f:
    data_smd_21 = json.load(f)

with open('/pscratch/sd/m/mavaylon/chem_final_data/SP/Full_Fields_ChargeSPin/smd/-12/h5/statistics.json', 'r') as f:
    data_smd_n12 = json.load(f)
    
with open('/pscratch/sd/m/mavaylon/chem_final_data/SP/Full_Fields_ChargeSPin/smd/-21/h5/statistics.json', 'r') as f:
    data_smd_n21 = json.load(f)

stats_dict= {
             # '-21': data_smd_n21,
             # '-12': data_smd_n12,
             # '01': data_smd_01,
             # '03': data_smd_03,
             # '12': data_smd_12,
             '21': data_smd_21
}


In [None]:
relative_energies(train_smd, stats_dict)

In [None]:
relative_energies(val_smd, stats_dict)

In [None]:
relative_energies(test_smd, stats_dict)

In [None]:
def re_sp_build_minimal_atoms_rel(data: dict,
                energy: str = None,
                forces: str = None,
                charge:str = None,
                spin:str = None,
                train = False) -> ase.Atoms:
    """ 
    Populate Atoms class with atoms in molecule.
        atoms.info : global variables
        atoms.array : variables for individual atoms
        
    Both "energy" and "forces" are the dict strings in data.
    """
    atom_list = []
    for item in tqdm(data):
        atoms = ase.atoms.Atoms(
            numbers=item.arrays['numbers'],
            positions=item.arrays['positions']
        )
        atoms.info['total_energy'] = item.info['energy']
        atoms.info['relative_energy'] = item.info['relative_energy']
        atoms.info['mol_id'] = item.info['mol_id']
        atoms.arrays['forces'] = np.array(item.arrays['forces'])
        atoms.info['charge'] =  item.info['charge']
        atoms.info['spin'] =  item.info['spin'] 
        atoms.info['optimized_parent_charge'] = item.info['optimized_parent_charge']
        atoms.info['optimized_parent_spin'] = item.info['optimized_parent_spin']
        atoms.info['solvent'] =  item.info['solvent'] 
        atoms.arrays['mulliken_partial_charges']=np.array(item.arrays['mulliken_partial_charges'])
        atoms.arrays['mulliken_partial_spins']=np.array(item.arrays['mulliken_partial_spins'])
        atoms.arrays['resp_partial_charges']=np.array(item.arrays['resp_partial_charges'])
        atoms.info['calc_resp_dipole_moments']=np.array(item.info['calc_resp_dipole_moments'])
        
        atom_list.append(atoms)
    return atom_list


In [None]:
retrain_smd = re_sp_build_minimal_atoms_rel(train_smd)

In [None]:
file = os.path.join('/pscratch/sd/m/mavaylon/chem_final_data/SP/Full_Fields_ChargeSPin/smd/21','rad_qm9_7_25_24_converted_E_F_convrespdm_relenergy_SMD_train_21.xyz')
ase.io.write(file, retrain_smd, format="extxyz")

In [None]:
doublet_data = []
for item in tqdm(filtered_data):
    if item['charge'] in [-1, 1]:
        doublet_data.append(item)

In [None]:
vacuum_doublet_data_data = []
smd_doublet_data_data = []
solvents = []
for item in doublet_data:
    solv = item['solvent']
    solvents.append(solv)
    if solv == 'vacuum':
        vacuum_doublet_data_data.append(item)
    elif solv == 'SMD':
        smd_doublet_data_data.append(item)

In [None]:
singlet_data = []
for item in tqdm(filtered_data):
    if item['charge'] in [0, -2, 2]:
        singlet_data.append(item)

In [None]:
vacuum_singlet_data_data = []
smd_singlet_data_data = []
solvents = []
for item in singlet_data:
    solv = item['solvent']
    solvents.append(solv)
    if solv == 'vacuum':
        vacuum_singlet_data_data.append(item)
    elif solv == 'SMD':
        smd_singlet_data_data.append(item)

#### You do not need to filter since these are all optimized points, which mean the forces are really small.