In [7]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.datasets.qm9 import QM9
import torch_geometric.datasets.qm9 as qm9
from torch_geometric.data import DataLoader
import torch_geometric.nn as tgnn
from torch_scatter import scatter
import tqdm
import numpy as np
import wandb

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
import sys
sys.path = ['/home/chrisw/Documents/projects/2021/graph-transformer/src'] + sys.path

In [9]:
from torch_geometric.nn.models.schnet import GaussianSmearing
from dfs_transformer import EarlyStopping
import dfs_transformer
from dfs_transformer.nn.models.myschnet import MySchNet

In [10]:
# [0] Reports MAE in eV / Chemical Accuracy of the target variable U0. 
# The chemical accuracy of U0 is 0.043 see [1, Table 5].

# Reproduced table [0]
# MXMNet: 0.00590/0.043 = 0.13720930232558143
# HMGNN:  0.00592/0.043 = 0.13767441860465118
# MPNN:   0.01935/0.043 = 0.45
# KRR:    0.0251 /0.043 = 0.5837209302325582
# [0] https://paperswithcode.com/sota/formation-energy-on-qm9
# [1] Neural Message Passing for Quantum Chemistry, https://arxiv.org/pdf/1704.01212v2.pdf
# MXMNet https://arxiv.org/pdf/2011.07457v1.pdf
# HMGNN https://arxiv.org/pdf/2009.12710v1.pdf
# MPNN https://arxiv.org/pdf/1704.01212v2.pdf
# KRR HDAD kernel ridge regression https://arxiv.org/pdf/1702.05532.pdf
# HDAD means HDAD (Histogram of distances, anglesand dihedral angles)

# [2] Reports the average value of MAE / Chemical Accuracy of over all targets
# [2] https://paperswithcode.com/sota/drug-discovery-on-qm9
target_dict = {0: 'mu, D, Dipole moment', 
               1: 'alpha, {a_0}^3, Isotropic polarizability', 
               2: 'epsilon_{HOMO}, eV, Highest occupied molecular orbital energy',
               3: 'epsilon_{LUMO}, eV, Lowest unoccupied molecular orbital energy',
               4: 'Delta, eV, Gap between HOMO and LUMO',
               5: '< R^2 >, {a_0}^2, Electronic spatial extent',
               6: 'ZPVE, eV, Zero point vibrational energy', 
               7: 'U_0, eV, Internal energy at 0K',
               8: 'U, eV, Internal energy at 298.15K', 
               9: 'H, eV, Enthalpy at 298.15K',
               10: 'G, eV, Free energy at 298.15K',  
               11: 'c_{v}, cal\(mol K), Heat capacity at 298.15K'}

chemical_accuracy = {idx:0.043 for idx in range(12)}
chemical_accuracy[0] = 0.1
chemical_accuracy[1] = 0.1
chemical_accuracy[5] = 1.2
chemical_accuracy[6] = 0.0012
chemical_accuracy[11] = 0.050

In [12]:
wandb.init(project='QM9-GAT', entity='chrisxx')
config = wandb.config
config.hidden_dim = 128
config.nlayers = 3
config.nhead = 1
config.lr = 0.0003
config.n_epochs = 5000
config.patience = 5
config.factor = 0.95
config.minimal_lr = 6e-8
config.target_idx = 7
config.batch_size = 256
config.valid_patience = 200
config.valid_minimal_improvement=0.005
config.n_train = 110000
config.n_valid = 10000
config.model_dir = '../models/qm9/MySchNet/noDFS/1/'
config.num_workers = 4
config.dfs_codes = None
config.use_pos = False
config.use_dist = True

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

[34m[1mwandb[0m: wandb version 0.11.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [13]:
dfs_codes = None
if config.dfs_codes is not None:
    import json
    with open(config.dfs_codes, 'r') as f:
        dfs_codes = json.load(f)

In [14]:
def transform(data, dfs_codes = dfs_codes, edge_transform = GaussianSmearing(0, 10, 50), use_dist=config.use_dist):
    features = data.x
    # make atomic number a one hot
    atomic_number = nn.functional.one_hot(features[:, 5].long(), 100)
    # make num_h a one hot
    num_h = nn.functional.one_hot(features[:, -1].long(), 9)
    data.x = torch.cat((features[:, :5], features[:, 6:-1], atomic_number, num_h), axis=1)
    if dfs_codes is not None:
        dfs_indices = nn.functional.one_hot(torch.LongTensor(dfs_codes[data.name]['dfs_indices']), 29)
        data.x = torch.cat((data.x, dfs_indices), axis=1)
    if use_dist:
        row, col = data.edge_index
        edge_weights = (data.pos[row] - data.pos[col]).norm(dim=-1)
        dist_feats = edge_transform(edge_weights)
        data.edge_attr = torch.cat((data.edge_attr, dist_feats), axis=1)
    return data

In [15]:
target_idx = config.target_idx

In [16]:
dataset = QM9('../datasets/qm9_geometric_work/', transform=transform)

In [17]:
dataset = dataset.shuffle()
train_dataset = dataset[:config.n_train]
valid_dataset = dataset[config.n_train:config.n_train+config.n_valid]
test_dataset = dataset[config.n_train+config.n_valid:]
config.n_test = len(test_dataset)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, pin_memory=True, num_workers=config.num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, pin_memory=True, num_workers=config.num_workers)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers=config.num_workers)

In [18]:
import os
os.makedirs(config.model_dir, exist_ok=True)

In [19]:
torch.save(dataset.indices(), config.model_dir+'dataset_indices.pt')

In [20]:
ngpu=1
device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

In [27]:
def score(loader, model, device=device):
    model = model.to(device)
    pbar = tqdm.tqdm(enumerate(loader, 0))
    maes = []
    for i, data in pbar:
        data.to(device)
        prediction = model(data.z, data.pos, data.batch)
        mae = (prediction.view(-1) - data.y[:, target_idx]).abs()
        maes += [mae.detach().cpu()]
    maes = torch.cat(maes, dim=0)
    mae = maes.mean().item()
    return mae/chemical_accuracy[config.target_idx]

# Model

In [22]:
target_vec = []

In [23]:
# based on https://schnetpack.readthedocs.io/en/stable/tutorials/tutorial_02_qm9.html
# and https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/models/schnet.html#SchNet
for data in train_loader:
    data = data.to(device)
    atomU0s = torch.tensor(qm9.atomrefs[target_idx], device=device)[torch.argmax(data.x[:, :5], axis=1)]
    target_modular = scatter(atomU0s, data.batch, dim=-1, reduce='sum')
    target_vec += [(data.y[:, target_idx] - target_modular).detach().cpu().numpy()]
target_vec = np.concatenate(target_vec, axis=0)

In [24]:
target_mean = np.mean(target_vec)
target_std = np.std(target_vec)

In [25]:
data = next(iter(train_loader))
model = MySchNet(readout='attention', atomref=dataset.atomref(target_idx), mean=target_mean, std=target_std)
loss = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=config.lr)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', verbose=True, patience=config.patience, factor=config.factor)
early_stopping = EarlyStopping(patience=config.valid_patience, delta=config.valid_minimal_improvement,
                              path=config.model_dir+'checkpoint.pt')



# Training

In [26]:
model = model.to(device)

In [None]:
loss_hist = []
try:
    # For each epoch
    for epoch in range(config.n_epochs):
        # For each batch in the dataloader
        pbar = tqdm.tqdm(enumerate(train_loader, 0))
        epoch_loss = 0
        for i, data in pbar:
            model.zero_grad()
            data.to(device)
            target = data.y[:, target_idx]
            prediction = model(data.z, data.pos, data.batch)
            output = loss(prediction.view(-1), target)
            mae = (prediction.view(-1) - target).abs().mean()
            epoch_loss = (epoch_loss*i + mae.item())/(i+1)
            
            pbar.set_description('Epoch %d: MAE/CA %2.6f'%(epoch+1, epoch_loss/chemical_accuracy[target_idx]))
            output.backward()
            optimizer.step()
            wandb.log({'MSE': output.item()})
        
        valid_loss = score(valid_loader, model)
        curr_lr = list(optimizer.param_groups)[0]['lr']
        wandb.log({'MAE':epoch_loss, 
                   'MAE/CA':epoch_loss/chemical_accuracy[target_idx],
                   'learning rate':curr_lr,
                   'MAE/CA valid':valid_loss})
        
        lr_scheduler.step(epoch_loss)
        early_stopping(valid_loss, model)
        loss_hist += [epoch_loss] 
        
        if early_stopping.early_stop:
            break
        
        if curr_lr < config.minimal_lr:
            break


except KeyboardInterrupt:
    print('keyboard interrupt caught')
    torch.save(model.state_dict(), config.model_dir+'gat_epoch%d.pt'%(epoch+1))

Epoch 1: MAE/CA 29.934328: : 430it [00:58,  7.37it/s]
40it [00:02, 16.43it/s]
Epoch 2: MAE/CA 12.170259: : 430it [00:46,  9.19it/s]
40it [00:02, 18.86it/s]

EarlyStopping counter: 1 out of 200



Epoch 3: MAE/CA 9.703605: : 430it [00:45,  9.43it/s]
40it [00:02, 18.08it/s]

EarlyStopping counter: 2 out of 200



Epoch 4: MAE/CA 8.937248: : 430it [00:45,  9.39it/s]
40it [00:02, 18.56it/s]
Epoch 5: MAE/CA 8.775295: : 430it [00:46,  9.31it/s]
40it [00:02, 18.56it/s]

EarlyStopping counter: 1 out of 200



Epoch 6: MAE/CA 7.936426: : 430it [00:46,  9.25it/s]
40it [00:02, 17.85it/s]

EarlyStopping counter: 2 out of 200



Epoch 7: MAE/CA 8.400103: : 240it [00:25,  9.70it/s]

In [None]:
pbar = tqdm.tqdm(enumerate(test_loader, 0))
epoch_loss = 0
maes = []
for i, data in pbar:
    data.to(device)
    prediction = model(data.z, data.pos, data.batch)
    mae = (prediction.view(-1) - data.y[:, target_idx]).abs()
    maes += [mae.detach().cpu()]
maes = torch.cat(maes, dim=0)
mae = maes.mean().item()
print(mae, mae/chemical_accuracy[target_idx])
wandb.log({'TEST MAE':mae, 'TEST MAE/CA':mae/chemical_accuracy[target_idx]})
