# Personal Computer Environment Setup

## OS Packages

In [None]:
!sudo apt-get update
!sudo apt-get install mecab libmecab-dev mecab-ipadic-utf8 -y

In [None]:
!pip install fugashi unidic-lite

## Pyhon Packages (Session Break)

In [None]:
!pip install torch transformer-lens einops

## Python Old Version Packages

In [None]:
!pip install datasets==3.6.0

# Experiments

## The simplest CC pruning method

In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
This script serves as a proof-of-concept for Phase 1 of the proposed research.

---
FINAL VICTORY VERSION:
- This definitive version ensures full reproducibility by setting all random seeds
  and configuring PyTorch's CUDA backend for deterministic operations.
---
"""
# ---------------------------------------------------------------------------
#  *** PRE-EXECUTION SETUP (MANDATORY) ***
# 1. Place the `requirements.txt` file (using `datasets==2.10.1`) in the same directory.
# 2. In your terminal or Colab cell, run: `!pip install -r requirements.txt`
# ---------------------------------------------------------------------------

import torch
import numpy as np
import einops
from datasets import load_dataset
from transformer_lens import HookedTransformer, HookedTransformerConfig, utils
from transformers import AutoModel, AutoTokenizer, AutoConfig
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import matplotlib.gridspec as gridspec
import random

def set_seed(seed: int):
    """
    Sets the random seed for Python, NumPy, and PyTorch to ensure reproducibility.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # The following two lines are crucial for CUDA determinism
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# =====================================================================================
# 1. SETUP: Model and Data Loading (Confirmed Working)
# =====================================================================================
def load_model_and_data(model_name: str, dataset_id: str, dataset_subset: str):
    print("--- Starting Final Model & Data Loading Process ---")

    print(f"Step 1/3: Loading original Hugging Face components for '{model_name}'...")
    hf_model = AutoModel.from_pretrained(model_name)
    hf_tokenizer = AutoTokenizer.from_pretrained(model_name)
    hf_config = AutoConfig.from_pretrained(model_name)
    print(" -> Success.")

    print("Step 2/3: Manually building and populating HookedTransformer model...")
    model_config = HookedTransformerConfig(
        d_model=hf_config.hidden_size,
        d_head=hf_config.hidden_size // hf_config.num_attention_heads,
        n_layers=hf_config.num_hidden_layers,
        n_heads=hf_config.num_attention_heads,
        d_vocab=hf_config.vocab_size,
        d_mlp=hf_config.intermediate_size,
        act_fn=hf_config.hidden_act,
        n_ctx=hf_config.max_position_embeddings,
        tokenizer_name=model_name,
        normalization_type="LN",
        model_name=model_name,
    )
    model = HookedTransformer(model_config)
    model.load_state_dict(hf_model.state_dict(), strict=False)
    model.eval()
    model.tokenizer = hf_tokenizer
    print(" -> Model ready.")

    print(f"Step 3/3: Loading dataset '{dataset_id}/{dataset_subset}'...")
    dataset = load_dataset(dataset_id, name=dataset_subset, trust_remote_code=True)
    validation_dataset = dataset["validation"]
    print(" -> Dataset loading complete.")

    print("\n--- Model and Data Ready for Analysis ---")
    return model, validation_dataset

# =====================================================================================
# 2. ANALYSIS (DEFINITIVE LOGIC FIX)
# =====================================================================================
def get_attention_head_outputs(model: HookedTransformer, text: str):
    hook_names = [utils.get_act_name("z", i) for i in range(model.cfg.n_layers)]

    _, cache = model.run_with_cache(
        model.to_tokens(text),
        names_filter=lambda name: name in hook_names
    )
    all_head_outputs = torch.stack([cache[name] for name in hook_names])
    return all_head_outputs

def analyze_cls_token_contribution(head_outputs: torch.Tensor, model: HookedTransformer):
    cls_token_outputs = head_outputs[:, :, 0, :, :]
    head_contributions = torch.linalg.norm(cls_token_outputs, dim=-1)
    return head_contributions.mean(dim=1)

# =====================================================================================
# 3. VISUALIZATION
# =====================================================================================
def plot_heatmap_with_explanation(scores: np.ndarray, max_score_idx: tuple, title: str, filename: str):
    """Plots a heatmap with labeled axes and a dynamic explanation panel."""
    fig = plt.figure(figsize=(20, 10))
    gs = gridspec.GridSpec(1, 2, width_ratios=[2.5, 1])

    ax_heatmap = fig.add_subplot(gs[0, 0])
    sns.heatmap(scores, cmap="viridis", ax=ax_heatmap, cbar_kws={'label': 'Contribution Score (L2 Norm)'})

    n_layers, n_heads = scores.shape
    y_labels = [f"L{i}" for i in range(n_layers)]
    x_labels = [f"H{i}" for i in range(n_heads)]
    ax_heatmap.set_yticklabels(y_labels, rotation=0)
    ax_heatmap.set_xticklabels(x_labels, rotation=45, ha="right")
    ax_heatmap.set_title("Attention Head Contributions to [CLS] Token")
    ax_heatmap.set_ylabel("Layer")
    ax_heatmap.set_xlabel("Head")

    ax_text = fig.add_subplot(gs[0, 1])
    ax_text.axis('off')

    most_influential_head_str = f"L{max_score_idx[0]}-H{max_score_idx[1]}"

    text_content = [
        ("How to Read This Heatmap", 16, 'bold', 0.95),
        ("■ Layer (Y-axis): Model Depth", 12, 'bold', 0.85),
        ("  • L0-L3 (Early): Process basic syntax.", 11, 'normal', 0.80),
        ("  • L8-L11 (Deep): Handle abstract semantics.", 11, 'normal', 0.75),
        ("\n■ Head (X-axis): Parallel Specialists", 12, 'bold', 0.68),
        ("  • Each layer has 12 heads focusing on different", 11, 'normal', 0.63), ("    word relationships.", 11, 'normal', 0.59),
        ("\n■ Color (Heat): Contribution Strength", 12, 'bold', 0.52),
        ("  • Bright Yellow: Critical for the task.", 11, 'normal', 0.47), ("  • Dark Purple: Less important for the task.", 11, 'normal', 0.42),
        ("\n■ Key Finding for This Run", 12, 'bold', 0.35),
        (f"  • The brightest spot is {most_influential_head_str}, showing", 11, 'normal', 0.30),
        ("    the model has learned specialized 'circuits'", 11, 'normal', 0.25), ("    for semantic comparison.", 11, 'normal', 0.20),
    ]

    for content, size, weight, y_pos in text_content:
        ax_text.text(0.0, y_pos, content, transform=ax_text.transAxes, fontsize=size, fontweight=weight, va='top')

    plt.suptitle(title, fontsize=20)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(filename, dpi=300)
    plt.close()
    print(f"\nHeatmap with explanation saved to {filename}")

def plot_distributions_with_examples(scores: np.ndarray, filename: str):
    """
    Plots distributions and a table of sentence examples in a single figure.
    """
    # MODIFICATION: Further tighten the layout
    fig = plt.figure(figsize=(16, 10))
    gs = gridspec.GridSpec(2, 2, height_ratios=[1, 0.4])

    # --- Top Left Plot: Histogram ---
    ax1 = fig.add_subplot(gs[0, 0])
    all_scores_flat = scores.flatten()
    sns.histplot(all_scores_flat, kde=False, ax=ax1, bins=30)
    ax1.set_title('Distribution of All Head Contributions')
    ax1.set_xlabel('Contribution Score (L2 Norm)')
    ax1.set_ylabel('Frequency')

    # --- Top Right Plot: Bar plot per layer ---
    ax2 = fig.add_subplot(gs[0, 1])
    layer_means = scores.mean(axis=1)

    layer_indices_str = [f"L{i}" for i in range(len(layer_means))]
    sns.barplot(x=layer_indices_str, y=layer_means, ax=ax2, palette="Blues_d")

    ax2.set_title('Average Contribution per Layer')
    ax2.set_xlabel('Layer')
    ax2.set_ylabel('Average Contribution Score (Log Scale)')
    ax2.set_yscale('log')
    ax2.grid(True, which='both', axis='y', linestyle='--', linewidth=0.7)

    # --- Bottom Plot: Table of Examples ---
    ax3 = fig.add_subplot(gs[1, :])
    ax3.axis('off')
    ax3.set_title("\nCausal Sentence Examples (for illustration)", fontsize=14, pad=10)

    table_data = [
        ["Paraphrase (High Sim, Low Overlap)", "A boy is playing the guitar.", "A young man is performing on a musical instrument."],
        ["Semantic Difference (Low Sim, High Overlap)", "The cat is chasing the dog.", "The dog is chasing the cat."],
        ["Inference (High Sim, Med Overlap)", "The government passed a new law.", "A bill was approved by congress."]
    ]
    col_labels = ["Example Type (Why deep layers are needed)", "Sentence 1", "Sentence 2"]

    table = ax3.table(cellText=table_data, colLabels=col_labels, loc='center', cellLoc='left', colWidths=[0.32, 0.34, 0.34])
    table.auto_set_font_size(False)
    table.set_fontsize(11)
    table.scale(1.1, 1.8)

    plt.suptitle('Analysis of Head Contributions & Causal Sentence Examples', fontsize=18)
    plt.tight_layout(rect=[0, 0.03, 1, 0.96])
    plt.savefig(filename, dpi=300)
    plt.close()
    print(f"Distribution plots with examples saved to {filename}")

# =====================================================================================
# 4. MAIN EXECUTION
# =====================================================================================
if __name__ == "__main__":
    set_seed(42)

    sns.set_theme(style="whitegrid")
    sns.set_context("paper", font_scale=1.2)
    # Optional: For journals requiring serif fonts (e.g., Times New Roman)
    # plt.rcParams['font.family'] = 'serif'
    # plt.rcParams['font.serif'] = ['Times']

    MODEL_NAME = "cl-nagoya/ruri-base-v2"
    DATASET_ID = "shunk031/JGLUE"
    DATASET_SUBSET = "JSTS"
    NUM_SAMPLES = 1457

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\n--- Using device: {device} ---")

    model, dataset = load_model_and_data(MODEL_NAME, DATASET_ID, DATASET_SUBSET)
    model.to(device)

    fixed_dataset = dataset.shuffle(seed=42)

    all_scores_s1 = []
    all_scores_s2 = []

    print(f"\nAnalyzing {len(fixed_dataset)} samples from the dataset on {device}...")

    for sample in tqdm(fixed_dataset):
        sentence1, sentence2 = sample['sentence1'], sample['sentence2']

        head_outputs_s1 = get_attention_head_outputs(model, sentence1)
        scores_s1 = analyze_cls_token_contribution(head_outputs_s1, model)
        all_scores_s1.append(scores_s1)

        head_outputs_s2 = get_attention_head_outputs(model, sentence2)
        scores_s2 = analyze_cls_token_contribution(head_outputs_s2, model)
        all_scores_s2.append(scores_s2)

    avg_scores_total = (torch.stack(all_scores_s1).mean(dim=0) + torch.stack(all_scores_s2).mean(dim=0)).cpu().numpy() / 2

    max_score_idx = np.unravel_index(np.argmax(avg_scores_total), avg_scores_total.shape)

    plot_heatmap_with_explanation(avg_scores_total, max_score_idx, "Average Contribution of Attention Heads to [CLS] Token (JSTS)", "attention_head_contribution_heatmap_with_explanation.png")

    plot_distributions_with_examples(avg_scores_total, "attention_head_distributions_with_examples.png")

    print("\n--- Analysis Complete ---")

    print(f"Most influential head found at L{max_score_idx[0]}-H{max_score_idx[1]}.")
    print("This marks the successful completion of Phase 1's initial exploration.")

## Preliminary Experiment

In [None]:
!python setup_and_run.py --mode demo

In [None]:
!python setup_and_run.py --mode full

In [None]:
!bash run_all_experiments.sh