In [1]:
import os
os.chdir('..')

In [6]:
from ddi_kt_2024.mol.gnn import GNN
from ddi_kt_2024.utils import load_pkl
from ddi_kt_2024.mol.preprocess import mapped_smiles_reader, candidate_smiles
from ddi_kt_2024.mol.mol_dataset import MolDataset

2024-03-05 01:18:37,959 - INFO - Enabling RDKit 2023.09.5 jupyter extensions


In [7]:
all_candidates_train = load_pkl('cache/pkl/v1/candidates.train.pkl')
all_candidates_test = load_pkl('cache/pkl/v1/candidates.test.pkl')
mapped_smiles = mapped_smiles_reader('cache/mapped_drugs/DDI/all_mapped.txt')
x_train, y_train = candidate_smiles(all_candidates_train, mapped_smiles)
x_test, y_test = candidate_smiles(all_candidates_test, mapped_smiles)
# dataset_train_mol1 = MolDataset(x_train, element=1)
# dataset_train_mol2 = MolDataset(x_train, element=2)
# dataset_test_mol1 = MolDataset(x_test, element=1)
dataset_test_mol2 = MolDataset(x_test, element=2)

Converting SMILES to PyG: 100%|██████████| 5716/5716 [00:03<00:00, 1555.80it/s]


In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv
from torch_geometric.nn import global_mean_pool

class GNN(torch.nn.Module):
    def __init__(self, 
                 atom_embedding_dim: int = 64,
                 bond_embedding_dim: int = 16,
                 bool_embedding_dim: int = 2,
                 num_node_features: int = 10,
                 num_edge_features: int = 4,
                 hidden_channels: int = 512, 
                 dropout_rate: float = 0.2,
                 device: str = 'cpu'):
        
        super(GNN, self).__init__()
        self.device = device
        self.dropout = dropout_rate
        self.hidden_channels = hidden_channels

        self.atom_encoder = nn.Embedding(num_embeddings=119, embedding_dim=atom_embedding_dim, padding_idx=0)
        self.bond_encoder = nn.Embedding(num_embeddings=22, embedding_dim=bond_embedding_dim, padding_idx=0)
        self.boolean_encoder = nn.Embedding(num_embeddings=3, embedding_dim=bool_embedding_dim, padding_idx=2)

        self.conv1 = GATv2Conv(num_node_features-2+atom_embedding_dim+bool_embedding_dim, 
                               hidden_channels,
                               edge_dim=num_edge_features-3+bond_embedding_dim+bool_embedding_dim*2)
        
        self.conv2 = GATv2Conv(hidden_channels, 
                               hidden_channels,
                               edge_dim=num_edge_features-3+bond_embedding_dim+bool_embedding_dim*2)
        
        self.conv3 = GATv2Conv(hidden_channels, 
                               hidden_channels,
                               edge_dim=num_edge_features-3+bond_embedding_dim+bool_embedding_dim*2)
        
        self.conv4 = GATv2Conv(hidden_channels, 
                               hidden_channels,
                               edge_dim=num_edge_features-3+bond_embedding_dim+bool_embedding_dim*2)

    def forward(self, mol):
        if mol.mol == 'None':
            return torch.zeros([1, self.hidden_channels]).to(self.device)
        
        x, edge_index, edge_attr, batch = mol.x, mol.edge_index, mol.edge_attr, mol.batch
        print(x, edge_index, edge_attr, batch)

        atomic_num0 = self.atom_encoder(x[:, 0].int()) # encode atom type
        atom_is_aromatic0 = self.boolean_encoder(x[:, -1].int()) # encode aromaticity

        bond_type0 = self.bond_encoder(edge_attr[:, 0].int()) # encode bond type
        bond_is_conjugated0 = self.boolean_encoder(edge_attr[:, -2].int()) # encode conjugation
        bond_is_aromatic0 = self.boolean_encoder(edge_attr[:, -1].int()) # encode aromaticity

        x = torch.cat([atomic_num0, x[:, 1:9], atom_is_aromatic0], dim=1)
        edge_attr = torch.cat([bond_type0, edge_attr[:, 1:2], bond_is_conjugated0, bond_is_aromatic0], dim=1)

        # GNN pass
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index, edge_attr)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv3(x, edge_index, edge_attr)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv4(x, edge_index, edge_attr)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = global_mean_pool(x, batch)
        
        return x

model = GNN()
print(model)

GNN(
  (atom_encoder): Embedding(119, 64, padding_idx=0)
  (bond_encoder): Embedding(22, 16, padding_idx=0)
  (boolean_encoder): Embedding(3, 2, padding_idx=2)
  (conv1): GATv2Conv(74, 512, heads=1)
  (conv2): GATv2Conv(512, 512, heads=1)
  (conv3): GATv2Conv(512, 512, heads=1)
  (conv4): GATv2Conv(512, 512, heads=1)
)


In [41]:
mol = dataset_test_mol2[121]
mol

Data(x=[1, 10], edge_index=[2, 2], edge_attr=[2, 4], mol='None', smiles='None')

In [42]:
model(mol)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.