In [7]:
# Extract direction vectors from LoRA adapters - MEMORY EFFICIENT VERSION
# Accumulate sums directly on GPU without storing individual vectors

import sys
import os
sys.path.append(os.path.abspath('..'))

from train_commons import SharedLoRA, PromptDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import pandas as pd


In [8]:
# Create a batched LoRA class for parallel processing
class BatchedSharedLoRA(torch.nn.Module):
    """
    Batched version of SharedLoRA that processes multiple adapters in parallel
    by concatenating their parameters and using batch operations.
    """
    def __init__(self, hidden_size, rank, scaling, num_adapters):
        super().__init__()
        self.hidden_size = hidden_size
        self.rank = rank
        self.scaling = scaling
        self.num_adapters = num_adapters
        
        # Stack all lora_A and lora_B parameters into batched tensors
        # Shape: (num_adapters, hidden_size, rank) and (num_adapters, rank, hidden_size)
        self.batched_lora_A = torch.nn.Parameter(torch.randn(num_adapters, hidden_size, rank))
        self.batched_lora_B = torch.nn.Parameter(torch.zeros(num_adapters, rank, hidden_size))
    
    def load_from_individual_adapters(self, adapter_paths, device):
        """Load parameters from individual adapter checkpoints"""
        lora_A_list = []
        lora_B_list = []
        
        for ckpt_path in adapter_paths:
            state_dict = torch.load(ckpt_path, map_location=device)
            lora_A_list.append(state_dict['lora_A'])
            lora_B_list.append(state_dict['lora_B'])
        
        # Stack into batched tensors
        self.batched_lora_A.data = torch.stack(lora_A_list, dim=0)
        self.batched_lora_B.data = torch.stack(lora_B_list, dim=0)
    
    def forward(self, x):
        """
        Apply all adapters in parallel using batched operations
        Args:
            x: Input tensor of shape (batch, seq_len, hidden_size)
        Returns:
            Tensor of modified outputs, shape (num_adapters, batch, seq_len, hidden_size)
        """
        # x shape: (batch, seq_len, hidden_size)
        # We need to compute x @ lora_A @ lora_B for each adapter
        
        # Expand x to work with batched operations
        # x_expanded shape: (num_adapters, batch, seq_len, hidden_size)
        x_expanded = x.unsqueeze(0).expand(self.num_adapters, -1, -1, -1)
        
        # Batched matrix multiplication: (num_adapters, batch, seq_len, hidden_size) @ (num_adapters, hidden_size, rank)
        # Result shape: (num_adapters, batch, seq_len, rank)
        intermediate = torch.einsum('nbsh,nhr->nbsr', x_expanded, self.batched_lora_A)
        
        # Second matrix multiplication: (num_adapters, batch, seq_len, rank) @ (num_adapters, rank, hidden_size)
        # Result shape: (num_adapters, batch, seq_len, hidden_size)
        updates = torch.einsum('nbsr,nrh->nbsh', intermediate, self.batched_lora_B)
        
        # Normalize updates
        update_norms = torch.norm(updates, p=2, dim=-1, keepdim=True) + 1e-8
        updates = updates / update_norms * self.scaling
        
        # Add updates to original input
        # x_expanded + updates, shape: (num_adapters, batch, seq_len, hidden_size)
        modified_outputs = x_expanded + updates
        
        return modified_outputs


In [None]:
def save_direction_vectors(accumulated_diffs, token_counts, phase, ckpt_idx_to_name, hidden_size, num_layers=48):
    """Extract and save final direction vectors for each checkpoint"""
    
    # Create directions directory if it doesn't exist
    directions_dir = f"../directions/phase{phase}"
    os.makedirs(directions_dir, exist_ok=True)
    
    saved_files = []
    
    # Process each checkpoint
    for ckpt_idx in range(len(ckpt_idx_to_name)):
        ckpt_name = ckpt_idx_to_name[ckpt_idx]
        
        # Extract direction vectors for this checkpoint across all layers
        layer_directions = []
        
        for layer_idx in range(num_layers):
            if token_counts[layer_idx] > 0:
                # Calculate average: accumulated_sum / total_tokens
                avg_direction = accumulated_diffs[layer_idx][:, ckpt_idx] / token_counts[layer_idx]
                layer_directions.append(avg_direction)
            else:
                # Create zero vector if no data for this layer
                layer_directions.append(torch.zeros(hidden_size, device=accumulated_diffs[0].device))
        
        # Stack all layer directions into a single tensor
        # Shape: (num_layers, hidden_size)
        all_directions = torch.stack(layer_directions, dim=0)
        
        # Move to CPU and save as torch tensor file
        filename = f"model_phase{phase}_{ckpt_name}.pt"
        filepath = os.path.join(directions_dir, filename)
        torch.save(all_directions.cpu(), filepath)
        
        print(f"Saved direction vectors: {filepath}")
        saved_files.append(filepath)
    
    return saved_files


In [None]:
phase = 1
checkpoint_dir = f"../divergence_adapters_phase{phase}"

# Check if checkpoint directory exists
if not os.path.exists(checkpoint_dir):
    raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}")

ckpts = os.listdir(checkpoint_dir)
# Filter for .pth files only
ckpts = [f for f in ckpts if f.endswith('.pth')]
if not ckpts:
    raise FileNotFoundError(f"No .pth checkpoint files found in {checkpoint_dir}")

ckpts = [os.path.join(checkpoint_dir, ckpt) for ckpt in ckpts]
# For testing with single checkpoint - uncomment line below:
# ckpts = ["../divergence_adapters_phase1/divergence_adapter_b12_run_6.pth"]

model_name = "Qwen/Qwen2.5-14B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
rank = 2
scaling = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Check if dataset exists
dataset_path = "../datasets/full_dataset.csv"
if not os.path.exists(dataset_path):
    raise FileNotFoundError(f"Dataset not found: {dataset_path}")

csv = pd.read_csv(dataset_path)
if csv.empty:
    raise ValueError("Dataset is empty")
if 'full_prompt' not in csv.columns:
    raise ValueError("Dataset missing required 'full_prompt' column")

full_prompts = PromptDataset(csv, prompt_column='full_prompt')

print(f"Found {len(ckpts)} checkpoints to process in parallel")
print(f"Processing {len(full_prompts)} prompts")

# Load the base model once
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    dtype=torch.bfloat16,
)
hidden_size = model.config.hidden_size
num_layers = len(model.model.layers)

# Create mapping from checkpoint index to checkpoint identifier
ckpt_idx_to_name = {}
for idx, ckpt_path in enumerate(ckpts):
    ckpt_name = os.path.basename(ckpt_path).replace('.pth', '')
    ckpt_idx_to_name[idx] = ckpt_name

print("Checkpoint mapping:")
for idx, name in ckpt_idx_to_name.items():
    print(f"  Index {idx}: {name}")

# Load ALL adapters into a single batched adapter
batched_adapter = BatchedSharedLoRA(hidden_size, rank=rank, scaling=scaling, num_adapters=len(ckpts))
batched_adapter.load_from_individual_adapters(ckpts, device)
batched_adapter = batched_adapter.to(device, dtype=torch.bfloat16)

print(f"Loaded {len(ckpts)} adapters into batched processor")

# MEMORY EFFICIENT: Initialize accumulation tensors on GPU
# Shape for each layer: (hidden_size, num_checkpoints)
accumulated_diffs = []
token_counts = []  # Track total tokens processed for each layer

for layer_idx in range(num_layers):
    # Initialize accumulation tensor for this layer (using bfloat16 to match model)
    accumulated_diffs.append(torch.zeros(hidden_size, len(ckpts), device=device, dtype=torch.bfloat16))
    token_counts.append(0)

current_layer_idx = 0

def apply_batched_adapter_hook(module, input, output):
    """Hook that applies ALL adapters in parallel and accumulates differences efficiently"""
    global current_layer_idx
    
    if isinstance(output, tuple):
        hidden_state = output[0]
    else:
        hidden_state = output
    
    # Apply all adapters in parallel using batched operations
    # modified_outputs shape: (num_adapters, batch, seq_len, hidden_size)
    modified_outputs = batched_adapter(hidden_state)
    
    # Calculate differences for each adapter
    # hidden_state shape: (batch, seq_len, hidden_size)
    # We need to expand it to match modified_outputs
    hidden_state_expanded = hidden_state.unsqueeze(0).expand(len(ckpts), -1, -1, -1)
    
    # Compute differences: (num_adapters, batch, seq_len, hidden_size)
    diff_batch = hidden_state_expanded - modified_outputs
    
    # Sum across batch and sequence dimensions to get (num_adapters, hidden_size)
    # Then transpose to (hidden_size, num_adapters) to match our accumulation tensor
    diff_sum = torch.sum(diff_batch, dim=(1, 2)).T  # (hidden_size, num_adapters)
    
    # Accumulate the differences directly on GPU
    accumulated_diffs[current_layer_idx] += diff_sum
    
    # Count tokens processed for this layer
    num_tokens = hidden_state.shape[0] * hidden_state.shape[1]  # batch_size * seq_len
    token_counts[current_layer_idx] += num_tokens
    
    # Return the original output (we don't modify the forward pass)
    return output

# Register hooks for each layer ONCE (will capture data for all adapters in parallel)
hook_handles = []
for layer_idx, layer in enumerate(model.model.layers):
    def make_hook(layer_index):
        def hook_fn(module, input, output):
            global current_layer_idx
            current_layer_idx = layer_index
            return apply_batched_adapter_hook(module, input, output)
        return hook_fn
    
    handle = layer.register_forward_hook(make_hook(layer_idx))
    hook_handles.append(handle)

print("Registered hooks for all layers")
print(f"Memory allocated: {accumulated_diffs[0].element_size() * accumulated_diffs[0].numel() * len(accumulated_diffs) / 1e9:.2f} GB")


Found 20 checkpoints to process in parallel
Processing 480 prompts


Loading checkpoint shards: 100%|██████████| 8/8 [00:05<00:00,  1.34it/s]


Checkpoint mapping:
  Index 0: divergence_adapter_b12_run_5
  Index 1: divergence_adapter_b12_run_4
  Index 2: divergence_adapter_b12_run_1
  Index 3: divergence_adapter_b12_run_10
  Index 4: divergence_adapter_b12_run_8
  Index 5: divergence_adapter_b12_run_15
  Index 6: divergence_adapter_b12_run_12
  Index 7: divergence_adapter_b12_run_16
  Index 8: divergence_adapter_b12_run_17
  Index 9: divergence_adapter_b12_run_19
  Index 10: divergence_adapter_b12_run_6
  Index 11: divergence_adapter_b12_run_13
  Index 12: divergence_adapter_b12_run_9
  Index 13: divergence_adapter_b12_run_14
  Index 14: divergence_adapter_b12_run_7
  Index 15: divergence_adapter_b12_run_18
  Index 16: divergence_adapter_b12_run_11
  Index 17: divergence_adapter_b12_run_20
  Index 18: divergence_adapter_b12_run_2
  Index 19: divergence_adapter_b12_run_3
Loaded 20 adapters into batched processor
Registered hooks for all layers
Memory allocated: 0.02 GB


In [11]:
# Process ALL prompts ONCE (collecting data for all adapters simultaneously)
print("Starting parallel processing of all adapters...")
for i, prompt in enumerate(full_prompts):
    if i % 10 == 0:
        print(f"Processing prompt {i}/{len(full_prompts)}")
        
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    # Generate tokens one by one to track per-token differences
    current_input = inputs['input_ids']
    
    for token_pos in range(50):  # max_new_tokens
        with torch.no_grad():
            outputs = model(current_input)
            next_token_logits = outputs.logits[0, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0).unsqueeze(0)
            
            if next_token.item() == im_end_token_id:
                break
                
            current_input = torch.cat([current_input, next_token], dim=1)

print("Completed processing all prompts for all adapters!")
print(f"Total tokens processed per layer: {token_counts[0]}")

# Clean up hooks
for handle in hook_handles:
    handle.remove()


Starting parallel processing of all adapters...
Processing prompt 0/480
Processing prompt 10/480
Processing prompt 20/480
Processing prompt 30/480
Processing prompt 40/480
Processing prompt 50/480
Processing prompt 60/480
Processing prompt 70/480
Processing prompt 80/480
Processing prompt 90/480
Processing prompt 100/480
Processing prompt 110/480
Processing prompt 120/480
Processing prompt 130/480
Processing prompt 140/480
Processing prompt 150/480
Processing prompt 160/480
Processing prompt 170/480
Processing prompt 180/480
Processing prompt 190/480
Processing prompt 200/480
Processing prompt 210/480
Processing prompt 220/480
Processing prompt 230/480
Processing prompt 240/480
Processing prompt 250/480
Processing prompt 260/480
Processing prompt 270/480
Processing prompt 280/480
Processing prompt 290/480
Processing prompt 300/480
Processing prompt 310/480
Processing prompt 320/480
Processing prompt 330/480
Processing prompt 340/480
Processing prompt 350/480
Processing prompt 360/480
P

In [12]:
# Extract and save direction vectors for ALL checkpoints
print("Extracting and saving direction vectors for all checkpoints...")
saved_directions = save_direction_vectors(accumulated_diffs, token_counts, phase, ckpt_idx_to_name, hidden_size)

# Clean up model and accumulated tensors
del model
del accumulated_diffs
del batched_adapter
torch.cuda.empty_cache()

print("Analysis complete! Extracted direction vectors for all checkpoints.")
print(f"Saved {len(saved_directions)} direction files to ../directions/ directory:")
for direction_path in saved_directions:
    print(f"  - {direction_path}")

# Show how to load the direction vectors
print("\nTo load direction vectors:")
print("directions = torch.load('path_to_file.pt')")
print("directions.shape  # Should be (num_layers, hidden_size)")
print("layer_0_direction = directions[0]  # Direction vector for layer 0")


Extracting and saving direction vectors for all checkpoints...
Saved direction vectors: ../directions/model_phase1_divergence_adapter_b12_run_5.pt
Saved direction vectors: ../directions/model_phase1_divergence_adapter_b12_run_4.pt
Saved direction vectors: ../directions/model_phase1_divergence_adapter_b12_run_1.pt
Saved direction vectors: ../directions/model_phase1_divergence_adapter_b12_run_10.pt
Saved direction vectors: ../directions/model_phase1_divergence_adapter_b12_run_8.pt
Saved direction vectors: ../directions/model_phase1_divergence_adapter_b12_run_15.pt
Saved direction vectors: ../directions/model_phase1_divergence_adapter_b12_run_12.pt
Saved direction vectors: ../directions/model_phase1_divergence_adapter_b12_run_16.pt
Saved direction vectors: ../directions/model_phase1_divergence_adapter_b12_run_17.pt
Saved direction vectors: ../directions/model_phase1_divergence_adapter_b12_run_19.pt
Saved direction vectors: ../directions/model_phase1_divergence_adapter_b12_run_6.pt
Saved d