In [1]:
from utils.datasets import *
from utils.molecular_graphs import smiles_to_graphs
from utils.point_clouds import *

from models.global_transformer import Transformer

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

%load_ext autoreload
%autoreload 2

## Data

In [43]:
datadir = '/scratch/midway3/jshe/data/qm9/scaffolded/'

data_names = [
    'atoms',
    'coordinates',
    'norm_y', 
]

train_dataset = NPYDataset(*[
    datadir + f'train_{data_name}.npy' 
    for data_name in data_names
])
validation_dataset = NPYDataset(*[
    datadir + f'validation_{data_name}.npy' 
    for data_name in data_names
])

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=collate_point_clouds)
validation_dataloader = DataLoader(validation_dataset, batch_size=2048, shuffle=True, collate_fn=collate_point_clouds)

n_properties = train_dataset.datas[-1].shape[1]

## Model

In [3]:
E, H, stack = 128, 8, 'p8'
E, H = int(E), int(H)

In [4]:
model = Transformer(
    in_features=(5, (9+1, 8+1, 2+1, 2+1)), 
    out_features=n_properties, 
    E=E, H=H, stack=stack, 
    dropout=0.1, 
)
print(f'Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')

Parameters: 1588175


In [7]:
model.load_state_dict(torch.load(
    './weights/qm9_scaffold/E128_H8_p8.pt', 
    map_location=torch.device('cpu')
))

<All keys matched successfully>

In [28]:
for i in range(8):
    print(model.transformer_blocks[i].bias_scale.log_weight.exp().t().detach())

tensor([[0.9991, 0.9991, 1.0090, 0.9991, 0.9991, 0.9991, 0.9990, 0.9991]])
tensor([[0.9985, 0.9990, 0.9992, 0.9991, 3.1006, 1.1198, 0.9397, 0.9957]])
tensor([[0.9985, 0.9982, 1.0143, 0.9982, 1.0682, 0.9981, 0.9986, 0.9979]])
tensor([[0.9995, 0.9991, 0.9363, 0.9993, 0.9986, 0.9993, 0.9989, 1.9192]])
tensor([[1.0008, 1.0000, 1.0001, 1.1683, 0.9082, 0.9990, 0.9997, 1.0009]])
tensor([[0.9995, 0.9984, 0.9536, 0.9994, 0.9998, 0.9993, 0.9996, 0.9990]])
tensor([[3.3805, 1.0814, 0.9989, 0.9706, 0.9999, 0.9849, 0.9226, 0.9999]])
tensor([[0.9949, 1.0003, 1.0003, 1.0002, 1.0003, 0.9999, 1.0000, 1.0003]])


In [None]:
optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.001)
mse = nn.MSELoss()

for epoch in range(64):

    # Train loop over batches
    
    for smiles, r, y_true in train_dataloader:
        model.train()
        optimizer.zero_grad()

        # Create graphs

        nodes_numerical, nodes_categorical, adj, padding = smiles_to_graphs(smiles, device='cpu')

        # Forward pass
        
        y_pred = model(
            nodes_numerical.float(), nodes_categorical, 
            adj, padding, r
        )
        loss = mse(y_pred, y_true)
        loss.backward()
        optimizer.step()

        print(loss)

        if loss.isnan(): 
            raise Exception('nan')