In [9]:
import re, sys, os
import pandas as pd
import numpy as np
import json
import torch
import torch.nn as nn
import warnings
import pickle
from rdkit import Chem, DataStructs, RDLogger
from rdkit.Chem import rdChemReactions, AllChem, Draw, PandasTools
from torch.utils.data import Dataset, DataLoader

RDLogger.DisableLog('rdApp.*')
warnings.filterwarnings('ignore')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Load

In [10]:
df_trainval = pd.read_parquet("../data/transformed/trainval_dataset_augmented_encoded.parquet")
with open("../tokens/stoi.pkl", "rb") as f:
    stoi = pickle.load(f)

with open("../tokens/itos.pkl", "rb") as f:
    itos = pickle.load(f)

vocab_size = len(stoi)

#### Train / validate

In [15]:
def encode(smiles, stoi, max_len=128):
    # returns list of token IDs with <s> at start, </s> at end, padded
    ids = [stoi["<s>"]] + [stoi[ch] for ch in smiles] + [stoi["</s>"]]
    if len(ids) < max_len:
        ids += [stoi["<pad>"]] * (max_len - len(ids))
    return ids[:max_len]

def decode(ids, itos):
    # returns string ignoring special tokens
    chars = [itos[i] for i in ids if itos[i] not in ("<pad>", "<s>", "</s>")]
    return "".join(chars)

def encode_target(smi, stoi, max_len=128):
    full_ids = encode(smi, stoi, max_len)   # <s> ... </s> + pad
    tgt_in = full_ids[:-1]                  # decoder input (drop final </s>)
    tgt_out = full_ids[1:]                  # loss target (drop initial <s>)
    return tgt_in, tgt_out
    
def decode_show_special(ids, itos):
    chars = [itos[i] for i in ids]  # do NOT filter
    return "".join(chars)

class MolDataset(Dataset):
    def __init__(self, df, stoi, max_len=128):
        self.src = [encode(s, stoi, max_len) for s in df['mol_smi']]
        self.tgt_in = df['tgt_in'].tolist()
        self.tgt_out = df['tgt_out'].tolist()
    def __len__(self):
        return len(self.src)
    def __getitem__(self, idx):
        return torch.tensor(self.src[idx]), torch.tensor(self.tgt_in[idx]), torch.tensor(self.tgt_out[idx])

In [16]:
# split trainval
from sklearn.model_selection import train_test_split
df_train, df_val = train_test_split(df_trainval, test_size=0.2, random_state=42)

train_dataset = MolDataset(df_train, stoi)
val_dataset   = MolDataset(df_val, stoi)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=32)

In [17]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model=384, n_layers=2, dropout=0.1):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.rnn = nn.GRU(d_model, d_model, num_layers=n_layers, batch_first=True, dropout=dropout)
    def forward(self, x):
        x = self.emb(x)
        out, h = self.rnn(x)
        return out, h

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model=384, n_layers=2, dropout=0.1):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.rnn = nn.GRU(d_model, d_model, num_layers=n_layers, batch_first=True, dropout=dropout)
        self.proj = nn.Linear(d_model, vocab_size)
    def forward(self, y_in, h):
        y = self.emb(y_in)
        out, h = self.rnn(y, h)
        logits = self.proj(out)
        return logits, h

class Seq2Seq(nn.Module):
    def __init__(self, vocab_size, d_model=384, n_layers=2, dropout=0.1):
        super().__init__()
        self.enc = Encoder(vocab_size, d_model, n_layers, dropout)
        self.dec = Decoder(vocab_size, d_model, n_layers, dropout)
    def forward(self, src, tgt_in):
        _, h = self.enc(src)
        logits, _ = self.dec(tgt_in, h)
        return logits

In [19]:
pad_id = stoi["<pad>"]
vocab_size = len(stoi)
model = Seq2Seq(vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=pad_id)  # ignore padding in loss

In [20]:
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for src, tgt_in, tgt_out in loader:
            src, tgt_in, tgt_out = src.to(device), tgt_in.to(device), tgt_out.to(device)
            logits = model(src, tgt_in)
            loss = criterion(logits.view(-1, logits.size(-1)), tgt_out.view(-1))
            total_loss += loss.item() * src.size(0)
    return total_loss / len(loader.dataset)

In [21]:
n_epochs = 20
checkpoint_path = "../models/seq2seq_gru_bbs.pt"

# --- Resume from checkpoint if exists
start_epoch = 0
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming from epoch {start_epoch}")

Resuming from epoch 20


In [22]:
# --- Training loop
for epoch in range(start_epoch, start_epoch+n_epochs):
    model.train()
    total_train_loss = 0

    for src, tgt_in, tgt_out in train_loader:
        src, tgt_in, tgt_out = src.to(device), tgt_in.to(device), tgt_out.to(device)
        optimizer.zero_grad()
        logits = model(src, tgt_in)
        loss = criterion(logits.view(-1, logits.size(-1)), tgt_out.view(-1))
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item() * src.size(0)

    train_loss = total_train_loss / len(train_loader.dataset)
    
    # --- Compute validation loss
    val_loss = validate(model, val_loader, criterion, device)
    
    # --- Save checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'stoi': stoi,
        'itos': itos
    }, checkpoint_path)
    
    print(f"Epoch {epoch:02d} | train loss {train_loss:.4f} | val loss {val_loss:.4f}")

Epoch 20 | train loss 0.0372 | val loss 0.0369
Epoch 21 | train loss 0.0373 | val loss 0.0377
Epoch 22 | train loss 0.0399 | val loss 0.0370
Epoch 23 | train loss 0.0368 | val loss 0.0363
Epoch 24 | train loss 0.0370 | val loss 0.0360
Epoch 25 | train loss 0.0350 | val loss 0.0364
Epoch 26 | train loss 0.0352 | val loss 0.0350
Epoch 27 | train loss 0.0366 | val loss 0.0386
Epoch 28 | train loss 0.0347 | val loss 0.0354
Epoch 29 | train loss 0.0337 | val loss 0.0330
Epoch 30 | train loss 0.0363 | val loss 0.0363
Epoch 31 | train loss 0.0352 | val loss 0.0348
Epoch 32 | train loss 0.0339 | val loss 0.0380
Epoch 33 | train loss 0.0334 | val loss 0.0339
Epoch 34 | train loss 0.0320 | val loss 0.0318
Epoch 35 | train loss 0.0320 | val loss 0.0355
Epoch 36 | train loss 0.0331 | val loss 0.0318
Epoch 37 | train loss 0.0326 | val loss 0.0324
Epoch 38 | train loss 0.0318 | val loss 0.0323
Epoch 39 | train loss 0.0316 | val loss 0.0314
