In [257]:
import os
import pandas as pd 
import numpy as np 
import rdkit 
import torch
import torch.nn as nn 
import torch.nn.functional as F 
import torch_geometric
from torch_geometric.data import Dataset, Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool, GATConv, GCNConv
from tqdm import tqdm 

In [280]:
def count_atoms(smi):
    mol = rdkit.Chem.MolFromSmiles(smi)
    if mol is not None:
        num_atoms = mol.GetNumAtoms()
        return num_atoms
    else:
        print("Error: Unable to parse SMILES string.")
        return None
    
def get_atom_dic(smi_list) :
    atom_dic = {}
    i = 0
    for smi in smi_list :
        mol = rdkit.Chem.MolFromSmiles(smi)
        for atom in mol.GetAtoms() :
            atom = atom.GetSymbol()
            if atom not in atom_dic :
                atom_dic[atom] = i 
                i += 1
    
    return atom_dic

In [337]:
class SMILES2CoorDataset(Dataset) :
    def __init__(self, root, filename, transform = None, pre_transform = None) :
        self.filename = filename
        super(SMILES2CoorDataset, self).__init__(root, transform, pre_transform)


    @property 
    def raw_file_names(self) :
        return self.filename
    
    @property
    def processed_file_names(self) :
        self.smi_list = self.read_txt(self.raw_paths[0])
        return [f'data_{i}.pt' for i in range(len(self.smi_list))]
    
    def download(self) : pass 

    def process(self) :
        
        smi_list = self.read_txt(self.raw_paths[0])
        coor_list = self._get_coor(self.raw_paths[1])
        # self.atom_dic = get_atom_dic(smi_list)
        longest_coor = self.get_longest(coor_list)
        np_coor_list = self.pad_coor(self.normalize_coor(coor_list), longest_coor)

        for index, smi in enumerate(tqdm(smi_list, total=len(smi_list))) :
            if count_atoms(smi) != len(coor_list[index]) :
                smi = smi_list[index - 69]
                label = torch.tensor(np_coor_list[index - 69]).unsqueeze(0)
            else :
                label = torch.tensor(np_coor_list[index]).unsqueeze(0)


            mol_obj = rdkit.Chem.MolFromSmiles(smi)

            node_features = self._get_node_features(mol_obj)
            edge_features = self._get_edge_features(mol_obj)
            edge_index = self._get_edge_index(mol_obj)

            # assert node_features.size(0) == label.size(0), "Different size"
                
            data = Data(x=node_features,
                        edge_index=edge_index,
                        edge_attr=edge_features,
                        y=label,
                        smiles=smi,
                        )
            
            torch.save(data, os.path.join(self.processed_dir, f'data_{index}.pt'))


    def _get_coor(self, coor_path) :
        coor_list = []
        supplier = rdkit.Chem.SDMolSupplier(coor_path)
        for mol in supplier:
            coor = []
            if mol is not None:
                conformer = mol.GetConformer()
                for atom in mol.GetAtoms():
                    atom_idx = atom.GetIdx()
                    x, y, z = conformer.GetAtomPosition(atom_idx)
                    coor_atom = list((x,y,z))
                    coor.append(coor_atom)
            coor_list.append(coor)

        # Replace invalid idx
        for i, coor in enumerate(coor_list):
            
            if len(coor) == 0 :
                if i == 0 :
                    coor_list = coor_list[1:]
                coor_list[i] = coor_list[i-1]
        return coor_list
    
    def normalize_coor(self, coor_list) :
        n_coor_list = []

        for mol_coor in coor_list :
            n_mol_coor = []

            x_origin, y_origin, z_origin = mol_coor[0]

            for atom_coor in mol_coor :
                n_atom_coor = [round(atom_coor[0] - x_origin, 2), 
                            round(atom_coor[1] - y_origin, 2), 
                            round(atom_coor[2] - z_origin, 2)]
                n_mol_coor.append(n_atom_coor)
            n_coor_list.append(n_mol_coor)
        return n_coor_list

    def pad_coor(self, coor_list, longest_coor) :
        p_coor_list = []

        for i in coor_list :
            if len(i) < longest_coor :
                zeros = [[0,0,0]] * (longest_coor - len(i))
                zeros = torch.tensor(zeros)
                i = torch.tensor(i)
                i = torch.cat((i, zeros), dim = 0)
                p_coor_list.append(i)
            else :
                p_coor_list.append(i)
        return p_coor_list

    def get_longest(self, input_list) :
        longest = 0
        for i in input_list :
            if len(i) > longest :
                longest = len(i)
        return longest
    
    def read_txt(self, path) :
        with open(path, 'r') as file :
            contents = file.readlines()
        contents = [content[:-1] for content in contents]
        return contents
    
    def _get_node_features(self, mol) :
        all_node_feats = [] 

        for atom in mol.GetAtoms() :
            node_feats = []
            node_feats.append(atom.GetAtomicNum())
            node_feats.append(atom.GetDegree())
            node_feats.append(atom.GetFormalCharge())
            node_feats.append(atom.GetHybridization())
            node_feats.append(atom.GetIsAromatic())
            node_feats.append(atom.GetTotalNumHs())
            node_feats.append(atom.GetNumRadicalElectrons())
            node_feats.append(atom.IsInRing())
            node_feats.append(atom.GetChiralTag())

            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)

    def _get_edge_features(self, mol) :
        all_edge_feats = [] 

        for bond in mol.GetBonds() :
            edge_feats = []

            edge_feats.append(bond.GetBondTypeAsDouble())
            edge_feats.append(bond.IsInRing())

            all_edge_feats += [edge_feats, edge_feats]


        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)
    
    def _get_edge_index(self, mol) :
        edge_indices, begin, end = [], [], []

        for bond in mol.GetBonds() :
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx() 
            begin.append(i), begin.append(j)
            end.append(i), end.append(j)
        edge_indices.append(begin), edge_indices.append(end)

        return torch.tensor(edge_indices)
    
    
    def len(self) :
        return len(self.smi_list)
    
    def get(self, idx) :
        data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
        return data

In [338]:
train_data = SMILES2CoorDataset(root='data', filename=['ADAGRASIB_SMILES.txt', "ADAGRASIB_COOR.sdf"])
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)

Processing...
  label = torch.tensor(np_coor_list[index]).unsqueeze(0)
  label = torch.tensor(np_coor_list[index - 69]).unsqueeze(0)
100%|██████████| 4255/4255 [00:04<00:00, 958.28it/s] 
Done!


In [404]:
a = 0 
for i in train_loader :
    a = i
    # ele, count = np.unique(i.batch.numpy(), return_counts=True)
    print(i)

    break

DataBatch(x=[293, 9], edge_index=[2, 612], edge_attr=[612, 2], y=[16, 22, 3], smiles=[16], batch=[293], ptr=[17])


In [405]:
batch = a.batch.numpy()
x = a.x
ele, count = np.unique(batch, return_counts=True)
ele, count

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],
       dtype=int64),
 array([18, 16, 22, 21, 21, 19, 21, 20, 14, 18, 21, 14, 16, 16, 17, 19],
       dtype=int64))

In [406]:
a = torch.empty(0)
cur = 0

for i, c in enumerate(count) :
    gap = 22 - c 
    zeros = torch.zeros(gap, 9)
    a = torch.cat((a, x[cur:cur + c,:], zeros)) if i != 0 else torch.cat((a, x[:c,:], zeros))
    cur += c


torch.Size([18, 9])
torch.Size([4, 9])
torch.Size([16, 9])
torch.Size([6, 9])
torch.Size([22, 9])
torch.Size([0, 9])
torch.Size([21, 9])
torch.Size([1, 9])
torch.Size([21, 9])
torch.Size([1, 9])
torch.Size([19, 9])
torch.Size([3, 9])
torch.Size([21, 9])
torch.Size([1, 9])
torch.Size([20, 9])
torch.Size([2, 9])
torch.Size([14, 9])
torch.Size([8, 9])
torch.Size([18, 9])
torch.Size([4, 9])
torch.Size([21, 9])
torch.Size([1, 9])
torch.Size([14, 9])
torch.Size([8, 9])
torch.Size([16, 9])
torch.Size([6, 9])
torch.Size([16, 9])
torch.Size([6, 9])
torch.Size([17, 9])
torch.Size([5, 9])
torch.Size([19, 9])
torch.Size([3, 9])


In [394]:
class Model(nn.Module) :
    def __init__(self, dim_model, num_head, dropout) :
        super(Model, self).__init__()
        self.Embedding = nn.Embedding(20, dim_model)
        self.Dropout = nn.Dropout(dropout)
        self.GATConv1 = GATConv(dim_model, dim_model, heads=num_head,dropout=dropout)
        self.GATConv2 = GATConv(dim_model * num_head, dim_model, dropout=dropout)
        self.Linear = nn.Linear(dim_model, 3) 


    def forward(self, input) :
        x, edge_index, edge_attr, batch = input.x, input.edge_index, input.edge_attr, input.batch 
        x = self.Dropout(self.Embedding(x))
        x = self.GATConv1(x, edge_index)
        x = F.leaky_relu(x)
        x = self.GATConv2(x, edge_index)
        x = F.leaky_relu(x)

        out = self.Linear(x) 
        
        return out


In [423]:
class Encoder(nn.Module) :
    def __init__(self, dim_model, num_head, dropout) :
        super(Encoder, self).__init__()
        self.dim_model = dim_model
        self.GATConv1 = GATConv(train_data.num_features, dim_model, heads=num_head, dropout=dropout)
        self.GATConv2 = GATConv(dim_model*num_head, dim_model, dropout=dropout)
    
    def forward(self, input) :
        x, edge_index, edge_attr, batch, longest = input.x, input.edge_index, input.edge_attr, input.batch, input.y.size(1)
        x = self.GATConv1(x, edge_index, edge_attr)
        x = F.leaky_relu(x)
        x = self.GATConv2(x, edge_index, edge_attr)
        x = F.leaky_relu(x)
        pool = global_mean_pool(x, batch)
        out = torch.empty(0)
        cur = 0 

        for i, c in enumerate(count) :
            zeros = torch.zeros(longest - c, self.dim_model)
            out = torch.cat((out, x[cur : cur + c, :], zeros)) if i != 0 else torch.cat((out, x[:c, :], zeros))
            cur += c
        return out, pool

In [424]:
class Decoder(nn.Module) :
    def __init__(self, dim_model, dropout) :
        super(Decoder, self).__init__()
        self.GRU = nn.GRU(dim_model, dim_model)
        

In [425]:
encoder = Encoder(128, 2, 0.5)

In [427]:
for i in train_loader:
    out, pool = encoder(i)
    print(f'out: {out.shape}')
    print(f'pool: {pool.shape}')
    break

out: torch.Size([352, 128])
pool: torch.Size([16, 128])


In [399]:
model = Model(dim_model=128,
              num_head=4,
              dropout=0.5)
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [400]:
for epoch in range(1, 30) :
    epoch_loss = 0 

    for data in train_loader :
        out = model(data) 
        # print(f'Data: {data}')
        # print(f'y: {data.y}')
        loss = criterion(out, data.y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        epoch_loss += loss.item()

    print(f'Epoch {epoch}: {epoch_loss / len(train_loader)}')
    epoch_loss = 0

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)