In [1]:
import pandas as pd
import sys
sys.path.append("../")
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import dense_to_sparse, to_dense_adj
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn.conv import TransformerConv, GATConv
from torch_geometric.nn import BatchNorm, global_max_pool, Set2Set
import numpy as np
from rdkit import Chem
from tqdm import tqdm
from rdkit import RDLogger
from torch_geometric.utils import to_networkx
import networkx as nx
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL) 


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
data = pd.read_csv("../data/androgen_data.csv")
smiles = data["canonical_smiles"].to_list()
molecules = [Chem.MolFromSmiles(smile) for smile in smiles]

In [3]:
non_unique_nums = []
non_unique_bonds = []

for molecule in molecules:
    
    for atom in molecule.GetAtoms():
        non_unique_nums.append(atom.GetAtomicNum())
    
    for bond in molecule.GetBonds():
        non_unique_bonds.append(bond.GetBondType())

unique_bonds = set(non_unique_bonds)
unique_nums = set(non_unique_nums)

In [4]:
atomic_alphabet = {1:"H", 6:"C", 7:"N", 8:"O", 9:"F",14:"Si",16:"S", 17:"Cl",35:"Br", 53:"I"}
encoded_atom_nums = {initial:encoded for encoded, initial in enumerate(unique_nums)}

encoded_atom_bonds = {initial:encoded+1 for encoded, initial in enumerate(unique_bonds)}
max_atoms = max([molecule.GetNumAtoms() for molecule in molecules])

In [5]:
print("Encoded atoms: ", encoded_atom_nums)
print("\nEncoded bonds: ", encoded_atom_bonds)

Encoded atoms:  {1: 0, 35: 1, 6: 2, 7: 3, 8: 4, 9: 5, 14: 6, 16: 7, 17: 8, 53: 9}

Encoded bonds:  {rdkit.Chem.rdchem.BondType.SINGLE: 1, rdkit.Chem.rdchem.BondType.DOUBLE: 2, rdkit.Chem.rdchem.BondType.TRIPLE: 3, rdkit.Chem.rdchem.BondType.AROMATIC: 4}


In [6]:
def get_annotation_matrix(molecule,alphabet):
    annotation_matrix = torch.zeros((molecule.GetNumAtoms(), len(alphabet)))
    
    for i, atom in enumerate(molecule.GetAtoms()):
        j = alphabet[atom.GetAtomicNum()]
        annotation_matrix[i,j] = 1
        
    return annotation_matrix.float()


def get_edge_index_with_attrs(molecule, alphabet):
    matrix: ndarray = Chem.GetAdjacencyMatrix(molecule)
    num_atoms = molecule.GetNumAtoms()

    for bond in molecule.GetBonds():
        i: int = bond.GetBeginAtomIdx()
        j: int = bond.GetEndAtomIdx()
        
        bond_idx = alphabet[bond.GetBondType()]
        
        matrix[i, j] = bond_idx
        matrix[j, i] = bond_idx
    
    return dense_to_sparse(torch.LongTensor(matrix))


def get_torch_data(molecule, atom_alphabet, bond_alphabet):
    
    edge_index, edge_attr = get_edge_index_with_attrs(molecule, bond_alphabet)
    x = get_annotation_matrix(molecule, atom_alphabet)
    
    return Data(
        x=x,
        edge_index=edge_index,
        edge_attr=edge_attr,
    )

In [7]:
dataset = [get_torch_data(molecule, encoded_atom_nums, encoded_atom_bonds) for molecule in molecules]
dataloader = DataLoader(dataset, batch_size=128)

In [206]:
class GraphGAN(nn.Module):
    def __init__(self, in_channels, max_mol_size, decode_dim, decode_hidden, atom_alphabet_size, bond_alphabet_size):
        self.in_channels = in_channels
        self.bond_alphabet_size = bond_alphabet_size
        self.encoder_embeding_size = 512
        self.max_mol_size = max_mol_size
        self.decode_dim = decode_dim
        self.decode_hidden = decode_hidden
        self.annotation_mat_size = atom_alphabet_size*max_mol_size
        self.adj_tensor_shape = int(((max_mol_size*(max_mol_size-1))/2)*bond_alphabet_size)
        #self.adj_tensor_shape = self.bond_alphabet_size*self.max_mol_size*self.max_mol_size
        super().__init__()
        
        self.conv1 = GATConv(in_channels=self.in_channels, edge_dim=1, out_channels=self.encoder_embeding_size)
        self.bn1 = BatchNorm(self.encoder_embeding_size)
        self.conv2 = GATConv(in_channels=self.encoder_embeding_size, edge_dim=1, out_channels=self.encoder_embeding_size)
        
        self.pooling = Set2Set(self.encoder_embeding_size, processing_steps=4)
        self.discrim_linear = nn.Linear(self.encoder_embeding_size*2, 1)
        
        self.gen_linear_1 = nn.Linear(self.decode_dim, self.decode_hidden)
        self.gen_linear_2 = nn.Linear(self.decode_hidden, self.decode_hidden)
        
        self.gen_annotation_mat = nn.Linear(self.decode_hidden, self.annotation_mat_size)
        self.gen_upper_adj_tensor = nn.Linear(self.decode_hidden, self.adj_tensor_shape)
    
    def discriminator(self, data):
        x, edge_index, edge_attr, batch = (
        data.x.float(),
        data.edge_index.long(),
        data.edge_attr.float(),
        data.batch
    )
        z = F.relu(self.conv1(x, edge_index, edge_attr))
        z = self.bn1(z)
        z = self.conv2(z, edge_index, edge_attr)
        z = self.pooling(z, batch)
        z = torch.sigmoid(self.discrim_linear(z))
        
        return z
    
    def generate_graph(self, z):
        z = F.relu(self.gen_linear_1(z))
        z = F.relu(self.gen_linear_2(z))
        
        annotation_matrix = F.gumbel_softmax(self.gen_annotation_mat(z).reshape(self.max_mol_size, self.in_channels), hard=True)
        adj_tensor_2d = torch.tanh(self.gen_upper_adj_tensor(z).reshape(4,-1))
        return annotation_matrix, adj_tensor_2d
    
    def generate_batch(self, Z):
        atoms_logits = []
        edge_logits = []

In [207]:
MAX_MOLECULE_SIZE = 20
noise = torch.randn(512)
batch = next(iter(dataloader))

model = GraphGAN(10, 20, 512, 1024, 10, 4)

with torch.no_grad():
    encoded = model.discriminator(batch)
    annotation_matrix, adj_tensor_2d = model.generate_graph(noise)

In [208]:
def get_adj_mat(max_nodes, triu_vals):
    triu_indices = torch.triu_indices(max_nodes,max_nodes,1)

    adj_new = torch.zeros(max_nodes, max_nodes)

    adj_new[triu_indices[0],triu_indices[1]] = triu_vals
    adj_new = adj_new + torch.transpose(adj_new, 0,1)
    
    return adj_new

In [215]:
def get_adj_tensor_3d(adj_tensor_2d, max_nodes, max_edges):
    adj_tensor_3d = torch.empty(max_edges, max_nodes, max_nodes)
    
    for edge in range(max_edges):
        adj_tensor_3d[edge] = get_adj_mat(max_nodes, adj_tensor_2d[edge])
    
    return adj_tensor_3d

In [226]:
d=Data()
d.x = annotation_matrix

adj_tensor_3d = get_adj_tensor_3d(adj_tensor_2d, 20, 4)
d.edge_index, d.edge_attr = dense_to_sparse(adj_tensor_3d)

d

Data(x=[20, 10], edge_index=[2, 1520], edge_attr=[1520])