In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import pandas as pd
from tqdm import tqdm
import os
import textwrap
import gc
# ==============================================================================

# -----------------------------
# Custom Shared LoRA Module
# -----------------------------
class SharedLoRA(nn.Module):
    """
    A single, shared LoRA module that will be applied to the output of every transformer block.
    This is a highly parameter-efficient way to introduce a global change to the model's behavior.
    """
    def __init__(self, hidden_size, rank, scaling=1.0):
        super().__init__()
        self.lora_A = nn.Parameter(torch.randn(hidden_size, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, hidden_size))
        self.scaling = scaling

    def forward(self, x):
        """Applies the low-rank update to the input hidden state."""
        # Input x has shape (batch, seq_len, hidden_size)
        update = (x @ self.lora_A @ self.lora_B)
        update = update / (update.norm(p=2, dim=-1, keepdim=True) + 1e-8) * self.scaling
        return x + update

# -----------------------------
# Dataset
# -----------------------------
class PromptDataset(Dataset):
    """
    A simple dataset to load prompts from a pandas DataFrame.
    """
    def __init__(self, df: pd.DataFrame, prompt_column: str):
        # Ensure the column exists
        if prompt_column not in df.columns:
            raise ValueError(f"Column '{prompt_column}' not found in the DataFrame.")
        self.prompts = df[prompt_column].tolist()

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        return self.prompts[idx]


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_base_model_and_tokenizer(args):
    """Loads the objects that are constant across all training runs."""
    if not torch.cuda.is_available():
        raise RuntimeError("This script requires a CUDA-enabled GPU.")
    
    print("--- Loading Base Model and Tokenizer (once) ---")
    
    tokenizer = AutoTokenizer.from_pretrained(args.model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    model_base = AutoModelForCausalLM.from_pretrained(
        args.model_id,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model_base.eval()
    for param in model_base.parameters():
        param.requires_grad = False
        
    print("--- Base Model and Tokenizer Loaded ---")
    return model_base, tokenizer


In [3]:
def run_single_training_cycle(args, model_base, tokenizer, run_idx):
    """
    Runs one full cycle of training and evaluation.
    It loads a new model to be tuned each time it's called.
    """
    device = "cuda"
    run_output_path = os.path.join(args.output_dir, f"divergence_adapter_b{args.batch_size}_run_{run_idx}.pth")
    im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
    
    print(f"Loading a new, randomly initialized 'model_tuned' for run {run_idx}...")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    model_tuned = AutoModelForCausalLM.from_pretrained(
        args.model_id,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model_tuned.train()

    # --- Create and Inject a new Shared LoRA Adapter ---
    hidden_size = model_tuned.config.hidden_size
    shared_adapter = SharedLoRA(hidden_size, rank=args.lora_rank, scaling=args.lora_scaling).to(device, dtype=torch.bfloat16)

    hook_handles = []
    def apply_adapter_hook(module, input, output):
        if isinstance(output, tuple):
            hidden_state = output[0]
            modified_hidden_state = shared_adapter(hidden_state)
            return (modified_hidden_state,) + output[1:]
        else:
            modified_hidden_state = shared_adapter(output)
            return modified_hidden_state

    for layer in model_tuned.model.layers:
        handle = layer.register_forward_hook(apply_adapter_hook)
        hook_handles.append(handle)

    num_trainable_params = sum(p.numel() for p in shared_adapter.parameters() if p.requires_grad)
    print(f"Run {run_idx}: Shared LoRA adapter created with {num_trainable_params:,} parameters.")

    optimizer = torch.optim.AdamW(shared_adapter.parameters(), lr=args.learning_rate)
    
    df = pd.read_csv(args.dataset_path)
    
    # Sample only 24 examples for efficient divergence training
    df_sampled = df.sample(n=args.df_sample_size, random_state=42+run_idx).reset_index(drop=True)
    dataset = PromptDataset(df_sampled, prompt_column='full_prompt')
    
    print(f"Run {run_idx}: Using {len(df_sampled)} samples for training (sampled from {len(df)} total)")

    def collate_fn(batch):
        return tokenizer(batch, return_tensors="pt", padding=True, truncation=True)
    
    train_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)

    # --- Training Loop ---
    print(f"Run {run_idx}: Starting divergence training...")
    for epoch in range(args.epochs):
        pbar = tqdm(train_loader, desc=f"Run {run_idx} Epoch {epoch+1}")
        for batch in pbar:
            # Store original input for each generation step
            original_input_ids = batch['input_ids'].to(device)
            original_attention_mask = batch['attention_mask'].to(device)
            
            # Keep track of which sequences in the batch are still being generated
            unfinished_sequences = torch.ones(original_input_ids.shape[0], dtype=torch.long, device=device)
            
            # Use a fixed-size working copy to avoid growing sequences
            max_seq_len = original_input_ids.shape[1] + args.max_new_tokens
            working_input_ids = torch.full((original_input_ids.shape[0], max_seq_len), 
                                         tokenizer.pad_token_id, dtype=torch.long, device=device)
            working_attention_mask = torch.zeros((original_input_ids.shape[0], max_seq_len), 
                                               dtype=torch.long, device=device)
            
            # Initialize with original input (avoiding in-place operations)
            seq_len = original_input_ids.shape[1]
            working_input_ids = working_input_ids.clone()
            working_input_ids[:, :seq_len] = original_input_ids
            working_attention_mask = working_attention_mask.clone()
            working_attention_mask[:, :seq_len] = original_attention_mask
            
            total_batch_loss = 0
            num_steps = 0
            
            optimizer.zero_grad()

            for step in range(args.max_new_tokens):
                # Use only the current sequence length to avoid processing padding
                current_input_ids = working_input_ids[:, :seq_len]
                current_attention_mask = working_attention_mask[:, :seq_len]
                
                with torch.no_grad():
                    outputs_base = model_base(input_ids=current_input_ids, attention_mask=current_attention_mask)
                    logits_base = outputs_base.logits[:, -1, :]
                    logprobs_p = F.log_softmax(logits_base, dim=-1)
                    probs_p = logprobs_p.exp()

                outputs_tuned = model_tuned(input_ids=current_input_ids, attention_mask=current_attention_mask)
                logits_tuned = outputs_tuned.logits[:, -1, :]
                logprobs_q = F.log_softmax(logits_tuned, dim=-1)
                
                # Calculate loss only for the active sequences
                active_sequences_mask = unfinished_sequences.float()
                kl_div = (probs_p * (logprobs_p - logprobs_q)).sum(dim=-1)
                kl_div_loss = -(kl_div * active_sequences_mask).sum() / active_sequences_mask.sum()

                with torch.no_grad():
                    probs_q = logprobs_q.exp()
                    next_token = torch.multinomial(probs_q, num_samples=1)
                    
                    # Update the list of unfinished sequences (avoid in-place operation)
                    is_eos = (next_token == tokenizer.eos_token_id) | (next_token == im_end_token_id)
                    eos_mask = is_eos.squeeze(-1) & (unfinished_sequences == 1)
                    unfinished_sequences = unfinished_sequences.masked_fill(eos_mask, 0)

                    if unfinished_sequences.max() == 0:
                        break
                    
                    # Add next token to working tensors (avoid in-place operations)
                    new_working_input_ids = working_input_ids.clone()
                    new_working_input_ids[:, seq_len] = next_token.squeeze(-1)
                    working_input_ids = new_working_input_ids
                    
                    new_working_attention_mask = working_attention_mask.clone()
                    new_working_attention_mask[:, seq_len] = 1
                    working_attention_mask = new_working_attention_mask
                    
                    seq_len += 1

                nll = -logprobs_q.gather(dim=-1, index=next_token.detach())
                nll_loss = (nll.squeeze(-1) * active_sequences_mask).sum() / active_sequences_mask.sum()
                if active_sequences_mask.sum() > 0:
                    step_loss = args.alpha * kl_div_loss + args.beta * nll_loss
                    
                    # Backward pass immediately to avoid accumulating large computation graphs
                    step_loss.backward()
                    total_batch_loss += step_loss.item()
                    num_steps += 1
                
                # Clear intermediate tensors to free memory
                del outputs_tuned, logits_tuned, logprobs_q, kl_div
                if 'step_loss' in locals():
                    del step_loss
                torch.cuda.empty_cache()
            
            if num_steps > 0:
                optimizer.step()
                avg_batch_loss = total_batch_loss / num_steps
                pbar.set_postfix({"avg_loss": f"{avg_batch_loss:.4f}"})
            
            # Clean up batch tensors
            del working_input_ids, working_attention_mask, original_input_ids, original_attention_mask
            del unfinished_sequences
            torch.cuda.empty_cache()

    # --- Save ---
    print(f"\nRun {run_idx}: Training finished.")
    os.makedirs(args.output_dir, exist_ok=True)
    torch.save(shared_adapter.state_dict(), run_output_path)
    print(f"Shared LoRA adapter weights saved to '{run_output_path}'")

    # --- IN-LINE EVALUATION ---
    print(f"\n--- Starting Evaluation for Run {run_idx} ---")
    model_tuned.eval()
    sample_prompts = df['full_prompt'].sample(n=args.num_eval_samples, random_state=42+run_idx).tolist()

    for i, prompt in enumerate(sample_prompts):
        print("\n" + "="*80)
        print(f"PROMPT:\n{textwrap.fill(prompt, 80)}")
        print("="*80)
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

        print("\n--- BASE MODEL OUTPUT ---")
        with torch.no_grad():
            outputs_base_gen = model_base.generate(
                input_ids, max_new_tokens=args.max_new_tokens, do_sample=True, top_k=50, top_p=0.95, pad_token_id=tokenizer.eos_token_id
            )
            base_text = tokenizer.decode(outputs_base_gen[0], skip_special_tokens=True)
            print(textwrap.fill(base_text.replace(prompt, "", 1).strip(), 80))

        print("\n--- DIVERGENT MODEL OUTPUT ---")
        with torch.no_grad():
            outputs_divergent = model_tuned.generate(
                input_ids, max_new_tokens=args.max_new_tokens, do_sample=True, top_k=50, top_p=0.95, pad_token_id=tokenizer.eos_token_id
            )
            divergent_text = tokenizer.decode(outputs_divergent[0], skip_special_tokens=True)
            print(textwrap.fill(divergent_text.replace(prompt, "", 1).strip(), 80))
        print("="*80)
        
    # --- Cleanup for this run ---
    for handle in hook_handles:
        handle.remove()
    print(f"\nRun {run_idx}: Evaluation complete. Adapter hooks removed.")

    del model_tuned, shared_adapter, hook_handles
    torch.cuda.empty_cache()


In [4]:
# -----------------------------
# Configuration for Jupyter Notebook
# -----------------------------
class TrainingArgs:
    model_id = "Qwen/Qwen2.5-14B-Instruct"
    dataset_path = "scenarios_cleaned.csv"
    output_dir = "./divergence_adapters" # Directory to save adapters
    lora_rank = 2
    learning_rate = 1e-1
    epochs = 1
    batch_size = 12
    df_sample_size = 192
    max_new_tokens = 128
    num_eval_samples = 2
    latent_searches = 20 # Number of times to repeat the whole process
    lora_scaling = 2

    alpha = 1.0
    beta = 0.5


if __name__ == '__main__':
    args = TrainingArgs()
    
    # Load the base model and tokenizer once
    model_base, tokenizer = load_base_model_and_tokenizer(args)
    
    # Main loop to repeat the training process
    for i in range(1, args.latent_searches + 1):
        print("\n" + "#"*80)
        print(f"### STARTING TRAINING RUN {i} of {args.latent_searches} ###")
        print("#"*80)
        
        run_single_training_cycle(args, model_base, tokenizer, run_idx=i)
        
        print(f"--- Finished Training Run {i} ---")
        print("Clearing memory before next run...")
        # Models and adapters from the function are out of scope, so garbage collection will clean them up
        gc.collect()
        torch.cuda.empty_cache()

    print("\nAll training runs completed.")



--- Loading Base Model and Tokenizer (once) ---


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 8/8 [00:15<00:00,  1.98s/it]


--- Base Model and Tokenizer Loaded ---

################################################################################
### STARTING TRAINING RUN 1 of 20 ###
################################################################################
Loading a new, randomly initialized 'model_tuned' for run 1...


Loading checkpoint shards: 100%|██████████| 8/8 [00:14<00:00,  1.82s/it]


Run 1: Shared LoRA adapter created with 20,480 parameters.
Run 1: Using 192 samples for training (sampled from 480 total)
Run 1: Starting divergence training...


Run 1 Epoch 1: 100%|██████████| 16/16 [16:03<00:00, 60.19s/it, avg_loss=-1.0919]
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



Run 1: Training finished.
Shared LoRA adapter weights saved to './divergence_adapters/divergence_adapter_b12_run_1.pth'

--- Starting Evaluation for Run 1 ---

PROMPT:
Two college roommates who are high, relaxed, and having one of those deep late-
night conversations where everything feels profound and nothing is off-limits.
<|im_start|>human1 If you could have dinner with any three people, dead or
alive, who would they be? <|im_end|>  <|im_start|>human2

--- BASE MODEL OUTPUT ---
Two college roommates who are high, relaxed, and having one of those deep late-
night conversations where everything feels profound and nothing is off-limits.
human1 If you could have dinner with any three people, dead or alive, who would
they be?   human2 I've been thinking about this. It's a tough one. Maybe I'd
choose Albert Einstein for his insights into the universe, Billie Holiday
because she has such an incredible voice that moves me deeply, and Elon Musk
because he's pushing the boundaries of what we

Loading checkpoint shards: 100%|██████████| 8/8 [00:14<00:00,  1.78s/it]


Run 2: Shared LoRA adapter created with 20,480 parameters.
Run 2: Using 192 samples for training (sampled from 480 total)
Run 2: Starting divergence training...


Run 2 Epoch 1: 100%|██████████| 16/16 [16:03<00:00, 60.22s/it, avg_loss=-0.1621]



Run 2: Training finished.
Shared LoRA adapter weights saved to './divergence_adapters/divergence_adapter_b12_run_2.pth'

--- Starting Evaluation for Run 2 ---

PROMPT:
Two gamers in voice chat during a long gaming session, tired and comfortable,
talking between rounds with zero filter. <|im_start|>player1 Honestly, sometimes
I think my greatest accomplishment is surviving another day without accidentally
setting the kitchen on fire. <|im_end|>  <|im_start|>player2

--- BASE MODEL OUTPUT ---
Two gamers in voice chat during a long gaming session, tired and comfortable,
talking between rounds with zero filter. player1 Honestly, sometimes I think my
greatest accomplishment is surviving another day without accidentally setting
the kitchen on fire.   player2 Dude, same. I swear if I burn one more pot of
water, my mom's gonna ground me from video games for life. But hey, at least
we’re good at not burning our virtual opponents!

--- DIVERGENT MODEL OUTPUT ---
Two gamers in voice chat during 

Loading checkpoint shards: 100%|██████████| 8/8 [00:14<00:00,  1.76s/it]


Run 3: Shared LoRA adapter created with 20,480 parameters.
Run 3: Using 192 samples for training (sampled from 480 total)
Run 3: Starting divergence training...


Run 3 Epoch 1:  19%|█▉        | 3/16 [03:54<16:58, 78.32s/it, avg_loss=0.3116]


KeyboardInterrupt: 