In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import pandas as pd
from tqdm import tqdm
import os

# -----------------------------
# Custom Shared LoRA Module
# -----------------------------
class SharedLoRA(nn.Module):
    """
    A single, shared LoRA module that will be applied to the output of every transformer block.
    This is a highly parameter-efficient way to introduce a global change to the model's behavior.
    """
    def __init__(self, hidden_size, rank, scaling=1.0):
        super().__init__()
        self.lora_A = nn.Parameter(torch.randn(hidden_size, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, hidden_size))
        self.scaling = scaling
        
        # Initialize A with Kaiming uniform for better stability
        nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)

    def forward(self, x):
        """Applies the low-rank update to the input hidden state."""
        # Input x has shape (batch, seq_len, hidden_size)
        update = (x @ self.lora_A @ self.lora_B) * self.scaling
        return x + update

# -----------------------------
# Dataset
# -----------------------------
class PromptDataset(Dataset):
    """
    A simple dataset to load prompts from a pandas DataFrame.
    """
    def __init__(self, df: pd.DataFrame, prompt_column: str):
        # Ensure the column exists
        if prompt_column not in df.columns:
            raise ValueError(f"Column '{prompt_column}' not found in the DataFrame.")
        self.prompts = df[prompt_column].tolist()

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        return self.prompts[idx]

# -----------------------------
# Main Training Function
# -----------------------------
def train(args):
    """
    Main function to run the divergence training process.
    """
    # --- Setup ---
    if not torch.cuda.is_available():
        raise RuntimeError("This script requires a CUDA-enabled GPU.")
    
    device = "cuda"
    print(f"Using device: {device}")

    # --- Load Tokenizer ---
    print(f"Loading tokenizer for '{args.model_id}'...")
    tokenizer = AutoTokenizer.from_pretrained(args.model_id)
    # Qwen models don't have a default pad token, so we use the EOS token.
    # This is fine for single-sequence processing.
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Get the ID for the special chat EOS token
    im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")

    # --- Quantization Config ---
    # Use 4-bit quantization on your H100 for memory efficiency
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    # --- Load Models ---
    print(f"Loading base model (frozen): '{args.model_id}'...")
    model_base = AutoModelForCausalLM.from_pretrained(
        args.model_id,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model_base.eval()
    for param in model_base.parameters():
        param.requires_grad = False

    print(f"Loading model to be tuned: '{args.model_id}'...")
    model_tuned = AutoModelForCausalLM.from_pretrained(
        args.model_id,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model_tuned.train() # Set to training mode

    # --- Create and Inject Shared LoRA Adapter ---
    hidden_size = model_tuned.config.hidden_size
    shared_adapter = SharedLoRA(hidden_size, rank=args.lora_rank).to(device, dtype=torch.bfloat16)

    # A list to hold hook handles
    hook_handles = []
    def apply_adapter_hook(module, input, output):
        # The output of a transformer block can be a tuple or a single tensor.
        # We handle both cases to avoid the TypeError.
        if isinstance(output, tuple):
            # Case 1: Output is a tuple (e.g., with past_key_values)
            hidden_state = output[0]
            modified_hidden_state = shared_adapter(hidden_state)
            return (modified_hidden_state,) + output[1:]
        else:
            # Case 2: Output is a single tensor (the hidden state)
            modified_hidden_state = shared_adapter(output)
            return modified_hidden_state

    # Attach the hook to the output of every transformer block
    for layer in model_tuned.model.layers:
        handle = layer.register_forward_hook(apply_adapter_hook)
        hook_handles.append(handle)

    num_trainable_params = sum(p.numel() for p in shared_adapter.parameters() if p.requires_grad)
    print(f"Shared LoRA adapter created and attached.")
    print(f"Number of trainable parameters: {num_trainable_params:,}")

    # --- Optimizer ---
    # The optimizer will only "see" and update the parameters of our shared adapter
    optimizer = torch.optim.AdamW(shared_adapter.parameters(), lr=args.learning_rate)

    # --- Load Dataset ---
    print(f"Loading dataset from '{args.dataset_path}'...")
    df = pd.read_csv(args.dataset_path)
    dataset = PromptDataset(df, prompt_column='full_prompt')
    # Using a batch size of 1 as per your loader example
    train_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    # --- Training Loop ---
    print("Starting divergence training...")
    for epoch in range(args.epochs):
        print(f"\n--- Epoch {epoch + 1}/{args.epochs} ---")
        
        # Use tqdm for a nice progress bar
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        for batch in pbar:
            # Our DataLoader yields a list of prompts (for batch_size > 1)
            # or a single item in a list (for batch_size = 1)
            prompt_text = batch[0]

            # Tokenize the initial prompt
            input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids.to(device)
            
            total_loss_for_prompt = 0
            
            # Autoregressive generation loop
            for step in range(args.max_new_tokens):
                # 1. Generate logprobs from the BASE model (P) - no gradients needed
                with torch.no_grad():
                    outputs_base = model_base(input_ids)
                    logits_base = outputs_base.logits[:, -1, :]
                    logprobs_p = F.log_softmax(logits_base, dim=-1)

                # 2. Generate logprobs from the TUNED model (Q) - requires gradients
                outputs_tuned = model_tuned(input_ids)
                logits_tuned = outputs_tuned.logits[:, -1, :]
                logprobs_q = F.log_softmax(logits_tuned, dim=-1)
                
                # 3. Calculate KL Divergence and Loss
                # D_KL(P || Q) = sum(P * (logP - logQ))
                # We use .detach() on probs_p to ensure no gradients flow back to the base model
                probs_p = logprobs_p.exp().detach()
                kl_div = (probs_p * (logprobs_p.detach() - logprobs_q)).sum(dim=-1)
                
                # We want to MAXIMIZE KL divergence, so we MINIMIZE its negative
                loss = -kl_div
                total_loss_for_prompt += loss.item()

                # Backpropagate the loss to update the shared_adapter
                loss.backward()

                # 4. Sample the next token from the BASE model's distribution
                with torch.no_grad():
                    next_token = torch.multinomial(probs_p, num_samples=1)

                # 5. Append the sampled token and continue for the next step
                input_ids = torch.cat([input_ids, next_token], dim=-1)

                # 6. Check for EOS token
                if next_token.item() in [tokenizer.eos_token_id, im_end_token_id]:
                    break
            
            # --- Update Weights ---
            # Update weights after processing the full sequence for this prompt
            optimizer.step()
            optimizer.zero_grad()
            
            avg_loss = total_loss_for_prompt / (step + 1)
            pbar.set_postfix({"avg_loss": f"{avg_loss:.4f}"})

    # --- Cleanup and Save ---
    # It's good practice to remove hooks when they are no longer needed
    for handle in hook_handles:
        handle.remove()
        
    print("\nTraining finished.")
    # Ensure output directory exists, handling the case where the path is just a filename
    output_dir = os.path.dirname(args.output_path)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        
    torch.save(shared_adapter.state_dict(), args.output_path)
    print(f"Shared LoRA adapter weights saved to '{args.output_path}'")

    model_tuned.eval()
    


# -----------------------------
# Configuration for Jupyter Notebook
# -----------------------------
class TrainingArgs:
    model_id = "Qwen/Qwen2.5-14B-Instruct"
    dataset_path = "scenarios_cleaned.csv"
    output_path = "./divergence_adapter.pth"
    lora_rank = 2
    learning_rate = 1e-4
    epochs = 1
    batch_size = 256
    max_new_tokens = 128

# --- Instantiate config and run training ---
# You can now easily modify the parameters above in your notebook cell
# before running the training.
args = TrainingArgs()
train(args)



Using device: cuda
Loading tokenizer for 'Qwen/Qwen2.5-14B-Instruct'...
Loading base model (frozen): 'Qwen/Qwen2.5-14B-Instruct'...


Loading checkpoint shards: 100%|██████████| 8/8 [00:14<00:00,  1.77s/it]


Loading model to be tuned: 'Qwen/Qwen2.5-14B-Instruct'...


Loading checkpoint shards: 100%|██████████| 8/8 [00:14<00:00,  1.83s/it]


Shared LoRA adapter created and attached.
Number of trainable parameters: 20,480
Loading dataset from 'scenarios_cleaned.csv'...
Starting divergence training...

--- Epoch 1/1 ---


Epoch 1: 100%|██████████| 2/2 [00:45<00:00, 22.60s/it, avg_loss=-0.5051]


Training finished.
Shared LoRA adapter weights saved to './divergence_adapter.pth'



