In [26]:
import torch 
from src.distill_datasets import SimpleDataset
from fairchem.core.common.registry import registry
import ase.io
from ase.neighborlist import natural_cutoffs, NeighborList
import gsd.hoomd
from tqdm import tqdm
import numpy as np


dataset_path = "/data/shared/ishan_stuff/transition1x/traj/train/CH4N2O_rxn4428.traj"
trajectory = ase.io.read(dataset_path, index=":")


In [27]:
def save_ovito_traj(positions, bonds, filename):
    """
    Save the given positions to a GSD file using Ovito.
    """

    t = gsd.hoomd.open(name=filename, mode="w")
    cell = 5 * torch.eye(3) * positions.cpu().abs().max()

    for i, (pos, bond) in tqdm(enumerate(zip(positions, bonds))):
        t.append(create_frame(i, pos, cell, bond))

    t.close()


def create_frame(step, position, cell, bonds):
    """
    Create an Ovito frame from the given positions.
    """
    # Particle positions, velocities, diameter
    # TODO: add option to add bonds between C and N atoms

    natoms = position.shape[0]
    position = torch.Tensor(position)
    partpos = position.tolist()
    diameter = 0.8 * np.ones((natoms,))
    diameter = diameter.tolist()
    # Now make gsd file
    s = gsd.hoomd.Frame()
    s.configuration.step = step
    s.particles.N = natoms
    s.particles.position = partpos
    s.particles.diameter = diameter
    s.configuration.box = [cell[0][0], cell[1][1], cell[2][2], 0, 0, 0]

    s.bonds.N = bonds.shape[0]
    s.bonds.group = bonds
    return s

In [28]:
traj = torch.stack([torch.tensor(atom.get_positions()) for atom in trajectory])

bonds = []
for atoms in trajectory:
    NL = NeighborList(natural_cutoffs(atoms), self_interaction=False)
    NL.update(atoms)
    bond = (
        torch.tensor(NL.get_connectivity_matrix().todense().nonzero())
        .T
    )

    bonds.append(bond)

assert (len(traj) - 10) % 8 == 0

traj = traj[13:][::8]
bonds = bonds[13:][::8]


In [29]:
save_ovito_traj(traj, bonds, "trans1x.gsd")

55it [00:00, 4992.35it/s]


In [30]:
chunked_traj = torch.chunk(traj, 10, dim=0)

In [31]:
torch.allclose(chunked_traj[0], chunked_traj[1], atol = 1e-3)
(chunked_traj[0] - chunked_traj[1]).abs().max()

tensor(0.1167, dtype=torch.float64)