In [1]:
# let's find out if the mean difference of activations in the model outputs converge
# PARALLEL VERSION: Load all LoRA adapters simultaneously

import sys
import os
sys.path.append(os.path.abspath('..'))

from train_commons import SharedLoRA, PromptDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# Set professional matplotlib theme
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams.update({
    'figure.facecolor': 'white',
    'axes.facecolor': 'white',
    'axes.edgecolor': '#333333',
    'axes.linewidth': 0.8,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.grid': True,
    'grid.color': '#E5E5E5',
    'grid.linewidth': 0.5,
    'font.size': 10,
    'axes.titlesize': 12,
    'axes.labelsize': 10,
    'xtick.labelsize': 8,
    'ytick.labelsize': 8,
    'legend.fontsize': 9,
    'lines.linewidth': 1.5,
    'lines.color': '#2E86AB',
    'axes.prop_cycle': plt.cycler('color', ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#4C956C', '#7209B7']),
    'figure.dpi': 100,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.facecolor': 'white'
})

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Create a batched LoRA class for parallel processing
class BatchedSharedLoRA(torch.nn.Module):
    """
    Batched version of SharedLoRA that processes multiple adapters in parallel
    by concatenating their parameters and using batch operations.
    """
    def __init__(self, hidden_size, rank, scaling, num_adapters):
        super().__init__()
        self.hidden_size = hidden_size
        self.rank = rank
        self.scaling = scaling
        self.num_adapters = num_adapters
        
        # Stack all lora_A and lora_B parameters into batched tensors
        # Shape: (num_adapters, hidden_size, rank) and (num_adapters, rank, hidden_size)
        self.batched_lora_A = torch.nn.Parameter(torch.randn(num_adapters, hidden_size, rank))
        self.batched_lora_B = torch.nn.Parameter(torch.zeros(num_adapters, rank, hidden_size))
    
    def load_from_individual_adapters(self, adapter_paths, device):
        """Load parameters from individual adapter checkpoints"""
        lora_A_list = []
        lora_B_list = []
        
        for ckpt_path in adapter_paths:
            state_dict = torch.load(ckpt_path, map_location=device)
            lora_A_list.append(state_dict['lora_A'])
            lora_B_list.append(state_dict['lora_B'])
        
        # Stack into batched tensors
        self.batched_lora_A.data = torch.stack(lora_A_list, dim=0)
        self.batched_lora_B.data = torch.stack(lora_B_list, dim=0)
    
    def forward(self, x):
        """
        Apply all adapters in parallel using batched operations
        Args:
            x: Input tensor of shape (batch, seq_len, hidden_size)
        Returns:
            List of modified outputs, one for each adapter
        """
        # x shape: (batch, seq_len, hidden_size)
        # We need to compute x @ lora_A @ lora_B for each adapter
        
        # Expand x to work with batched operations
        # x_expanded shape: (num_adapters, batch, seq_len, hidden_size)
        x_expanded = x.unsqueeze(0).expand(self.num_adapters, -1, -1, -1)
        
        # Batched matrix multiplication: (num_adapters, batch, seq_len, hidden_size) @ (num_adapters, hidden_size, rank)
        # Result shape: (num_adapters, batch, seq_len, rank)
        intermediate = torch.einsum('nbsh,nhr->nbsr', x_expanded, self.batched_lora_A)
        
        # Second matrix multiplication: (num_adapters, batch, seq_len, rank) @ (num_adapters, rank, hidden_size)
        # Result shape: (num_adapters, batch, seq_len, hidden_size)
        updates = torch.einsum('nbsr,nrh->nbsh', intermediate, self.batched_lora_B)
        
        # Normalize updates
        update_norms = torch.norm(updates, p=2, dim=-1, keepdim=True) + 1e-8
        updates = updates / update_norms * self.scaling
        
        # Add updates to original input
        # x_expanded + updates, shape: (num_adapters, batch, seq_len, hidden_size)
        modified_outputs = x_expanded + updates
        
        return modified_outputs


In [3]:
def calculate_rolling_average(diff_vectors, window_size=10):
    """Calculate rolling average of vector magnitudes with given window size"""
    if len(diff_vectors) < 1:
        return []
    
    # Calculate magnitude (L2 norm) for each difference vector
    magnitudes = []
    for diff_vec in diff_vectors:
        if diff_vec.numel() > 0:  # Check if tensor is not empty
            magnitude = torch.norm(diff_vec.flatten()).item()
            magnitudes.append(magnitude)
        else:
            magnitudes.append(0.0)
    
    if len(magnitudes) < window_size:
        # For short sequences, return cumulative average
        rolling_avg = []
        for i in range(len(magnitudes)):
            rolling_avg.append(np.mean(magnitudes[:i+1]))
        return rolling_avg
    
    rolling_avg = []
    for i in range(len(magnitudes)):
        if i < window_size:
            rolling_avg.append(np.mean(magnitudes[:i+1]))
        else:
            rolling_avg.append(np.mean(magnitudes[i-window_size+1:i+1]))
    return rolling_avg

In [4]:
def plot_convergence(activation_data, phase, ckpt_name, num_layers=48):
    """Create a figure with 48 subplots showing rolling mean convergence for each layer"""
    fig, axes = plt.subplots(8, 6, figsize=(28, 36), facecolor='white')
    
    # Professional title styling
    fig.suptitle(f'Activation Difference Convergence Analysis\nPhase {phase} • {ckpt_name}', 
                fontsize=18, fontweight='bold', y=0.98, color='#2C3E50')
    
    for layer_idx in range(num_layers):
        row = layer_idx // 6
        col = layer_idx % 6
        ax = axes[row, col]
        
        if layer_idx in activation_data and len(activation_data[layer_idx]) > 0:
            diff_vectors = activation_data[layer_idx]
            rolling_avg = calculate_rolling_average(diff_vectors, window_size=10)
            
            if len(rolling_avg) > 0:
                # Professional line styling
                ax.plot(rolling_avg, linewidth=2, color='#2E86AB', alpha=0.8)
                ax.set_title(f'Layer {layer_idx}', fontsize=11, fontweight='semibold', pad=8)
                ax.set_xlabel('Token Position', fontsize=9, color='#34495E')
                ax.set_ylabel('Rolling Mean\nMagnitude', fontsize=9, color='#34495E')
                
                # Enhanced grid styling
                ax.grid(True, alpha=0.4, linestyle='-', linewidth=0.5)
                ax.set_axisbelow(True)
                
                # Professional tick styling
                ax.tick_params(labelsize=7, colors='#5D6D7E')
                
                # Add subtle background
                ax.set_facecolor('#FAFAFA')
                
                # Improve axis appearance
                for spine in ax.spines.values():
                    spine.set_color('#BDC3C7')
                    spine.set_linewidth(0.8)
                
            else:
                ax.text(0.5, 0.5, 'No Valid Data', ha='center', va='center', 
                       transform=ax.transAxes, fontsize=9, color='#7F8C8D', style='italic')
                ax.set_title(f'Layer {layer_idx}', fontsize=11, fontweight='semibold', pad=8)
                ax.set_facecolor('#F8F9FA')
        else:
            ax.text(0.5, 0.5, 'No Data', ha='center', va='center', 
                   transform=ax.transAxes, fontsize=9, color='#95A5A6', style='italic')
            ax.set_title(f'Layer {layer_idx}', fontsize=11, fontweight='semibold', pad=8)
            ax.set_facecolor('#F8F9FA')
    
    # Professional layout with better spacing
    plt.tight_layout(rect=[0, 0.02, 1, 0.96], h_pad=2.5, w_pad=1.5)
    
    # Create figures directory if it doesn't exist
    figures_dir = "../figures"
    os.makedirs(figures_dir, exist_ok=True)
    
    # Save as SVG file with proper naming convention
    filename = f"dir_convergence_phase{phase}_{ckpt_name}.svg"
    filepath = os.path.join(figures_dir, filename)
    plt.savefig(filepath, format='svg', bbox_inches='tight', facecolor='white', 
                edgecolor='none', dpi=300)
    plt.close(fig)  # Close figure to free memory
    
    print(f"Saved convergence plot: {filepath}")
    return filepath

In [5]:
phase = 1
ckpts = os.listdir(f"../divergence_adapters_phase{phase}")
# ckpts = [os.path.join(f"../divergence_adapters_phase{phase}", ckpt) for ckpt in ckpts]
# For testing with single checkpoint - uncomment line below:
ckpts = ["../divergence_adapters_phase1/divergence_adapter_b12_run_6.pth"]

model_name = "Qwen/Qwen2.5-14B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
rank = 2
scaling = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

csv = pd.read_csv("../datasets/full_dataset.csv")[:100]
full_prompts = PromptDataset(csv, prompt_column='full_prompt')

print(f"Found {len(ckpts)} checkpoints to process in parallel")
print(f"Processing {len(full_prompts)} prompts")

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    dtype=torch.bfloat16,
)
hidden_size = model.config.hidden_size


# Create mapping from checkpoint index to checkpoint identifier
ckpt_idx_to_name = {}
ckpt_idx_to_path = {}
for idx, ckpt_path in enumerate(ckpts):
    ckpt_name = os.path.basename(ckpt_path).replace('.pth', '')
    ckpt_idx_to_name[idx] = ckpt_name
    ckpt_idx_to_path[idx] = ckpt_path

print("Checkpoint mapping:")
for idx, name in ckpt_idx_to_name.items():
    print(f"  Index {idx}: {name}")

# Load ALL adapters into a single batched adapter
batched_adapter = BatchedSharedLoRA(hidden_size, rank=rank, scaling=scaling, num_adapters=len(ckpts))
batched_adapter.load_from_individual_adapters(ckpts, device)
batched_adapter = batched_adapter.to(device, dtype=torch.bfloat16)

print(f"Loaded {len(ckpts)} adapters into batched processor")

# Global variables to store activation differences for ALL checkpoints
# Structure: activation_diffs[ckpt_idx][layer_idx] -> list of difference vectors
activation_diffs = defaultdict(lambda: defaultdict(list))
current_layer_idx = 0

def apply_batched_adapter_hook(module, input, output):
    """Hook that applies ALL adapters in parallel using batched operations"""
    global current_layer_idx
    
    if isinstance(output, tuple):
        hidden_state = output[0]
        
        # Apply all adapters in parallel using batched operations
        # modified_outputs shape: (num_adapters, batch, seq_len, hidden_size)
        modified_outputs = batched_adapter(hidden_state)
        
        # Calculate differences for each adapter
        # hidden_state shape: (batch, seq_len, hidden_size)
        # We need to expand it to match modified_outputs
        hidden_state_expanded = hidden_state.unsqueeze(0).expand(len(ckpts), -1, -1, -1)
        
        # Compute differences: (num_adapters, batch, seq_len, hidden_size)
        diff_batch = hidden_state_expanded - modified_outputs
        
        # Store each adapter's difference vector
        for ckpt_idx in range(len(ckpts)):
            diff_vector = diff_batch[ckpt_idx].detach().cpu()
            activation_diffs[ckpt_idx][current_layer_idx].append(diff_vector)
        
        # Return the original output (we don't modify the forward pass)
        return output
    else:
        # Apply all adapters in parallel using batched operations
        modified_outputs = batched_adapter(output)
        
        # Calculate differences for each adapter
        output_expanded = output.unsqueeze(0).expand(len(ckpts), -1, -1, -1)
        diff_batch = output_expanded - modified_outputs
        
        # Store each adapter's difference vector
        for ckpt_idx in range(len(ckpts)):
            diff_vector = diff_batch[ckpt_idx].detach().cpu()
            activation_diffs[ckpt_idx][current_layer_idx].append(diff_vector)
        
        # Return the original output (we don't modify the forward pass)
        return output

# Register hooks for each layer ONCE (will capture data for all adapters in parallel)
hook_handles = []
for layer_idx, layer in enumerate(model.model.layers):
    def make_hook(layer_index):
        def hook_fn(module, input, output):
            global current_layer_idx
            current_layer_idx = layer_index
            return apply_batched_adapter_hook(module, input, output)
        return hook_fn
    
    handle = layer.register_forward_hook(make_hook(layer_idx))
    hook_handles.append(handle)

print("Registered hooks for all layers")

# Process ALL prompts ONCE (collecting data for all adapters simultaneously)
print("Starting parallel processing of all adapters...")
for i, prompt in enumerate(full_prompts):
    if i % 10 == 0:
        print(f"Processing prompt {i}/{len(full_prompts)}")
        
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    # Generate tokens one by one to track per-token differences
    current_input = inputs['input_ids']
    
    for token_pos in range(50):  # max_new_tokens
        with torch.no_grad():
            outputs = model(current_input)
            next_token_logits = outputs.logits[0, -1, :]
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0).unsqueeze(0)
            
            if next_token.item() == im_end_token_id:
                break
                
            current_input = torch.cat([current_input, next_token], dim=1)

print("Completed processing all prompts for all adapters!")

# Clean up hooks
for handle in hook_handles:
    handle.remove()

Found 1 checkpoints to process in parallel
Processing 100 prompts


Loading checkpoint shards: 100%|██████████| 8/8 [00:05<00:00,  1.38it/s]


Checkpoint mapping:
  Index 0: divergence_adapter_b12_run_6
Loaded 1 adapters into batched processor
Registered hooks for all layers
Starting parallel processing of all adapters...
Processing prompt 0/100
Processing prompt 10/100
Processing prompt 20/100
Processing prompt 30/100
Processing prompt 40/100
Processing prompt 50/100
Processing prompt 60/100
Processing prompt 70/100
Processing prompt 80/100
Processing prompt 90/100
Completed processing all prompts for all adapters!


In [6]:
# Generate convergence plots for ALL checkpoints
print("Generating convergence plots for all checkpoints...")
saved_plots = []
for ckpt_idx in range(len(ckpts)):
    ckpt_name = ckpt_idx_to_name[ckpt_idx]
    print(f"Creating plot {ckpt_idx + 1}/{len(ckpts)} for checkpoint {ckpt_name}")
    filepath = plot_convergence(activation_diffs[ckpt_idx], phase, ckpt_name)
    saved_plots.append(filepath)

# Clean up model
del model
torch.cuda.empty_cache()

print("Analysis complete! Generated convergence plots for all checkpoints.")
print(f"Saved {len(saved_plots)} plots to ../figures/ directory:")
for plot_path in saved_plots:
    print(f"  - {plot_path}")


Generating convergence plots for all checkpoints...
Creating plot 1/1 for checkpoint divergence_adapter_b12_run_6


Saved convergence plot: ../figures/dir_convergence_phase1_divergence_adapter_b12_run_6.svg
Analysis complete! Generated convergence plots for all checkpoints.
Saved 1 plots to ../figures/ directory:
  - ../figures/dir_convergence_phase1_divergence_adapter_b12_run_6.svg
