In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from utils.datasets import *
from utils.point_clouds import *
from utils.smiles import *

from models.transformer import Transformer

import matplotlib.pyplot as plt

from IPython.display import clear_output

%load_ext autoreload
%autoreload 2


## Data

In [2]:
datadir = '/scratch/midway3/jshe/data/qm9/scaffolded/'
fnames = [
    #'smiles.npy',
    #'conformers.npy',
    'atoms.npy', 
    'coordinates.npy', 
    'y.npy', 
]
collate_fn = collate_point_clouds

train_dataset = NPYDataset(*[
    datadir + f'train/{fname}' for fname in fnames
])
validation_dataset = NPYDataset(*[
    datadir + f'validation/{fname}' for fname in fnames
])

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=64, collate_fn=collate_fn, shuffle=True
)
validation_dataloader = DataLoader(
    validation_dataset, 
    batch_size=2048, collate_fn=collate_fn, shuffle=True
)

n_properties = train_dataset.n_properties

## Model

In [3]:
model = Transformer(
    in_features=6, 
    out_features=n_properties, 
    E=128, H=8, D=8, 
    dropout=0.1, 
)
print(f'Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')

Parameters: 1585231


In [4]:
optimizer = torch.optim.Adam(model.parameters())
mse = nn.MSELoss()

for epoch in range(64):
    for i, (atoms, padding, coordinates, y_true) in enumerate(train_dataloader):
        model.train()
        optimizer.zero_grad()

        # Send to device

        atom_features = atoms.float()
        padding = padding
        coordinates = coordinates.float()
        y_true = y_true.float()

        # Forward pass
        
        y_pred = model(
            atom_features, coordinates, padding, 
        )
        loss = mse(y_pred, y_true)
        loss.backward()
        optimizer.step()

        print(i, float(loss))
        clear_output(wait=True)

KeyboardInterrupt: 