In [1]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
import os, gc
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
import json
from rdkit import Chem, rdBase
from rdkit.Chem import Descriptors
from src import sascorer

rdBase.DisableLog('rdApp.error')

In [None]:
class Generator(nn.Module):
    def __init__(self, vocab_size, prop_dim, d_model=256, nhead=8, num_layers=4, max_len=128, dropout=0.1): 
        super().__init__()
        self.d_model = d_model
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.prop_embed = nn.Linear(prop_dim, d_model)
        self.dropout = nn.Dropout(dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=512, batch_first=False, dropout=dropout  
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, props):
        src = torch.clamp(src, 0, self.token_embed.num_embeddings - 1)
        B, L = src.shape
        tok_emb = self.token_embed(src) * (self.d_model ** 0.5)
        pos = torch.arange(L, device=src.device).unsqueeze(0)
        pos_emb = self.pos_embed(pos)
        prop_emb = self.prop_embed(props).unsqueeze(1)
        
        x = tok_emb + pos_emb + prop_emb
        x = self.dropout(x) 
        x = x.transpose(0, 1)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(L).to(src.device)
        out = self.transformer(x, mask=tgt_mask)
        out = out.transpose(0, 1) 
        logits = self.fc_out(out)
        return logits

    def sample(self, props_to_use, token_maps, max_len=128, top_k=50):
        token_to_idx, _ = token_maps
        start_token_id = token_to_idx['<START>']
        stop_token_id  = token_to_idx['<END>']
        pad_token_id   = token_to_idx['<PAD>']
    
        batch_size = props_to_use.size(0)
        device = props_to_use.device
    
        generated_seqs = torch.full((batch_size, 1), start_token_id, dtype=torch.long, device=device)
        sum_log_probs = torch.zeros(batch_size, device=device)
    
        for t in range(max_len - 1):
    
            # Forward with grads ON → log_probs have grad_fn
            logits = self.forward(generated_seqs, props_to_use)
            last_logits = logits[:, -1, :]
    
            # Top-k filtering
            v, _ = torch.topk(last_logits, top_k, dim=-1)
            last_logits[last_logits < v[:, [-1]]] = -float("inf")
    
            probs = F.softmax(last_logits, dim=-1)
            log_probs = F.log_softmax(last_logits, dim=-1)
    
            # Non-differentiable sampling
            with torch.no_grad():
                next_token = torch.multinomial(probs, num_samples=1)
    
            # Log-prob of chosen token (differentiable)
            chosen_logprob = log_probs.gather(1, next_token).squeeze(1)
    
            not_finished = (generated_seqs[:, -1] != stop_token_id).float()
            sum_log_probs += chosen_logprob * not_finished
    
            generated_seqs = torch.cat([generated_seqs, next_token], dim=1)
    
            # detach so graph doesn't grow across timesteps
            generated_seqs = generated_seqs.detach()
    
            if not_finished.sum() == 0:
                break
    
        # Pad to max_len (no grad)
        with torch.no_grad():
            B, L = generated_seqs.shape
            if L < max_len:
                pads = torch.full((B, max_len - L), pad_token_id, dtype=torch.long, device=device)
                generated_seqs = torch.cat([generated_seqs, pads], dim=1)
    
        return generated_seqs, sum_log_probs



class Discriminator(nn.Module):
    def __init__(self, vocab_size, prop_dim, d_model=256, nhead=8, num_layers=4, max_len=128, dropout=0.1): 
        super().__init__()
        self.d_model = d_model
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        self.prop_embed = nn.Linear(prop_dim, d_model)
        self.dropout = nn.Dropout(dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=512, batch_first=False, dropout=dropout  
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(d_model, 1)

    def forward(self, src, props):
        src = torch.clamp(src, 0, self.token_embed.num_embeddings - 1)
        B, L = src.shape
        tok_emb = self.token_embed(src) * (self.d_model ** 0.5)
        pos = torch.arange(L, device=src.device).unsqueeze(0)
        pos_emb = self.pos_embed(pos)
        prop_emb = self.prop_embed(props).unsqueeze(1)
        
        x = tok_emb + pos_emb + prop_emb
        x = self.dropout(x)
        x = x.transpose(0, 1)
        out = self.transformer(x)
        pooled_output = out[0, :, :]
        logit = self.fc_out(pooled_output) 
        return logit.squeeze(-1)


class PropertyDataset(Dataset):
    def __init__(self, properties_csv):
        super().__init__()
        self.props_df = pd.read_csv(properties_csv)
        self.prop_columns = ['QED', 'SAS', 'LogP', 'TPSA', 'MolWt']
        self.properties = torch.tensor(
            self.props_df[self.prop_columns].values,
            dtype=torch.float
        )
        print(f" Loaded {len(self.properties)} target properties.")

    def __len__(self):
        return len(self.properties)

    def __getitem__(self, idx):
        return self.properties[idx]


def get_token_maps():
    token_to_idx = {
    "#": 2, "%": 3, "(": 4, ")": 5, "+": 6, "-": 7, ".": 8, "/": 9, "0": 10, "1": 11, "2": 12, "3": 13,
    "4": 14, "5": 15, "6": 16, "7": 17, "8": 18, "9": 19, "=": 20, "@": 21, "A": 22, "B": 23, "C": 24,
    "D": 25, "E": 26, "F": 27, "G": 28, "H": 29, "I": 30, "K": 31, "L": 32, "M": 33, "N": 34, "O": 35,
    "P": 36, "R": 37, "S": 38, "T": 39, "U": 40, "V": 41, "W": 42, "X": 43, "Y": 44, "Z": 45, "[": 46,
    "\\": 47, "]": 48, "a": 49, "b": 50, "c": 51, "d": 52, "e": 53, "f": 54, "g": 55, "h": 56, "i": 57,
    "k": 58, "l": 59, "m": 60, "n": 61, "o": 62, "p": 63, "r": 64, "s": 65, "t": 66, "u": 67,
    "y": 68, "<PAD>": 0, "<START>": 1, "<END>": 69}
    idx_to_token = {v: k for k, v in token_to_idx.items()}
    return token_to_idx, idx_to_token


def decode_smiles(tensor, idx_to_token):
    smiles_list = []
    for row in tensor:
        smi = ""
        for idx in row:
            idx = idx.item()
            if idx == 1: continue
            if idx == 69: break
            if idx == 0:  break
            smi += idx_to_token.get(idx, '?')
        smiles_list.append(smi)
    return smiles_list


def calculate_properties(smiles_list, device):
    props = []
    for smi in smiles_list:
        mol = Chem.MolFromSmiles(smi)
        if mol:
            try:
                qed = Descriptors.qed(mol)
                sas = sascorer.calculateScore(mol)
                logp = Descriptors.MolLogP(mol)
                tpsa = Descriptors.TPSA(mol)
                mw = Descriptors.MolWt(mol)
                props.append([qed, sas, logp, tpsa, mw])
            except:
                props.append([0.0] * 5)
        else:
            props.append([0.0] * 5)
    return torch.tensor(props, dtype=torch.float32).to(device)

In [None]:
CHECKPOINT_DIR = "../results/models_5l/"
TRAIN_PROPERTIES_CSV = "../data/processed_5l/train_properties.csv"
GEN_CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "u&c_generator_epoch_50.pt") 
DISC_CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "discriminator_epoch_1.pt")

STATS_PATH = "../data/processed_5l/property_stats.json" 

VOCAB_SIZE = 70 
PROP_DIM = 5
D_MODEL = 256
N_HEAD = 8
NUM_LAYERS = 4
MAX_LEN = 128
DROPOUT = 0.1

# RL Training Hyperparameters
RL_STEPS = 5000
BATCH_SIZE = 8 
G_LEARNING_RATE = 1e-5
p_uncond = 0.1
W_DISC = 0.2
W_PROP = 0.8

# 1. Setup Device and Token Maps
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
gc.collect()
torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
token_to_idx, idx_to_token = get_token_maps()
token_maps = (token_to_idx, idx_to_token)

# 2. Load Property Data
prop_dataset = PropertyDataset(properties_csv=TRAIN_PROPERTIES_CSV)
prop_dataloader = DataLoader(prop_dataset, batch_size=BATCH_SIZE, shuffle=True)
prop_iter = iter(prop_dataloader)


print(f"Loading property stats from {STATS_PATH}...")
try:
    with open(STATS_PATH, 'r') as f:
        loaded_stats = json.load(f)
    
    prop_cols = ['QED', 'SAS', 'LogP', 'TPSA', 'MolWt']
    prop_stats = {
        'min': torch.tensor([loaded_stats['min'][col] for col in prop_cols], dtype=torch.float32).to(device),
        'max': torch.tensor([loaded_stats['max'][col] for col in prop_cols], dtype=torch.float32).to(device)
    }
    prop_stats['range'] = (prop_stats['max'] - prop_stats['min']) + 1e-8 

    print(f"Loaded property stats (Min): {prop_stats['min'].cpu().numpy()}")
    print(f"Loaded property stats (Max):  {prop_stats['max'].cpu().numpy()}")
except FileNotFoundError:
    print(f" FATAL ERROR: prop_stats.json not found at {STATS_PATH}")
    print("Please re-run your preprocessing script to create this file.")
    # This will stop the notebook from proceeding
    raise

# 4. Load Pre-trained Generator
generator = Generator(
    vocab_size=VOCAB_SIZE, prop_dim=PROP_DIM, d_model=D_MODEL, nhead=N_HEAD, 
    num_layers=NUM_LAYERS, max_len=MAX_LEN, dropout=DROPOUT
).to(device)
gen_checkpoint = torch.load(GEN_CHECKPOINT_PATH, map_location=device, weights_only=True)
generator.load_state_dict(gen_checkpoint['model_state_dict'])
print(f" Loaded Generator from epoch {gen_checkpoint['epoch']}")

# 5. Load Pre-trained Discriminator
discriminator = Discriminator(
    vocab_size=VOCAB_SIZE, prop_dim=PROP_DIM, d_model=D_MODEL, nhead=N_HEAD, 
    num_layers=NUM_LAYERS, max_len=MAX_LEN, dropout=DROPOUT
).to(device)
disc_checkpoint = torch.load(DISC_CHECKPOINT_PATH, map_location=device, weights_only=True)

# 6. FREEZE Discriminator
discriminator.load_state_dict(disc_checkpoint['model_state_dict'])
discriminator.eval()
for param in discriminator.parameters():
    param.requires_grad = False
print(" Loaded Discriminator and set to eval() mode (weights frozen).")
    
# 7. Setup Generator Optimizer
optimizer_G = optim.Adam(generator.parameters(), lr=G_LEARNING_RATE)

Using device: cuda
 Loaded 278937 target properties.
Loading property stats from ../data/processed_5l/property_stats.json...
Loaded property stats (Min): [ 3.9431723e-03  1.0000000e+00 -8.7627800e+01  0.0000000e+00
  1.0080000e+00]
Loaded property stats (Max):  [9.4825125e-01 1.0000000e+01 5.9808720e+01 4.2015000e+03 1.8838697e+04]




 Loaded Generator from epoch 50
 Loaded Discriminator and set to eval() mode (weights frozen).


In [None]:
# RL FINE-TUNING LOOP
print(f"Using BATCH_SIZE = {BATCH_SIZE}.")
baseline = None  # EMA baseline

tqdm_iter = tqdm(range(RL_STEPS), desc="RL Step")

for step in tqdm_iter:
    generator.train()

    # 1. Target Properties
    try:
        target_props_raw = next(prop_iter)
    except StopIteration:
        prop_iter = iter(prop_dataloader)
        target_props_raw = next(prop_iter)

    if target_props_raw.shape[0] != BATCH_SIZE:
        prop_iter = iter(prop_dataloader)
        target_props_raw = next(prop_iter)

    target_props_raw = target_props_raw.to(device)

    # Normalize target props to [0,1]
    target_props = (target_props_raw - prop_stats['min']) / prop_stats['range']
    target_props = torch.clamp(target_props, 0.0, 1.0)

    # 2. Conditional dropping
    props_to_use = target_props
    if torch.rand(1).item() < p_uncond:
        props_to_use = torch.zeros_like(target_props)

    # 3. Generate fake molecules
    fake_seqs, sum_log_probs = generator.sample(props_to_use, token_maps, max_len=MAX_LEN)

    # 4. Rewards
    # Discriminator reward
    disc_logits = discriminator(fake_seqs.detach(), props_to_use)
    reward_D = torch.sigmoid(disc_logits)     # [B] in (0,1)

    # Decode SMILES
    smiles_list = decode_smiles(fake_seqs.detach(), idx_to_token)

    # RDKit props
    actual_props = calculate_properties(smiles_list, device)
    actual_props_norm = (actual_props - prop_stats['min']) / prop_stats['range']
    actual_props_norm = torch.clamp(actual_props_norm, 0.0, 1.0)

    # MSE per sample
    mse_per_item = F.mse_loss(actual_props_norm, target_props, reduction="none").mean(dim=1)

    # Proposed Property Reward
    # r = exp(-beta * mse), β=20
    beta = 20.0
    reward_P = torch.exp(-beta * mse_per_item)    # (0,1]

    # Validity reward
    valid_list = [1.0 if Chem.MolFromSmiles(s) else 0.0 for s in smiles_list]
    reward_valid = torch.tensor(valid_list, device=device, dtype=torch.float32)

    # TOTAL reward
    total_reward = (
        W_DISC * reward_D +
        W_PROP * reward_P +
        0.2 * reward_valid      # validity reward weight
    )

    # Advantage (baseline + normalization)
    if baseline is None:
        baseline = total_reward.mean().detach()
    else:
        baseline = 0.99 * baseline + 0.01 * total_reward.mean().detach()

    advantage = total_reward - baseline
    advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)

    # POLICY GRADIENT LOSS
    policy_loss = - (sum_log_probs * advantage.detach()).mean()

    optimizer_G.zero_grad()
    policy_loss.backward()
    optimizer_G.step()

    # LOGGING
    if step % 50 == 0:
        frac_valid = float(sum(valid_list)) / len(valid_list)
        avg_mse = mse_per_item.mean().item()

        tqdm_iter.set_postfix(
            loss=f"{policy_loss.item():.4f}",
            val=f"{frac_valid:.2f}",
            mse=f"{avg_mse:.3f}",
            rD=f"{reward_D.mean().item():.3f}",
            rP=f"{reward_P.mean().item():.3f}",
            tot=f"{total_reward.mean().item():.3f}"
        )

    # Checkpoint
    if step > 0 and step % 500 == 0:
        path = f"{CHECKPOINT_DIR}/generator_RL_step_{step}.pt"
        torch.save({
            "step": step,
            "model_state_dict": generator.state_dict(),
            "optimizer_state_dict": optimizer_G.state_dict(),
        }, path)
        print(f"\n--- Checkpoint saved to {path} ---")

print("RL Fine-Tuning Complete!")


Using BATCH_SIZE = 8.


RL Step:  10%|█         | 501/5000 [24:02<3:53:07,  3.11s/it, loss=-4.5367, mse=0.007, rD=0.000, rP=0.889, tot=0.862, val=0.75] 


--- Checkpoint saved to ../results/models_5l//generator_RL_step_500.pt ---


RL Step:  20%|██        | 1001/5000 [48:19<2:27:25,  2.21s/it, loss=-1.0539, mse=0.002, rD=1.000, rP=0.966, tot=1.173, val=1.00] 


--- Checkpoint saved to ../results/models_5l//generator_RL_step_1000.pt ---


RL Step:  30%|███       | 1501/5000 [1:07:30<2:02:51,  2.11s/it, loss=-0.6417, mse=0.043, rD=0.000, rP=0.454, tot=0.563, val=1.00]


--- Checkpoint saved to ../results/models_5l//generator_RL_step_1500.pt ---


RL Step:  40%|████      | 2001/5000 [1:21:44<1:14:18,  1.49s/it, loss=-5.9659, mse=0.015, rD=1.000, rP=0.872, tot=1.073, val=0.88] 


--- Checkpoint saved to ../results/models_5l//generator_RL_step_2000.pt ---


RL Step:  50%|█████     | 2501/5000 [1:40:35<2:00:42,  2.90s/it, loss=-1.8318, mse=0.002, rD=0.000, rP=0.970, tot=0.976, val=1.00] 


--- Checkpoint saved to ../results/models_5l//generator_RL_step_2500.pt ---


RL Step:  60%|██████    | 3001/5000 [2:00:15<29:07,  1.14it/s, loss=-1.6198, mse=0.001, rD=1.000, rP=0.979, tot=1.183, val=1.00]  


--- Checkpoint saved to ../results/models_5l//generator_RL_step_3000.pt ---


RL Step:  70%|███████   | 3501/5000 [2:21:14<1:12:13,  2.89s/it, loss=-2.8551, mse=0.001, rD=0.000, rP=0.979, tot=0.983, val=1.00] 


--- Checkpoint saved to ../results/models_5l//generator_RL_step_3500.pt ---


RL Step:  80%|████████  | 4001/5000 [2:46:04<47:44,  2.87s/it, loss=-5.3314, mse=0.002, rD=0.000, rP=0.964, tot=0.971, val=1.00]  


--- Checkpoint saved to ../results/models_5l//generator_RL_step_4000.pt ---


RL Step:  90%|█████████ | 4501/5000 [3:10:28<24:34,  2.96s/it, loss=8.1686, mse=0.001, rD=0.000, rP=0.988, tot=0.990, val=1.00] 


--- Checkpoint saved to ../results/models_5l//generator_RL_step_4500.pt ---


RL Step: 100%|██████████| 5000/5000 [3:35:10<00:00,  2.58s/it, loss=-4.3154, mse=0.006, rD=0.001, rP=0.896, tot=0.892, val=0.88]

✅ RL Fine-Tuning Complete!





In [6]:
step

4999

In [7]:
path = f"{CHECKPOINT_DIR}/generator_RL_step_{step+1}.pt"
torch.save({
    "step": step+1,
    "model_state_dict": generator.state_dict(),
    "optimizer_state_dict": optimizer_G.state_dict(),
    }, path)
print(f"\n--- Checkpoint saved to {path} ---")


--- Checkpoint saved to ../results/models_5l//generator_RL_step_5000.pt ---
