In [1]:
import sys
sys.path.append('./../../mdlm')
sys.path.append('./../..')

import os
os.environ["HF_HOME"] = "/vol/bitbucket/cp524/hf_cache"
os.environ["TRITON_CACHE_DIR"] = "/vol/bitbucket/cp524/triton_cache"
os.environ["WANDB_MODE"] = "disabled"

%load_ext autoreload
%autoreload 2

import math
import torch
import torch.nn.functional as F
from tqdm import tqdm
from datetime import datetime
from peft import LoraConfig, get_peft_model
import wandb

from src.utils.rich_print import rich_print
from src.toxicity_classifier.scorer import ToxicityScorer
from src.ppl.gpt2_ppl import compute_perplexity

if torch.cuda.is_available():
    rich_print("[bold green]CUDA is available. Using GPU.[/bold green]")
    device = torch.device("cuda")
else:
    rich_print("[bold yellow]CUDA is not available. Using CPU.[/bold yellow]")
    device = torch.device("cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from hydra.core.global_hydra import GlobalHydra
GlobalHydra.instance().clear()

from hydra import initialize, compose
from omegaconf import DictConfig

initialize(
    config_path="configs",     # relative to cwd (which is src/)
    job_name="notebook",
    version_base=None          # disable legacy‐version checks
)

config: DictConfig = compose(config_name="train_final")

In [3]:
# Start a new wandb run to track this script.
run = wandb.init(
    # Set the wandb entity where your project will be logged (generally your team name).
    entity="chinmaypani42-imperial-college-london",
    # Set the wandb project where this run will be logged.
    project="mdlm-toxicity-log-variance-finetuning",
    # Track hyperparameters and run metadata.
    config=config,
)

In [4]:
from mdlm import dataloader
tokenizer = dataloader.get_tokenizer(config)

In [5]:
toxicity_scorer = ToxicityScorer()

@torch.no_grad()
def compute_rewards(tokens) -> torch.Tensor:
    """
    takes integer tokens directly
    """
    texts = tokenizer.batch_decode(tokens)
    scores = toxicity_scorer.score_texts(texts)
    return scores

@torch.no_grad()
def compute_rewards_scaled(tokens) -> torch.Tensor:
    """
    takes integer tokens directly
    """
    return compute_rewards(tokens) / config.finetuning.alpha

@torch.no_grad()
def estimate_rewards_scaled(probs, num_samples, method='mean'):
    B = probs.shape[0]
    dist = torch.distributions.Categorical(probs=probs)
    samples = dist.sample((num_samples,)).reshape(num_samples * B, -1) # type: ignore
    rewards = compute_rewards_scaled(samples).reshape(num_samples, B)
    if method == 'mean':
        return rewards.mean(dim=0) # E[r(x)/alpha]
    elif method == 'logmeanexp':
        return rewards.logsumexp(dim=0) - math.log(num_samples) # log E[exp(r(x)/alpha)]
    else:
        raise ValueError(f"Unknown method: {method}")

Some weights of the model checkpoint at s-nlp/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
from src.mdlm_diffusion import MDLMDiffusion
def _load_from_checkpoint(config, tokenizer):
    """Load model from checkpoint"""
    if 'hf' in config.backbone:
        return MDLMDiffusion(config, tokenizer=tokenizer).to(device)

    return MDLMDiffusion.load_from_checkpoint(
        config.eval.checkpoint_path, tokenizer=tokenizer, config=config
    )

In [7]:
p_ref = _load_from_checkpoint(config, tokenizer)
p_ref.eval()

MDLMDiffusion(
  (backbone): MDLM(
    (backbone): DITBackbone(
      (vocab_embed): EmbeddingLayer()
      (sigma_map): TimestepEmbedder(
        (mlp): Sequential(
          (0): Linear(in_features=256, out_features=128, bias=True)
          (1): SiLU()
          (2): Linear(in_features=128, out_features=128, bias=True)
        )
      )
      (rotary_emb): Rotary()
      (blocks): ModuleList(
        (0-11): 12 x DDiTBlock(
          (norm1): LayerNorm()
          (attn_qkv): Linear(in_features=768, out_features=2304, bias=False)
          (attn_out): Linear(in_features=768, out_features=768, bias=False)
          (dropout1): Dropout(p=0.1, inplace=False)
          (norm2): LayerNorm()
          (mlp): Sequential(
            (0): Linear(in_features=768, out_features=3072, bias=True)
            (1): GELU(approximate='tanh')
            (2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (dropout2): Dropout(p=0.1, inplace=False)
          (adaLN_modulati

In [8]:
q_phi = _load_from_checkpoint(config, tokenizer)
q_phi.eval()
f_psi = torch.nn.Parameter(torch.zeros(config.finetuning.num_timesteps, device=q_phi.device))

In [9]:
if config.finetuning.lora.enabled:
    # LORA stuff
    lora_config = LoraConfig(
        target_modules=list(config.finetuning.lora.target_modules),
        r=config.finetuning.lora.r,
        lora_alpha=config.finetuning.lora.lora_alpha,
        lora_dropout=config.finetuning.lora.lora_dropout,
        bias=config.finetuning.lora.bias,
    )
    q_phi.backbone = get_peft_model(q_phi.backbone, lora_config) # type: ignore

    # Extract the lora layers for optimizer
    # get_peft_model already freezes everything except lora params
    q_phi_lora_layers = filter(lambda p: p.requires_grad, q_phi.parameters())

In [10]:
def summary(model):
    # quick print counts
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total params: {total:,}, Trainable params: {trainable:,}") 

summary(q_phi)

Total params: 170,811,986, Trainable params: 1,184,768


In [11]:
trainable_params = list(filter(lambda p: p.requires_grad, q_phi.parameters())) + [f_psi]
optimizer = torch.optim.AdamW(trainable_params, lr=config.finetuning.lr)

In [12]:
base_dir = 'model_weights_final'  # keep base folder
timestamp = datetime.now().strftime("%Y%m%d/%H%M%S")  # e.g. 20250818/004927
model_save_dir = os.path.join(base_dir, timestamp)
run.log({"model_save_dir": model_save_dir})

os.makedirs(model_save_dir, exist_ok=True)

# Save config and metadata files
from omegaconf import OmegaConf
OmegaConf.save(config=config, f=f'{model_save_dir}/config.yaml')

In [13]:
loss_trace = []
reward_trace = []

In [14]:
L = q_phi.config.model.length
eps=1e-5
timesteps = torch.linspace(1, eps, config.finetuning.num_timesteps + 1, device=q_phi.device)
dt = (1 - eps) / config.finetuning.num_timesteps

# Training loop
for epoch in range(config.finetuning.num_epochs):
    wandb.log({"epoch": epoch+1})
    total_epoch_loss = 0.0
    for batch_idx in range(config.finetuning.batches_per_epoch):
        q_phi.train()
        
        # Clear all grads
        optimizer.zero_grad()
        
        rewards_prev = None
        log_prob_p_ref = None
        log_prob_q_phi = None
        total_loss_for_all_timesteps = 0.0
        total_log_variance_loss_for_all_timesteps = 0.0
        total_kl_loss_for_all_timesteps = 0.0
        kl_loss = torch.tensor(0.0, device=q_phi.device)
        
        # Generate batch_size samples from q_phi
        z_t = q_phi._sample_prior(config.finetuning.batch_size, L, prompt_ids=None).to(q_phi.device) # type: ignore
        for i in range(config.finetuning.num_timesteps, 0, -1):
            t = timesteps[config.finetuning.num_timesteps - i] * torch.ones(z_t.shape[0], 1, device=q_phi.device)
            # Invoke pretrained and finetune models
            with torch.enable_grad():
                q_phi_zs_given_zt, q_phi_z0_given_zt = q_phi._sample_step(z_t, t, dt)
            with torch.no_grad():
                p_ref_zs_given_zt, p_ref_z0_given_zt = p_ref._sample_step(z_t, t, dt)
                
            # Estimate rewards
            rewards = estimate_rewards_scaled(p_ref_z0_given_zt, config.finetuning.num_samples_for_reward_estimate, method=config.finetuning.reward_estimate_method)
            
            if i < config.finetuning.num_timesteps:
                # Sanity checks
                assert rewards is not None and rewards_prev is not None
                assert log_prob_p_ref is not None and log_prob_q_phi is not None
                assert log_prob_q_phi.requires_grad
                
                log_w = (rewards - rewards_prev) + (log_prob_p_ref - log_prob_q_phi) # Shape: (batch-size,)
                log_variance = (log_w - f_psi[i]) ** 2
                log_variance_loss = log_variance.mean(dim=0) # take mean across batch dimension
                total_log_variance_loss_for_all_timesteps += log_variance_loss.item()
                run.log({"log_variance_loss_per_timestep": log_variance_loss.item()})
                
                total_loss = log_variance_loss + kl_loss
                total_loss_for_all_timesteps += total_loss.item()
                run.log({"total_loss_per_timestep": total_loss.item()})
                
                # Accumulate gradients
                total_loss.backward()
                
            
            if config.finetuning.kl_method == 'forward':
                kld_batch = torch.where(
                    p_ref_z0_given_zt > 0,
                    p_ref_z0_given_zt * (torch.log(p_ref_z0_given_zt) - torch.log(q_phi_z0_given_zt.clamp_min(1e-12))),
                    torch.zeros_like(p_ref_z0_given_zt)
                ).sum(dim=(1, 2))
            elif config.finetuning.kl_method == 'backward':
                kld_batch = torch.where(
                    q_phi_z0_given_zt > 0,
                    q_phi_z0_given_zt * (torch.log(q_phi_z0_given_zt.clamp_min(1e-12)) - torch.log(p_ref_z0_given_zt.clamp_min(1e-12))),
                    torch.zeros_like(q_phi_z0_given_zt)
                ).sum(dim=(1, 2))
            else:
                raise ValueError(f"Unknown KL method: {config.finetuning.kl_method}")
        
            kl_loss = config.finetuning.kl_weight * kld_batch.mean(dim=0) # take mean across batch dimension
            total_kl_loss_for_all_timesteps += kl_loss.item()
            run.log({"kl_loss_per_timestep": kl_loss.item(), "kl_div_per_timestep": kld_batch.mean(dim=0).item()})
            
            
            q_phi_dist = torch.distributions.Categorical(probs=q_phi_zs_given_zt)
            p_ref_dist = torch.distributions.Categorical(probs=p_ref_zs_given_zt)
            
            if config.finetuning.sample_onpolicy:
                z_s = q_phi_dist.sample()
            else:
                z_s = p_ref_dist.sample()
                
            log_prob_q_phi = q_phi_dist.log_prob(z_s).sum(dim=1)
            log_prob_p_ref = p_ref_dist.log_prob(z_s).sum(dim=1)
            
            # Update for next step
            z_t = z_s
            rewards_prev = rewards
            
        z_0 = z_t
        if q_phi.config.sampling.noise_removal:
            with torch.no_grad():
                t = timesteps[-1] * torch.ones(z_0.shape[0], 1, device=q_phi.device)
                unet_conditioning = q_phi.noise(t)[0]
                logits = q_phi.forward(z_0, unet_conditioning)
                z_0 = logits[:, :, :-1].argmax(dim=-1)
        
        # Compute rewards
        rewards = compute_rewards_scaled(z_0)
        assert rewards_prev is not None and log_prob_p_ref is not None and log_prob_q_phi is not None
        log_w = (rewards - rewards_prev) + (log_prob_p_ref - log_prob_q_phi) # Shape: (batch-size,)
        log_variance = (log_w - f_psi[0]) ** 2
        log_variance_loss = log_variance.mean(dim=0) # take mean across batch dimension
        total_log_variance_loss_for_all_timesteps += log_variance_loss.item()
        run.log({"log_variance_loss_per_timestep": log_variance_loss.item()})
        
        total_loss = log_variance_loss + kl_loss
        total_loss_for_all_timesteps += total_loss.item()
        run.log({"total_loss_per_timestep": total_loss.item()})
        
        # accumulate gradients
        total_loss.backward()
        
        # gradients step
        optimizer.step()

        print((f"Batch {batch_idx+1}/{config.finetuning.batches_per_epoch}, "
            f"Loss: {total_loss_for_all_timesteps}, Reward (avg): {rewards.mean(dim=0).item() * config.finetuning.alpha} "
            f"KL Loss: {total_kl_loss_for_all_timesteps}"))
        run.log({
            "total_loss": total_loss_for_all_timesteps, 
            "log_variance_loss": total_log_variance_loss_for_all_timesteps, 
            "kl_loss": total_kl_loss_for_all_timesteps,
            "kl_div": total_kl_loss_for_all_timesteps / config.finetuning.kl_weight,
            "final_reward": rewards.mean(dim=0).item() * config.finetuning.alpha
        })
        total_epoch_loss += total_loss_for_all_timesteps
    
    q_phi.eval()
    avg_loss = total_epoch_loss / config.finetuning.batches_per_epoch
    run.log({"epoch_avg_loss": avg_loss}, step=epoch+1)
    
    tokens = q_phi.sample(num_steps=100)
    avg_rewards = compute_rewards(tokens).mean().item()
    run.log({"epoch_rewards": avg_rewards}, step=epoch+1)
    
    # perplexity
    texts = tokenizer.batch_decode(tokens)
    ppl, total_ppl = compute_perplexity(
        generations=[{
            "context": "",
            "generations": texts,
        }],
        device=device,
    )
    run.log({"epoch_ppl": ppl, "epoch_total_ppl": total_ppl}, step=epoch+1)
     # Create a wandb Table
    table = wandb.Table(columns=["Prompt", "Generated"])
    for prompt, gen in zip([""]*len(texts), texts):
        table.add_data(prompt, gen)
    # Log the whole table for this epoch
    wandb.log({f"epoch_samples": table}, step=epoch+1)

    print(f"Epoch {epoch+1}/{config.finetuning.num_epochs},  Loss (avg): {avg_loss}, Reward: {avg_rewards}, PPL: {ppl}/{total_ppl}")
    
    ckpt_path = f'{model_save_dir}/ckpt_{epoch+1}'
    if config.finetuning.lora.enabled:
        q_phi.backbone.save_pretrained(f"{ckpt_path}/lora")
    else:
        torch.save(q_phi.state_dict(), f"{ckpt_path}/model.pth")
    # Save f_psi
    torch.save(f_psi, f"{ckpt_path}/f_psi.pth")
    # Save optimizer state
    torch.save(optimizer.state_dict(), f"{ckpt_path}/optimizer.pth")
    run.log({"ckpt_path": ckpt_path}, step=epoch+1)

    loss_trace.append(avg_loss)
    reward_trace.append(avg_rewards)
        
    # If BOTH loss and reward stop imporving, then stop training
    if (
        min(loss_trace) < min(loss_trace[-config.finetuning.patience:]) and 
        max(reward_trace) > max(reward_trace[-config.finetuning.patience:])
    ):
        break

  with torch.cuda.amp.autocast(dtype=torch.float32):
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
  with torch.cuda.amp.autocast(enabled=False):
  with torch.cuda.amp.autocast(enabled=False):


Batch 1/5, Loss: 32761.004670619965, Reward (avg): -8.938025665283204 KL Loss: 1013.1174644231796
Batch 2/5, Loss: 53212.31704568863, Reward (avg): -9.035257720947266 KL Loss: 933.1049040555954


KeyboardInterrupt: 

In [None]:
run.finish()