In [14]:
import re, sys, os
import pandas as pd
import numpy as np
import json
import torch
import torch.nn as nn
import warnings
from rdkit import Chem, DataStructs, RDLogger
from rdkit.Chem import rdChemReactions, AllChem, Draw, PandasTools
from torch.utils.data import Dataset, DataLoader

RDLogger.DisableLog('rdApp.*')
warnings.filterwarnings('ignore')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Load

In [16]:
df_trainval = pd.read_parquet("../data/transformed/trainval_dataset_augmented_encoded.parquet")

#### Train / validate

In [18]:
def decode_show_special(ids, itos):
    chars = [itos[i] for i in ids]  # do NOT filter
    return "".join(chars)

class MolDataset(Dataset):
    def __init__(self, df, stoi, max_len=128):
        self.src = [encode(s, stoi, max_len) for s in df['mol_smi']]
        self.tgt_in = df['tgt_in'].tolist()
        self.tgt_out = df['tgt_out'].tolist()
    def __len__(self):
        return len(self.src)
    def __getitem__(self, idx):
        return torch.tensor(self.src[idx]), torch.tensor(self.tgt_in[idx]), torch.tensor(self.tgt_out[idx])

In [19]:
# split trainval
from sklearn.model_selection import train_test_split
df_train, df_val = train_test_split(df_trainval, test_size=0.2, random_state=42)

train_dataset = MolDataset(df_train, stoi)
val_dataset   = MolDataset(df_val, stoi)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=32)

In [20]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model=384, n_layers=2, dropout=0.1):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.rnn = nn.GRU(d_model, d_model, num_layers=n_layers, batch_first=True, dropout=dropout)
    def forward(self, x):
        x = self.emb(x)
        out, h = self.rnn(x)
        return out, h

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model=384, n_layers=2, dropout=0.1):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.rnn = nn.GRU(d_model, d_model, num_layers=n_layers, batch_first=True, dropout=dropout)
        self.proj = nn.Linear(d_model, vocab_size)
    def forward(self, y_in, h):
        y = self.emb(y_in)
        out, h = self.rnn(y, h)
        logits = self.proj(out)
        return logits, h

class Seq2Seq(nn.Module):
    def __init__(self, vocab_size, d_model=384, n_layers=2, dropout=0.1):
        super().__init__()
        self.enc = Encoder(vocab_size, d_model, n_layers, dropout)
        self.dec = Decoder(vocab_size, d_model, n_layers, dropout)
    def forward(self, src, tgt_in):
        _, h = self.enc(src)
        logits, _ = self.dec(tgt_in, h)
        return logits

In [21]:
vocab_size = len(stoi)
model = Seq2Seq(vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=pad_id)  # ignore padding in loss

In [22]:
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for src, tgt_in, tgt_out in loader:
            src, tgt_in, tgt_out = src.to(device), tgt_in.to(device), tgt_out.to(device)
            logits = model(src, tgt_in)
            loss = criterion(logits.view(-1, logits.size(-1)), tgt_out.view(-1))
            total_loss += loss.item() * src.size(0)
    return total_loss / len(loader.dataset)

In [24]:
n_epochs = 20
checkpoint_path = "../models/seq2seq_gru_bbs.pt"

# --- Resume from checkpoint if exists
start_epoch = 0
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming from epoch {start_epoch}")

In [25]:
# --- Training loop
for epoch in range(start_epoch, start_epoch+n_epochs):
    model.train()
    total_train_loss = 0

    for src, tgt_in, tgt_out in train_loader:
        src, tgt_in, tgt_out = src.to(device), tgt_in.to(device), tgt_out.to(device)
        optimizer.zero_grad()
        logits = model(src, tgt_in)
        loss = criterion(logits.view(-1, logits.size(-1)), tgt_out.view(-1))
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item() * src.size(0)

    train_loss = total_train_loss / len(train_loader.dataset)
    
    # --- Compute validation loss
    val_loss = validate(model, val_loader, criterion, device)
    
    # --- Save checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'stoi': stoi,
        'itos': itos
    }, checkpoint_path)
    
    print(f"Epoch {epoch:02d} | train loss {train_loss:.4f} | val loss {val_loss:.4f}")

Epoch 00 | train loss 0.2706 | val loss 0.1378
Epoch 01 | train loss 0.1252 | val loss 0.1084
Epoch 02 | train loss 0.0992 | val loss 0.0894
Epoch 03 | train loss 0.0823 | val loss 0.0752
Epoch 04 | train loss 0.0715 | val loss 0.0705
Epoch 05 | train loss 0.0657 | val loss 0.0635
Epoch 06 | train loss 0.0611 | val loss 0.0597
Epoch 07 | train loss 0.0576 | val loss 0.0571
Epoch 08 | train loss 0.0548 | val loss 0.0529
Epoch 09 | train loss 0.0523 | val loss 0.0519
Epoch 10 | train loss 0.0495 | val loss 0.0461
Epoch 11 | train loss 0.0467 | val loss 0.0464
Epoch 12 | train loss 0.0463 | val loss 0.0461
Epoch 13 | train loss 0.0433 | val loss 0.0428
Epoch 14 | train loss 0.0439 | val loss 0.0417
Epoch 15 | train loss 0.0414 | val loss 0.0407
Epoch 16 | train loss 0.0411 | val loss 0.0409
Epoch 17 | train loss 0.0398 | val loss 0.0393
Epoch 18 | train loss 0.0401 | val loss 0.0441
Epoch 19 | train loss 0.0390 | val loss 0.0370


In [None]:
torch.save({
        'epoch': 90,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'stoi': stoi,
        'itos': itos
    }, checkpoint_path)