In [1]:
import pandas as pd
import sys
sys.path.append("../")
from tcad.tools.nntools import SmilesDataSet, SmilesEncoder
from tcad.deep.fit import train_gan
from torch.utils.data import DataLoader
from tcad.deep.models import GAN
import matplotlib.pyplot as plt
import torch
import numpy as np
from rdkit import Chem
from tqdm import tqdm
from rdkit import RDLogger

lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL) 


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
data = pd.read_csv("../data/androgen_data.csv")
smiles = data["canonical_smiles"].to_list()
molecules = [Chem.MolFromSmiles(smile) for smile in smiles]

In [19]:
non_unique_nums = []

for molecule in molecules:
    
    for atom in molecule.GetAtoms():
        non_unique_nums.append(atom.GetAtomicNum())

unique_nums = set(unique_nums)
unique_nums

{1, 6, 7, 8, 9, 14, 16, 17, 35, 53}

In [42]:
atomic_alphabet = {1:"H", 6:"C", 7:"N", 8:"O", 9:"F",14:"Si",16:"S", 17:"Cl",35:"Br", 53:"I"}
encoded_atom_nums = {initial:encoded for encoded, initial in enumerate(unique_nums)}
encoded_atom_nums["pad"] = len(encoded_atom_nums)

In [43]:
encoded_atom_nums

{1: 0, 35: 1, 6: 2, 7: 3, 8: 4, 9: 5, 14: 6, 16: 7, 17: 8, 53: 9, 'pad': 10}

In [33]:
max_atoms = max([molecule.GetNumAtoms() for molecule in molecules])
max_atoms

82

In [47]:
def get_annotation_matrix(molecule, max_size, alphabet):
    annotation_matrix = torch.zeros((max_size, len(alphabet)))
    num_atoms = molecule.GetNumAtoms()
    
    for i, atom in enumerate(molecule.GetAtoms()):
        j = alphabet[atom.GetAtomicNum()]
        annotation_matrix[i,j] = 1
    for i in range(1, max_size-num_atoms+1):
        annotation_matrix[-i, -1] = 1
    return annotation_matrix

In [48]:
get_annotation_matrix(molecules[0], 82, encoded_atom_nums)

tensor([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.,