In [37]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

class MolecularVAE(nn.Module):
    def __init__(self):
        super(MolecularVAE, self).__init__()

        self.conv_1 = nn.Conv1d(120, 9, kernel_size=9)
        self.conv_2 = nn.Conv1d(9, 9, kernel_size=9)
        self.conv_3 = nn.Conv1d(9, 10, kernel_size=11)
        self.linear_0 = nn.Linear(80, 435)
        self.linear_1 = nn.Linear(435, 292)
        self.linear_2 = nn.Linear(435, 292)

        self.linear_3 = nn.Linear(292, 292)
        self.gru = nn.GRU(292, 501, 3, batch_first=True)
        self.linear_4 = nn.Linear(501, 34)

        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()

    def encode(self, x):
        x = self.relu(self.conv_1(x))
        x = self.relu(self.conv_2(x))
        x = self.relu(self.conv_3(x))
        x = x.view(x.size(0), -1)
        x = F.selu(self.linear_0(x))
        return self.linear_1(x), self.linear_2(x)

    def sampling(self, z_mean, z_logvar):
        epsilon = 1e-2 * torch.randn_like(z_logvar)
        return torch.exp(0.5 * z_logvar) * epsilon + z_mean

    def decode(self, z):
        z = F.selu(self.linear_3(z))
        z = z.view(z.size(0), 1, z.size(-1)).repeat(1, 120, 1)
        output, hn = self.gru(z)
        out_reshape = output.contiguous().view(-1, output.size(-1))
        y0 = F.softmax(self.linear_4(out_reshape), dim=1)
        y = y0.contiguous().view(output.size(0), -1, y0.size(-1))
        return y

    def forward(self, x):
        z_mean, z_logvar = self.encode(x)
        z = self.sampling(z_mean, z_logvar)
        return self.decode(z), z_mean, z_logvar

In [3]:
import pandas as pd

df = pd.read_csv('250k_rndm_zinc_drugs_clean_3.csv')
df["smiles"] = df["smiles"].str.rstrip("\n")
charset = set("".join(df["smiles"].values.tolist()))

In [4]:
import numpy as np

def one_hot_encode_smiles(smiles, charset, max_length=120):
    char_to_int = dict((c, i) for i, c in enumerate(charset))
    integer_encoded = [char_to_int[char] for char in smiles]
    if len(integer_encoded) > max_length:
        integer_encoded = integer_encoded[:max_length]
    else:
        integer_encoded = integer_encoded + [0] * (max_length - len(integer_encoded))
    onehot_encoded = np.zeros((max_length, len(charset)), dtype=np.float32)
    for i, val in enumerate(integer_encoded):
        onehot_encoded[i, val] = 1.0

    return onehot_encoded

In [5]:
x = one_hot_encode_smiles(df["smiles"][0], charset)
x

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [1., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [6]:
conv_1 = nn.Conv1d(120, 9, kernel_size=9)
conv_2 = nn.Conv1d(9, 9, kernel_size=9)
conv_3 = nn.Conv1d(9, 10, kernel_size=11)
linear_0 = nn.Linear(80, 435)
linear_1 = nn.Linear(435, 292)
linear_2 = nn.Linear(435, 292)

In [7]:
x = one_hot_encode_smiles(df["smiles"][0], charset)
x = torch.tensor([x])
print(x.shape)
x = nn.ReLU()(conv_1(x))
print(x.shape)
x = nn.ReLU()(conv_2(x))
print(x.shape)
x = nn.ReLU()(conv_3(x))
print(x.shape)
x = x.view(x.size(0), -1)
x = F.selu(linear_0(x))
z_mean = linear_1(x)
z_logvar = linear_2(x)

torch.Size([1, 120, 34])
torch.Size([1, 9, 26])
torch.Size([1, 9, 18])
torch.Size([1, 10, 8])


  x = torch.tensor([x])


In [8]:
z_mean.shape, z_logvar.shape

(torch.Size([1, 292]), torch.Size([1, 292]))

In [26]:
epsilon = 1e-2 * torch.randn_like(z_logvar)
z = torch.exp(0.5 * z_logvar) * epsilon + z_mean

In [27]:
z.shape

torch.Size([1, 292])

In [28]:
linear_3 = nn.Linear(292, 292)
gru = nn.GRU(292, 501, 3, batch_first=True)
linear_4 = nn.Linear(501, 33)

z = F.selu(linear_3(z))
print(z.shape)
z = z.view(z.size(0), 1, z.size(-1)).repeat(1, 120, 1)
print(z)
output, hn = gru(z)
print(output.shape)
out_reshape = output.contiguous().view(-1, output.size(-1))
print(out_reshape.shape)
y0 = F.softmax(linear_4(out_reshape), dim=1)
y = y0.contiguous().view(output.size(0), -1, y0.size(-1))


torch.Size([1, 292])
tensor([[[ 0.0413,  0.0164, -0.0519,  ..., -0.0442,  0.0355,  0.0573],
         [ 0.0413,  0.0164, -0.0519,  ..., -0.0442,  0.0355,  0.0573],
         [ 0.0413,  0.0164, -0.0519,  ..., -0.0442,  0.0355,  0.0573],
         ...,
         [ 0.0413,  0.0164, -0.0519,  ..., -0.0442,  0.0355,  0.0573],
         [ 0.0413,  0.0164, -0.0519,  ..., -0.0442,  0.0355,  0.0573],
         [ 0.0413,  0.0164, -0.0519,  ..., -0.0442,  0.0355,  0.0573]]],
       grad_fn=<RepeatBackward0>)
torch.Size([1, 120, 501])
torch.Size([120, 501])


In [14]:
y.shape

torch.Size([1, 120, 33])

In [29]:
def decode_smiles_from_one_hot(one_hot_encoded, charset):

    int_to_char = {i: c for i, c in enumerate(charset)}
    integer_decoded = np.argmax(one_hot_encoded, axis=1)
    chars = [int_to_char[idx] for idx in integer_decoded]
    smiles = ''.join(chars).rstrip()

    return smiles

In [34]:
smiles = decode_smiles_from_one_hot(y[0].detach().numpy(), charset)
smiles

'PP5555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555'

In [48]:
model = MolecularVAE()
input_smiles = df["smiles"][0]
x = one_hot_encode_smiles(df["smiles"][0], charset)
x = torch.tensor([x])
y, z_mean, z_logvar = model(x)
output_smiles = decode_smiles_from_one_hot(y[0].detach().numpy(), charset)
print(f"Input:{input_smiles}")
print(f"Ouptput: {output_smiles}")

Input:CC(C)(C)c1ccc2occ(CC(=O)Nc3ccccc3F)c2c1
Ouptput: rrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrrr
