In [4]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [5]:
import os
from os.path import isfile, join
import random
from torchtext.data import RawField, Field, TabularDataset, BucketIterator
from torch.utils.data import Dataset, DataLoader
from rdkit import Chem
from scipy.linalg import fractional_matrix_power


import numpy as np
import pandas as pd

MASKING_RATE = 0.15
ERASE_RATE = 0.5

def get_dir_files(dir_path):
    list_file = [f for f in os.listdir(dir_path) if isfile(join(dir_path, f))]
    return list_file

train_dataset_path = './dataset/processed_zinc_smiles/data_xs/train'
val_dataset_path = './dataset/processed_zinc_smiles/data_xs/val'

list_trains = get_dir_files(train_dataset_path)

pd.read_csv(join(train_dataset_path, list_trains[0])).head(5)
# a.hist()
# print(a.loc[0, :])

Unnamed: 0,smile,logP,mr,tpsa,length
0,Cc1ccc([N+](=O)[O-])cc1NC(=O)CN(C)Cc1ccsc1,3.03522,87.1161,75.48,22
1,O=c1ccn(CCCOc2ccc([N+](=O)[O-])cc2)c2ccccc12,3.3788,91.4164,74.37,24
2,CCCCOc1ccc(CSc2ncn[nH]2)cc1[N+](=O)[O-],3.1841,79.4431,93.94,21
3,Cn1cnc([N+](=O)[O-])c1N1CCC(=Cc2cccc(F)c2)CC1,3.1512,85.6864,64.2,23
4,CCOc1ccc(C(=O)N[C@H]2C[C@H]3CC[C@]2(C)C3(C)C)cc1,4.0299,88.0942,38.33,22


In [6]:
def atom_feature(atom):
    return np.array(char_to_ix(atom.GetSymbol(),
                              ['C', 'N', 'O', 'S', 'F', 'H', 'Si', 'P', 'Cl', 'Br',
                               'Li', 'Na', 'K', 'Mg', 'Ca', 'Fe', 'As', 'Al', 'I', 'B',
                               'V', 'Tl', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn',
                               'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'Mn', 'Cr', 'Pt', 'Hg', 'Pb']) +
                    one_of_k_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5]) +
                    one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) +
                    one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5]) +
                    [atom.GetIsAromatic()])    # (40, 6, 5, 6, 1)

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def char_to_ix(x, allowable_set):
    if x not in allowable_set:
        return [0] # Unknown Atom Token
    return [allowable_set.index(x)+1]

In [7]:
class zincDataset(Dataset):
    def __init__(self, data_path, skip_header=True):
        self.data = pd.read_csv(data_path)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        row = self.data.loc[index, :]
        smile = row.smile
        mol = Chem.MolFromSmiles(smile)
        adj = Chem.rdmolops.GetAdjacencyMatrix(mol)
        list_feature = list()
        for atom in mol.GetAtoms():
            list_feature.append(atom_feature(atom))
        
        return row.length, np.array(list_feature), adj, row.logP, row.mr, row.tpsa

In [11]:
def random_onehot(size):
    """ Generate random one-hot encoding vector with given size. """
    temp = np.zeros(size)
    temp[np.random.randint(0, size)] = 1
    return temp 

def normalize_adj(mx):
    """ Symmetry Normalization """
    rowsum = np.diag(np.array(mx.sum(1)))
    r_inv = fractional_matrix_power(rowsum, -0.5)
    r_inv[np.isinf(r_inv)] = 0.
    return r_inv.dot(mx).dot(r_inv)

def masking_feature(feature, num_masking):
    """ Given feature, select 'num_masking' node feature and perturbate them.
    
        [5 features : Atom symbol, degree, num Hs, valence, isAromatic]  
        were masked with zero or changed with random one-hot encoding 
        or remained with origianl data(but still should be predicted).
        
        Masking process was conducted on each feature indiviually. 
        For example, if ERASE_RATE = 0.5, probability for all feature information with zero is 0.5^5 = 0.03125
        
        return original hode feature with their corresponding indices
    """
    
    masking_indices = np.random.choice(len(feature), num_masking, replace=False)
    ground_truth = np.copy(feature[masking_indices, :])
    for i in masking_indices:
        prob_masking = np.random.rand(5)
        # Masking Atom Symbol 
        if prob_masking[0] < ERASE_RATE:
            feature[i, 0] = 0
        elif prob_masking[0] > 1- ((1-ERASE_RATE) * 0.5):
            feature[i, 0] = np.random.randint(1, 41)
            
        # Masking Degree 
        if prob_masking[1] < ERASE_RATE:
            feature[i, 1:7] = np.zeros(6)
        elif prob_masking[1] > 1- ((1-ERASE_RATE) * 0.5):
            feature[i, 1:7] =  random_onehot(6)
        
        # Masking Num Hs
        if prob_masking[2] < ERASE_RATE:
            feature[i, 7:12] = np.zeros(5)
        elif prob_masking[2] > 1- ((1-ERASE_RATE) * 0.5):
            feature[i, 7:12] =  random_onehot(5)
            
        # Masking Valence
        if prob_masking[3] < ERASE_RATE:
            feature[i, 12:18] = np.zeros(6)
        elif prob_masking[3] > 1- ((1-ERASE_RATE) * 0.5):
            feature[i, 12:18] =  random_onehot(6)
            
        # Masking IsAromatic
        if prob_masking[4] < ERASE_RATE:
            feature[i, 18] = (feature[i, 18]+1)%2

    return feature, ground_truth, masking_indices


def postprocess_batch(mini_batch):
    """ Given mini-batch sample, adjacency matrix and node feature vectors were padded with zero. """
    max_length = max([row[0] for row in mini_batch])
    min_length = min([row[0] for row in mini_batch])
    print(min_length, max_length)
    num_masking = int(max_length * MASKING_RATE)
    batch_length = len(mini_batch)
    batch_feature = np.zeros((batch_length, max_length, mini_batch[0][1].shape[1]), dtype=int)
    batch_adj = np.zeros((batch_length, max_length, max_length))
    batch_property = np.zeros((batch_length, 3))
    batch_ground = np.zeros((batch_length, num_masking, mini_batch[0][1].shape[1]), dtype=int)
    batch_masking = np.zeros((batch_length, num_masking), dtype=int)
    
    for i, row in enumerate(mini_batch):
        mol_length, feature, adj = row[0], row[1], row[2]
        masked_feature, ground_truth, masking_indices  = masking_feature(feature, num_masking)
        batch_feature[i, :mol_length, :] = masked_feature
        batch_ground[i, :, :] = ground_truth
        batch_masking[i, :] = masking_indices
        batch_adj[i, :mol_length, :mol_length] = normalize_adj(adj+np.eye(len(adj)))
        batch_property[i, :] = [row[3], row[4], row[5]]
        
    return batch_feature, batch_adj, batch_property, batch_ground, batch_masking

In [12]:
train_dataset = zincDataset(data_path=join(train_dataset_path, list_trains[0]))
print(len(train_dataset))

424114


In [13]:
train_dataloader = DataLoader(train_dataset, batch_size=1000, collate_fn=postprocess_batch, num_workers=12)
# print(next(train_dataloader.__iter__())[3])
for batch in train_dataloader:
#     print(batch[0])
    print(5)

12 35
17 26
17 28
18 26
12 19
16 26
13 37
13 41
17 26
19 27
18 34
17 36
5
5
5
5
5
18 31
5
5
5
5
5
5
5
13 31
13 27
17 24
13 51
17 24
12 36
16 27
20 28
20 25
13 23
17 36
5
5
5
5
5
5
5
5
5
5
5
5
17 32
17 31
18 31
20 31
19 31


Process Process-15:
Process Process-22:
Process Process-20:
Process Process-23:
Process Process-21:
Process Process-18:
Process Process-13:
Process Process-19:
Process Process-14:
Process Process-24:
Process Process-16:
Traceback (most recent call last):
Process Process-17:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/process.py", line 93, in run
   

Traceback (most recent call last):
  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2963, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-13-6942e66905a8>", line 3, in <module>
    for batch in train_dataloader:
  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 280, in __next__
    idx, batch = self._get_batch()
  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 259, in _get_batch
    return self.data_queue.get()
  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6

  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 57, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 57, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 57, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/site-packages/scipy/linalg/matfuncs.py", line 138, in fractional_matrix_power
    return scipy.linalg._matfuncs_inv_ssq._fractional_matrix_power(A, t)
  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 57, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])


KeyboardInterrupt: 

  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 57, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "<ipython-input-7-73b881655bda>", line 17, in __getitem__
    return row.length, np.array(list_feature), adj, row.logP, row.mr, row.tpsa
KeyboardInterrupt
  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/site-packages/scipy/linalg/matfuncs.py", line 138, in fractional_matrix_power
    return scipy.linalg._matfuncs_inv_ssq._fractional_matrix_power(A, t)
  File "<ipython-input-7-73b881655bda>", line 9, in __getitem__
    row = self.data.loc[index, :]
  File "<ipython-input-11-5cee214a9a13>", line 10, in normalize_adj
    r_inv = fractional_matrix_power(rowsum, -0.5)
  File "<ipython-input-7-73b881655bda>", line 17, in __getitem__
    return row.length, np.array(list_feature), adj, row.logP, row.mr, row.tpsa
  File "/home/jaeyoung/anaconda3/envs/pytorch/lib/python3.6/site-packages/t