In [11]:
import pandas as pd
import torch

In [3]:
df = pd.read_csv('./data/train.csv')
smiles_list = df['SMILES'].tolist()

In [14]:
# Tokenize SMILES strings (character-level )

from collections import Counter

def tokenize(smiles):
    return list(smiles)  # character-level

tokens = [token for s in smiles_list for token in tokenize(s)]
vocab = ['<pad>', '<bos>', '<eos>', '<unk>'] + sorted(set(tokens))
stoi = {ch: i for i, ch in enumerate(vocab)}
itos = {i: ch for ch, i in stoi.items()}

In [15]:
MAX_LEN = 60

def encode(smiles):
    tokens = ['<bos>'] + tokenize(smiles) + ['<eos>']
    idxs = [stoi.get(t, stoi['<unk>']) for t in tokens]
    idxs = idxs[:MAX_LEN] + [stoi['<pad>']] * (MAX_LEN - len(idxs))
    return idxs

input_tensor = torch.tensor([encode(s) for s in smiles_list])

In [16]:
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, latent_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        x = self.embedding(x)
        _, (h, _) = self.lstm(x)
        h = h[-1]  # last layer hidden state
        return self.fc_mu(h), self.fc_logvar(h)

In [17]:
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

In [18]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim, latent_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.fc = nn.Linear(latent_dim, hidden_dim)
        self.lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
        self.out = nn.Linear(hidden_dim, vocab_size)

    def forward(self, z, x):
        h = torch.tanh(self.fc(z)).unsqueeze(0)
        c = torch.zeros_like(h)
        x = self.embedding(x)
        output, _ = self.lstm(x, (h, c))
        return self.out(output)

In [19]:
class VAE(nn.Module):
    def __init__(self, vocab_size, emb_dim=128, hidden_dim=256, latent_dim=64):
        super().__init__()
        self.encoder = Encoder(vocab_size, emb_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(vocab_size, emb_dim, hidden_dim, latent_dim)

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = reparameterize(mu, logvar)
        x_recon = self.decoder(z, x[:, :-1])  # teacher forcing
        return x_recon, mu, logvar

In [20]:
def vae_loss(recon_logits, x, mu, logvar):
    recon_loss = nn.CrossEntropyLoss(ignore_index=stoi['<pad>'])(recon_logits.view(-1, recon_logits.size(-1)), x[:, 1:].contiguous().view(-1))
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)
    return recon_loss + kl_loss

In [27]:
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

model = VAE(len(vocab)).to("mps")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

dataset = TensorDataset(input_tensor)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

for epoch in range(1):
    model.train()
    total_loss = 0
    for batch, in tqdm(loader):
        batch = batch.to("mps")
        optimizer.zero_grad()
        recon_logits, mu, logvar = model(batch)
        loss = vae_loss(recon_logits, batch, mu, logvar)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch}: Loss = {total_loss / len(loader):.4f}")

  0%|          | 0/24761 [00:00<?, ?it/s]

100%|██████████| 24761/24761 [05:55<00:00, 69.68it/s]

Epoch 0: Loss = 0.6055





In [46]:
model.eval()
with torch.no_grad():
    z = torch.randn(1, 64).to("mps")
    print("Latent vector z:", z)
    start_token = torch.tensor([[stoi['<bos>']]]).to("mps")
    generated = [start_token]
    
    for _ in range(MAX_LEN):
        inp = torch.cat(generated, dim=1)
        logits = model.decoder(z, inp)
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
        generated.append(next_token)
        if next_token.item() == stoi['<eos>']:
            break

    decoded = ''.join([itos[t.item()] for t in torch.cat(generated, dim=1)[0] if t.item() not in [stoi['<bos>'], stoi['<eos>'], stoi['<pad>']]])
    print("Generated SMILES:", decoded)


Latent vector z: tensor([[ 0.5367, -1.1114,  0.3204, -2.7990, -0.4941, -1.2613,  0.6209, -0.0722,
         -1.5696,  2.1767,  0.7549, -0.8277,  1.2693, -0.4947, -0.0656,  2.2062,
         -0.6120,  0.3245, -1.0779, -0.2699,  0.0936,  1.0524, -0.3196,  0.3275,
         -0.2393, -0.8876,  0.8789,  0.9816,  0.2032,  0.2390, -0.6276,  0.1032,
         -0.9381, -0.5492, -0.6289,  0.2948, -1.2225, -0.7944, -1.4701, -0.2657,
         -0.1899,  1.2341,  0.3867, -0.0350, -2.1999, -1.1844, -1.2625,  0.9390,
         -1.6533, -0.7387,  2.9420,  0.6933,  0.0176,  0.4217, -0.6841,  0.5865,
          0.6801, -1.1861, -1.3386,  0.9958, -0.2431, -1.9837, -0.0599, -0.2664]],
       device='mps:0')
Generated SMILES: CC(C)(C)OC(=O)N1CCC(NC(=O)c2ccccc2)CC1


In [45]:
from rdkit import Chem
mol = Chem.MolFromSmiles(decoded)
if mol:
    print("Valid molecule!")
else:
    print("Invalid SMILES.")

Valid molecule!
