In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
from os.path import isfile, join
import random
from torchtext.data import RawField, Field, TabularDataset, BucketIterator
from torch.utils.data import Dataset, DataLoader
from rdkit import Chem
from scipy.linalg import fractional_matrix_power


import numpy as np
import pandas as pd

MASKING_RATE = 0.15
ERASE_RATE = 0.5

def get_dir_files(dir_path):
    list_file = [f for f in os.listdir(dir_path) if isfile(join(dir_path, f))]
    return list_file

train_dataset_path = './dataset/processed_zinc_smiles/data_xs/train'
val_dataset_path = './dataset/processed_zinc_smiles/data_xs/val'

list_trains = get_dir_files(train_dataset_path)

pd.read_csv(join(train_dataset_path, list_trains[0])).head(5)
# a.hist()
# print(a.loc[0, :])

In [None]:
def atom_feature(atom):
    return np.array(char_to_ix(atom.GetSymbol(),
                              ['C', 'N', 'O', 'S', 'F', 'H', 'Si', 'P', 'Cl', 'Br',
                               'Li', 'Na', 'K', 'Mg', 'Ca', 'Fe', 'As', 'Al', 'I', 'B',
                               'V', 'Tl', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn',
                               'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'Mn', 'Cr', 'Pt', 'Hg', 'Pb']) +
                    one_of_k_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5]) +
                    one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) +
                    one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5]) +
                    [atom.GetIsAromatic()])    # (40, 6, 5, 6, 1)

def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def char_to_ix(x, allowable_set):
    if x not in allowable_set:
        return [0] # Unknown Atom Token
    return [allowable_set.index(x)+1]

In [None]:
class zincDataset(Dataset):
    def __init__(self, data_path, skip_header=True):
        self.data = pd.read_csv(data_path)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        row = self.data.loc[index, :]
        smile = row.smile
        mol = Chem.MolFromSmiles(smile)
        adj = Chem.rdmolops.GetAdjacencyMatrix(mol)
        list_feature = list()
        for atom in mol.GetAtoms():
            list_feature.append(atom_feature(atom))
        
        return row.length, np.array(list_feature), adj, row.logP, row.mr, row.tpsa

In [None]:
def random_onehot(size):
    temp = np.zeros(size)
    temp[np.random.randint(0, size)] = 1
    return temp 

def normalize_adj(mx):
    """Symmetry Normalization"""
    rowsum = np.diag(np.array(mx.sum(1)))
    r_inv = fractional_matrix_power(rowsum, -0.5)
    r_inv[np.isinf(r_inv)] = 0.
    return r_inv.dot(mx).dot(r_inv)

def masking_feature(feature, num_masking):
    masking_indices = np.random.choice(len(feature), num_masking, replace=False)
    ground_truth = np.copy(feature[masking_indices, :])
    for i in masking_indices:
        prob_masking = np.random.rand(5)
        # Masking Atom Symbol 
        if prob_masking[0] < ERASE_RATE:
            feature[i, 0] = 0
        elif prob_masking[0] > 1- ((1-ERASE_RATE) * 0.5):
            feature[i, 0] = np.random.randint(1, 41)
            
        # Masking Degree 
        if prob_masking[1] < ERASE_RATE:
            feature[i, 1:7] = np.zeros(6)
        elif prob_masking[1] > 1- ((1-ERASE_RATE) * 0.5):
            feature[i, 1:7] =  random_onehot(6)
        
        # Masking Num Hs
        if prob_masking[2] < ERASE_RATE:
            feature[i, 7:12] = np.zeros(5)
        elif prob_masking[2] > 1- ((1-ERASE_RATE) * 0.5):
            feature[i, 7:12] =  random_onehot(5)
            
        # Masking Valence
        if prob_masking[3] < ERASE_RATE:
            feature[i, 12:18] = np.zeros(6)
        elif prob_masking[3] > 1- ((1-ERASE_RATE) * 0.5):
            feature[i, 12:18] =  random_onehot(6)
            
        # Masking IsAromatic
        if prob_masking[4] < ERASE_RATE:
            feature[i, 18] = (feature[i, 18]+1)%2

    return feature, ground_truth, masking_indices


def postprocess_batch(mini_batch):
    max_length = max([row[0] for row in mini_batch])
    num_masking = int(max_length * MASKING_RATE)
    batch_length = len(mini_batch)
    batch_feature = np.zeros((batch_length, max_length, mini_batch[0][1].shape[1]), dtype=int)
    batch_adj = np.zeros((batch_length, max_length, max_length))
    batch_property = np.zeros((batch_length, 3))
    batch_ground = np.zeros((batch_length, num_masking, mini_batch[0][1].shape[1]), dtype=int)
    batch_masking = np.zeros((batch_length, num_masking), dtype=int)
    
    for i, row in enumerate(mini_batch):
        mol_length, feature, adj = row[0], row[1], row[2]
        masked_feature, ground_truth, masking_indices  = masking_feature(feature, num_masking)
        batch_feature[i, :mol_length, :] = masked_feature
        batch_ground[i, :, :] = ground_truth
        batch_masking[i, :] = masking_indices
        batch_adj[i, :mol_length, :mol_length] = normalize_adj(adj+np.eye(len(adj)))
        batch_property[i, :] = [row[3], row[4], row[5]]
        
    return batch_feature, batch_adj, batch_property, batch_ground, batch_masking

In [None]:
train_dataset = zincDataset(data_path=join(train_dataset_path, list_trains[0]))
print(len(train_dataset))

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=1000, collate_fn=postprocess_batch, num_workers=12)
# print(next(train_dataloader.__iter__())[3])
for batch in train_dataloader:
#     print(batch[0])
    print(5)