In [1]:
from celerity.dataloader import TimeLaggedDataset, DataLoader, TrajectoryDataset
from celerity.models import VAMPnetEstimator, VAMPNetModel 
from celerity.featurizer import Dihedrals

import torch
from addict import Dict as Adict
import deeptime as dt
import mdtraj as md

In [2]:
trans_config = Adict(
    topology_path = 'data/topology.pdb', 
    which=['phi', 'psi', 'chi1', 'chi2', 'chi3', 'chi4', 'chi5']
)
transform = Dihedrals(trans_config)

In [7]:
lag_time = 1
stride = 10
validation_split = 0.3
batch_size = 1000

# vampnet estimator config
nn_config = Adict(
        lag_time = lag_time, 
        network_dimensions = [34, 100, 100,100, 10], 
        lr = 5e-4, 
        n_epochs = 1, 
        optimizer=torch.optim.Adam, 
        score = Adict(
              method='VAMP2', 
              mode='regularize', 
              epsilon=1e-6
        ), 
        loss = dt.decomposition.deep.vampnet_loss, 
        device="cuda"
    ) 
# Timelagged dataset config (for estimating)
tlds_config = Adict(
    traj_paths_pattern='data/trajectory.h5', 
    lag_time=lag_time, 
    in_memory=True, 
    stride=stride,
)
# Trajectory dataset config (for transforming)
tds_config = Adict(
    traj_path_pattern='data/trajectory.h5', 
    in_memory=True , 
    stride=stride
)
# Dataloader config
dl_config = Adict(
    shuffle=True, 
    transform=transform, 
    output='tensor', 
    batch_size=batch_size, 
    dataset=None, 
    num_workers=5
)

In [8]:
dataset = TimeLaggedDataset(tlds_config)
n_val = int(len(dataset)*validation_split)
train_data, val_data = torch.utils.data.random_split(dataset, [len(dataset) - n_val, n_val])

In [9]:
train_config = Adict(dl_config)
train_config.update({'dataset': train_data})
train_loader = DataLoader(train_config)

val_config = Adict(dl_config)
val_config.update({'dataset': val_data})
val_loader = DataLoader(val_config) 

In [10]:
mod = VAMPnetEstimator(nn_config)
mod.fit(train_loader, val_loader)

0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 22

KeyboardInterrupt: 

In [12]:
mod.dict_scores['train']['VAMP2']

{0: 6.409778118133545,
 1: 6.788443088531494,
 2: 7.3441877365112305,
 3: 7.74589204788208,
 4: 8.036165237426758,
 5: 8.163345336914062,
 6: 8.430634498596191,
 7: 8.556655883789062,
 8: 8.626734733581543,
 9: 8.717721939086914,
 10: 8.797319412231445,
 11: 8.936635971069336,
 12: 8.953822135925293,
 13: 8.994768142700195,
 14: 9.01839828491211,
 15: 9.057689666748047,
 16: 9.103723526000977,
 17: 9.1278076171875,
 18: 9.230537414550781,
 19: 9.160781860351562,
 20: 9.261265754699707,
 21: 9.212214469909668,
 22: 9.27663803100586,
 23: 9.327813148498535,
 24: 9.33942985534668,
 25: 9.367663383483887,
 26: 9.394597053527832,
 27: 9.3475980758667,
 28: 9.415539741516113,
 29: 9.448305130004883,
 30: 9.445259094238281,
 31: 9.472975730895996,
 32: 9.504988670349121,
 33: 9.521698951721191,
 34: 9.534695625305176,
 35: 9.5328369140625,
 36: 9.545097351074219,
 37: 9.562959671020508,
 38: 9.553508758544922,
 39: 9.570413589477539,
 40: 9.576491355895996,
 41: 9.58882999420166,
 42: 9.58799