In [83]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [126]:

import os
from os.path import isfile, join
from torchtext.data import RawField, Field, TabularDataset, BucketIterator
from torch.utils.data import Dataset, DataLoader
from rdkit import Chem

import numpy as np
import pandas as pd

def get_dir_files(dir_path):
    list_file = [f for f in os.listdir(dir_path) if isfile(join(dir_path, f))]
    return list_file

train_dataset_path = './dataset/processed_zinc_smiles/data_xs/train'
val_dataset_path = './dataset/processed_zinc_smiles/data_xs/val'

list_trains = get_dir_files(train_dataset_path)

pd.read_csv(join(train_dataset_path, list_trains[0])).head(5)
# a.hist()
# print(a.loc[0, :])

Unnamed: 0,smile,logP,mr,tpsa,length
0,Cc1ccc([N+](=O)[O-])cc1NC(=O)CN(C)Cc1ccsc1,3.03522,87.1161,75.48,22
1,O=c1ccn(CCCOc2ccc([N+](=O)[O-])cc2)c2ccccc12,3.3788,91.4164,74.37,24
2,CCCCOc1ccc(CSc2ncn[nH]2)cc1[N+](=O)[O-],3.1841,79.4431,93.94,21
3,Cn1cnc([N+](=O)[O-])c1N1CCC(=Cc2cccc(F)c2)CC1,3.1512,85.6864,64.2,23
4,CCOc1ccc(C(=O)N[C@H]2C[C@H]3CC[C@]2(C)C3(C)C)cc1,4.0299,88.0942,38.33,22


In [120]:
def atom_feature(atom):
    return np.array(char_to_ix(atom.GetSymbol(),
                              ['C', 'N', 'O', 'S', 'F', 'H', 'Si', 'P', 'Cl', 'Br',
                               'Li', 'Na', 'K', 'Mg', 'Ca', 'Fe', 'As', 'Al', 'I', 'B',
                               'V', 'Tl', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn',
                               'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'Mn', 'Cr', 'Pt', 'Hg', 'Pb']) +
                    one_of_k_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5]) +
                    one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) +
                    one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5]) +
                    [atom.GetIsAromatic()])    # (40, 6, 5, 6, 1)

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def char_to_ix(x, allowable_set):
    if x not in allowable_set:
        return [0] # Unknown Atom Token
    return [allowable_set.index(x)+1]

In [142]:
class zincDataset(Dataset):
    def __init__(self, data_path, skip_header=True):
        self.data = pd.read_csv(data_path)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        row = self.data.loc[index, :]
        smile = row.smile
        mol = Chem.MolFromSmiles(smile)
        adj = Chem.rdmolops.GetAdjacencyMatrix(mol)
        list_feature = list()
        for atom in mol.GetAtoms():
            list_feature.append(atom_feature(atom))
        
        return list_feature, adj, row.logP, row.mr, row.tpsa,

In [143]:
train_dataset = zincDataset(data_path=join(train_dataset_path, list_trains[0]))

In [144]:
def padding_A(x):
    print(x)

In [145]:
train_dataloader = DataLoader(train_dataset, batch_size=2, collate_fn=padding_A)
print(next(train_dataloader.__iter__()))

[([array([1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]), array([1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]), array([1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]), array([1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]), array([1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]), array([2, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]), array([3, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]), array([3, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]), array([1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]), array([1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1]), array([2, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]), array([1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]), array([3, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]), array([1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0]), array([2, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]), array([