In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from transformers import AutoModelForCausalLM

In [2]:
def calculate_percentage_change_llama(base_model, fine_tuned_model, component_type="mlp"):
    """
    Calculate and plot the percentage change in weights for MLPs or attention heads across layers
    between a base LLaMA-2 model and a fine-tuned model, with separate lines for up-projection and down-projection weights.

    Args:
        base_model: The base LLaMA-2 model.
        fine_tuned_model: The fine-tuned LLaMA-2 model.
        component_type (str): The type of component to analyze ("mlp" or "attention").

    """
    assert component_type in ["mlp", "attention"], "component_type must be 'mlp' or 'attention'."

    # Transfer models to the available device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    base_model = base_model.to(device)
    fine_tuned_model = fine_tuned_model.to(device)

    num_layers = len(base_model.model.layers)
    up_proj_changes = []
    down_proj_changes = []
    q_changes, k_changes, v_changes, o_changes = [], [], [], []


    if component_type == "mlp":
            for i in range(num_layers):
                # Up-projection weights
                base_up_weights = base_model.model.layers[i].mlp.up_proj.weight.data.to(device)
                fine_tuned_up_weights = fine_tuned_model.model.layers[i].mlp.up_proj.weight.data.to(device)
                up_change = torch.abs(fine_tuned_up_weights - base_up_weights) / torch.abs(base_up_weights + 1e-1000)
                up_proj_changes.append(torch.mean(up_change).item() * 100)

                # Down-projection weights
                base_down_weights = base_model.model.layers[i].mlp.down_proj.weight.data.to(device)
                fine_tuned_down_weights = fine_tuned_model.model.layers[i].mlp.down_proj.weight.data.to(device)
                down_change = torch.abs(fine_tuned_down_weights - base_down_weights) / torch.abs(base_down_weights + 1e-1000)
                down_proj_changes.append(torch.mean(down_change).item() * 100)

            return (up_proj_changes, down_proj_changes)
            # Plot the results
            # plt.figure(figsize=(10, 6))
            # plt.plot(range(num_layers), up_proj_changes, marker='o', color='blue', label="Up-Projection Weights")
            # plt.plot(range(num_layers), down_proj_changes, marker='o', color='red', label="Down-Projection Weights")
            # plt.title("Percentage Change in MLP Weights Across Layers (LLaMA-2)")
            # plt.xlabel("Layer")
            # plt.ylabel("Percentage Change (%)")
            # plt.grid(True)
            # plt.legend()
            # plt.show()

    elif component_type == "attention":
            for i in range(num_layers):
                # Extract Q, K, V, O projection weights
                base_q_weights = base_model.model.layers[i].self_attn.q_proj.weight.data.to(device)
                fine_tuned_q_weights = fine_tuned_model.model.layers[i].self_attn.q_proj.weight.data.to(device)
                base_k_weights = base_model.model.layers[i].self_attn.k_proj.weight.data.to(device)
                fine_tuned_k_weights = fine_tuned_model.model.layers[i].self_attn.k_proj.weight.data.to(device)
                base_v_weights = base_model.model.layers[i].self_attn.v_proj.weight.data.to(device)
                fine_tuned_v_weights = fine_tuned_model.model.layers[i].self_attn.v_proj.weight.data.to(device)
                base_o_weights = base_model.model.layers[i].self_attn.o_proj.weight.data.to(device)
                fine_tuned_o_weights = fine_tuned_model.model.layers[i].self_attn.o_proj.weight.data.to(device)

                # Compute percentage changes
                q_change = torch.mean(torch.abs(fine_tuned_q_weights - base_q_weights) / torch.abs(base_q_weights + 1e-10)).item() * 100
                k_change = torch.mean(torch.abs(fine_tuned_k_weights - base_k_weights) / torch.abs(base_k_weights + 1e-10)).item() * 100
                v_change = torch.mean(torch.abs(fine_tuned_v_weights - base_v_weights) / torch.abs(base_v_weights + 1e-10)).item() * 100
                o_change = torch.mean(torch.abs(fine_tuned_o_weights - base_o_weights) / torch.abs(base_o_weights + 1e-10)).item() * 100

                q_changes.append(q_change)
                k_changes.append(k_change)
                v_changes.append(v_change)
                o_changes.append(o_change)

            return (q_changes, k_changes, v_changes, o_changes)
            # Plot Q, K, V, O changes
            # plt.figure(figsize=(10, 6))
            # plt.plot(range(num_layers), q_changes, marker='o', color='blue', label="Q Projection")
            # plt.plot(range(num_layers), k_changes, marker='o', color='green', label="K Projection")
            # plt.plot(range(num_layers), v_changes, marker='o', color='red', label="V Projection")
            # plt.plot(range(num_layers), o_changes, marker='o', color='purple', label="O Projection")
            # plt.title("Percentage Change in Q, K, V, and O Projections Across Layers (LLaMA)")
            # plt.xlabel("Layer")
            # plt.ylabel("Percentage Change (%)")
            # plt.grid(True)
            # plt.legend()
            # plt.show()