In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 
import torch_geometric
from torch_geometric.data import Dataset, Data 
from torch_geometric.loader import DataLoader
import torch_geometric.nn as gnn
import rdkit
from rdkit.Chem import MolFromSmiles as get_mol
import numpy as np 
from tqdm import tqdm
import os

In [2]:
def get_smi(path) :
    with open(path, 'r') as file :
        contents = file.readlines() 
    smi_list = [content[:-1] for content in contents]
    return smi_list

def get_coor(path) :
    coor_list = []
    supplier = rdkit.Chem.SDMolSupplier(path)
    for mol in supplier:
        coor = []
        if mol is not None:
            conformer = mol.GetConformer()
            for atom in mol.GetAtoms():
                atom_idx = atom.GetIdx()
                x, y, z = conformer.GetAtomPosition(atom_idx)
                coor_atom = list((x,y,z))
                coor.append(coor_atom)
        coor_list.append(coor)

    # Replace invalid idx
    for i, coor in enumerate(coor_list):
        
        if len(coor) == 0 :
            if i == 0 :
                coor_list = coor_list[1:]
            coor_list[i] = coor_list[i-1]
    return coor_list

def get_edge_index(mol) :
    edge_indices, begin, end = [], [], []

    for bond in mol.GetBonds() :
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx() 
        begin.append(i), end.append(j), begin.append(j), end.append(i)

        # end.append(i), end.append(j)
    edge_indices.append(begin), edge_indices.append(end)

    return torch.tensor(edge_indices)

def get_node_features(mol) :
    all_node_feats = [] 

    for atom in mol.GetAtoms() :
        node_feats = []
        node_feats.append(atom.GetAtomicNum())
        node_feats.append(atom.GetDegree())
        node_feats.append(atom.GetFormalCharge())
        node_feats.append(atom.GetHybridization())
        node_feats.append(atom.GetIsAromatic())
        node_feats.append(atom.GetTotalNumHs())
        node_feats.append(atom.GetNumRadicalElectrons())
        node_feats.append(atom.IsInRing())
        node_feats.append(atom.GetChiralTag())

        all_node_feats.append(node_feats)

    all_node_feats = np.asarray(all_node_feats)
    return torch.tensor(all_node_feats, dtype=torch.float)

def count_atoms(smi):
    mol = rdkit.Chem.MolFromSmiles(smi)
    if mol is not None:
        num_atoms = mol.GetNumAtoms()
        return num_atoms
    else:
        print("Error: Unable to parse SMILES string.")
        return None

In [3]:
smi_list = get_smi('./data/ADAGRASIB_SMILES.txt')
coor_list = get_coor('./data/ADAGRASIB_COOR.sdf')
mol_list = [get_mol(smi) for smi in smi_list]

In [4]:
edge_idx_list = [get_edge_index(mol) for mol in mol_list]

In [5]:
node_feat_list = [get_node_features(mol) for mol in mol_list]

In [6]:
class MyDataset(Dataset) : 
    def __init__(self, root, filename, transform = None, pre_transform = None) :
        self.filename = filename 
        super(MyDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self) :
        return self.filename 

    @property
    def processed_file_names(self) :
        self.smi_list = get_smi(self.raw_paths[0])
        return [f'data_{i}.pt' for i in range(len(self.smi_list))]
    
    def download(self) : pass 

    def process(self) :
        smi_list = get_smi(self.raw_paths[0])
        # mol_list = [get_mol(smi) for smi in smi_list]
        coor_list = get_coor(self.raw_paths[1])

        for i, smi in enumerate(tqdm(smi_list, total=len(smi_list))) :
            if count_atoms(smi) != len(coor_list[i]) :
                smi = smi_list[i - 69]
                print('true')
            mol = get_mol(smi) 

            node_feat = get_node_features(mol)
            edge_i = get_edge_index(mol)

            data = Data(x = node_feat,
                        edge_index = edge_i,
                        y = torch.tensor(coor_list[i]))
            torch.save(data, os.path.join(self.processed_dir, f'data_{i}.pt'))


    def len(self) :
        return len(self.smi_list)
    
    def get(self, idx) :
        data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [9]:
train_set = MyDataset(root='data', filename=['ADAGRASIB_SMILES.txt', 'ADAGRASIB_COOR.sdf'])
train_loader = DataLoader(train_set, batch_size=32, shuffle=True)

FileNotFoundError: [Errno 2] No such file or directory: 'data/raw/ADAGRASIB_SMILES.txt'

In [88]:
for i in train_loader :
    print(i.\)
    break

tensor([[6., 1., 0.,  ..., 0., 0., 0.],
        [6., 2., 0.,  ..., 0., 0., 0.],
        [6., 2., 0.,  ..., 0., 0., 0.],
        ...,
        [6., 2., 0.,  ..., 0., 1., 0.],
        [7., 2., 0.,  ..., 0., 1., 0.],
        [6., 2., 0.,  ..., 0., 1., 0.]])


In [73]:
class Model(nn.Module) :
    def __init__(self, dim_model, num_head, dropout) :
        super(Model, self).__init__() 

        self.GATConv1 = gnn.GATConv(train_set.num_node_features, dim_model, heads=num_head, dropout=dropout)
        self.GATConv2 = gnn.GATConv(dim_model * num_head, 3, dropout=dropout)

    def forward(self, input) :
        x, edge_idx, batch = input.x, input.edge_index, input.batch 

        x = self.GATConv1(x, edge_idx)
        x = F.leaky_relu(x) 
        x = self.GATConv2(x, edge_idx)
        x = F.leaky_relu(x) 

        return x

In [74]:
model = Model(128, 2, 0.5)

In [85]:
for input in train_loader :
    print(input.y[0]);break

[[-3.727, -1.0768, -0.0018], [-2.7591, -0.482, -0.0015], [-1.5391, 0.2678, -0.0011], [-1.5733, 1.6642, -0.0002], [-0.3949, 2.38, 0.0006], [0.8217, 1.7192, 0.0006], [0.866, 0.3376, -0.0007], [-0.3071, -0.3944, 0.0043], [-0.2635, -1.7447, 0.0031], [2.1717, -0.3595, -0.0012], [3.2058, 0.2841, -0.0004], [2.2121, -1.5769, -0.0023]]
