# Tutorial 2: MD17 Energy and Force Prediction

## Step 1. Load Packages and Set Random Seeds and Device

In [2]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader
from torch.autograd import grad

from Geom3D.models import SchNet
from Geom3D.datasets import DatasetMD17

import sys
sys.path.insert(0, "../examples_3D")

from tqdm import tqdm

seed = 42
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)

device = "cuda" if torch.cuda.is_available() else torch.device("cpu")

## Step 2. Set Task, Dataset and Dataloader

In [3]:
data_root = "../data/MD17"
task = "aspirin"

dataset = DatasetMD17(data_root, task=task)
split_idx = dataset.get_idx_split(len(dataset.data.y), train_size=1000, valid_size=1000, seed=seed)

# Split into train-valid-test
train_dataset, valid_dataset, test_dataset = \
    dataset[split_idx['train']], dataset[split_idx['valid']], dataset[split_idx['test']]

# Remove energy mean.
ENERGY_MEAN_TOTAL = 0
FORCE_MEAN_TOTAL = 0
NUM_ATOM = None
for data in train_dataset:
    energy = data.y
    force = data.force
    NUM_ATOM = force.size()[0]
    energy_mean = energy / NUM_ATOM
    ENERGY_MEAN_TOTAL += energy_mean
    force_rms = torch.sqrt(torch.mean(force.square()))
    FORCE_MEAN_TOTAL += force_rms
ENERGY_MEAN_TOTAL /= len(train_dataset)
FORCE_MEAN_TOTAL /= len(train_dataset)
ENERGY_MEAN_TOTAL = ENERGY_MEAN_TOTAL.to(device)
FORCE_MEAN_TOTAL = FORCE_MEAN_TOTAL.to(device)


# Set dataloaders
batch_size = 128
MD17_train_batch_size = 1
num_workers = 0
train_loader = DataLoader(train_dataset, batch_size=MD17_train_batch_size, shuffle=True, num_workers=num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

## Step 3. Set Model

In [4]:
node_class, edge_class = 119, 5
num_tasks = 1

emb_dim = 128
SchNet_num_filters = 128
SchNet_num_interactions = 6
SchNet_num_gaussians = 51
SchNet_cutoff = 10
SchNet_readout = "mean"

model = SchNet(
    hidden_channels=emb_dim,
    num_filters=SchNet_num_filters,
    num_interactions=SchNet_num_interactions,
    num_gaussians=SchNet_num_gaussians,
    cutoff=SchNet_cutoff,
    readout=SchNet_readout,
    node_class=node_class,
).to(device)
graph_pred_linear = torch.nn.Linear(emb_dim, num_tasks).to(device)

## Step 4. Set Optimizer

In [5]:
lr = 5e-4
decay = 0
criterion = nn.L1Loss()

model_param_group = [
    {"params": model.parameters(), "lr": lr},
    {"params": graph_pred_linear.parameters(), "lr": lr}]
optimizer = optim.Adam(model_param_group, lr=lr, weight_decay=decay)

# Step 5. Start Training

In [6]:
md17_energy_coeff = 0.05
md17_force_coeff = 0.95

epochs = 1
for e in range(1, 1+epochs):
    for batch in tqdm(train_loader):
        batch = batch.to(device)
        positions = batch.positions
        positions.requires_grad_()
        
        molecule_3D_repr = model(batch.x, positions, batch.batch)
        pred_energy = graph_pred_linear(molecule_3D_repr).squeeze()
        pred_force = -grad(
            outputs=pred_energy, inputs=positions, grad_outputs=torch.ones_like(pred_energy),
            create_graph=True, retain_graph=True)[0]

        actual_energy = batch.y
        actual_force = batch.force
        
        loss = md17_energy_coeff * criterion(pred_energy, actual_energy) + \
                md17_force_coeff * criterion(pred_force, actual_force)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

  return F.l1_loss(input, target, reduction=self.reduction)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:18<00:00, 53.19it/s]


## Step 6. Start Evaluation

In [7]:
def mean_absolute_error(pred, target):
    return np.mean(np.abs(pred - target))

model.eval()
graph_pred_linear.eval()
pred_force_list = torch.Tensor([]).to(device)
actual_force_list = torch.Tensor([]).to(device)

pred_energy_list, actual_energy_list = [], []

for batch in tqdm(valid_loader):
    batch = batch.to(device)
    positions = batch.positions
    positions.requires_grad_()

    molecule_3D_repr = model(batch.x, positions, batch.batch)
    pred_energy = graph_pred_linear(molecule_3D_repr).squeeze()
    force = -grad(
        outputs=pred_energy, inputs=positions, grad_outputs=torch.ones_like(pred_energy),
        create_graph=True, retain_graph=True)[0]

    pred_energy_list.append(pred_energy.cpu().detach())
    actual_energy_list.append(batch.y.cpu())
    pred_force_list = torch.cat([pred_force_list, force], dim=0)
    actual_force_list = torch.cat([actual_force_list, batch.force], dim=0)

pred_energy_list = torch.cat(pred_energy_list, dim=0)
actual_energy_list = torch.cat(actual_energy_list, dim=0)
energy_mae = torch.mean(torch.abs(pred_energy_list - actual_energy_list)).cpu().item()
force_mae = torch.mean(torch.abs(pred_force_list - actual_force_list)).cpu().item()

print("Energy: {}".format(energy_mae))
print("Force: {}".format(force_mae))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 31.84it/s]

Energy: 21143.380859375
Force: 352.1929016113281



