In [19]:
import h5py
from tqdm import tqdm
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator

In [43]:
def get_entries(in_dir, file_name):
    entries = []

    with h5py.File(f"{in_dir}/{file_name}.h5", 'r') as hdf5_file:
        num_configs = len(hdf5_file["config_batch_0"])
        # num_configs = 1000
        for i in tqdm(range(num_configs)):
            config_group = hdf5_file[f'config_batch_0/config_{i}']
            atomic_numbers = config_group['atomic_numbers'][:]

            # filter out samples
            # if not all([element in most_common_elements_only_one_per_sample for element in atomic_numbers]):
            #     continue


            # cell = config_group['cell'][:]
            # charges = config_group['charges'][:]
            # energy = config_group['energy'][()] # curtis: why is energy ()??
            # forces = config_group['forces'][:]
            # positions = config_group['positions'][:]
            
            # Extract properties
            properties = {
                'atomic_numbers': config_group['atomic_numbers'][:],
                'cell': config_group['cell'][:],
                'charges': config_group['charges'][:],
                'config_type': config_group['config_type'][()],  # Assuming scalar
                'dipole': config_group['dipole'][:],
                'energy': config_group['energy'][()],  # Assuming scalar
                'energy_weight': config_group['energy_weight'][()],  # Assuming scalar
                'forces': config_group['forces'][:],
                'forces_weight': config_group['forces_weight'][()],  # Assuming scalar
                'pbc': config_group['pbc'][:],
                'positions': config_group['positions'][:],
                'stress': config_group['stress'][:],
                'stress_weight': config_group['stress_weight'][()],  # Assuming scalar
                'virials': config_group['virials'][:],
                'virials_weight': config_group['virials_weight'][()],  # Assuming scalar
                'weight': config_group['weight'][()]  # Assuming scalar
            }
            
            # we do NOT need to save it. as an atoms object!
            

            # I checked. positions=positions are setting the cartesian coordinates.
            #atoms = Atoms(numbers=atomic_numbers, positions=positions, cell=cell, pbc=[True, True, True], charges=charges)
            #atoms = Atoms(numbers=atomic_numbers, positions=positions, cell=cell, pbc=[True, True, True], charges=charges)

            # I verified that the energy IS the energy that includes the correction (see curtis_read_alexandria.ipynb)
            #calc = SinglePointCalculator(atoms, energy=energy, forces=forces)
            #atoms.set_calculator(calc)
            entries.append(properties)

    print(f"found {num_configs} systems")
    print(f"after filtering, found {len(entries)} systems")
    return entries

In [44]:
IN_TRAIN_DIR = "/Users/curtischong/Documents/dev/joule/datasets/real_mace/train"

def parse_datasets(in_dir, in_dir_prefix, num_files):
    results = []
    for i in range(num_files):
        print(f"parsing {in_dir_prefix}_{i}")
        results.extend(get_entries(in_dir, f"{in_dir_prefix}_{i}"))
    return results

results = parse_datasets(IN_TRAIN_DIR, "train", num_files=1) # TODO: increase num_files

parsing train_0


100%|████████████████████████████████████████████████████████████████████████████████████████| 24661/24661 [00:15<00:00, 1612.92it/s]

found 24661 systems
after filtering, found 24661 systems





In [45]:
results[9000]

{'atomic_numbers': array([3, 8, 8]),
 'cell': array([[ 2.30281076, -0.39696027,  2.29894373],
        [ 0.69315251,  2.23160376,  2.29894368],
        [-0.64823271, -0.39696027,  3.18871066]]),
 'charges': array([0., 0., 0.]),
 'config_type': b'Default',
 'dipole': array([0., 0., 0.]),
 'energy': -11.34848575,
 'energy_weight': 1.0,
 'forces': array([[-0.        , -0.        , -0.        ],
        [ 0.0104404 ,  0.0063934 ,  0.03462714],
        [-0.0104404 , -0.0063934 , -0.03462714]]),
 'forces_weight': 1.0,
 'pbc': array([ True,  True,  True]),
 'positions': array([[1.17386892, 0.71884384, 3.89330044],
        [0.59527832, 0.36452604, 1.97431853],
        [1.75246006, 1.07315113, 5.81228413]]),
 'stress': array([[ 0.00309506, -0.00038883, -0.00210592],
        [-0.00038883,  0.00349191, -0.0012896 ],
        [-0.00210592, -0.0012896 , -0.00325457]]),
 'stress_weight': 1.0,
 'virials': array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]),
 'virials_weight': 0.0,
 'weigh

In [46]:
results[3666]

{'atomic_numbers': array([3, 8, 8]),
 'cell': array([[ 2.30281076, -0.39696027,  2.29894373],
        [ 0.69315251,  2.23160376,  2.29894368],
        [-0.64823271, -0.39696027,  3.18871066]]),
 'charges': array([0., 0., 0.]),
 'config_type': b'Default',
 'dipole': array([0., 0., 0.]),
 'energy': -11.62460911,
 'energy_weight': 1.0,
 'forces': array([[ 0.        ,  0.        ,  0.        ],
        [-0.00057577, -0.00035259, -0.00190963],
        [ 0.00057577,  0.00035259,  0.00190963]]),
 'forces_weight': 1.0,
 'pbc': array([ True,  True,  True]),
 'positions': array([[1.17386892, 0.71884384, 3.89330044],
        [0.59527832, 0.36452604, 1.97431853],
        [1.75246006, 1.07315113, 5.81228413]]),
 'stress': array([[-0.09327584,  0.00276309,  0.01496505],
        [ 0.00276309, -0.09609598,  0.00916419],
        [ 0.01496505,  0.00916419, -0.04815416]]),
 'stress_weight': 1.0,
 'virials': array([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]]),
 'virials_weight': 0.0,
 'weigh

In [None]:
# forces, stress and energy are the only thigns that are diff^^^

In [40]:
results[9000]

array([0., 0., 0.])

In [41]:
results[3666].get_initial_charges()

array([0., 0., 0.])

In [30]:
results[9000].calc.results

{'energy': -11.34848575,
 'forces': array([[-0.        , -0.        , -0.        ],
        [ 0.0104404 ,  0.0063934 ,  0.03462714],
        [-0.0104404 , -0.0063934 , -0.03462714]])}

In [29]:
results[3666].calc.results

{'energy': -11.62460911,
 'forces': array([[ 0.        ,  0.        ,  0.        ],
        [-0.00057577, -0.00035259, -0.00190963],
        [ 0.00057577,  0.00035259,  0.00190963]])}