# Loading molecules from Training Dataset

In [1]:
import torch
from torch.utils.data import Dataset
import pandas as pd

class MoleculeDataset(Dataset):
    def __init__(self, encoded_path, properties_csv):
        """
        encoded_path: path to train_encoded.pt (tensor of token IDs)
        properties_csv: path to CSV with normalized properties
        """
        # Load encoded token sequences
        self.encoded_sequences = torch.load(encoded_path)  # shape: [num_molecules, seq_len]

        # Load property vectors
        self.props_df = pd.read_csv(properties_csv)
        # Only keep the property columns you want
        self.prop_columns = ['QED', 'SAS', 'LogP', 'TPSA', 'MolWt']
        self.properties = torch.tensor(
            self.props_df[self.prop_columns].values,
            dtype=torch.float
        )

        # Sanity check: number of sequences and properties must match
        assert len(self.encoded_sequences) == len(self.properties), \
            "Mismatch: sequences vs properties"

    def __len__(self):
        return len(self.encoded_sequences)

    def __getitem__(self, idx):
        """
        Returns:
            seq_tensor: LongTensor of token IDs (padded)
            prop_tensor: FloatTensor of property values
        """
        seq_tensor = self.encoded_sequences[idx]
        prop_tensor = self.properties[idx]
        return seq_tensor, prop_tensor


In [2]:
from torch.utils.data import DataLoader

dataset = MoleculeDataset(
    encoded_path="../data/processed_5l/train_encoded.pt",
    properties_csv="../data/processed_5l/train_properties.csv"
)

# Example: inspect first sample
seq, prop = dataset[0]
print("Token IDs:", seq)
print("Properties:", prop)

# DataLoader for batch training
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Example: one batch
for batch_seq, batch_props in dataloader:
    print(batch_seq.shape)   # [64, seq_len]
    print(batch_props.shape) # [64, 5]
    break


  self.encoded_sequences = torch.load(encoded_path)  # shape: [num_molecules, seq_len]


Token IDs: tensor([ 1, 51, 11, 51, 51, 51,  4,  7, 51, 12, 51, 61, 51, 13, 61, 51, 51, 61,
        51, 13, 61, 12,  5, 51, 51, 11, 69,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0])
Properties: tensor([0.6461, 0.1320, 0.6085, 0.0123, 0.0110])
torch.Size([64, 128])
torch.Size([64, 5])


# Conditional + unconditional Generator Definition

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Generator(nn.Module):
    def __init__(self, vocab_size, prop_dim, d_model=256, nhead=8, num_layers=4, max_len=128, dropout=0.1): 
        super().__init__()
        self.d_model = d_model
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.prop_embed = nn.Linear(prop_dim, d_model)
        self.dropout = nn.Dropout(dropout)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=512,
            batch_first=False,
            dropout=dropout  
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, props):
        src = torch.clamp(src, 0, self.token_embed.num_embeddings - 1)
        B, L = src.shape
        tok_emb = self.token_embed(src) * (self.d_model ** 0.5)
        pos = torch.arange(L, device=src.device).unsqueeze(0)
        pos_emb = self.pos_embed(pos)
        prop_emb = self.prop_embed(props).unsqueeze(1)
        
        x = tok_emb + pos_emb + prop_emb
        x = self.dropout(x) 
        x = x.transpose(0, 1)  # [seq_len, batch, dim]
        
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(L).to(src.device)
        out = self.transformer(x, mask=tgt_mask)
        
        out = out.transpose(0, 1) # Back to [batch, seq_len, dim]
        logits = self.fc_out(out)
        return logits

# Training 20 epoch trained generator

In [None]:
import torch.optim as optim
import torch, gc, os
from tqdm import tqdm

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
gc.collect()
torch.cuda.empty_cache()

vocab_size = int(torch.max(dataset.encoded_sequences)) + 1
print("Vocab size =", vocab_size)

dataset.encoded_sequences = dataset.encoded_sequences.long()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

model = Generator(vocab_size=vocab_size, prop_dim=5, max_len=128, dropout=0.1).to(device) 
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # assuming 0 is padding token

start_epoch = 0 

p_uncond = 0.1 
TOTAL_EPOCHS = 20 

if torch.max(dataset.encoded_sequences) >= vocab_size:
    print(" Token index out of range! Check vocab size.")

for epoch in range(start_epoch, TOTAL_EPOCHS):
    model.train()
    total_loss = 0
    
    batch_iterator = tqdm(
        enumerate(dataloader), 
        # --- Update description to show correct epoch numbers ---
        desc=f"Epoch {epoch+1}/{TOTAL_EPOCHS}", 
        total=len(dataloader)
    )

    for i, (seqs, props) in batch_iterator:
        seqs, props = seqs.to(device), props.to(device)

        if torch.rand(1).item() < p_uncond:
            props = torch.zeros_like(props)

        inputs = seqs[:, :-1]
        targets = seqs[:, 1:]
        
        max_token = torch.max(seqs)
        if max_token >= vocab_size:
            print(f" Token {max_token} >= vocab_size {vocab_size}")
            raise ValueError("Token index out of range!")

        logits = model(inputs, props) 
        logits = logits.reshape(-1, vocab_size)
        targets = targets.reshape(-1)
        
        loss = criterion(logits, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        running_avg_loss = total_loss / (i + 1)
        batch_iterator.set_postfix(avg_loss=f"{running_avg_loss:.4f}")
    
    avg_loss = total_loss / len(dataloader) 
    print() 

    if (epoch + 1) % 3 == 0:
        checkpoint_path = f"../results/models_5l/u&c_generator_epoch_{epoch+1}.pt"
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss
        }, checkpoint_path)
        print(f"--- Checkpoint saved to {checkpoint_path} ---")

print(" Training complete!")



Vocab size = 70
Using: cuda
‚ö†Ô∏è No checkpoint found. Starting training from scratch.


Epoch 1/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:42<00:00, 15.43it/s, avg_loss=1.1947]





Epoch 2/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:49<00:00, 15.04it/s, avg_loss=0.9701]





Epoch 3/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:49<00:00, 15.06it/s, avg_loss=0.9097]



--- Checkpoint saved to ../results/models_5l/u&c_generator_epoch_3.pt ---


Epoch 4/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:46<00:00, 15.21it/s, avg_loss=0.8737]





Epoch 5/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:57<00:00, 14.67it/s, avg_loss=0.8481]





Epoch 6/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:51<00:00, 14.95it/s, avg_loss=0.8289]



--- Checkpoint saved to ../results/models_5l/u&c_generator_epoch_6.pt ---


Epoch 7/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:47<00:00, 15.18it/s, avg_loss=0.8136]





Epoch 8/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:47<00:00, 15.17it/s, avg_loss=0.8010]





Epoch 9/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:49<00:00, 15.03it/s, avg_loss=0.7899]



--- Checkpoint saved to ../results/models_5l/u&c_generator_epoch_9.pt ---


Epoch 10/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:43<00:00, 15.36it/s, avg_loss=0.7806]





Epoch 11/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:45<00:00, 15.25it/s, avg_loss=0.7723]





Epoch 12/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:45<00:00, 15.28it/s, avg_loss=0.7647]



--- Checkpoint saved to ../results/models_5l/u&c_generator_epoch_12.pt ---


Epoch 13/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:46<00:00, 15.21it/s, avg_loss=0.7567]





Epoch 14/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:44<00:00, 15.30it/s, avg_loss=0.7502]





Epoch 15/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:45<00:00, 15.25it/s, avg_loss=0.7431]



--- Checkpoint saved to ../results/models_5l/u&c_generator_epoch_15.pt ---


Epoch 16/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:44<00:00, 15.34it/s, avg_loss=0.7369]





Epoch 17/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:43<00:00, 15.37it/s, avg_loss=0.7314]





Epoch 18/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:43<00:00, 15.39it/s, avg_loss=0.7253]



--- Checkpoint saved to ../results/models_5l/u&c_generator_epoch_18.pt ---


Epoch 19/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:43<00:00, 15.36it/s, avg_loss=0.7206]





Epoch 20/20: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:43<00:00, 15.39it/s, avg_loss=0.7162]


‚úÖ Training complete!





In [9]:
checkpoint_path="../results/models_5l/u&c_generator_epoch_20.pt"
torch.save({
            'epoch': 20,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss
        }, checkpoint_path)

# Testing 20 epoch trained generator

In [None]:
import torch
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = 70
prop_dim = 5
max_len_model = 128

model = Generator(vocab_size, prop_dim, max_len=max_len_model,dropout=0.1).to(device)

checkpoint = torch.load("../results/models_5l/u&c_generator_epoch_20.pt", map_location=device,weights_only=True) 
model.load_state_dict(checkpoint["model_state_dict"])
print(f" Loaded model from epoch {checkpoint['epoch']} with loss {checkpoint['loss']:.4f}")
model.eval()


test_props = torch.tensor([[0.6460672474469629,0.13202386890910442,0.6084964566032581,0.012271807687701365,0.011000075433875664]], dtype=torch.float32).to(device)
print(" Generating with conditions:", test_props)


props_to_use = test_props 

start_token_id = 1  
stop_token_id = 69   
max_gen_len = 128
top_k = 50

generated = torch.tensor([[start_token_id]], dtype=torch.long).to(device)

with torch.no_grad():
    for _ in range(max_gen_len):
        
        logits = model(generated, props_to_use)
        last_logits = logits[:, -1, :]

        v, _ = torch.topk(last_logits, top_k)
        last_logits[last_logits < v[:, [-1]]] = -float('Inf')

        probs = F.softmax(last_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        generated = torch.cat([generated, next_token], dim=1)
        if next_token.item() == stop_token_id:
            break

print("Generated token sequence:\n", generated.cpu().numpy().tolist()[0])

token_to_idx = {
    "#": 2, "%": 3, "(": 4, ")": 5, "+": 6, "-": 7, ".": 8, "/": 9, "0": 10, "1": 11, "2": 12, "3": 13,
    "4": 14, "5": 15, "6": 16, "7": 17, "8": 18, "9": 19, "=": 20, "@": 21, "A": 22, "B": 23, "C": 24,
    "D": 25, "E": 26, "F": 27, "G": 28, "H": 29, "I": 30, "K": 31, "L": 32, "M": 33, "N": 34, "O": 35,
    "P": 36, "R": 37, "S": 38, "T": 39, "U": 40, "V": 41, "W": 42, "X": 43, "Y": 44, "Z": 45, "[": 46,
    "\\": 47, "]": 48, "a": 49, "b": 50, "c": 51, "d": 52, "e": 53, "f": 54, "g": 55, "h": 56, "i": 57,
    "k": 58, "l": 59, "m": 60, "n": 61, "o": 62, "p": 63, "r": 64, "s": 65, "t": 66, "u": 67,
    "y": 68, "<PAD>": 0, "<START>": 1, "<END>": 69}

idx_to_token = {v: k for k, v in token_to_idx.items()}
generated_seq = generated.cpu().numpy().tolist()[0]
decoded = ''.join(idx_to_token.get(tok, '?') for tok in generated_seq if tok not in [0, 1, 68])
print(" Decoded SMILES:", decoded)

‚úÖ Loaded model from epoch 20 with loss 0.7162
üß¨ Generating with conditions: tensor([[0.6461, 0.1320, 0.6085, 0.0123, 0.0110]], device='cuda:0')
Generated token sequence:
 [1, 24, 24, 4, 24, 5, 51, 11, 51, 4, 24, 4, 20, 35, 5, 35, 5, 51, 51, 12, 51, 4, 24, 59, 5, 51, 51, 4, 24, 59, 5, 51, 4, 24, 59, 5, 51, 12, 51, 11, 24, 59, 69]
üß¨ Decoded SMILES: CC(C)c1c(C(=O)O)cc2c(Cl)cc(Cl)c(Cl)c2c1Cl<END>


# Training for 50 epochas

In [None]:
import torch.optim as optim
import torch, gc, os
from tqdm import tqdm

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
gc.collect()
torch.cuda.empty_cache()

vocab_size = int(torch.max(dataset.encoded_sequences)) + 1
print("Vocab size =", vocab_size)

dataset.encoded_sequences = dataset.encoded_sequences.long()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

model = Generator(vocab_size=vocab_size, prop_dim=5, max_len=128, dropout=0.1).to(device) 
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=0)  


checkpoint_to_load = "../results/models_5l/u&c_generator_epoch_20.pt"
start_epoch = 0

if os.path.exists(checkpoint_to_load):
    checkpoint = torch.load(checkpoint_to_load, map_location=device, weights_only=True)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    last_loss = checkpoint['loss']
    
    print(f" Resuming training from epoch {start_epoch + 1}. Last loss was {last_loss:.4f}")
else:
    print(" No checkpoint found. Starting training from scratch.")

p_uncond = 0.1 
TOTAL_EPOCHS = 50

if torch.max(dataset.encoded_sequences) >= vocab_size:
    print(" Token index out of range! Check vocab size.")

for epoch in range(start_epoch, TOTAL_EPOCHS):
    model.train()
    total_loss = 0
    
    batch_iterator = tqdm(
        enumerate(dataloader), 
        desc=f"Epoch {epoch+1}/{TOTAL_EPOCHS}", 
        total=len(dataloader)
    )

    for i, (seqs, props) in batch_iterator:
        seqs, props = seqs.to(device), props.to(device)

        if torch.rand(1).item() < p_uncond:
            props = torch.zeros_like(props)

        inputs = seqs[:, :-1]
        targets = seqs[:, 1:]
        
        max_token = torch.max(seqs)
        if max_token >= vocab_size:
            print(f" Token {max_token} >= vocab_size {vocab_size}")
            raise ValueError("Token index out of range!")

        logits = model(inputs, props) 
        logits = logits.reshape(-1, vocab_size)
        targets = targets.reshape(-1)
        
        loss = criterion(logits, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        running_avg_loss = total_loss / (i + 1)
        batch_iterator.set_postfix(avg_loss=f"{running_avg_loss:.4f}")
    
    avg_loss = total_loss / len(dataloader) 
    print()

    # save checkpoint every 3 epochs
    if (epoch + 1) % 3 == 0:
        checkpoint_path = f"../results/models_5l/u&c_generator_epoch_{epoch+1}.pt"
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss
        }, checkpoint_path)
        print(f"--- Checkpoint saved to {checkpoint_path} ---")

print(" Training complete!")

Vocab size = 70
Using: cuda




‚úÖ Resuming training from epoch 21. Last loss was 0.7162


Epoch 21/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:55<00:00, 14.73it/s, avg_loss=0.7124]



--- Checkpoint saved to ../results/models_5l/u&c_generator_epoch_21.pt ---


Epoch 22/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [05:18<00:00, 13.67it/s, avg_loss=0.7077]





Epoch 23/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [05:19<00:00, 13.66it/s, avg_loss=0.7045]





Epoch 24/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:52<00:00, 14.89it/s, avg_loss=0.7015]



--- Checkpoint saved to ../results/models_5l/u&c_generator_epoch_24.pt ---


Epoch 25/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:52<00:00, 14.90it/s, avg_loss=0.6974]





Epoch 26/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:53<00:00, 14.85it/s, avg_loss=0.6940]





Epoch 27/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:55<00:00, 14.75it/s, avg_loss=0.6924]



--- Checkpoint saved to ../results/models_5l/u&c_generator_epoch_27.pt ---


Epoch 28/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:49<00:00, 15.08it/s, avg_loss=0.6890]





Epoch 29/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:47<00:00, 15.15it/s, avg_loss=0.6865]





Epoch 30/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:45<00:00, 15.24it/s, avg_loss=0.6839]



--- Checkpoint saved to ../results/models_5l/u&c_generator_epoch_30.pt ---


Epoch 31/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:45<00:00, 15.26it/s, avg_loss=0.6817]





Epoch 32/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:45<00:00, 15.24it/s, avg_loss=0.6789]





Epoch 33/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:47<00:00, 15.14it/s, avg_loss=0.6773]



--- Checkpoint saved to ../results/models_5l/u&c_generator_epoch_33.pt ---


Epoch 34/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:52<00:00, 14.91it/s, avg_loss=0.6745]





Epoch 35/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:52<00:00, 14.92it/s, avg_loss=0.6725]





Epoch 36/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:49<00:00, 15.07it/s, avg_loss=0.6708]



--- Checkpoint saved to ../results/models_5l/u&c_generator_epoch_36.pt ---


Epoch 37/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:46<00:00, 15.21it/s, avg_loss=0.6687]





Epoch 38/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:47<00:00, 15.16it/s, avg_loss=0.6674]





Epoch 39/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:48<00:00, 15.13it/s, avg_loss=0.6653]



--- Checkpoint saved to ../results/models_5l/u&c_generator_epoch_39.pt ---


Epoch 40/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:47<00:00, 15.17it/s, avg_loss=0.6632]





Epoch 41/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:47<00:00, 15.18it/s, avg_loss=0.6623]





Epoch 42/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:46<00:00, 15.19it/s, avg_loss=0.6604]



--- Checkpoint saved to ../results/models_5l/u&c_generator_epoch_42.pt ---


Epoch 43/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:47<00:00, 15.19it/s, avg_loss=0.6589]





Epoch 44/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:46<00:00, 15.22it/s, avg_loss=0.6574]





Epoch 45/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:50<00:00, 14.99it/s, avg_loss=0.6563]



--- Checkpoint saved to ../results/models_5l/u&c_generator_epoch_45.pt ---


Epoch 46/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:49<00:00, 15.04it/s, avg_loss=0.6546]





Epoch 47/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:48<00:00, 15.09it/s, avg_loss=0.6521]





Epoch 48/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:45<00:00, 15.24it/s, avg_loss=0.6513]



--- Checkpoint saved to ../results/models_5l/u&c_generator_epoch_48.pt ---


Epoch 49/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:45<00:00, 15.25it/s, avg_loss=0.6481]





Epoch 50/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4359/4359 [04:45<00:00, 15.25it/s, avg_loss=0.6450]


‚úÖ Training complete!





In [6]:
checkpoint_path="../results/models_5l/u&c_generator_epoch_50.pt"
torch.save({
            'epoch': 50,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss
        }, checkpoint_path)

# Testing 50 epoch trained generator

In [None]:
import torch
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = 70
prop_dim = 5
max_len_model = 128

model = Generator(vocab_size, prop_dim, max_len=max_len_model,dropout=0.1).to(device)

checkpoint = torch.load("../results/models_5l/u&c_generator_epoch_50.pt", map_location=device,weights_only=True) 
model.load_state_dict(checkpoint["model_state_dict"])
print(f" Loaded model from epoch {checkpoint['epoch']} with loss {checkpoint['loss']:.4f}")
model.eval()

test_props = torch.tensor([[0.6460672474469629,0.13202386890910442,0.6084964566032581,0.012271807687701365,0.011000075433875664]], dtype=torch.float32).to(device)
print(" Generating with conditions:", test_props)

props_to_use = test_props 

start_token_id = 1  # <START> token ID
stop_token_id = 69   # <END> token ID
max_gen_len = 128
top_k = 50

generated = torch.tensor([[start_token_id]], dtype=torch.long).to(device)

with torch.no_grad():
    for _ in range(max_gen_len):
        # Model only needs to see the current sequence and target properties
        logits = model(generated, props_to_use)
        
        # Get logits for the *last* token only
        last_logits = logits[:, -1, :] # Shape: [batch, vocab_size]

        v, _ = torch.topk(last_logits, top_k)
        last_logits[last_logits < v[:, [-1]]] = -float('Inf')

        probs = F.softmax(last_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        
        # 3. Append the new token
        generated = torch.cat([generated, next_token], dim=1)
        
        # 4. Stop if we hit the <END> token
        if next_token.item() == stop_token_id: 
            break

print("Generated token sequence:\n", generated.cpu().numpy().tolist()[0])

token_to_idx = {
    "#": 2, "%": 3, "(": 4, ")": 5, "+": 6, "-": 7, ".": 8, "/": 9, "0": 10, "1": 11, "2": 12, "3": 13,
    "4": 14, "5": 15, "6": 16, "7": 17, "8": 18, "9": 19, "=": 20, "@": 21, "A": 22, "B": 23, "C": 24,
    "D": 25, "E": 26, "F": 27, "G": 28, "H": 29, "I": 30, "K": 31, "L": 32, "M": 33, "N": 34, "O": 35,
    "P": 36, "R": 37, "S": 38, "T": 39, "U": 40, "V": 41, "W": 42, "X": 43, "Y": 44, "Z": 45, "[": 46,
    "\\": 47, "]": 48, "a": 49, "b": 50, "c": 51, "d": 52, "e": 53, "f": 54, "g": 55, "h": 56, "i": 57,
    "k": 58, "l": 59, "m": 60, "n": 61, "o": 62, "p": 63, "r": 64, "s": 65, "t": 66, "u": 67,
    "y": 68, "<PAD>": 0, "<START>": 1, "<END>": 69}

idx_to_token = {v: k for k, v in token_to_idx.items()}
generated_seq = generated.cpu().numpy().tolist()[0]
decoded = ''.join(idx_to_token.get(tok, '?') for tok in generated_seq if tok not in [0, 1, 68])
print("üß¨ Decoded SMILES:", decoded)



‚úÖ Loaded model from epoch 50 with loss 0.6450
üß¨ Generating with conditions: tensor([[0.6461, 0.1320, 0.6085, 0.0123, 0.0110]], device='cuda:0')
Generated token sequence:
 [1, 38, 20, 38, 4, 20, 35, 5, 4, 35, 5, 34, 51, 11, 51, 51, 51, 4, 34, 24, 51, 12, 51, 51, 51, 51, 51, 12, 5, 51, 51, 11, 69]
üß¨ Decoded SMILES: S=S(=O)(O)Nc1ccc(NCc2ccccc2)cc1<END>
