In [13]:
from pathlib import Path
import mdtraj as md
from celerity.dataloader import TrajectoryDataset, DataLoader
from celerity.featurizer import Dihedrals
import h5py
import numpy as np

In [2]:
data_dir = '/home/mbowley/ANI-Peptides/outputs/production_ypgdv_capped_amber_equilibrated_amber_004427_010322'
out_dir = 'data'
top_path = f'{data_dir}/topology.pdb'
traj_file = Path(data_dir).joinpath('trajectory.dcd')
new_traj_file = Path(out_dir).joinpath(traj_file.with_suffix('.h5').name)

In [3]:
# assert not new_traj_file.exists()
# !mdconvert {str(traj_file)} -t {top_path} -c 10000 -o {str(new_traj_file)}

In [12]:
!mdconvert {str(traj_file)} -t {top_path} -i 0:1000 -o data/frames_1000.h5 

['xyz', 'cell_lengths', 'cell_angles', 'topology']
converted 1000 frames, 82 atoms 


In [7]:
transform = Dihedrals(dict(topology_path = f'{data_dir}/topology.pdb', 
                           which=['phi', 'psi','chi1', 'chi2', 'chi3', 'chi4', 'chi5'], 
                           coosin=True))
dataset = TrajectoryDataset(dict(traj_paths_pattern=str(new_traj_file), 
                                 stride=1))
loader = DataLoader(dict(batch_size=1000, 
                         transform=transform, 
                         dataset=dataset)) 

In [21]:
# Check the loader works 

for chunk in loader: 
    x = chunk
    break

traj = md.load('data/frames_1000.h5')
features = []
for att in transform.options.which: 
    _, vals = getattr(md, f'compute_{att}')(traj)
    print(vals.shape)
    features.append(vals)
features = np.concatenate(features, axis=1)
features = np.concatenate([np.cos(features), np.sin(features)], axis=1)
assert np.allclose(x, features)

(1000, 5)
(1000, 5)
(1000, 4)
(1000, 3)
(1000, 0)
(1000, 0)
(1000, 0)




In [20]:
features

array([[ 0.24819146,  0.44255477,  0.21971345, ..., -0.82917094,
         0.6422112 , -0.34171584],
       [ 0.33028018,  0.59127253,  0.05109643, ..., -0.9208793 ,
         0.34089726,  0.14087634],
       [ 0.36022195,  0.5476826 ,  0.09846118, ..., -0.8838565 ,
         0.6854048 ,  0.08636171],
       ...,
       [ 0.34615216,  0.46468964,  0.5475895 , ..., -0.96336156,
         0.12577799,  0.69522685],
       [ 0.5200439 ,  0.1754095 ,  0.00397164, ..., -0.9997048 ,
         0.5405179 ,  0.6940098 ],
       [ 0.20879333,  0.5682233 ,  0.85809654, ..., -0.7753583 ,
         0.5142238 ,  0.9371281 ]], dtype=float32)