# Variable Sized Data With CGSchNet

In developing transferable CG protein force fields, it is often advantageous to use data from multiple molecules. This can naturally lead to datasets composed of differently sized molecules. This notebook walks through an example of how to work with such datasets using the CGSchNet framework. Please note that the dataset that we use here is only for demonstration purposes, and the results within this notebook cannot be used to draw any conclusions. Furthermore, the functionalites demonstrated below are only compatible with SchNet-related tools within CGnet. 

## Setting up and loading data
First,  we import all necessary tools for this notebook. Then we will load a reduced dialanine dataset.

In [None]:
import torch
import torch.nn as nn
import numpy as np

from cgnet.feature import (SchnetFeature, GeometryStatistics,
                           MultiMoleculeDataset, LinearLayer,
                           multi_molecule_collate, CGBeadEmbedding,
                           GaussianRBF)
from cgnet.network import (CGnet, HarmonicLayer, ForceLoss, ZscoreLayer,
                           lipschitz_projection, dataset_loss, Simulation)

import mdtraj as md
from cgnet.molecule import CGMolecule

from torch.utils.data import DataLoader, RandomSampler
from torch.optim.lr_scheduler import MultiStepLR

import matplotlib.pyplot as plt
%matplotlib inline

# We specify the CPU as the training/simulating device here.
# If you have machine  with a GPU, you can use the GPU for
# accelerated training/simulation by specifying 
# device = torch.device('cuda')
device = torch.device('cpu')
# Here, we are seeding the numpy random number generator so that the same pseudomolecule
# sizes are used in the dataset each time the notebook is run. 
np.random.seed(1133)

Here we load the dialnine dataset, which has been reduced in size for the purpose of easy demonstration. We also create atomic embeddings as in the CG-Force-Fields-With-SchNet-Embeddings notebook:

In [None]:
coords = np.load('./data/ala2_coordinates.npy')
forces = np.load('./data/ala2_forces.npy')
embeddings = np.tile([6, 7, 6, 6, 7], [coords.shape[0], 1])

print("Coordinates size: {}".format(coords.shape))
print("Forces size: {}".format(forces.shape))
print("Embeddings size: {}".format(embeddings.shape))

## Making a dataset for variable length molecules

In order to simulate a dataset that contains proteins of various sizes, we take each example in the dataset and randomly truncate the number of beads in the coordinates and the forces. These "varied" examples can range from 2 to the full 5 CG beads of dialanine. As such, the indivudual exmaples are entirely artificial and meaningless and are used for demonstration only:

In [None]:
random_lengths = np.random.randint(2, high=6, size=len(coords))

varied_coords = []
varied_forces = []
varied_embeddings = []

for num, length in enumerate(random_lengths):
    varied_coords.append(coords[num, :length, :])
    varied_forces.append(forces[num, :length, :])
    varied_embeddings.append(embeddings[num, :length])

This variable length functionality is not (yet) compatible with using `GeometryStatistics`, which would normally be used to help set up priors for fixed-length molecules. Instead, we seek to demonstrate a usage of CGSchNet in the same spirit of the original SchNet architecture [2] which does not use priors. For simplicity here, we assume that priors are not needed (which is in general not a safe assumption for learning CG force fields, as noted by [1]). 

With no need to instance a `GeometryStatistics` object, we move onto constructing our dataset. This is done using the `MultiMoleculeDataset` object. This object is used to handle and assemble datasets that contain multiple molecules of variable lengths. It is intialized by passing __*lists of variable length coordinates, forces, and embeddings*__, where these three lists are all in the same order (e.g., `varied_coords[2]`, `varied_list[2]`, and `varied_embeddings[2]` give the coordinates, forces, and embeddings of the example at index 2):

In [None]:
varied_dataset = MultiMoleculeDataset(varied_coords, varied_forces, varied_embeddings)
print("Dataset length: {}".format(len(varied_dataset)))

Lets take a look at the first 3 examples from this varied molecule dataset:

In [None]:
indices = [0, 1, 2]
for index in indices:
    output = varied_dataset.__getitem__([index])[0]
    print("Coordinates size: {}".format(output['coords'].shape))
    print(output['coords'])
    print("Forces size: {}".format(output['forces'].shape))
    print(output['forces'])
    print("Embeddings size: {}".format(output['embeddings'].shape))
    print(output['embeddings'], end='\n\n')

We can see that these three examples contain proteins of lenghts 2, 5, and 4 beads respectively (for numpy random seed 1133). We can also see that `MultiMoleculeDataset` returns a dictionary for each example, with the keys `coords`, `forces`, and `embeddings` with the corresponding coordinates, forces, and embeddings for each example.

## Padding our input into the model

With out dataset set up, we have to ask "how can a neural network handle variable length data". The answer, which is inspired by both [SchNetPack](https://github.com/atomistic-machine-learning/schnetpack) [3] and typical practices taken from natural language processing is __*input padding and masking*__. Input padding is preprocessing method that ensures all examples fed into a neural network in a batch are of the same size. This is accomplished by inserting or "padding" smaller size examples in the batch with 0's such that all examples in the batch are the same size as the largest example in the batch. Luckily, PyTorch has recurrent neural network utilities that add these and other related padding functionalities. We have used these to create a special __*collating function*__ called `multi_molecule_collate` to work with `MultiMoleculeDataset`. This collating function must be passed to the `collate_fn` kwarg of a PyTorch `DataLoader` object, which combines a dataset with a sampler:

In [None]:
loader = DataLoader(varied_dataset, batch_size=3, shuffle=False, collate_fn=multi_molecule_collate)

for num, batch in enumerate(loader):
    coords, forces, embeddings = batch
    print("Coordinates size:", coords.size())
    print(coords)
    print("Forces size:", forces.size())
    print(forces)
    print("Embeddings size:", embeddings.size())
    print(embeddings)
    break # we just want to grab only the first batch from the data loader, so we break here.

We can see that padding gives us a way to stack all of the coordinates, forces, and embeddings into respective torch tensors that can be passed to a model. Note the locations of the padding zeros in relation to the raw ouput of `MultiMoleculeDataset` above. We can see that padding zeros have been appended to the ends of the smaller molecules' embeddings to reach the size of the embeddings of the largest molecule in the batch. Similarly, padding zeros have also been appended to the the coordinates and forces of the smaller molecules in the batch, such that their number of beads is artifically extended to match the number of beads of the largest molecule in the batch.

## Setting up the model

Now that we have a way to properly pad input from a dataset of variable length molecules, we can create our model and begin training. Here, we make a "classic" CGSchNet model in which the `CGnet` feature is just a `SchnetFeature` (no `GeometryFeature` or `FeatureCombiner` here) with the following architectural hyperparameters:

In [None]:
# Hyperparameters

n_layers = 5
n_nodes = 128
activation = nn.Tanh()
batch_size = 512
learning_rate = 3e-4
rate_decay = 0.3
lipschitz_strength = 4.0

# schnet-specific parameters
n_embeddings = 10
n_gaussians = 50
n_interaction_blocks = 5
cutoff = 5.0

num_epochs = 15

save_model = False
directory = '.' # to save model

We make sure to set the embeddings to the maximum number for the atomic embeddings of this system. Remember, by default we cannot use 0's for embedding integers - these are reserved for padding integers as shown above.

In [None]:
embedding_layer = CGBeadEmbedding(n_embeddings, n_nodes)

Becasue we are not using a `FeatureCombiner`/`GeometryFeature`, we must specify `calculate_geometry=True` in the `SchnetFeature` initialization. This allows the `SchnetFeature` to calculate pairwise distances on the fly:

In [None]:
rbf_layer = GaussianRBF(high_cutoff=cutoff, n_gaussians=n_gaussians)

schnet_feature = SchnetFeature(feature_size=n_nodes,
                               embedding_layer=embedding_layer,
                               rbf_layer=rbf_layer,
                               n_interaction_blocks=n_interaction_blocks,
                               calculate_geometry=True,
                               activation=activation,
                               n_beads=10,
                               neighbor_cutoff=None,
                               device=device)

Lastly, we assemble the `CGnet` model as usual. As a reminder, we are not using prior energy terms:

In [None]:
layers = LinearLayer(n_nodes,
                     n_nodes,
                     activation=activation)

for _ in range(n_layers - 1):
    layers += LinearLayer(n_nodes,
                          n_nodes,
                          activation=activation)

# The last layer produces a single value
layers += LinearLayer(n_nodes, 1, activation=None)

variable_model = CGnet(layers, ForceLoss(),
                 feature=schnet_feature,
                 priors=None).to(device)
print(variable_model)

## Training our model

With our model and dataset all set, we create a new dataloader for training (the old one used above was just for demonstration purposes). We also create an optimizer and a learning rate scheduler for our model:

In [None]:
loader = DataLoader(varied_dataset, batch_size=512, shuffle=True, collate_fn=multi_molecule_collate)
optimizer = torch.optim.Adam(variable_model.parameters(),
                             lr=learning_rate)
scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30, 40, 50],
                        gamma=rate_decay)

Next, we begin training. Here, we are not using a test set. __*This is done only for ease of demonstration. In production/research endeavors, a dedicated test set or cross validation strategy is essential for accurate model assessment*__.

In [None]:
# Here we are surpressing batchwise printouts, by setting batch_freq=np.inf.
# As a consequence, printouts during training will only occur onece every epoch

batch_freq = np.inf 
verbose = True
epoch_freq = 1
epochal_train_losses = []

for epoch in range(1, num_epochs+1):
    train_loss = dataset_loss(variable_model, loader,
                              optimizer,
                              verbose_interval=batch_freq)

    scheduler.step()
    epochal_train_losses.append(train_loss)
    
if save_model:
    torch.save(variable_model, "{}/variable_model.pt".format(directory))

In [None]:
fig = plt.figure()
plt.plot(np.arange(0, len(epochal_train_losses), 1),
         epochal_train_losses, label='Training Loss')
plt.legend(loc='best')
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.show()

## Simulation

With our model trained, we can perform simulation just as we normally would. Here we will simulate a molecule with 5 beads (the largest size molecule our model can handle, and the largest possible molecule in our dataset). Keep in mind that because our dataset is artificial, no meaningful conclusions can be drawn from the simulation results. It is only meant for demonstration.

In [None]:
n_sims = 100
n_timesteps = 1000
save_interval = 10

KBOLTZMANN = 1.38064852e-23
AVOGADRO = 6.022140857e23
JPERKCAL = 4184
temperature = 300
beta = JPERKCAL / KBOLTZMANN / AVOGADRO / temperature

In [None]:
coords = np.load('./data/ala2_coordinates.npy')
forces = np.load('./data/ala2_forces.npy')
embeddings = np.tile([6, 7, 6, 6, 7], [coords.shape[0], 1])

initial_coords = np.concatenate([coords[i].reshape(-1, 5, 3)
                                 for i in np.arange(0, coords.shape[0],
                                                    coords.shape[0]//n_sims)],
                                                    axis=0)
initial_coords = torch.tensor(initial_coords, requires_grad=True).to(device)
sim_embeddings = torch.tensor(embeddings[:n_sims]).to(device)

print("Produced {} initial coordinates.".format(len(initial_coords)))
variable_model.eval()
sim = Simulation(variable_model, initial_coords, sim_embeddings, length=n_timesteps,
                 save_interval=save_interval, beta=beta,
                 save_potential=True, device=device,
                 log_interval=100, log_type='print')

traj = sim.simulate()

In [None]:
names = ['C', 'N', 'CA','C', 'N']
resseq = [1, 2, 2, 2, 3]

resmap = {1: 'ACE', 2: 'ALA', 3: 'NME'}

ala2_cg = CGMolecule(names=names, resseq=resseq, resmap=resmap,
                          bonds='standard')

ala2_traj = ala2_cg.make_trajectory(coords)
ala2_simulated_traj = ala2_cg.make_trajectory(np.concatenate(traj, axis=0))

beads = [(0, 1, 2, 3), (1, 2, 3, 4)]

dihedrals_ref = md.compute_dihedrals(ala2_traj, np.vstack(beads))
dihedrals_cg = md.compute_dihedrals(ala2_simulated_traj, np.vstack(beads))

pot, _ = variable_model.forward(torch.tensor(coords, requires_grad=True),
                          torch.tensor(embeddings))
pot = pot.detach().numpy()
pot = pot - np.min(pot)

sim_pot = np.concatenate(sim.simulated_potential, axis=0)
sim_pot = sim_pot - np.min(sim_pot)

In [None]:
plt.subplots(figsize=(8, 4))

plt.subplot(1, 2, 1)
plt.scatter(dihedrals_ref[:, 0].reshape(-1), dihedrals_ref[:, 1].reshape(-1),
            c=pot.flatten(), cmap=plt.get_cmap("viridis"), alpha=0.5, s=0.5)
plt.xlabel(r'$\phi$', fontsize=16)
plt.ylabel(r'$\psi$', fontsize=16)
plt.xlim(-np.pi, np.pi)
plt.ylim(-np.pi, np.pi)
plt.title(r'Original all-atom trajectory')
clb=plt.colorbar()
clb.ax.set_title(r'$U\left(\frac{kcal}{mol}\right)$')

plt.subplot(1, 2, 2)
plt.scatter(dihedrals_cg[:, 0].reshape(-1), dihedrals_cg[:, 1].reshape(-1),
            c=sim_pot.flatten(), cmap=plt.get_cmap("viridis"), alpha=0.5, s=0.5)
plt.xlabel(r'$\phi$', fontsize=16)
plt.ylabel(r'$\psi$', fontsize=16)
plt.xlim(-np.pi, np.pi)
plt.ylim(-np.pi, np.pi)
plt.title('Simulated CG trajectory')
clb=plt.colorbar()
clb.ax.set_title(r'$U\left(\frac{kcal}{mol}\right)$')

plt.tight_layout()

In [None]:
def plot_ramachandran(phi, psi, bins=60, cmap=plt.cm.magma):
    edges = np.array([[-np.pi, np.pi], [-np.pi, np.pi]])
    counts, _, _ = np.histogram2d(psi.reshape(-1),
                                  phi.reshape(-1),
                                  bins=bins,
                                  range=edges)
    populations = counts / np.sum(counts)
    
    # compute energies for only non-zero entries
    # 1/beta is approximately 0.6 kcal/mol at 300 K
    energies = -0.6*np.log(populations,
                           out=np.zeros_like(populations),
                           where=(populations > 0))
    
    # make the lowest energy slightly above zero
    energies = np.where(energies,
                        energies-np.min(energies[np.nonzero(energies)]) + 1e-6,
                        0)
    
    # mask the zero values from the colormap
    zvals_masked = np.ma.masked_where(energies == 0, energies)

    cmap.set_bad(color='white')
    img = plt.imshow(zvals_masked, interpolation='nearest', cmap=cmap)
    plt.gca().invert_yaxis()
    
    plt.xticks([-0.5, bins / 2, bins], 
               [r'$-\pi$', r'$0$', r'$\pi$'])

    plt.yticks([-0.5, bins / 2, bins],
               [r'$-\pi$', r'$0$', r'$\pi$'])
    
    plt.xlabel(r'$\phi$', fontsize=16)
    plt.ylabel(r'$\psi$', fontsize=16)
    
    cb = plt.colorbar()
    cb.ax.set_title(r'$\tilde{F}\left(\frac{kcal}{mol}\right)$')

In [None]:
fig, axes = plt.subplots(figsize=(8, 4))

plt.subplot(1, 2, 1)
plot_ramachandran(dihedrals_ref[:, 0], dihedrals_ref[:, 1])
plt.title('Original all-atom trajectory')

plt.subplot(1, 2, 2)
plot_ramachandran(dihedrals_cg[:, 0], dihedrals_cg[:, 1])
plt.title('Simulated CG trajectory')

plt.tight_layout()

#### *References*

[1] Wang, J., Olsson, S., Wehmeyer, C., Pérez, A., Charron, N. E., de Fabritiis, G., Noé, F., and Clementi, C. (2019). Machine Learning of Coarse-Grained Molecular Dynamics Force Fields. _ACS Central Science._ https://doi.org/10.1021/acscentsci.8b00913

[2] Schütt, K. T., Sauceda, H. E., Kindermans, P.-J., Tkatchenko, A., & Müller, K.-R. (2018). SchNet – A deep learning architecture for molecules and materials. The Journal of Chemical Physics, 148(24), 241722. https://doi.org/10.1063/1.5019779

[3] Schütt, K. T., Kessel, P., Gastegger, M., Nicoli, K. A., Tkatchenko, A., & Müller, K.-R. (2019). SchNetPack: A Deep Learning Toolbox For Atomistic Systems. Journal of Chemical Theory and Computation, 15(1), 448–455. https://doi.org/10.1021/acs.jctc.8b00908
