In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
import numpy as np
import os, gc
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset

In [2]:
# --- 1. Generator Class ---
# We need this to load your pre-trained model
class Generator(nn.Module):
    def __init__(self, vocab_size, prop_dim, d_model=256, nhead=8, num_layers=4, max_len=128, dropout=0.1): 
        super().__init__()
        self.d_model = d_model
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.prop_embed = nn.Linear(prop_dim, d_model)
        self.dropout = nn.Dropout(dropout)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=512, batch_first=False, dropout=dropout  
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, props):
        src = torch.clamp(src, 0, self.token_embed.num_embeddings - 1)
        B, L = src.shape
        tok_emb = self.token_embed(src) * (self.d_model ** 0.5)
        pos = torch.arange(L, device=src.device).unsqueeze(0)
        pos_emb = self.pos_embed(pos)
        prop_emb = self.prop_embed(props).unsqueeze(1)
        
        x = tok_emb + pos_emb + prop_emb
        x = self.dropout(x) 
        x = x.transpose(0, 1)  # Transformer expects [L, B, D]
        
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(L).to(src.device)
        out = self.transformer(x, mask=tgt_mask)
        
        out = out.transpose(0, 1) 
        logits = self.fc_out(out)
        return logits

# --- 2. Discriminator Class (NEW) ---
# This is an "Encoder" as you planned: it's a Transformer Encoder *without* a mask
class Discriminator(nn.Module):
    def __init__(self, vocab_size, prop_dim, d_model=256, nhead=8, num_layers=4, max_len=128, dropout=0.1): 
        super().__init__() # <-- FIX: __init__
        self.d_model = d_model
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.prop_embed = nn.Linear(prop_dim, d_model)
        self.dropout = nn.Dropout(dropout)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=512, batch_first=False, dropout=dropout  
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(d_model, 1) # Output 1 logit for (Real/Fake)

    def forward(self, src, props):
        src = torch.clamp(src, 0, self.token_embed.num_embeddings - 1)
        B, L = src.shape
        
        # This logic exactly matches your Generator's input prep
        tok_emb = self.token_embed(src) * (self.d_model ** 0.5)
        pos = torch.arange(L, device=src.device).unsqueeze(0)
        pos_emb = self.pos_embed(pos)
        prop_emb = self.prop_embed(props).unsqueeze(1)
        
        x = tok_emb + pos_emb + prop_emb
        x = self.dropout(x)
        x = x.transpose(0, 1) # [L, B, D]
        
        # --- CRITICAL DIFFERENCE ---
        # No mask! The Discriminator sees the *entire* sequence to judge it.
        out = self.transformer(x) # Shape: [L, B, D]
        
        # Global Pooling: We use the embedding of the first token (<START>)
        # as a [CLS] token to represent the entire sequence.
        pooled_output = out[0, :, :] # Shape: [B, D]
        
        logit = self.fc_out(pooled_output) 
        return logit.squeeze(-1) # Output shape: [batch]

# --- 3. Token Maps ---
def get_token_maps():
    token_to_idx = {
    "#": 2, "%": 3, "(": 4, ")": 5, "+": 6, "-": 7, ".": 8, "/": 9, "0": 10, "1": 11, "2": 12, "3": 13,
    "4": 14, "5": 15, "6": 16, "7": 17, "8": 18, "9": 19, "=": 20, "@": 21, "A": 22, "B": 23, "C": 24,
    "D": 25, "E": 26, "F": 27, "G": 28, "H": 29, "I": 30, "K": 31, "L": 32, "M": 33, "N": 34, "O": 35,
    "P": 36, "R": 37, "S": 38, "T": 39, "U": 40, "V": 41, "W": 42, "X": 43, "Y": 44, "Z": 45, "[": 46,
    "\\": 47, "]": 48, "a": 49, "b": 50, "c": 51, "d": 52, "e": 53, "f": 54, "g": 55, "h": 56, "i": 57,
    "k": 58, "l": 59, "m": 60, "n": 61, "o": 62, "p": 63, "r": 64, "s": 65, "t": 66, "u": 67,
    "y": 68, "<PAD>": 0, "<START>": 1, "<END>": 69}
    idx_to_token = {v: k for k, v in token_to_idx.items()}
    return token_to_idx, idx_to_token

# --- 4. Dataset---
class MoleculeDataset(Dataset):
    def __init__(self, encoded_path, properties_csv):
        super().__init__() # <-- FIX: __init__
        
        # Load property vectors
        self.props_df = pd.read_csv(properties_csv)
        self.prop_columns = ['QED', 'SAS', 'LogP', 'TPSA', 'MolWt']
        self.properties = torch.tensor(
            self.props_df[self.prop_columns].values,
            dtype=torch.float
        )

        # Load encoded token sequences
        print(f"Loading real encoded sequences from {encoded_path}...")
        self.encoded_sequences = torch.load(encoded_path, weights_only=True)
        print("✅ Loaded encoded sequences.")
        
        assert len(self.encoded_sequences) == len(self.properties), "Mismatch: sequences vs properties"

    def __len__(self): # <-- FIX: __len__
        return len(self.encoded_sequences)

    def __getitem__(self, idx): # <-- FIX: __getitem__
        seq_tensor = self.encoded_sequences[idx]
        prop_tensor = self.properties[idx]
        return seq_tensor, prop_tensor

# --- 5. Correct Fake Sample Generator ---
# This generates *full, autoregressive* samples
def generate_fake_samples(generator, props_to_use, batch_size, device, token_maps, max_len=128):
    generator.eval() # Generator is for inference here
    
    token_to_idx, _ = token_maps
    start_token_id = token_to_idx['<START>']
    stop_token_id = token_to_idx['<END>']
    pad_token_id = token_to_idx['<PAD>']
    
    top_k = 50

    generated_seqs = torch.tensor([[start_token_id]] * batch_size, dtype=torch.long).to(device)
    
    with torch.no_grad():
        for _ in range(max_len - 1): # -1 for the start token
            logits = generator(generated_seqs, props_to_use)
            last_logits = logits[:, -1, :]
            
            v, _ = torch.topk(last_logits, top_k)
            last_logits[last_logits < v[:, [-1]]] = -float('Inf')
            
            probs = F.softmax(last_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated_seqs = torch.cat([generated_seqs, next_token], dim=1)
            
            # Stop if all sequences have hit the <END> token
            if (next_token.squeeze() == stop_token_id).all():
                break
    
    # --- CRITICAL: Pad all sequences to max_len ---
    B, L = generated_seqs.shape
    if L < max_len:
        pads = torch.full((B, max_len - L), pad_token_id, dtype=torch.long, device=device)
        generated_seqs = torch.cat([generated_seqs, pads], dim=1)
    
    return generated_seqs

In [None]:
# --- Setup Paths and Hyperparameters ---
CHECKPOINT_DIR = "../results/models_5l/"
TRAIN_ENCODED_PATH = "../data/processed_5l/train_encoded.pt"
TRAIN_PROPERTIES_CSV = "../data/processed_5l/train_properties.csv"
GEN_CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "u&c_generator_epoch_50.pt") 

# Model Hyperparameters (Must match your Generator)
VOCAB_SIZE = 70 
PROP_DIM = 5
D_MODEL = 256
N_HEAD = 8
NUM_LAYERS = 4
MAX_LEN = 128
DROPOUT = 0.1

# Training Hyperparameters
TOTAL_EPOCHS = 10  # Pre-training the discriminator is usually fast (5-10 epochs)
LR = 1e-4
BATCH_SIZE = 64  # Using the same batch size as your last G-run
p_uncond = 0.1   # Must match the Generator's training!

# --- 1. Data and Device Setup ---
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
gc.collect()
torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

dataset = MoleculeDataset(encoded_path=TRAIN_ENCODED_PATH, properties_csv=TRAIN_PROPERTIES_CSV)
dataset.encoded_sequences = dataset.encoded_sequences.long() 
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

token_to_idx, idx_to_token = get_token_maps()
token_maps = (token_to_idx, idx_to_token)

# --- 2. Load Pre-trained Generator (for generating fake samples) ---
generator = Generator(
    vocab_size=VOCAB_SIZE, prop_dim=PROP_DIM, d_model=D_MODEL, nhead=N_HEAD, 
    num_layers=NUM_LAYERS, max_len=MAX_LEN, dropout=DROPOUT
).to(device)

gen_checkpoint = torch.load(GEN_CHECKPOINT_PATH, map_location=device, weights_only=True)
generator.load_state_dict(gen_checkpoint['model_state_dict'])
generator.eval() # Set to eval mode, we are only using it for inference
print(f" Loaded Generator from epoch {gen_checkpoint['epoch']}")

# --- 3. Initialize Discriminator and Optimizer ---
discriminator = Discriminator(
    vocab_size=VOCAB_SIZE, prop_dim=PROP_DIM, d_model=D_MODEL, nhead=N_HEAD, 
    num_layers=NUM_LAYERS, max_len=MAX_LEN, dropout=DROPOUT
).to(device) 

optimizer_D = optim.Adam(discriminator.parameters(), lr=LR)
criterion = nn.BCEWithLogitsLoss() # Standard for GANs
start_epoch = 0

print(" Setup complete. Starting discriminator training...")

Using device: cuda
Loading real encoded sequences from ../data/processed_5l/train_encoded.pt...
✅ Loaded encoded sequences.




✅ Loaded Generator from epoch 50
✅ Setup complete. Starting discriminator training...


In [6]:
TOTAL_EPOCHS = 3
TOTAL_EPOCHS

3

In [7]:
# --- 4. Training Loop ---
for epoch in range(start_epoch, TOTAL_EPOCHS):
    discriminator.train()
    total_loss = 0
    
    batch_iterator = tqdm(
        enumerate(dataloader), 
        desc=f"Disc Epoch {epoch+1}/{TOTAL_EPOCHS}", 
        total=len(dataloader)
    )

    for i, (real_seqs, real_props) in batch_iterator:
        real_seqs, real_props = real_seqs.to(device), real_props.to(device)
        batch_size = real_seqs.size(0)
        
        # --- Handle incomplete last batch ---
        if batch_size != BATCH_SIZE:
            continue # Skip batches that aren't full
            
        # --- Conditional Dropping ---
        # Teach the Discriminator to also be an unconditional critic
        props_to_use = real_props
        if torch.rand(1).item() < p_uncond:
            props_to_use = torch.zeros_like(real_props)
        
        # --- A. Train on Real Samples (Target Label = 1.0) ---
        discriminator.zero_grad()
        real_logits = discriminator(real_seqs, props_to_use)
        real_labels = torch.ones(batch_size, device=device) 
        loss_real = criterion(real_logits, real_labels)
        loss_real.backward() # Calculate gradients

        # --- B. Train on Fake Samples (Target Label = 0.0) ---
        
        # Generate full, autoregressive fake samples
        fake_seqs = generate_fake_samples(
            generator, props_to_use, batch_size, device, token_maps, max_len=MAX_LEN
        )
        
        # Use .detach() to stop gradients from flowing back into the generator
        fake_logits = discriminator(fake_seqs.detach(), props_to_use)
        fake_labels = torch.zeros(batch_size, device=device)
        loss_fake = criterion(fake_logits, fake_labels)
        loss_fake.backward() # Calculate gradients
        
        # --- C. Update Discriminator ---
        # The optimizer steps *after* both .backward() calls
        optimizer_D.step()
        
        loss_D = loss_real + loss_fake
        total_loss += loss_D.item()
        
        running_avg_loss = total_loss / (i + 1)
        batch_iterator.set_postfix(
            avg_loss=f"{running_avg_loss:.4f}", 
            real_loss=f"{loss_real.item():.4f}", 
            fake_loss=f"{loss_fake.item():.4f}"
        )
    
    avg_loss = total_loss / len(dataloader) 
    print()

    # --- 5. Save Checkpoint ---
    checkpoint_path = f"{CHECKPOINT_DIR}discriminator_epoch_{epoch+1}.pt"
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': discriminator.state_dict(),
        'optimizer_state_dict': optimizer_D.state_dict(),
        'loss': avg_loss
    }, checkpoint_path)
    print(f"--- Checkpoint saved to {checkpoint_path} ---")

print("✅ Discriminator Pre-training complete!")

Disc Epoch 1/3: 100%|██████████| 4359/4359 [2:09:19<00:00,  1.78s/it, avg_loss=0.0280, fake_loss=0.0109, real_loss=0.0102]  



--- Checkpoint saved to ../results/models_5l/discriminator_epoch_1.pt ---


Disc Epoch 2/3:   3%|▎         | 141/4359 [04:01<2:00:19,  1.71s/it, avg_loss=0.0267, fake_loss=0.0300, real_loss=0.0257]


KeyboardInterrupt: 

In [8]:
@torch.no_grad() # Disable gradients for this entire function
def validate_discriminator(discriminator, generator, val_dataloader, device, token_maps, p_uncond=0.1):
    discriminator.eval()
    generator.eval()
    
    total_real_correct = 0
    total_fake_correct = 0
    total_samples = 0
    
    batch_iterator = tqdm(val_dataloader, desc="Validating Discriminator")

    for real_seqs, real_props in batch_iterator:
        real_seqs, real_props = real_seqs.to(device), real_props.to(device)
        batch_size = real_seqs.size(0)
        
        # --- Handle incomplete last batch ---
        if batch_size != val_dataloader.batch_size:
            continue
            
        total_samples += batch_size

        # --- Decide on properties (conditional or unconditional) ---
        props_to_use = real_props
        if torch.rand(1).item() < p_uncond:
            props_to_use = torch.zeros_like(real_props)

        # --- 1. Test on REAL data ---
        # We expect the logit to be positive (> 0)
        real_logits = discriminator(real_seqs, props_to_use)
        total_real_correct += (real_logits > 0).sum().item()

        # --- 2. Test on FAKE data ---
        fake_seqs = generate_fake_samples(
            generator, props_to_use, batch_size, device, token_maps
        )
        
        # We expect the logit to be negative (< 0)
        fake_logits = discriminator(fake_seqs, props_to_use)
        total_fake_correct += (fake_logits < 0).sum().item()
        
        # Update progress bar
        real_acc = (total_real_correct / total_samples) * 100
        fake_acc = (total_fake_correct / total_samples) * 100
        batch_iterator.set_postfix(
            real_acc=f"{real_acc:.2f}%", 
            fake_acc=f"{fake_acc:.2f}%"
        )
    
    print("\n--- Validation Results ---")
    final_real_acc = (total_real_correct / total_samples) * 100
    final_fake_acc = (total_fake_correct / total_samples) * 100
    final_total_acc = ((total_real_correct + total_fake_correct) / (total_samples * 2)) * 100
    
    print(f" Real Accuracy: {final_real_acc:.2f}%")
    print(f" Fake Accuracy: {final_fake_acc:.2f}%")
    print(f" Total Accuracy: {final_total_acc:.2f}%")
    print("\n(This is the discriminator's performance on unseen data)")

In [9]:
# --- 1. Define Validation Data Paths ---
VAL_ENCODED_PATH = "../data/processed_5l/val_encoded.pt"
VAL_PROPERTIES_CSV = "../data/processed_5l/val_properties.csv"

# --- 2. Load the Validation Dataset ---
try:
    val_dataset = MoleculeDataset(
        encoded_path=VAL_ENCODED_PATH, 
        properties_csv=VAL_PROPERTIES_CSV
    )
    val_dataloader = DataLoader(
        val_dataset, 
        batch_size=BATCH_SIZE, # Use BATCH_SIZE defined in the training cell
        shuffle=False
    )
    print(f" Loaded validation dataset from {VAL_PROPERTIES_CSV}")
except FileNotFoundError:
    print(f" ERROR: Validation data not found. Skipping validation.")
    print(f"    - Searched for: {VAL_ENCODED_PATH}")
    print(f"    - Searched for: {VAL_PROPERTIES_CSV}")
    val_dataloader = None

# --- 3. Load the Final Discriminator Checkpoint ---
final_disc_epoch = TOTAL_EPOCHS 
DISC_CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "discriminator_epoch_1.pt")

if os.path.exists(DISC_CHECKPOINT_PATH) and val_dataloader:
    print(f"Loading final discriminator from: {DISC_CHECKPOINT_PATH}")
    
    # Load the checkpoint
    disc_checkpoint = torch.load(DISC_CHECKPOINT_PATH, map_location=device, weights_only=True)
    
    # Load the weights into the 'discriminator' model object from the training cell
    discriminator.load_state_dict(disc_checkpoint['model_state_dict'])
    print(f" Loaded Discriminator from epoch {disc_checkpoint['epoch']} for validation.")

    # --- 4. Run Validation ---
    # The 'generator' object is already loaded and in eval mode
    # The 'discriminator' object is now loaded
    # The 'val_dataloader', 'device', 'token_maps', and 'p_uncond' are all defined
    
    validate_discriminator(
        discriminator=discriminator, 
        generator=generator, 
        val_dataloader=val_dataloader, 
        device=device, 
        token_maps=token_maps, 
        p_uncond=p_uncond
    )
else:
    if not val_dataloader:
        print("Skipping validation because validation data was not found.")
    else:
        print(f" ERROR: Final discriminator checkpoint not found at {DISC_CHECKPOINT_PATH}")
        print("Skipping validation.")

Loading real encoded sequences from ../data/processed_5l/val_encoded.pt...
✅ Loaded encoded sequences.
 Loaded validation dataset from ../data/processed_5l/val_properties.csv
Loading final discriminator from: ../results/models_5l/discriminator_epoch_1.pt
 Loaded Discriminator from epoch 1 for validation.


Validating Discriminator: 100%|██████████| 243/243 [06:41<00:00,  1.65s/it, fake_acc=98.20%, real_acc=99.16%]


--- Validation Results ---
 Real Accuracy: 99.16%
 Fake Accuracy: 98.20%
 Total Accuracy: 98.68%

(This is the discriminator's performance on unseen data)



