In [120]:
import pandas as pd
import rdkit
import torch
import torch.nn as nn 
import torch_geometric
from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool
import numpy as np 
import os
from tqdm import tqdm

In [101]:
class MoleculeDataset(Dataset) :
    def __init__(self, root, filename, transform = None, pre_transform = None) :
        self.filename = filename

        super(MoleculeDataset, self).__init__(root, transform, pre_transform)

    @property 
    def raw_file_names(self) :
        return self.filename
    
    @property
    def processed_file_names(self) :
        self.data = pd.read_csv(self.raw_paths[0]).reset_index()
        return [f'data_{i}.pt' for i in list(self.data.index)]
    
    def download(self) : pass 

    def process(self) :
        self.data = pd.read_csv(self.raw_paths[0])

        for index, mol in tqdm(self.data.iterrows(), total = self.data.shape[0]) :
            mol_obj = rdkit.Chem.MolFromSmiles(mol['Canonical SMILES'])

            node_features = self._get_node_features(mol_obj)
            edge_features = self._get_edge_features(mol_obj)
            edge_index = self._get_edge_index(mol_obj)
            label = self._get_labels(mol['iLOGP'])

            data = Data(x=node_features,
                        edge_index=edge_index,
                        edge_attr=edge_features,
                        y=label,
                        smiles=mol['Canonical SMILES'],
                        )
            
            torch.save(data, os.path.join(self.processed_dir, f'data_{index}.pt'))

    def _get_node_features(self, mol) :
        all_node_feats = [] 

        for atom in mol.GetAtoms() :
            node_feats = []

            node_feats.append(atom.GetAtomicNum())

            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)

        return torch.tensor(all_node_feats, dtype=torch.float)

    def _get_edge_features(self, mol) :
        all_edge_feats = [] 

        for bond in mol.GetBonds() :
            edge_feats = []

            edge_feats.append(bond.GetBondTypeAsDouble())

            all_edge_feats.append(edge_feats)

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)
    
    def _get_edge_index(self, mol) :
        edge_indices, begin, end = [], [], []

        for bond in mol.GetBonds() :
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx() 
            begin.append(i), begin.append(j)
            end.append(i), end.append(j)
        edge_indices.append(begin), edge_indices.append(end)

        return torch.tensor(edge_indices)
    
    def _get_labels(self, label) :
        label = np.asarray([label])
        return torch.tensor(label)
    
    def len(self) :
        return self.data.shape[0]
    
    def get(self, idx) :
        data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [116]:
train_data = MoleculeDataset(root='data', filename='ADAGRASIB_SMILES.csv')
train_loader = DataLoader(train_data, batch_size=16,shuffle=True)

In [118]:
for i in train_loader :
    print(i)
    break

DataBatch(x=[298, 1], edge_index=[2, 632], edge_attr=[316, 1], y=[16], smiles=[16], batch=[298], ptr=[17])


In [133]:
class Model(nn.Module) :
    def __init__(self, dim_model) :
        super(Model, self).__init__()
        self.conv1 = torch_geometric.nn.GCNConv(train_data.num_node_features, dim_model)
        self.conv2 = torch_geometric.nn.GCNConv(dim_model, dim_model)
        self.conv3 = torch_geometric.nn.GCNConv(dim_model, dim_model)
        self.lin = nn.Linear(dim_model, 1)

    def forward(self,x, edge_index, batch) :
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        x = global_mean_pool(x, batch)

        x = self.lin(x)
        return x

In [136]:
model = Model(1028)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.L1Loss()

In [137]:
for epoch in range(1, 30) :
    epoch_loss = 0 

    for data in train_loader :
        out = model(data.x, data.edge_index, data.batch) 

        loss = criterion(out, data.y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        epoch_loss += loss.item()

    print(f'Epoch {epoch}: {epoch_loss}')
    epoch_loss = 0

Epoch 1: 178.36116401217154
Epoch 2: 138.44649847477996
Epoch 3: 134.8664748427147
Epoch 4: 139.5482557846739
Epoch 5: 135.206374981622


KeyboardInterrupt: 