In [1]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
from rdkit import Chem, rdBase
from rdkit.Chem import Descriptors
from src import sascorer

In [None]:
rdBase.DisableLog('rdApp.error')

In [4]:
class Generator(nn.Module):
    def __init__(self, vocab_size, prop_dim, d_model=256, nhead=8, num_layers=4, max_len=128, dropout=0.1): 
        super().__init__()
        self.d_model = d_model
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.prop_embed = nn.Linear(prop_dim, d_model)
        self.dropout = nn.Dropout(dropout)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=512,
            batch_first=False,
            dropout=dropout  
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, props):
        src = torch.clamp(src, 0, self.token_embed.num_embeddings - 1)
        B, L = src.shape
        tok_emb = self.token_embed(src) * (self.d_model ** 0.5)
        pos = torch.arange(L, device=src.device).unsqueeze(0)
        pos_emb = self.pos_embed(pos)
        prop_emb = self.prop_embed(props).unsqueeze(1)
        
        x = tok_emb + pos_emb + prop_emb
        x = self.dropout(x) 
        x = x.transpose(0, 1)  
        
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(L).to(src.device)
        out = self.transformer(x, mask=tgt_mask)
        
        out = out.transpose(0, 1) 
        logits = self.fc_out(out)
        return logits


In [5]:
def get_token_maps():
    """Returns the token-to-ID and ID-to-token maps."""
    token_to_idx = {
    "#": 2, "%": 3, "(": 4, ")": 5, "+": 6, "-": 7, ".": 8, "/": 9, "0": 10, "1": 11, "2": 12, "3": 13,
    "4": 14, "5": 15, "6": 16, "7": 17, "8": 18, "9": 19, "=": 20, "@": 21, "A": 22, "B": 23, "C": 24,
    "D": 25, "E": 26, "F": 27, "G": 28, "H": 29, "I": 30, "K": 31, "L": 32, "M": 33, "N": 34, "O": 35,
    "P": 36, "R": 37, "S": 38, "T": 39, "U": 40, "V": 41, "W": 42, "X": 43, "Y": 44, "Z": 45, "[": 46,
    "\\": 47, "]": 48, "a": 49, "b": 50, "c": 51, "d": 52, "e": 53, "f": 54, "g": 55, "h": 56, "i": 57,
    "k": 58, "l": 59, "m": 60, "n": 61, "o": 62, "p": 63, "r": 64, "s": 65, "t": 66, "u": 67,
    "y": 68, "<PAD>": 0, "<START>": 1, "<END>": 69}
    idx_to_token = {v: k for k, v in token_to_idx.items()}
    return token_to_idx, idx_to_token

In [None]:
def load_model(checkpoint_path, device):
    """Loads the generator model from a checkpoint."""
    # Model parameters
    vocab_size = 70
    prop_dim = 5
    max_len_model = 128
    
    model = Generator(vocab_size, prop_dim, max_len=max_len_model, dropout=0.1).to(device)
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    
    print(f"âœ… Loaded model from epoch {checkpoint['epoch']} with loss {checkpoint['loss']:.4f}")
    return model

In [None]:
def load_training_smiles(training_data_path):
    """Loads the canonical SMILES from the training set for novelty check."""
    print(f"Loading training data from {training_data_path} for novelty check...")
    df = pd.read_csv(training_data_path)
    training_smiles = set(df['canonical'].tolist())
    print(f"Found {len(training_smiles)} unique training SMILES.")
    return training_smiles

def decode_smiles(tensor, idx_to_token):
    """Decodes a tensor of IDs into a SMILES string."""
    smiles_list = []
    for row in tensor:
        smi = ""
        for idx in row:
            idx = idx.item()
            if idx == 1: 
                continue
            if idx == 69: 
                break
            if idx == 0: 
                break
            smi += idx_to_token.get(idx, '?')
        smiles_list.append(smi)
    return smiles_list

In [None]:
def generate_and_validate(model, props_to_use, token_maps, device, num_to_gen, gen_batch_size, training_smiles):
    
    _, idx_to_token = token_maps
    start_token_id = 1
    stop_token_id = 69
    max_gen_len = 128
    top_k = 50

    all_generated_smiles = []
    all_valid_smiles = []
    
    # Store calculated properties
    props_calculated = {
        'QED': [], 'LogP': [], 'MolWt': [], 'TPSA': [] , 'SAS': []
    }
    
    # Calculate how many batches we need
    num_batches = int(np.ceil(num_to_gen / gen_batch_size))

    for _ in tqdm(range(num_batches), desc="Generating molecules"):
        # Generate a batch
        generated = torch.tensor([[start_token_id]] * gen_batch_size, dtype=torch.long).to(device)
        
        props_batch = props_to_use.repeat(gen_batch_size, 1)

        with torch.no_grad():
            for _ in range(max_gen_len):
                logits = model(generated, props_batch)
                last_logits = logits[:, -1, :] 
                
                v, _ = torch.topk(last_logits, top_k)
                last_logits[last_logits < v[:, [-1]]] = -float('Inf')
                
                probs = F.softmax(last_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                generated = torch.cat([generated, next_token], dim=1)
                
                if (next_token == stop_token_id).all():
                    break
        
        # Decode and Validate the batch
        decoded_batch = decode_smiles(generated, idx_to_token)
        all_generated_smiles.extend(decoded_batch)

        for smi in decoded_batch:
            mol = Chem.MolFromSmiles(smi)
            if mol: # If molecule is valid
                all_valid_smiles.append(smi)
                props_calculated['QED'].append(Descriptors.qed(mol))
                props_calculated['LogP'].append(Descriptors.MolLogP(mol))
                props_calculated['MolWt'].append(Descriptors.MolWt(mol))
                props_calculated['TPSA'].append(Descriptors.TPSA(mol))
                props_calculated['SAS'].append(sascorer.calculateScore(mol))

    # Calculate Metrics
    num_generated = len(all_generated_smiles)
    num_valid = len(all_valid_smiles)
    
    if num_valid == 0:
        print(" No valid molecules were generated.")
        return

    validity = num_valid / num_generated
    
    # Uniqueness
    valid_set = set(all_valid_smiles)
    uniqueness = len(valid_set) / num_valid

    # Novelty
    novel_smiles = [s for s in valid_set if s not in training_smiles]
    novelty = len(novel_smiles) / len(valid_set)

    # Print Report
    print("\n--- METRICS ---")
    print(f"Total Generated: {num_generated}")
    print(f" Validity:     {validity * 100:.2f}% ({num_valid} molecules)")
    print(f" Uniqueness:   {uniqueness * 100:.2f}% ({len(valid_set)} unique)")
    print(f" Novelty:      {novelty * 100:.2f}% ({len(novel_smiles)} novel)")

    # Print Property Report
    print("\n--- PROPERTY ANALYSIS (of valid molecules) ---")
    print(f"Target Props (Normalized): {props_to_use[0].cpu().numpy()}")
    print("Actual Props (Un-normalized, Avg):")
    for key, values in props_calculated.items():
        if values:
            print(f"  {key}: {np.mean(values):.4f}")
            
    print("\n" + "="*40 + "\n")

In [None]:
CHECKPOINT_PATH = "../results/models_5l/u&c_generator_epoch_50.pt" 
TRAINING_DATA_PATH = "../data/processed_5l/train_properties.csv" 

NUM_TO_GENERATE = 1000
GEN_BATCH_SIZE = 32 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = load_model(CHECKPOINT_PATH, device)
token_maps = get_token_maps()
training_smiles = load_training_smiles(TRAINING_DATA_PATH)

cuda
âœ… Loaded model from epoch 50 with loss 0.6450
Loading training data from ../data/processed_5l/train_properties.csv for novelty check...




Found 278937 unique training SMILES.


In [None]:
# Test 1: Conditional Generation
target_props = torch.tensor([[
    0.8781114617473635,0.07833406806496232,0.6368721378699783,0.03738495242074764,0.09905331048030695 # a sample from testing dataset
]], dtype=torch.float32).to(device)

print("--- 1. VALIDATING: Conditional Generation ---")
generate_and_validate(
    model, 
    props_to_use=target_props,
    token_maps=token_maps,
    device=device,
    num_to_gen=NUM_TO_GENERATE,
    gen_batch_size=GEN_BATCH_SIZE,
    training_smiles=training_smiles
)

--- 1. VALIDATING: Conditional Generation ---


Generating molecules: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 32/32 [00:24<00:00,  1.29it/s]


--- METRICS ---
Total Generated: 1024
âœ… Validity:     75.78% (776 molecules)
âœ¨ Uniqueness:   100.00% (776 unique)
ðŸ’¡ Novelty:      99.74% (774 novel)

--- PROPERTY ANALYSIS (of valid molecules) ---
Target Props (Normalized): [0.8781115  0.07833407 0.6368721  0.03738495 0.09905331]
Actual Props (Un-normalized, Avg):
  QED: 0.3667
  LogP: 3.3454
  MolWt: 463.5802
  TPSA: 140.1037
  SAS: 2.3691







In [None]:
# Test 2: Unconditional Generation
uncond_props = torch.zeros((1, 5), dtype=torch.float32).to(device)

print("VALIDATING: Unconditional Generation ---")
generate_and_validate(
    model, 
    props_to_use=uncond_props,
    token_maps=token_maps,
    device=device,
    num_to_gen=NUM_TO_GENERATE,
    gen_batch_size=GEN_BATCH_SIZE,
    training_smiles=training_smiles
)

print("Validation complete.")

--- 2. VALIDATING: Unconditional Generation ---


Generating molecules: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 32/32 [00:24<00:00,  1.33it/s]


--- METRICS ---
Total Generated: 1024
âœ… Validity:     81.35% (833 molecules)
âœ¨ Uniqueness:   99.76% (831 unique)
ðŸ’¡ Novelty:      92.06% (765 novel)

--- PROPERTY ANALYSIS (of valid molecules) ---
Target Props (Normalized): [0. 0. 0. 0. 0.]
Actual Props (Un-normalized, Avg):
  QED: 0.5285
  LogP: 2.5194
  MolWt: 298.0055
  TPSA: 65.2163
  SAS: 2.9442


âœ… Validation complete.



