In [9]:
import numpy as np
import pandas as pd
from rdkit import Chem
import chemprop
import torch
from torch import Tensor
from torch.nn import ReLU, Linear
from torch.nn.functional import log_softmax, relu, dropout
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.nn import Sequential, GCNConv, MeanAggregation
import pytorch_lightning as tl
from typing import Tuple, List

In [2]:
DATA_DIR = '../data/AID1445/'
DATAFILE = 'SD.csv'

In [3]:
def get_connectivity(mol):
    conns = []
    b2a = mol.b2a
    a2b = mol.a2b
    for aI, bonds in enumerate(a2b):
        neighbours = [(b2a[bI], aI) for bI in bonds]
        conns.extend(neighbours)
    return conns

In [4]:
def read_data(filepath: str) -> pd.DataFrame:
    return pd.read_csv(filepath)

def process_data(df: pd.DataFrame) -> Tuple[List[Tensor]]:
    smiles = df['neut-smiles']
    mols = [chemprop.features.featurization.MolGraph(s) for s in smiles]
    xs = [Tensor(m.f_atoms) for m in mols]
    conns = [get_connectivity(m) for m in mols]
    edge_indexes = [torch.tensor(conn, dtype=torch.long).T.contiguous() for conn in conns]
    ys = Tensor(df['SD Z-score'])
    return [Data(x=x, edge_index=index, y=y) for x, index, y in zip(xs, edge_indexes, ys)]

In [5]:
class HTSDataset(InMemoryDataset):
    def __init__(self, root):
        super().__init__(root)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    def __get__(self, idx):
        return self.get(idx)

    @property
    def raw_file_names(self):
        return [DATAFILE]

    @property
    def processed_file_names(self):
        return ['processed_data.pt']

    def process(self):
        df = read_data(DATA_DIR + self.raw_file_names[0])[:10]
        data_list = process_data(df)
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [36]:
class GNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(133, 16)
        self.conv2 = GCNConv(16, 1)
        self.pool = MeanAggregation()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = relu(x)
        x = dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = relu(x)
        pred = self.pool(x)
        return pred

In [None]:
class LitGNN(tl.LightningModule):
    def __init__(self):
        super.__init__()
        self.gnn = GNN()
        
    def training_step(self, batch, batch_idx):
        out = self.gnn(batch)
        loss = torch.nn.MSELoss(out, batch.y)
        return loss
        

In [7]:
dataset = HTSDataset(root=DATA_DIR)

In [None]:
model = Sequential('x, edge_index', [
    (GCNConv(133, 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    (GCNConv(64, 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    Linear(64, 1),
])

In [38]:
model = GNN()

In [39]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
data = dataset[0].to(device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train()
for epoch in range(100):
    optimizer.zero_grad()
    out = model(data)
    loss = criterion(out, data.y)
    loss.backward()
    optimizer.step()
    print(float(loss))

216.3248748779297
211.6264190673828
206.9586639404297
202.56454467773438
190.80978393554688
187.3816680908203
173.55650329589844
168.5679168701172
162.45236206054688
160.78143310546875
144.31227111816406
136.71212768554688
133.3456268310547
123.01642608642578
120.4052963256836
105.66065216064453
101.59955596923828
90.78846740722656
90.00833129882812
64.02740478515625
67.14730072021484
63.539451599121094
47.36288070678711
44.83290481567383
38.09035873413086
40.10647964477539
27.57817268371582
31.328781127929688
29.708393096923828
10.699440002441406
24.529300689697266
7.247215270996094
8.382115364074707
3.444091558456421
1.3027859926223755
0.025992373004555702
0.23231422901153564
0.4898768961429596
0.007766392081975937
6.309578895568848
2.6312575340270996
0.2824018895626068
7.7228851318359375
7.423844337463379
14.791214942932129
5.400835990905762
1.6683247089385986
3.062523365020752
8.475177764892578
4.383960723876953
3.9545481204986572
8.105175971984863
9.101662635803223
3.8180842399597