In [2]:
from pathlib import Path

from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import pandas as pd

from chemprop import data, featurizers, models, nn

import h5py

In [17]:
with h5py.File("/home/calvin/Downloads/Transition1x.h5", "r") as hf:
    split = "train"
    formula = "C7H10"
    rxn = "rxn9689"
    grp = hf[split][formula][rxn]
    print(f"=== Info for formula='{formula}', rxn='{rxn}' ===\n")

    # top‐level subgroup/dataset names
    print("Contents:", list(grp.keys()), "\n")

    # attributes on the rxn group
    print("Group attributes:")
    for attr_name, attr_val in grp.attrs.items():
        print(f"  {attr_name!r}: {attr_val}")
    print()

    # for each member under that group, print type, shape/dtype, attrs
    for name, node in grp.items():
        kind = "Group" if isinstance(node, h5py.Group) else "Dataset"
        print(f"{name} ({kind}):")
        if isinstance(node, h5py.Dataset):
            print(f"  shape={node.shape}, dtype={node.dtype}")
        else:
            print(f"  keys={list(node.keys())}")
        if node.attrs:
            print("  attrs:")
            for k, v in node.attrs.items():
                print(f"    {k!r}: {v}")
        print()
    print("=== Data ===")
    # grab the positions dataset and load into memory
    positions = grp['transition_state']['positions']

    # print its shape and data
    print("Positions shape:", positions.shape)
    print("Positions array:")
    print(positions[:])
    print("Positions array:")
    print(positions[:])


    # Need to convert the atomic num to chemical symbols
    symbols = {
        1: "H",
        2: "He",
        3: "Li",
        4: "Be",
        5: "B",
        6: "C",
        7: "N",
        8: "O",
        9: "F",
        10: "Ne",
        11: "Na",
        12: "Mg",
        13: "Al",
        14: "Si",
        15: "P",
        16: "S",
        17: "Cl",
        18: "Ar",
        19: "K",
        20: "Ca",
        21: "Sc",
        22: "Ti",
        23: "V",
        24: "Cr",
        25: "Mn",
    }
    # print positions with newlines and atomic_numbers in the first column
    atomic_numbers = grp['product']['atomic_numbers'][:]  # (17,)
    for frame_idx, frame in enumerate(positions):       # positions.shape = (514,17,3)
        print(f"--- Frame {frame_idx} ---")
        for num, (x, y, z) in zip(atomic_numbers, frame):
            symbol = symbols.get(num, "Unknown")
            if symbol == "Unknown":
                print(f"Unknown atomic number: {num}")
            else:
                print(f"{symbol} {x:.6f} {y:.6f} {z:.6f}")
        print()
    # Need to print the positions with "\n" and place the atomic_numbers in the first column


=== Info for formula='C7H10', rxn='rxn9689' ===

Contents: ['atomic_numbers', 'positions', 'product', 'reactant', 'transition_state', 'wB97x_6-31G(d).energy', 'wB97x_6-31G(d).forces'] 

Group attributes:

atomic_numbers (Dataset):
  shape=(17,), dtype=int32

positions (Dataset):
  shape=(514, 17, 3), dtype=float64

product (Group):
  keys=['atomic_numbers', 'hash', 'positions', 'wB97x_6-31G(d).energy', 'wB97x_6-31G(d).forces']

reactant (Group):
  keys=['atomic_numbers', 'hash', 'positions', 'wB97x_6-31G(d).energy', 'wB97x_6-31G(d).forces']

transition_state (Group):
  keys=['atomic_numbers', 'hash', 'positions', 'wB97x_6-31G(d).energy', 'wB97x_6-31G(d).forces']

wB97x_6-31G(d).energy (Dataset):
  shape=(514,), dtype=float64

wB97x_6-31G(d).forces (Dataset):
  shape=(514, 17, 3), dtype=float64

=== Data ===
Positions shape: (1, 17, 3)
Positions array:
[[[ 2.04255972e+00 -5.40010715e-02  1.22349330e-01]
  [ 5.63347485e-01  1.23325942e-01  2.85414241e-01]
  [-3.45196346e-01 -8.05479732e-