# Tutorial 2: QM9 Energy Prediction

## Step 1. Load Packages and Set Random Seeds and Device

In [2]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.loader import DataLoader

from Geom3D.models import SchNet
from Geom3D.datasets import MoleculeDatasetQM9

import sys
sys.path.insert(0, "../examples_3D")

from tqdm import tqdm

seed = 42
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)

device = "cuda" if torch.cuda.is_available() else torch.device("cpu")

## Step 2. Set Task, Dataset and Dataloader

In [3]:
dataset = "QM9"
task = "u0"

data_root = "../data/{}".format(dataset)
dataset = MoleculeDatasetQM9(data_root, dataset=dataset, task=task)
task_id = dataset.task_id

# Split into train-valid-test
from splitters import qm9_random_customized_01
train_dataset, valid_dataset, test_dataset = qm9_random_customized_01(dataset, seed=seed)


# Get the mean and std on the task property in training set
TRAIN_mean, TRAIN_std = (
    train_dataset.mean()[task_id].item(),
    train_dataset.std()[task_id].item(),
)
print("Train mean: {}\tTrain std: {}".format(TRAIN_mean, TRAIN_std))


# Set dataloaders
batch_size = 128
num_workers = 0
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

Dataset: QM9
Data: Data(x=[2359210], edge_index=[2, 4883516], edge_attr=[4883516, 3], positions=[2359210, 3], id=[130831], y=[1700803])
train_idx:  [112526 100113  55940 ...  62969  20239  83098]
valid_idx:  [ 90042  89438  45073 ...  60277  89452 125059]
test_idx:  [ 24143  20981  30492 ...    860  15795 121958]
Train mean: -76.16375732421875	Train std: 9.87142562866211


## Step 3. Set Model

In [4]:
node_class, edge_class = 119, 5
num_tasks = 1

emb_dim = 128
SchNet_num_filters = 128
SchNet_num_interactions = 6
SchNet_num_gaussians = 51
SchNet_cutoff = 10
SchNet_readout = "mean"

model = SchNet(
    hidden_channels=emb_dim,
    num_filters=SchNet_num_filters,
    num_interactions=SchNet_num_interactions,
    num_gaussians=SchNet_num_gaussians,
    cutoff=SchNet_cutoff,
    readout=SchNet_readout,
    node_class=node_class,
).to(device)
graph_pred_linear = torch.nn.Linear(emb_dim, num_tasks).to(device)

## Step 4. Set Optimizer

In [5]:
lr = 5e-4
decay = 0
criterion = nn.L1Loss()

model_param_group = [
    {"params": model.parameters(), "lr": lr},
    {"params": graph_pred_linear.parameters(), "lr": lr}]
optimizer = optim.Adam(model_param_group, lr=lr, weight_decay=decay)

# Step 5. Start Training

In [6]:
epochs = 1
for e in range(1, 1+epochs):
    for batch in tqdm(train_loader):
        batch = batch.to(device)
        
        molecule_3D_repr = model(batch.x, batch.positions, batch.batch)
        pred = graph_pred_linear(molecule_3D_repr).squeeze()
        
        B = pred.size()[0]
        y = batch.y.view(B, -1)[:, task_id]
        # normalize
        y = (y - TRAIN_mean) / TRAIN_std

        loss = criterion(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 860/860 [00:20<00:00, 42.45it/s]


## Step 6. Start Evaluation

In [7]:
def mean_absolute_error(pred, target):
    return np.mean(np.abs(pred - target))

with torch.no_grad():
    model.eval()
    graph_pred_linear.eval()

    y_true = []
    y_scores = []

    for batch in tqdm(test_loader):
        batch = batch.to(device)

        molecule_3D_repr = model(batch.x, batch.positions, batch.batch)

        pred = graph_pred_linear(molecule_3D_repr).squeeze()

        B = pred.size()[0]
        y = batch.y.view(B, -1)[:, task_id]
        # denormalize
        pred = pred * TRAIN_std + TRAIN_mean

        y_true.append(y)
        y_scores.append(pred)

    y_true = torch.cat(y_true, dim=0).cpu().numpy()
    y_scores = torch.cat(y_scores, dim=0).cpu().numpy()

    mae = mean_absolute_error(y_scores, y_true)
    print("MAE: {}".format(mae))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 85/85 [00:01<00:00, 58.13it/s]

MAE: 0.9062439799308777



