In [1]:
!git clone https://github.com/hodakamori/torch-tutorial

Cloning into 'torch-tutorial'...
remote: Enumerating objects: 64, done.[K
remote: Counting objects: 100% (64/64), done.[K
remote: Compressing objects: 100% (42/42), done.[K
remote: Total 64 (delta 16), reused 58 (delta 13), pack-reused 0[K
Receiving objects: 100% (64/64), 9.53 MiB | 28.88 MiB/s, done.
Resolving deltas: 100% (16/16), done.


In [3]:
!pip install graph-transformer-pytorch rdkit MDAnalysis

Collecting graph-transformer-pytorch
  Downloading graph_transformer_pytorch-0.1.1-py3-none-any.whl (4.3 kB)
Collecting rdkit
  Downloading rdkit-2023.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.4/34.4 MB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting MDAnalysis
  Downloading MDAnalysis-2.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.1/10.1 MB[0m [31m57.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.3 (from graph-transformer-pytorch)
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting rotary-embedding-torch (from graph-transformer-pytorch)
  Downloading rotary_embedding_torch-0.5.3-py3-none-any.whl (5.3 kB)
Collecting GridDataFormats>=0.

In [4]:
!cp torch-tutorial/diffusion_two_for_one/* .

In [5]:
import torch
import random
from torch.utils.data import DataLoader, random_split
from dataset import CGCoordsDataset
from model import Net
from utils import smiles2structure, add_diffusion_noise



In [6]:
device = torch.device("cuda")
topology_path = "./ala2_cg.pdb"
traj_path = "./ala2_cg.xtc"
dataset = CGCoordsDataset(topology_path, traj_path)
print(len(dataset))

MAX_EPOCHS = 5
BATCH_SIZE = 64
MAX_NOISE_LEVEL = 10

1000




In [7]:
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
bonds = dataset.bonds
model = Net(num_atoms=5, num_node_features=64)
model.to(device)
optimizer = torch.optim.Adam(model.parameters())
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=0.001, total_steps=100000
)
loss_func = torch.nn.MSELoss(reduction="none")

for epoch in range(MAX_EPOCHS):
    history = []
    for indices, coords in dataloader:
        coords.requires_grad_()
        noise_level = random.randint(1,MAX_NOISE_LEVEL + 1)
        noised_coordinates = add_diffusion_noise(coords, noise_level=noise_level)
        noise_true = noised_coordinates - coords
        energy = model(indices, coords, bonds, noise_level=noise_level)
        if coords.grad is not None:
            coords.grad.zero_()
        energy.backward(retain_graph=True)
        noise_pred = coords.grad
        loss = loss_func(noise_true, noise_pred).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        history.append(loss.detach().numpy())
    print(epoch, sum(history) / len(history))

0 15.042465947568417
1 24.261884544044733
2 21.024511612951756
3 24.446103036403656
4 26.455496564507484
