In [27]:
import matplotlib.pyplot as plt
import pandas as pd

import sys
sys.path.append('/home/jshe/prop-pred/src/data')
from data_utils.datasets import SmilesDataset
from data_utils.graphs import smiles_to_graphs

from graph_transformer import GraphTransformer

import torch
import torch.nn as nn
from torch.utils.data import random_split, DataLoader

from sklearn.metrics import r2_score

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Data

In [28]:
dataset = SmilesDataset(
    smiles='/home/jshe/prop-pred/src/data/qm9/smiles.csv', 
    y='/home/jshe/prop-pred/src/data/qm9/norm_y.csv', 
    d = '/home/jshe/prop-pred/src/data/qm9/distances.npy'
)

train_dataset, validation_dataset, test_dataset = random_split(
    dataset, lengths=(0.8, 0.1, 0.1), 
    generator=torch.Generator().manual_seed(8)
)

In [48]:
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=4096, shuffle=True)

## Model

In [31]:
hyperparameters = dict(
    numerical_features=5, categorical_features=(9+1, 8+1, 2+1, 2+1), 
    E=64, H=4, stack='MMMMMMMM', 
    dropout=0.1, 
    out_features=dataset.n_properties, 
)

model = GraphTransformer(**hyperparameters).to(device)
model.eval()

print(sum(p.numel() for p in model.parameters()))

400847


In [32]:
model.load_state_dict(torch.load('./weights/E64H4/MMMMMMMM.pt', map_location=device))

<All keys matched successfully>

## Train

In [52]:
smiles, d, y_true = next(iter(validation_dataloader))

nodes_numerical, nodes_categorical, adj, padding = smiles_to_graphs(smiles, device=device)

In [53]:
y = model(nodes_numerical.float(), nodes_categorical, d, adj, padding)

In [54]:
nn.L1Loss()(y_true, y)

tensor(0.0849, grad_fn=<MeanBackward0>)

In [7]:
B, L, _ = nodes_numerical.shape

padding_mask = padding.unsqueeze(-1).expand(B, L, model.E)
padding_causal_mask = (padding.unsqueeze(-2) | padding.unsqueeze(-1)).unsqueeze(1)
graph_causal_mask = ~adj.unsqueeze(1)
diag_causal_mask = torch.diag(torch.ones(L)).bool().expand_as(padding_causal_mask).to(padding.device)

In [8]:
coulomb_bias = -2 * torch.log(d).unsqueeze(1)
potential_bias = -torch.log(d).unsqueeze(1)

In [86]:
mixed_causal_mask = torch.concat((
    padding_causal_mask.expand(B, model.H // 2, L, L), 
    graph_causal_mask.expand(B, model.H // 2, L, L), 
), dim=1)

In [57]:
mixed_bias = torch.concat((
    coulomb_bias.expand(B, model.H // 2, L, L), 
    torch.zeros(B, model.H // 2, L, L), 
), dim=1)