In [7]:
from dataclasses import dataclass
import numpy as np
import pandas as pd
from rdkit import Chem
import chemprop
import torch
from torch import Tensor
from torch.nn import ReLU, Linear, MSELoss
from torch.nn.functional import log_softmax, relu, dropout
from torch.optim import AdamW
from torch_geometric.data import Data, Batch, InMemoryDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import (
    Sequential,
    MessagePassing, GCNConv, GATConv, GATv2Conv, GINConv,
    Aggregation, MeanAggregation, MaxAggregation, SumAggregation
)
import pytorch_lightning as tl
from typing import Tuple, List, Type, Union

In [5]:
DATA_DIR = '../data/AID1445/'
DATAFILE = 'SD.csv'

TRAIN_TEST_SPLIT = 0.9
BATCH_SIZE = 1

In [6]:
def get_connectivity(mol):
    conns = []
    b2a = mol.b2a
    a2b = mol.a2b
    for aI, bonds in enumerate(a2b):
        neighbours = [(b2a[bI], aI) for bI in bonds]
        conns.extend(neighbours)
    return conns

In [7]:
def read_data(filepath: str) -> pd.DataFrame:
    return pd.read_csv(filepath)

def process_data(df: pd.DataFrame) -> Tuple[List[Tensor]]:
    smiles = df['neut-smiles']
    mols = [chemprop.features.featurization.MolGraph(s) for s in smiles]
    xs = [Tensor(m.f_atoms) for m in mols]
    conns = [get_connectivity(m) for m in mols]
    edge_indexes = [torch.tensor(conn, dtype=torch.long).T.contiguous() for conn in conns]
    ys = Tensor(df['SD Z-score'].values)
    return [Data(x=x, edge_index=index, y=y) for x, index, y in zip(xs, edge_indexes, ys)]

In [None]:
def process_data_batch_mol_graph(df: pd.DataFrame) -> Tuple[List[Tensor]]:
    smiles = df['neut-smiles']
    mols = chemprop.features.featurization.mol2graph(smiles)
    xs = [mols.f_atoms[s:e] for s, e in mols.a_scope]
    edge_indexes = [mols.get_a2a()[s, e] for s, e in mols.b_scope]
    ys = Tensor(df['SD Z-score'].values)
    return [Data(x=x, edge_index=index, y=y) for x, index, y in zip(xs, edge_indexes, ys)]

In [8]:
class DRDataset(InMemoryDataset):
    def __init__(self, root):
        super().__init__(root)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    def __get__(self, idx):
        return self.get(idx)

    @property
    def raw_file_names(self):
        return [DATAFILE]

    @property
    def processed_file_names(self):
        return ['processed_data.pt']

    def process(self):
        df = read_data(DATA_DIR + self.raw_file_names[0])
        dr_df = df[df['DR'].notnull()]
        data_list = process_data(dr_df)
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [9]:
def split_dataset(dataset, ratio):
    num_train_samples = int(ratio * len(dataset))
    training_dataset = dataset.index_select(slice(num_train_samples))
    validation_dataset = dataset.index_select(slice(num_train_samples, None))
    return training_dataset, validation_dataset

In [8]:
@dataclass
class ModelArchitecture:
    layer_types: List[Type[MessagePassing]]
    features: List[int]
    pool_type: List[Type[Aggregation]]

In [24]:
class GNN(torch.nn.Module):
    def __init__(self, arch: ModelArchitecture):
        super().__init__()
        self.layers = []
        for layer_type, num_in, num_out in zip(arch.layer_types, arch.features[:-1], arch.features[1:]):
            self.layers.append(layer_type(num_in, num_out))
        self.pool = arch.pool_type()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = relu(x)
        x = dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        pred = self.pool(x=x, index=data.batch)
        return pred.reshape(1, 1) if pred.shape == 1 else pred
            

In [25]:
class LitGNN(tl.LightningModule):
    def __init__(self):
        super().__init__()
        self.gnn = GNN()
        self.loss = MSELoss()
        
    def training_step(self, batch, batch_idx):
        out = self.gnn(batch)
        loss = self.loss(out, batch.y)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=len(batch))
        return Tensor(loss)
    
    def validation_step(self, batch, batch_idx):
        out = self.gnn(batch)
        loss = self.loss(out, batch.y)
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=len(batch))
        return Tensor(loss)
    
    def test_step(self, batch, batch_idx):
        out = self.gnn(batch)
        
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, batch_size=len(batch))
        return Tensor(loss)
    
    def configure_optimizers(self):
        optimiser = AdamW(self.gnn.parameters(), lr=1e-3)
        return optimiser
        

In [12]:
dataset = DRDataset(root=DATA_DIR)
dataset.shuffle()
training_dataset, validation_dataset = split_dataset(dataset, TRAIN_TEST_SPLIT)
training_dataloader = DataLoader(training_dataset, batch_size=BATCH_SIZE)
validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE)

In [10]:
model = Sequential('x, edge_index', [
    (GCNConv(133, 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    (GCNConv(64, 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    Linear(64, 1),
])

In [None]:
model = LitGNN()
trainer = tl.Trainer(
    accelerator="cpu",
    max_epochs=5
)
trainer.fit(
    model,
    training_dataloader,
    validation_dataloader
)

In [27]:
trainer.validate(model, validation_dataloader)

Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 66/66 [00:00<00:00, 191.87it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     val_loss_epoch         0.2511858344078064
───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

[{'val_loss_epoch': 0.2511858344078064}]

In [17]:
%reload_ext tensorboard
%tensorboard --logdir=src/lightning_logs/