In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import pandas as pd
from tqdm import tqdm
import os
import textwrap

# -----------------------------
# Custom Shared LoRA Module
# -----------------------------
class SharedLoRA(nn.Module):
    """
    A single, shared LoRA module that will be applied to the output of every transformer block.
    This is a highly parameter-efficient way to introduce a global change to the model's behavior.
    """
    def __init__(self, hidden_size, rank, scaling=1.0):
        super().__init__()
        self.lora_A = nn.Parameter(torch.randn(hidden_size, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, hidden_size))
        self.scaling = scaling
        
        # Initialize A with Kaiming uniform for better stability
        nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)

    def forward(self, x):
        """Applies the low-rank update to the input hidden state with L2 normalization."""
        # Compute low-rank update
        update = (x @ self.lora_A @ self.lora_B) * self.scaling

        # L2 normalization across hidden_size dimension
        update = update / (update.norm(p=2, dim=-1, keepdim=True) + 1e-12)

        return x + update

# -----------------------------
# Dataset
# -----------------------------
class PromptDataset(Dataset):
    """
    A simple dataset to load prompts from a pandas DataFrame.
    """
    def __init__(self, df: pd.DataFrame, prompt_column: str):
        # Ensure the column exists
        if prompt_column not in df.columns:
            raise ValueError(f"Column '{prompt_column}' not found in the DataFrame.")
        self.prompts = df[prompt_column].tolist()

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        return self.prompts[idx]

# -----------------------------
# Main Training & Evaluation Function
# -----------------------------
def train_and_evaluate(args):
    """
    Main function to run the divergence training process and then evaluate the results.
    """
    # --- Setup ---
    if not torch.cuda.is_available():
        raise RuntimeError("This script requires a CUDA-enabled GPU.")
    
    device = "cuda"
    print(f"Using device: {device}")

    # --- Load Tokenizer ---
    print(f"Loading tokenizer for '{args.model_id}'...")
    tokenizer = AutoTokenizer.from_pretrained(args.model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")

    # --- Quantization Config ---
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    # --- Load Models ---
    print(f"Loading base model (frozen): '{args.model_id}'...")
    model_base = AutoModelForCausalLM.from_pretrained(
        args.model_id,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model_base.eval()
    for param in model_base.parameters():
        param.requires_grad = False

    print(f"Loading model to be tuned: '{args.model_id}'...")
    model_tuned = AutoModelForCausalLM.from_pretrained(
        args.model_id,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model_tuned.train()

    # --- Create and Inject Shared LoRA Adapter ---
    hidden_size = model_tuned.config.hidden_size
    shared_adapter = SharedLoRA(hidden_size, rank=args.lora_rank).to(device, dtype=torch.bfloat16)

    hook_handles = []
    def apply_adapter_hook(module, input, output):
        if isinstance(output, tuple):
            hidden_state = output[0]
            modified_hidden_state = shared_adapter(hidden_state)
            return (modified_hidden_state,) + output[1:]
        else:
            modified_hidden_state = shared_adapter(output)
            return modified_hidden_state

    for layer in model_tuned.model.layers:
        handle = layer.register_forward_hook(apply_adapter_hook)
        hook_handles.append(handle)

    num_trainable_params = sum(p.numel() for p in shared_adapter.parameters() if p.requires_grad)
    print(f"Shared LoRA adapter created and attached.")
    print(f"Number of trainable parameters: {num_trainable_params:,}")

    # --- Optimizer ---
    optimizer = torch.optim.AdamW(shared_adapter.parameters(), lr=args.learning_rate)

    # --- Load Dataset ---
    print(f"Loading dataset from '{args.dataset_path}'...")
    df = pd.read_csv(args.dataset_path)
    dataset = PromptDataset(df, prompt_column='full_prompt')
    train_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    # --- Training Loop ---
    print("Starting divergence training...")
    for epoch in range(args.epochs):
        print(f"\n--- Epoch {epoch + 1}/{args.epochs} ---")
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        for batch in pbar:
            prompt_text = batch[0]
            input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids.to(device)
            
            total_loss_for_prompt = 0
            
            for step in range(args.max_new_tokens):
                with torch.no_grad():
                    outputs_base = model_base(input_ids)
                    logits_base = outputs_base.logits[:, -1, :]
                    logprobs_p = F.log_softmax(logits_base, dim=-1)

                outputs_tuned = model_tuned(input_ids)
                logits_tuned = outputs_tuned.logits[:, -1, :]
                logprobs_q = F.log_softmax(logits_tuned, dim=-1)
                
                probs_p = logprobs_p.exp().detach()
                kl_div = (probs_p * (logprobs_p.detach() - logprobs_q)).sum(dim=-1)
                
                loss = -kl_div
                total_loss_for_prompt += loss.item()

                loss.backward()

                with torch.no_grad():
                    next_token = torch.multinomial(probs_p, num_samples=1)

                input_ids = torch.cat([input_ids, next_token], dim=-1)

                if next_token.item() in [tokenizer.eos_token_id, im_end_token_id]:
                    break
            
            optimizer.step()
            optimizer.zero_grad()
            
            avg_loss = total_loss_for_prompt / (step + 1)
            pbar.set_postfix({"avg_loss": f"{avg_loss:.4f}"})

    # --- Save ---
    print("\nTraining finished.")
    output_dir = os.path.dirname(args.output_path)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        
    torch.save(shared_adapter.state_dict(), args.output_path)
    print(f"Shared LoRA adapter weights saved to '{args.output_path}'")

    # --- IN-LINE EVALUATION ---
    print("\n--- Starting Evaluation ---")
    model_tuned.eval() # Switch tuned model to evaluation mode
    
    sample_prompts = df['full_prompt'].sample(n=args.num_eval_samples, random_state=42).tolist()

    for i, prompt in enumerate(sample_prompts):
        print("\n" + "="*80)
        print(f"SAMPLE {i+1}/{args.num_eval_samples}")
        print(f"PROMPT:\n{textwrap.fill(prompt, 80)}")
        print("="*80)
        
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)

        # --- Base Model Generation ---
        print("\n--- BASE MODEL OUTPUT ---")
        with torch.no_grad():
            outputs_base = model_base.generate(
                input_ids, 
                max_new_tokens=args.max_new_tokens,
                do_sample=True, top_k=50, top_p=0.95,
                pad_token_id=tokenizer.eos_token_id
            )
            base_text = tokenizer.decode(outputs_base[0], skip_special_tokens=True)
            print(textwrap.fill(base_text.replace(prompt, "", 1).strip(), 80))

        # --- Divergent Model Generation ---
        # model_tuned still has the adapter hooks attached
        print("\n--- DIVERGENT MODEL OUTPUT ---")
        with torch.no_grad():
            outputs_divergent = model_tuned.generate(
                input_ids, 
                max_new_tokens=args.max_new_tokens,
                do_sample=True, top_k=50, top_p=0.95,
                pad_token_id=tokenizer.eos_token_id
            )
            divergent_text = tokenizer.decode(outputs_divergent[0], skip_special_tokens=True)
            print(textwrap.fill(divergent_text.replace(prompt, "", 1).strip(), 80))
        print("="*80)
        
    # --- Cleanup Hooks ---
    for handle in hook_handles:
        handle.remove()
    print("\nEvaluation complete. Adapter hooks removed.")


# -----------------------------
# Configuration for Jupyter Notebook
# -----------------------------
class TrainingArgs:
    model_id = "Qwen/Qwen2.5-14B-Instruct"
    dataset_path = "scenarios_cleaned.csv"
    output_path = "./divergence_adapter.pth"
    lora_rank = 1
    learning_rate = 1e-4
    epochs = 32
    # Batch size must be 1 for this autoregressive training logic
    batch_size = 256
    max_new_tokens = 128
    num_eval_samples = 5 # Number of samples to show for evaluation

# --- Instantiate config and run training + evaluation ---
args = TrainingArgs()
train_and_evaluate(args)



Using device: cuda
Loading tokenizer for 'Qwen/Qwen2.5-14B-Instruct'...


Loading base model (frozen): 'Qwen/Qwen2.5-14B-Instruct'...


Loading checkpoint shards: 100%|██████████| 8/8 [00:14<00:00,  1.78s/it]


Loading model to be tuned: 'Qwen/Qwen2.5-14B-Instruct'...


Loading checkpoint shards: 100%|██████████| 8/8 [00:14<00:00,  1.77s/it]


Shared LoRA adapter created and attached.
Number of trainable parameters: 20,480
Loading dataset from 'scenarios_cleaned.csv'...
Starting divergence training...

--- Epoch 1/1 ---


Epoch 1: 100%|██████████| 2/2 [00:45<00:00, 22.80s/it, avg_loss=-0.3189]
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



Training finished.
Shared LoRA adapter weights saved to './divergence_adapter.pth'

--- Starting Evaluation ---

SAMPLE 1/5
PROMPT:
Someone in therapy finally being completely honest with themselves, internal
thoughts they've never voiced, breakthrough moments of self-awareness.
<|im_start|>therapist What's a goal you've abandoned, and what's the real reason
you gave up on it? <|im_end|>  <|im_start|>patient_thoughts

--- BASE MODEL OUTPUT ---
Someone in therapy finally being completely honest with themselves, internal
thoughts they've never voiced, breakthrough moments of self-awareness. therapist
What's a goal you've abandoned, and what's the real reason you gave up on it?
patient_thoughts I think about my dream to become a professional dancer. I
always believed that was my calling, but deep down, I was terrified of not
making it big enough, of the rigorous physical demands, and the constant
scrutiny from others. I told myself I didn't have the time or the space to
pursue it, but re