In [14]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import rdkit 
import multiprocessing
import copy
import math 
import random
import pickle 
import utils 
import model 
from utils import parallel_f, get_mol, replace_atom, tokenize, pad, MyDataset, get_dic
from model import TransformerVAE, device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   

In [None]:
def get_smi_list(path) :
    with open(path, 'r') as file :
        return [smi[:-1] for smi in file.readlines()]
    

def replace_atom(smi) :
    return smi.replace('Cl', 'L').replace('Br', 'R') 

def get_mol(smi) :
    return rdkit.Chem.MolFromSmiles(smi)

def parallel_f(f, input_list) :
    pool = multiprocessing.Pool()
    return pool.map(f, input_list)

def get_dic(smi_list) :
    dic = {'<START>': 0, '<END>': 1, '<PAD>': 2}
    for smi in smi_list :
        for char in smi :
            if char not in dic :
                dic[char] = len(dic) 
    return dic 

def tokenize(smi) :
    return [0] + [smi_dic[char] for char in smi] + [1]

def pad(smi) :
    return smi + [2] * (max_len - len(smi))

def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

def subsequent_mask(size):
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    )
    return subsequent_mask == 0


def get_mask(target) :
    mask = (target != smi_dic['<PAD>']).unsqueeze(-2)
    return mask & subsequent_mask(target.size(-1)).type_as(mask.data)


class MyDataset(torch.utils.data.Dataset) :
    def __init__(self, token_list) :
        self.token_list = token_list

    def __len__(self) :
        return len(self.token_list)

    def __getitem__(self, idx) :   
        return torch.tensor(self.token_list[idx], dtype=torch.long)

In [7]:
# Load ChEMBL dataset

with open('data/chembl24_canon_train.pickle', 'rb') as file :
    smi_list = pickle.load(file) 


smi_list = [smi for smi in smi_list if len(smi) < 40] # Choose only smiles with length < 40

print(f'Number of data: {len(smi_list)}')

Number of data: 373814


In [9]:
# Preprocess data
# parallel_f is a function for multiprocessing on CPU to speed up the process

mol_list = parallel_f(get_mol, smi_list)

smi_list = parallel_f(replace_atom, smi_list)

smi_dic = get_dic(smi_list)


inv_dic = {v:k for k, v in smi_dic.items()}

token_list = parallel_f(tokenize, smi_list)
max_len = len(max(token_list, key=len))
token_list = parallel_f(pad, token_list)


In [10]:
BATCH_SIZE = 64

dataset = MyDataset(token_list)
train_set, val_set = random_split(dataset, [0.95, 0.05])
train_loader = DataLoader(train_set, batch_size = BATCH_SIZE, shuffle = True)
val_loader = DataLoader(val_set, batch_size = BATCH_SIZE, shuffle = True)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)
    

    
class Attention(nn.Module) :
    def __init__(self, dim_model, num_head) :
        super(Attention, self).__init__()
        assert dim_model % num_head == 0, 'dim_model % num_head != 0'
        self.dim_model = dim_model
        self.num_head = num_head
        self.dim_head = dim_model // num_head

        self.Q = nn.Linear(dim_model, dim_model)
        self.K = nn.Linear(dim_model, dim_model)
        self.V = nn.Linear(dim_model, dim_model)

        self.out = nn.Linear(dim_model, dim_model)

    def forward(self, Q, K, V, mask = None) :
        B = Q.size(0) 

        Q, K, V = self.Q(Q), self.K(K), self.V(V)

        len_Q, len_K, len_V = Q.size(1), K.size(1), V.size(1)

        Q = Q.reshape(B, self.num_head, len_Q, self.dim_head)
        K = K.reshape(B, self.num_head, len_K, self.dim_head)
        V = V.reshape(B, self.num_head, len_V, self.dim_head)
        
        K_T = K.transpose(2,3).contiguous()

        attn_score = Q @ K_T

        attn_score = attn_score / (self.dim_head ** 1/2)
        if mask is not None :
            attn_score = attn_score.masked_fill(mask == 0, -1e9)

        attn_distribution = torch.softmax(attn_score, dim = -1)

        attn = attn_distribution @ V

        attn = attn.reshape(B, len_Q, self.num_head * self.dim_head)
        
        attn = self.out(attn)

        return attn, attn_distribution
    

class FirstEncoder(nn.Module) :
    def __init__(self, dim_model, dim_latent, num_head, dropout, vocab_size) :
        super(FirstEncoder, self).__init__()
        self.dim_model = dim_model

        self.embed = nn.Embedding(vocab_size, dim_model)
        self.pos = PositionalEncoding(dim_model, dropout) 

        self.norm1 = nn.LayerNorm(dim_model)
        self.drop1 = nn.Dropout(dropout)
        self.self_attn = Attention(dim_model, num_head) 

        self.norm2 = nn.LayerNorm(dim_model)
        self.feed_foward = nn.Sequential(
            nn.Linear(dim_model, dim_model),
            nn.LeakyReLU(),
            nn.Dropout(),
            nn.Linear(dim_model, dim_model)
        )
        self.drop2 = nn.Dropout(dropout)

        self.norm3 = nn.LayerNorm(dim_model)

        self.mu, self.sigma = nn.Linear(dim_model, dim_latent), nn.Linear(dim_model, dim_latent)
        self.N = torch.distributions.Normal(0, 1)
        self.N.loc = self.N.loc.cuda() 
        self.N.scale = self.N.scale.cuda()
        self.kl = 0

    def forward(self, x) :
        x = self.embed(x) * (self.dim_model ** 0.5) 
        x = self.pos(x)

        x = self.norm1(x)
        attn, self_attn = self.self_attn(x, x, x)
        x = x + self.drop1(attn)

        x = self.norm2(x)
        x = self.feed_foward(x)
        x = x + self.drop2(x)   

        x = self.norm3(x)

        mu, sigma = self.mu(x), torch.exp(self.sigma(x))
        z = mu + sigma * self.N.sample(mu.shape)
        self.kl = (sigma ** 2 + mu ** 2 - torch.log(sigma) - 1/2).sum()

        return z



class SecondEncoder(nn.Module) :
    def __init__(self, dim_model,  dim_latent, num_head, dropout) :
        super(SecondEncoder, self).__init__()

        self.norm1 = nn.LayerNorm(dim_latent)
        self.drop1 = nn.Dropout(dropout)
        self.self_attn = Attention(dim_latent, num_head) 

        self.norm2 = nn.LayerNorm(dim_latent)
        self.feed_foward = nn.Sequential(
            nn.Linear(dim_latent, dim_model),
            nn.LeakyReLU(),
            nn.Dropout(),
            nn.Linear(dim_model, dim_model)
        )
        self.drop2 = nn.Dropout(dropout)

        self.norm3 = nn.LayerNorm(dim_model)
    
    def forward(self, z) :
        z = self.norm1(z)
        attn, self_attn = self.self_attn(z, z, z)
        z = z + self.drop1(attn)

        z = self.norm2(z)
        z = self.feed_foward(z)
        z = z + self.drop2(z)   

        z = self.norm3(z)

        return z
class Decoder(nn.Module) :
    def __init__(self, dim_model, num_head, dropout, vocab_size) :
        super(Decoder, self).__init__()

        self.dim_model = dim_model
        
        self.embed = nn.Embedding(vocab_size, dim_model)
        self.pos = PositionalEncoding(dim_model, dropout)

        self.norm1 = nn.LayerNorm(dim_model)   
        self.self_attn = Attention(dim_model, num_head)
        self.drop1 = nn.Dropout(dropout)

        self.norm2 = nn.LayerNorm(dim_model)
        self.cross_attn = Attention(dim_model, num_head)
        self.drop2 = nn.Dropout(dropout)

        self.norm3 = nn.LayerNorm(dim_model)
        self.feed_foward = nn.Sequential(
            nn.Linear(dim_model, dim_model),
            nn.LeakyReLU(),
            nn.Dropout(),
            nn.Linear(dim_model, dim_model),
            nn.LeakyReLU()
        )
        self.drop3 = nn.Dropout(dropout)
        
        self.norm4 = nn.LayerNorm(dim_model)

        self.proj = nn.Linear(dim_model, vocab_size)

    def forward(self, memory, target) :
        mask = get_mask(target)
        mask = mask.unsqueeze(1).to(device)

        target = self.embed(target) * (self.dim_model ** 0.5)
        target = self.pos(target)

        target = self.norm1(target)
        attn, self_attn = self.self_attn(target, target, target, mask)
        target = target + self.drop1(attn)

        target = self.norm2(target)
        attn, cross_attn = self.cross_attn(target, memory, memory)
        target = target + self.drop2(attn)

        target = self.norm3(target)
        target = self.feed_foward(target)
        target = target + self.drop3(target)

        target = self.norm4(target)

        target = self.proj(target)
        target = F.log_softmax(target, dim = -1)

        return target
    

class TransformerVAE(nn.Module) :
    def __init__(self, dim_model, dim_latent, num_head, dropout, vocab_size) :
        super(TransformerVAE, self).__init__()

        self.first_encoder = FirstEncoder(dim_model, dim_latent, num_head, dropout, vocab_size)
        self.pos = PositionalEncoding(dim_latent, dropout)
        self.second_encoder = SecondEncoder(dim_model, dim_latent, num_head, dropout)
        self.decoder = Decoder(dim_model, num_head, dropout, vocab_size)

    def forward(self, x, target) :
        z = self.first_encoder(x) 
        z = self.pos(z)

        memory = self.second_encoder(z)
        target = self.decoder(memory, target)

        return target 

In [12]:
model = TransformerVAE(256, 64, 8, 0.5, len(smi_dic)).to(device)

loss_fn = nn.NLLLoss()
optim = torch.optim.Adam(model.parameters(), lr = 0.001) 

NUM_EPOCH = 1

In [13]:
for epoch in range(1, NUM_EPOCH + 1) :
    train_loss = 0
    val_loss = 0
    model.train()
    for input in train_loader :
        input = input.to(device)
        output = model(input, input[:, :-1])

        loss = loss_fn(output.reshape(-1, len(smi_dic)), input[:, 1:].reshape(-1)) + model.first_encoder.kl
        train_loss += loss.item()
        loss.backward()
        optim.step()
        optim.zero_grad()

    # model.eval()
    # for input in val_loader :
    #     input = input.to(device)
    #     output = model(input, input[:, :-1])
    #     loss = loss_fn(output.reshape(-1, len(smi_dic)), input[:, 1:].reshape(-1)) + model.first_encoder.kl
    #     val_loss += loss.item()
    print(f'epoch : {epoch}, train loss : {train_loss / len(train_loader)} val loss : {val_loss / len(val_loader)}')


NameError: name 'smi_dic' is not defined

In [27]:
model.eval()
encoder = model.second_encoder
decoder = model.decoder


for _ in range(30) :
    z = torch.randn(1, max_len, 128).to(device)
    target = torch.zeros(1, 1, dtype=torch.long).to(device)

    for i in range(max_len - 1) :
        memory = encoder(z) 
        out = decoder(memory, target)
        _, idx = torch.topk(out, 1, dim = -1)
        idx = idx[:, -1, :]
        target = torch.cat([target, idx], dim = 1)

    target = target.squeeze(0).tolist()
    smiles = ''.join([inv_dic[i] for i in target])
    print(smiles)



<START>CCC(C(=CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC
<START>CCCCCCC(=O)NC(=O)NC(=O)NC(C)C(C)C)cc1)cccc1)C(=O)NCCC(C(=O)NCCCCC(CCC)C)C)CCCCCCCCCCCCCCCCCCCCCCCCCCC
<START>CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC
<START>CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC
<START>CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC
<START>CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC
<START>CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC
<START>CCCCCCC(=O)NC(=O)NC(=O)NC(C)C)cc1)C(=O)NC(CC)C)C)cc1ccc1<END><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD><PAD

<START>CCC(=O)cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc
<START>CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC
<START>CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC
<START>CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC
<START>CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC
<START>CCCCCCC(=O)NC(=O)NC(=O)NC(C)C(C)C)cc1)cccc1)C(=O)NCCCC(CCC)C)CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC
<START>CCCCCC(F)c1cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc
<START>CCC(C(=CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC
<START>CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC
<START>CCCCCCCCCCCC

In [88]:
a = get_mol('C(=O)NC(=O)NC(=O)NC(=O)NC(=O)NC(=O)N')