In [1]:
import sys
sys.path.append('/home/jshe/prop-pred/src/data')
from data_utils.datasets import SmilesDataset
from data_utils.graphs import smiles_to_graphs

#from graph_transformer import GraphTransformer

import torch
import torch.nn as nn
from torch.utils.data import random_split, DataLoader

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

## Data

In [2]:
dataset = SmilesDataset(
    smiles='/home/jshe/prop-pred/src/data/qm9/smiles.csv', 
    y='/home/jshe/prop-pred/src/data/qm9/norm_y.csv', 
    d='/home/jshe/prop-pred/src/data/qm9/distances.npy'
)
train_dataset, *_ = random_split(
    dataset, lengths=(0.8, 0.1, 0.1), 
    generator=torch.Generator().manual_seed(16)
)
del _

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

In [3]:
smiles, d, y_true = next(iter(train_dataloader))
numerical_node_features, categorical_node_features, edges, padding = smiles_to_graphs(smiles, device=device)

print(f'Amount of padding: {torch.sum(numerical_node_features.sum(dim=-1) == 0, dim=-1)}')

Amount of padding: tensor([0, 1, 0, 0, 0, 0, 0, 0])


## Model

In [4]:
class MultiheadAttention(nn.Module):
    def __init__(self, E, H, dropout):
        super().__init__()

        self.E, self.H = E, H
        self.scale = (E // H) ** -0.5

        self.QKV = nn.Linear(E, E * 3, bias=False)
        self.out_map = nn.Linear(E, E, bias=False)

    def forward(self, embeddings, mask=None, bias=None):

        B, L, E = embeddings.size() # Batch, no. Tokens, Embed dim.
        A = E // self.H # Attention dim.

        # Compute and separate Q, K, V matrices

        qkv = self.QKV(embeddings)
        qkv = qkv.reshape(B, L, self.H, 3 * A)
        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim=-1)

        # Compute masked attention pattern

        attn = q @ k.transpose(-2, -1) * self.scale
        if bias is not None:
            attn += bias
        if mask is not None: 
            attn.masked_fill_(mask.unsqueeze(1), torch.finfo(attn.dtype).min)
        attn = torch.softmax(attn, dim=-1)

        # Compute values

        values = attn @ v
        values = values.permute(0, 2, 1, 3) # (B, L, H, A)
        values = values.reshape(B, L, E) # E = H * A
        
        return self.out_map(values)


In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, E, H, dropout):
        super().__init__()
        
        self.attention = MultiheadAttention(E, H, dropout)
        self.norm_1 = nn.LayerNorm(E)
        self.mlp = nn.Sequential(
            nn.Linear(E, E * 4), 
            nn.ReLU(), 
            nn.Linear(E * 4, E)
        )
        self.norm_2 = nn.LayerNorm(E)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x0, padding_mask, causal_mask=None, bias=None):

        # Attention residual block

        x1 = self.attention(x0, causal_mask, bias)
        x1 = self.dropout(x1) 
        x2 = x1 + x0
        x2 = self.norm_1(x2)

        # MLP residual block
        x3 = self.mlp(x2)
        x3 = x3.masked_fill(padding_mask, 0)
        x3 = self.dropout(x3)
        x4 = x3 + x2
        x4 = self.norm_2(x4)

        return x4

In [6]:
class GraphTransformer(nn.Module):
    '''
    Transformer with local and global masked self-attention stack. 
    '''
    def __init__(self, numerical_features, categorical_features, E, H, stack, out_features, dropout):
        super().__init__()

        self.E, self.H = E, H
        self.stack = stack

        # Embedding layers
        self.numerical_embed = nn.Linear(numerical_features, E, bias=False)
        self.categorical_embeds = nn.ModuleList([
            nn.Embedding(n_categories, E, padding_idx=0) 
            for n_categories in categorical_features
        ])

        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(E, H, dropout)
            for _ in range(len(stack))
        ])

        # Out map
        self.out_map = nn.Linear(E, out_features)

    def forward(self, numerical_node_features, categorical_node_features, bias, adj, padding):

        B, L, _ = categorical_node_features.size()

        # Create causal and padding masks

        padding_mask = padding.unsqueeze(-1).expand(B, L, self.E)
        padding_causal_mask = torch.logical_or(
            padding.unsqueeze(-2), padding.unsqueeze(-1)
        )
        graph_causal_mask = (~adj)
        diag_causal_mask = torch.diag(torch.ones(L)).bool().expand_as(padding_causal_mask).to(padding.device)

        # Forward Pass

        x = sum(embed(categorical_node_features[:, :, i]) for i, embed in enumerate(self.categorical_embeds))
        x += self.numerical_embed(numerical_node_features)

        for block_type, transformer_block in zip(self.stack, self.transformer_blocks):
            if block_type == 'L': 
                x = transformer_block(x, padding_mask, graph_causal_mask)
            elif block_type == 'G':
                x = transformer_block(x, padding_mask, padding_causal_mask | diag_causal_mask, bias)
            
            if torch.any(x.isnan()):
                raise Exception(f'NaN at {block_type}-block {transformer_block}')
        
        x = x.sum(dim=1) # (B, E)

        return self.out_map(x)


In [7]:
hyperparameters = dict(
    numerical_features=5, categorical_features=(9+1, 8+1, 2+1, 2+1), 
    E=32, H=2, stack='G', 
    dropout=0.1, 
    out_features=dataset.n_properties, 
)

model = GraphTransformer(**hyperparameters).to(device)

## Train

In [8]:
optimizer = torch.optim.Adam(model.parameters())
mse = nn.MSELoss()

In [28]:
optimizer.zero_grad()

with torch.autograd.detect_anomaly():
    #smiles, d, y_true = next(iter(train_dataloader))
    #numerical_node_features, categorical_node_features, edges, padding = smiles_to_graphs(smiles, device=device)
    
    y_pred = model(
        numerical_node_features.float(), categorical_node_features, 
        -2 * torch.log(d.unsqueeze(1)), 
        edges, padding
    )
    loss = mse(y_pred, y_true.to(device))
    loss.backward()
    optimizer.step()

    print(loss)

tensor(2.6249, grad_fn=<MseLossBackward0>)


  with torch.autograd.detect_anomaly():
