In [None]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np

In [None]:
def one_hot_encode_smiles(smiles, charset, max_length=120):
    char_to_int = dict((c, i) for i, c in enumerate(charset))
    integer_encoded = [char_to_int[char] for char in smiles]
    if len(integer_encoded) > max_length:
        integer_encoded = integer_encoded[:max_length]
    else:
        integer_encoded = integer_encoded + [0] * (max_length - len(integer_encoded))
    onehot_encoded = np.zeros((max_length, len(charset)), dtype=np.float32)
    for i, val in enumerate(integer_encoded):
        onehot_encoded[i, val] = 1.0

    return onehot_encoded

In [None]:
def decode_smiles_from_one_hot(one_hot_encoded, charset):

    int_to_char = {i: c for i, c in enumerate(charset)}
    integer_decoded = np.argmax(one_hot_encoded, axis=1)
    chars = [int_to_char[idx] for idx in integer_decoded]
    smiles = ''.join(chars).rstrip()

    return smiles

In [None]:
class SMILESDataset(Dataset):
    def __init__(self, smiles_list, charset):
        self.smiles_list = smiles_list
        self.charset = charset

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        encoded_smiles = one_hot_encode_smiles(smiles, self.charset)
        return torch.FloatTensor(encoded_smiles), smiles

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

class MolecularVAE(nn.Module):
    def __init__(self):
        super(MolecularVAE, self).__init__()

        self.conv_1 = nn.Conv1d(120, 9, kernel_size=9)
        self.conv_2 = nn.Conv1d(9, 9, kernel_size=9)
        self.conv_3 = nn.Conv1d(9, 10, kernel_size=11)
        self.linear_0 = nn.Linear(90, 435)
        self.linear_1 = nn.Linear(435, 292)
        self.linear_2 = nn.Linear(435, 292)

        self.linear_3 = nn.Linear(292, 292)
        self.gru = nn.GRU(292, 501, 3, batch_first=True)
        self.linear_4 = nn.Linear(501, 35)

        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()

    def encode(self, x):
        x = self.relu(self.conv_1(x))
        x = self.relu(self.conv_2(x))
        x = self.relu(self.conv_3(x))
        x = x.view(x.size(0), -1)
        x = F.selu(self.linear_0(x))
        return self.linear_1(x), self.linear_2(x)

    def sampling(self, z_mean, z_logvar):
        epsilon = 1e-2 * torch.randn_like(z_logvar)
        return torch.exp(0.5 * z_logvar) * epsilon + z_mean

    def decode(self, z):
        z = F.selu(self.linear_3(z))
        z = z.view(z.size(0), 1, z.size(-1)).repeat(1, 120, 1)
        output, hn = self.gru(z)
        out_reshape = output.contiguous().view(-1, output.size(-1))
        y0 = F.softmax(self.linear_4(out_reshape), dim=1)
        y = y0.contiguous().view(output.size(0), -1, y0.size(-1))
        return y

    def forward(self, x):
        z_mean, z_logvar = self.encode(x)
        z = self.sampling(z_mean, z_logvar)
        return self.decode(z), z_mean, z_logvar

In [None]:
df = pd.read_csv("250k_rndm_zinc_drugs_clean_3.csv")
df["smiles"] = df["smiles"].str.rstrip("\n")
charset = set("".join(df["smiles"].values.tolist()))
charset = sorted(list(charset))
charset.insert(0, " ")
dataset = SMILESDataset(df["smiles"].values.tolist(), charset)

In [None]:
len(charset)

35

In [None]:
test_ratio = 0.2

test_size = int(test_ratio * len(dataset))
train_size = len(dataset) - test_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [None]:
def vae_loss(x_decoded_mean, x, z_mean, z_logvar):
    xent_loss = F.binary_cross_entropy(x_decoded_mean, x, size_average=False)
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())
    return xent_loss + kl_loss

In [18]:
import torch
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
torch.manual_seed(42)
epochs = 60
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = MolecularVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, input_smiles) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        with autocast():
            output, mean, logvar = model(data)
            loss = vae_loss(output, data, mean, logvar)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        if batch_idx %100 == 0:
            output_smiles = decode_smiles_from_one_hot(output[0].detach().cpu().numpy(), charset)
            print(f"Input: {input_smiles[0]}")
            print(f"Output: {output_smiles}")
            print(f'Epoch {epoch} / Batch {batch_idx}\tLoss: {loss.item():.4f}')

        # Clear unnecessary tensors
        del data, output, mean, logvar, loss
        torch.cuda.empty_cache()

    avg_loss = train_loss / len(train_loader.dataset)
    print(f'Train Epoch: {epoch}\tAverage loss: {avg_loss:.4f}')
    return avg_loss

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    scheduler.step(train_loss)

    # Print memory usage
    print(f"Current GPU memory usage: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    print(f"Max GPU memory usage: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': train_loss,
    }, f'/content/drive/MyDrive/checkpoints/checkpoint_epoch_{epoch}.pt')



Input: CCc1ccc([C@H](C)NC(=O)[C@H](C)[n+]2cc(CC)ccc2C)cc1
Output: NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN
Epoch 1 / Batch 0	Loss: 140565.4688
Input: Cc1cnn(C2CCN(C(=O)CNC(=O)c3ccccc3)CC2)c1
Output: CCCCCCCCCCCCCCCcccccccccccccccccccccccccc
Epoch 1 / Batch 100	Loss: 43223.9102
Input: Cc1oc(NC(=O)[C@@H](C)n2c(N)nnc2SCC#N)c(C#N)c1C
Output: CCCCCCCCCCCCCCccccccccccccccccccccccccccccc
Epoch 1 / Batch 200	Loss: 41121.3398
Input: CCc1ccc(NC(=O)C(=O)N[C@@H]2CCSc3ccc(F)cc32)cc1
Output: CCCcccccccccccccccccccccccccccccccccccccccccc
Epoch 1 / Batch 300	Loss: 40291.3828
Input: O=S(=O)([N-]c1cccc2ccncc12)c1cn[nH]c1
Output: CCCcccccCCCCCCCCCccccccccccccccccc1
Epoch 1 / Batch 400	Loss: 39114.6797
Input: C[C@@H](O)CC(C)(C)CNC(=O)N[C@H](c1ccc(F)cc1)C1CC1
Output: CCCcccccccCCCCCCCCCCCCCCCCCCCCcccccccccccccccccc
Epoch 1 / Batch 500	Loss: 39985.4805
Input: CC[NH+](CC)CCCn1c([C@@H](C)O)nc2ccc(Cl)cc21
Output: CCCccccccccCCCCCCC