In [50]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import Dataset 
import torch_geometric
import torch_geometric.nn as gnn
from torch_geometric.data import Dataset as gDataset
from torch_geometric.data import Data as gData 
from torch_geometric.loader import DataLoader as gDataLoader
from torch_geometric.datasets import QM9, ZINC
from tqdm import tqdm 
import rdkit
from rdkit.Chem import MolFromSmiles as get_mol
from rdkit.Chem.rdmolops import GetAdjacencyMatrix as get_mat
import numpy as np 
import os 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [51]:
def get_smi(path) :
    with open(path, 'r') as file :
        contents = file.readlines() 
    smi_list = [content[:-1] for content in contents]
    return smi_list

def get_coor(path) :
    coor_list = []
    supplier = rdkit.Chem.SDMolSupplier(path)
    for mol in supplier:
        coor = []
        if mol is not None:
            conformer = mol.GetConformer()
            for atom in mol.GetAtoms():
                atom_idx = atom.GetIdx()
                x, y, z = conformer.GetAtomPosition(atom_idx)
                coor_atom = list((x,y,z))
                coor.append(coor_atom)
        coor_list.append(coor)

    # Replace invalid idx
    for i, coor in enumerate(coor_list):
        
        if len(coor) == 0 :
            if i == 0 :
                coor_list = coor_list[1:]
            coor_list[i] = coor_list[i-1]
    return coor_list

def get_edge_index(mol) :
    edge_indices, begin, end = [], [], []

    for bond in mol.GetBonds() :
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx() 
        begin.append(i), end.append(j)

    for bond in mol.GetBonds() :
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        begin.append(j), end.append(i)
    edge_indices.append(begin), edge_indices.append(end)

    return torch.tensor(edge_indices)

def get_edge_features( mol) :
    all_edge_feats = [] 

    for bond in mol.GetBonds() :
        edge_feats = []

        edge_feats.append(bond.GetBondType())
        all_edge_feats.append(edge_feats)
    
    for bond in mol.GetBonds() :
        edge_feats = []

        edge_feats.append(bond.GetBondType())
        all_edge_feats.append(edge_feats)

    all_edge_feats = np.asarray(all_edge_feats)
    return torch.tensor(all_edge_feats, dtype=torch.float)

def get_node_features(mol) :
    all_node_feats = [] 

    for atom in mol.GetAtoms() :
        node_feats = []
        node_feats.append(atom.GetAtomicNum())
        node_feats.append(atom.GetDegree())
        node_feats.append(atom.GetFormalCharge())
        node_feats.append(atom.GetHybridization())
        node_feats.append(atom.GetIsAromatic())
        node_feats.append(atom.GetTotalNumHs())
        node_feats.append(atom.GetNumRadicalElectrons())
        node_feats.append(atom.IsInRing())
        node_feats.append(atom.GetChiralTag())

        all_node_feats.append(node_feats)

    all_node_feats = np.asarray(all_node_feats)
    return torch.tensor(all_node_feats, dtype=torch.float)

def count_atoms(smi):
    mol = rdkit.Chem.MolFromSmiles(smi)
    if mol is not None:
        num_atoms = mol.GetNumAtoms()
        return num_atoms
    else:
        print("Error: Unable to parse SMILES string.")
        return None

def get_atom(smi) :
    mol = get_mol(smi) 
    atom_list = [x.GetSymbol() for x in mol.GetAtoms()]
    return atom_list

def get_bond(smi) :
    mol = get_mol(smi) 
    bond_list = [x.GetBondType() for x in mol.GetBonds()]
    return bond_list 

def get_atom_mat(smi, max_atom) :
    atom_list = get_atom(smi) 
    mat = get_mat(get_mol(smi))
    wmat = np.zeros((max_atom, max_atom))
    for i, (m, w) in enumerate(zip(mat, wmat)) :
        ones = np.where(m == 1)[0]

        for idx in ones :
            wmat[i][idx] = atom_dic[atom_list[idx]]
    return torch.tensor(wmat, dtype = torch.long)

def get_bond_mat(smi, max_atom) :
    mol = get_mol(smi) 
    mat = get_mat(mol)

    wmat = np.zeros((max_atom, max_atom))

    for bond in mol.GetBonds() :
        b, e = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()

        wmat[b][e] = bond_dic[str(bond.GetBondType())]
        wmat[e][b] = bond_dic[str(bond.GetBondType())]

    return torch.tensor(wmat, dtype = torch.long)

def get_dic(smi_list) :
    a_dic, b_dic, i, j  = {}, {}, 1, 1 

    for smi in smi_list :
        mol = get_mol(smi) 
        for atom in mol.GetAtoms() :
            symbol = atom.GetSymbol() 
            if symbol not in a_dic : 
                a_dic[symbol] = i; i += 1

        for bond in mol.GetBonds() :
            bond_type = str(bond.GetBondType())
            if bond_type not in b_dic : 
                b_dic[bond_type] = j; j += 1 
    return a_dic, b_dic 
    

def pad_bond(tensor, max_atom) :
    row, col = tensor.size() 
    pad = torch.cat((tensor, torch.zeros(row, max_atom - col)), dim = 1)
    pad = torch.cat((pad, torch.zeros(max_atom - row, pad.size(1))), dim = 0)
    return pad 


In [52]:
smi_list = get_smi('./data/raw/ADAGRASIB_SMILES.txt')
atom_dic, bond_dic = get_dic(smi_list)
atom_dic['None'] = 0 
bond_dic['None'] = 0

In [302]:
class MyData(gData) :
    def __cat_dim__(self, key, value, *args, **kwargs):
        if key == 'atom_mat' or key == 'bond_mat':
            return None
        return super().__cat_dim__(key, value, *args, **kwargs)

In [415]:
class MyDataset(gDataset) : 
    def __init__(self, root, filename, transform = None, pre_transform = None) :
        self.filename = filename 
        super(MyDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self) :
        return self.filename 

    @property
    def processed_file_names(self) :
        self.smi_list = get_smi(self.raw_paths[0])
        return [f'data_{i}.pt' for i in range(len(self.smi_list))]
    
    def download(self) : pass 

    def process(self) :
        smi_list = get_smi(self.raw_paths[0])
        atom_dic = get_dic(smi_list) 
        
        max_atom = count_atoms(max(smi_list, key = lambda x : count_atoms(x)))
        atom_mat_list = [get_wmat(smi, max_atom) for smi in smi_list]
        bond_mat_list = [get_bond_mat(smi, max_atom) for smi in smi_list]

        for i, smi in enumerate(tqdm(smi_list, total=len(smi_list))) :
            mol = get_mol(smi) 

            node_feat = get_node_features(mol)
            edge_idx = get_edge_index(mol)
            edge_attr = get_edge_features(mol)

            data = MyData(x = node_feat,
                        edge_index = edge_idx,
                        edge_attr=edge_attr,
                        atom_mat = atom_mat_list[i],
                        bond_mat = bond_mat_list[i])

            torch.save(data, os.path.join(self.processed_dir, f'data_{i}.pt'))


    def len(self) :
        return len(self.smi_list)
    
    def get(self, idx) :
        data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
        return data.to(device)

In [416]:
dataset = MyDataset(root='data', filename=['ADAGRASIB_SMILES.txt'])
train_loader = gDataLoader(dataset, batch_size=16)

Processing...
100%|███████████████████████████████████████████████████████████████████████████████████████████| 4255/4255 [00:15<00:00, 271.76it/s]
Done!


In [67]:
class VAE(nn.Module) :
    def __init__(self, dim_latent) :
        super(VAE, self).__init__()
        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
        self.N.scale = self.N.scale.cuda()
        self.kl = 0

        self.gc1 = gnn.GCNConv(9, 128)
        self.gc1_bn = nn.BatchNorm1d(128)
        self.gc2 = gnn.GCNConv(128, 128) 
        self.gc2_bn = nn.BatchNorm1d(128)
        
        self.ff1 = nn.Linear(128, 128) 
        self.ff1_bn = nn.BatchNorm1d(128)

        self.mu = nn.Linear(128, dim_latent)
        self.sigma = nn.Linear(128, dim_latent) 
    
        self.a_emb = nn.Embedding(len(atom_dic), 128) 
        self.b_emb = nn.Embedding(len(bond_dic), 128)

        self.a_seq = nn.Sequential(
            nn.Linear(dim_latent, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 512),
            nn.LeakyReLU(),
            nn.Linear(512, len(atom_dic))
        )

        self.b_seq = nn.Sequential(
            nn.Linear(dim_latent, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 512),
            nn.LeakyReLU(),
            nn.Linear(512, len(bond_dic))
        )
    def encode(self, input) :
        x, edge_index, batch, atom_mat, bond_mat = input.x, input.edge_index, input.batch, input.atom_mat, input.bond_mat


        atom_mat = self.a_emb(atom_mat) 
        bond_mat = self.b_emb(bond_mat) 

        x = self.gc1(x, edge_index)
        x = F.relu(self.gc1_bn(x))
        x = self.gc2(x, edge_index)
        x = F.relu(self.gc2_bn(x))
        x = gnn.global_add_pool(x, batch) 
        x = self.ff1(x) 
        x = F.relu(self.ff1_bn(x))
        
        x = x.unsqueeze(1).unsqueeze(1)

        x = x + atom_mat + bond_mat
        
        mu, sigma = self.mu(x), self.sigma(x) 

        z = mu + sigma * self.N.sample(mu.shape)

        self.kl = (sigma ** 2 + mu ** 2 - torch.log(sigma) - 1/2).sum()

        return z 

    def decode(self, z) :
        atom_mat, bond_mat = self.a_seq(z), self.b_seq(z) 
        return atom_mat, bond_mat
    
    def forward(self, x) :
        z = self.encode(x) 
        atom_mat, bond_mat = self.decode(z)
        
        return atom_mat, bond_mat



In [68]:
model = VAE(32).to(device)

In [69]:
for i in train_loader :
    a_mat, b_mat = model(i)
    print(a_mat.shape)
    break

torch.Size([16, 22, 22, 13])


In [62]:
a = torch.randn(5)
a

tensor([-0.1138, -0.3743, -0.3067,  0.1735, -1.2267])

In [63]:
F.softmax(a)

  F.softmax(a)


tensor([0.2349, 0.1810, 0.1937, 0.3131, 0.0772])