In [None]:
import re

import pandas as pd
import seaborn as sns


def prepare_dataframe(csv_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    df["canonical_pair"] = df.apply(lambda x: "__".join(sorted([x["model_1"], x["model_2"]])), axis=1)
    return df


def show_linearcka_to_corresponding_layer(df: pd.DataFrame):
    for base_model in set(df.model_1.unique()).union(set(df.model_2.unique())):
        data = df[(df.model_1 == base_model) | (df.model_2 == base_model)].copy()

        # data = data[(data.layer_1 == data.layer_2)]
        layer_to_cat = {0: 0, 6: 6, 12: 12, 18: 18, 35: "final", 27: "final"}
        cat_order = [0, 6, 12, 18, "final"]
        data["depth_cat_1"] = pd.Categorical(data["layer_1"].map(layer_to_cat), categories=cat_order, ordered=True)
        data["depth_cat_2"] = pd.Categorical(data["layer_2"].map(layer_to_cat), categories=cat_order, ordered=True)
        data = data[(data.depth_cat_1 == data.depth_cat_2)]

        data["descendant"] = data["canonical_pair"].apply(lambda s: re.sub(base_model, "", s, count=1).strip("_"))
        data

        g = sns.catplot(data, x="depth_cat_1", y="cka_linear", hue="descendant", kind="point", col="token_type")
        g.axes[0, 0].set_ylabel(f"cka_linear to {base_model} at same layer")
        # sns.catplot(data, x="layer_1", y="cka_rbf", hue="canonical_pair", kind="point", col="token_type")

## Models descending from Qwen-R1-distill-7B

In [None]:
df = prepare_dataframe("cka_results.csv")
df

In [None]:
pivot = df.pivot_table(values="cka_linear", index=["token_type", "layer_1"], columns=["canonical_pair", "layer_2"])
pivot

In [None]:
df.canonical_pair.unique()

All models descend from DeepSeek-R1-Distill-Qwen-7B. How did they change?

In [None]:
base_model = "DeepSeek-R1-Distill-Qwen-7B"
data = df[(df.model_1 == base_model) | (df.model_2 == base_model)].copy()


# data = data[(data.layer_1 == data.layer_2)]
layer_to_cat = {0: 0, 6: 6, 12: 12, 18: 18, 35: "final", 27: "final"}
cat_order = [0, 6, 12, 18, "final"]
data["depth_cat_1"] = pd.Categorical(data["layer_1"].map(layer_to_cat), categories=cat_order, ordered=True)
data["depth_cat_2"] = pd.Categorical(data["layer_2"].map(layer_to_cat), categories=cat_order, ordered=True)
data = data[(data.depth_cat_1 == data.depth_cat_2)]

data["descendant"] = data["canonical_pair"].apply(lambda s: re.sub(base_model, "", s, count=1).strip("_"))
data

g = sns.catplot(data, x="depth_cat_1", y="cka_linear", hue="descendant", kind="point", col="token_type")
g.axes[0, 0].set_ylabel(f"cka_linear to {base_model} at same layer")
# sns.catplot(data, x="layer_1", y="cka_rbf", hue="canonical_pair", kind="point", col="token_type")

Only Light-R1 is a pure SFT model, the others also involve some RL. Does the SFT model have different representations than the RL(+SFT) models?

We are checking pairwise similarities now to test the hypothesis that SFT has a different effect than RL on the model.

In [None]:
show_linearcka_to_corresponding_layer(df)

### With same input tokens

Comparing reps of all tokens between corresponding layers

Caveat: answers longer than 8096 tokens were cut off to avoid OOM

In [None]:
df = prepare_dataframe("cka_fixed_input_results.csv")
df

Filtering out the final layer, because it is way lower than the rest. 
Makes sense, because their next token predictions are probably different.

In [None]:
plotdata = df.loc[df.layer <= 21]
# sns.relplot(plotdata, x="layer", y="cka_rbf", hue="canonical_pair", kind="line", style="token_origin")
sns.relplot(plotdata, x="layer", y="cka_linear", hue="canonical_pair", kind="line", style="token_origin")

In [None]:
plotdata = df.loc[df.layer < df.layer.max()]
sns.relplot(plotdata, x="layer", y="cka_linear", hue="canonical_pair", kind="line", style="token_origin")

plotdata = df[df.layer < df.layer.max()]
sns.relplot(plotdata, x="layer", y="cka_rbf", hue="canonical_pair", kind="line", style="token_origin")

In [None]:
plotdata = df  # .loc[df.layer < df.layer.max()]
sns.relplot(plotdata, x="layer", y="cka_linear", hue="canonical_pair", kind="line", style="token_origin")

plotdata = df  # [df.layer < df.layer.max()]
sns.relplot(plotdata, x="layer", y="cka_rbf", hue="canonical_pair", kind="line", style="token_origin")

In [None]:
plotdata = df.loc[df.layer < df.layer.max()]
sns.relplot(plotdata, x="layer", y="cka_linear", hue="canonical_pair", kind="line", col="token_origin")

plotdata = df[df.layer < df.layer.max()]
sns.relplot(plotdata, x="layer", y="cka_rbf", hue="canonical_pair", kind="line", col="token_origin")

In [None]:
df

In [None]:
df["contains_basemodel"] = (df["model_1"] == "DeepSeek-R1-Distill-Qwen-7B") | (
    df["model_2"] == "DeepSeek-R1-Distill-Qwen-7B"
)

plotdata = df.loc[df["contains_basemodel"] & (df.layer < df.layer.max())]
sns.relplot(plotdata, x="layer", y="cka_linear", hue="canonical_pair", kind="line", col="token_origin")

plotdata = df.loc[df["contains_basemodel"] & (df.layer < df.layer.max())]
sns.relplot(plotdata, x="layer", y="cka_rbf", hue="canonical_pair", kind="line", col="token_origin")

## Models descending from Qwen2.5-Math-7B (optionally +rope extension)

In [None]:
df = prepare_dataframe("cka_results_qwen2.5-math.csv")
df

openr1-distill and qwen-r1-distill are pure SFT models, acereason1.1 also involves some RL. Does the SFT model have different representations than the RL(+SFT) models?

We are checking pairwise similarities now to test the hypothesis that SFT has a different effect than RL on the model.

openr1 is a descendant of qwen2.5-math-rope300k. the rope300k model should be basically identical to the regular qwen2.5 math model. Lets check


**Result:**\
qwen-rope is quite different to qwen.\
acereason and openr1 are more similar to each other compared to the base cot models\
during reasoning, all models are dissimilar (but that could be confounded)

The reasoning step analysis is not robust enough: the idea was that models that reason in roughly the same direction have similar representations. But their reasoning direction can change all the time with _wait_ etc.

In [None]:
show_linearcka_to_corresponding_layer(df)

### Same analysis as above with fixed tokens

In [None]:
df = prepare_dataframe("cka_fixed_input_results_qwen2.5-math-extra.csv")
df


In [None]:
# plotdata = df.loc[df.layer < df.layer.max()]
plotdata = df
sns.relplot(plotdata, x="layer", y="cka_linear", hue="canonical_pair", kind="line", col="token_origin")

# plotdata = df[df.layer < df.layer.max()]
plotdata = df
sns.relplot(plotdata, x="layer", y="cka_rbf", hue="canonical_pair", kind="line", col="token_origin")

In [None]:
df["contains_basemodel"] = (df["model_1"] == "Qwen2.5-Math-7B") | (df["model_2"] == "Qwen2.5-Math-7B")

plotdata = df.loc[df["contains_basemodel"]]
sns.relplot(plotdata, x="layer", y="cka_linear", hue="canonical_pair", kind="line", col="token_origin")

plotdata = df[df["contains_basemodel"]]
sns.relplot(plotdata, x="layer", y="cka_rbf", hue="canonical_pair", kind="line", col="token_origin")

In [None]:
df["contains_basemodel"] = (df["model_1"] == "Qwen2.5-Math-7B") | (df["model_2"] == "Qwen2.5-Math-7B")

plotdata = df.loc[df["contains_basemodel"] & (df.layer < df.layer.max())]
sns.relplot(plotdata, x="layer", y="cka_linear", hue="canonical_pair", kind="line", col="token_origin")

plotdata = df[df["contains_basemodel"] & (df.layer < df.layer.max())]
sns.relplot(plotdata, x="layer", y="cka_rbf", hue="canonical_pair", kind="line", col="token_origin")

## Models descending from R1-distill-Qwen2.5-1.5B

only with fixed tokens, generated from the qwen25-math descendants or the r1-qwen-7b descendants

In [None]:
df = prepare_dataframe("cka_fixed_input_results_r1-distill-qwen2.5-1.5b-1-cache2.csv")
df["contains_basemodel"] = (df["model_1"] == "DeepSeek-R1-Distill-Qwen-1-5B") | (
    df["model_2"] == "DeepSeek-R1-Distill-Qwen-1-5B"
)

# plotdata = df.loc[df.layer < df.layer.max()]
plotdata = df
sns.relplot(plotdata, x="layer", y="cka_linear", hue="canonical_pair", kind="line", col="token_origin")

# plotdata = df[df.layer < df.layer.max()]
plotdata = df
sns.relplot(plotdata, x="layer", y="cka_rbf", hue="canonical_pair", kind="line", col="token_origin")

plotdata = df.loc[df["contains_basemodel"]]
sns.relplot(plotdata, x="layer", y="cka_linear", hue="canonical_pair", kind="line", col="token_origin")

plotdata = df[df["contains_basemodel"]]
sns.relplot(plotdata, x="layer", y="cka_rbf", hue="canonical_pair", kind="line", col="token_origin")

In [None]:
df = prepare_dataframe("cka_fixed_input_results_r1-distill-qwen2.5-1.5b-1-cache1.csv")
df["contains_basemodel"] = (df["model_1"] == "DeepSeek-R1-Distill-Qwen-1-5B") | (
    df["model_2"] == "DeepSeek-R1-Distill-Qwen-1-5B"
)

# plotdata = df.loc[df.layer < df.layer.max()]
plotdata = df
sns.relplot(plotdata, x="layer", y="cka_linear", hue="canonical_pair", kind="line", col="token_origin")

# plotdata = df[df.layer < df.layer.max()]
plotdata = df
sns.relplot(plotdata, x="layer", y="cka_rbf", hue="canonical_pair", kind="line", col="token_origin")

plotdata = df.loc[df["contains_basemodel"]]
sns.relplot(plotdata, x="layer", y="cka_linear", hue="canonical_pair", kind="line", col="token_origin")

plotdata = df[df["contains_basemodel"]]
sns.relplot(plotdata, x="layer", y="cka_rbf", hue="canonical_pair", kind="line", col="token_origin")

In [None]:
df = prepare_dataframe("cka_fixed_input_results_r1-distill-qwen2.5-1.5b-2-cache2.csv")
df["contains_basemodel"] = (df["model_1"] == "DeepSeek-R1-Distill-Qwen-1-5B") | (
    df["model_2"] == "DeepSeek-R1-Distill-Qwen-1-5B"
)

# plotdata = df.loc[df.layer < df.layer.max()]
plotdata = df
sns.relplot(plotdata, x="layer", y="cka_linear", hue="canonical_pair", kind="line", col="token_origin")

# plotdata = df[df.layer < df.layer.max()]
plotdata = df
sns.relplot(plotdata, x="layer", y="cka_rbf", hue="canonical_pair", kind="line", col="token_origin")

plotdata = df.loc[df["contains_basemodel"]]
sns.relplot(plotdata, x="layer", y="cka_linear", hue="canonical_pair", kind="line", col="token_origin")

plotdata = df[df["contains_basemodel"]]
sns.relplot(plotdata, x="layer", y="cka_rbf", hue="canonical_pair", kind="line", col="token_origin")

In [None]:
df = prepare_dataframe("cka_fixed_input_results_r1-distill-qwen2.5-1.5b-2-cache1.csv")
df["contains_basemodel"] = (df["model_1"] == "DeepSeek-R1-Distill-Qwen-1-5B") | (
    df["model_2"] == "DeepSeek-R1-Distill-Qwen-1-5B"
)

# plotdata = df.loc[df.layer < df.layer.max()]
plotdata = df
sns.relplot(plotdata, x="layer", y="cka_linear", hue="canonical_pair", kind="line", col="token_origin")

# plotdata = df[df.layer < df.layer.max()]
plotdata = df
sns.relplot(plotdata, x="layer", y="cka_rbf", hue="canonical_pair", kind="line", col="token_origin")

plotdata = df.loc[df["contains_basemodel"]]
sns.relplot(plotdata, x="layer", y="cka_linear", hue="canonical_pair", kind="line", col="token_origin")

plotdata = df[df["contains_basemodel"]]
sns.relplot(plotdata, x="layer", y="cka_rbf", hue="canonical_pair", kind="line", col="token_origin")

### combined view across caches (responses) and models

In [None]:
all_data = []

base_model = "DeepSeek-R1-Distill-Qwen-1-5B"
dfs = [
    prepare_dataframe("cka_fixed_input_results_r1-distill-qwen2.5-1.5b-1-cache2.csv"),
    prepare_dataframe("cka_fixed_input_results_r1-distill-qwen2.5-1.5b-1-cache1.csv"),
    prepare_dataframe("cka_fixed_input_results_r1-distill-qwen2.5-1.5b-2-cache2.csv"),
    prepare_dataframe("cka_fixed_input_results_r1-distill-qwen2.5-1.5b-2-cache1.csv"),
]
df = pd.concat(dfs, axis=0)
df["base_model"] = base_model
all_data.append(df)

base_model = "Qwen2.5-Math-7B"
df = prepare_dataframe("cka_fixed_input_results_qwen2.5-math-extra.csv")
df["base_model"] = base_model
all_data.append(df)

base_model = "DeepSeek-R1-Distill-Qwen-7B"
dfs = [
    prepare_dataframe("cka_fixed_input_results.csv"),
    prepare_dataframe("cka_fixed_input_results_cache2.csv"),
]
df = pd.concat(dfs, axis=0)
df["base_model"] = base_model
all_data.append(df)

df = pd.concat(all_data, axis=0).reset_index()
df["contains_basemodel"] = (df["model_1"] == df["base_model"]) | (df["model_2"] == df["base_model"])


In [None]:
plotdata = df.loc[df["contains_basemodel"]].copy()
plotdata = plotdata[plotdata.layer < plotdata.layer.max()]
plotdata["descendant"] = plotdata.apply(
    lambda x: x["model_1"] if x["model_2"] == x["base_model"] else x["model_2"], axis=1
)

sns.catplot(
    plotdata,
    x="layer",
    y="cka_linear",
    hue="descendant",
    col="token_origin",
    sharey=False,
    row="base_model",
    kind="point",
)
# sns.relplot(plotdata, x="layer", y="cka_linear", hue="descendant", kind="line", col="token_origin")
# sns.relplot(plotdata, x="layer", y="cka_rbf", hue="descendant", kind="line", col="token_origin")

In [None]:
pivot = (
    plotdata.groupby(["base_model", "descendant", "token_origin", "layer"])["cka_linear"]
    .agg(["mean", "std"])
    .reset_index()
)
pivot.head()

In [None]:
sns.relplot(
    pivot[(pivot.base_model == "DeepSeek-R1-Distill-Qwen-1-5B")],
    # pivot[(pivot.base_model == "DeepSeek-R1-Distill-Qwen-1-5B") & (pivot["mean"] > 0.99)],
    x="layer",
    y="mean",
    hue="descendant",
    kind="line",
    col="token_origin",
)

In [None]:
pivot["descendant"].unique()

## In-depth analysis of cka scores

Lets look at a model and dataset, where we see the bathtub shape.

Then lets look at:
* grad of score wrt kernel elements (which elements have the biggest influence on the similarity score)
* stats about values in the kernel matrix

In [None]:
settings = [
    {
        "token_origin": "DeepSeek-R1-Distill-Qwen-7B",
        "model": "Qwen2.5-Math-7B-Oat-Zero",
        "model_hf": "sail/Qwen2.5-Math-7B-Oat-Zero",
        "base_model": "Qwen2.5-Math-7B",
        "base_model_hf": "Qwen/Qwen2.5-Math-7B",
        "bathtub": "middle",
    },
    {
        "token_origin": "DeepSeek-R1-Distill-Qwen-7B",
        # "model": "OpenR1-Distill",
        # "model_hf": "open-r1/OpenR1-Distill-7B",
        "model": "AceReason-Nemotron-1.1-7B",
        "model_hf": "nvidia/AceReason-Nemotron-1.1-7B",
        "base_model": "Qwen2.5-Math-7B",
        "base_model_hf": "Qwen/Qwen2.5-Math-7B",
        "bathtub": "everywhere",
    },
    {
        "token_origin": "DeepSeek-R1-Distill-Qwen-7B",
        "model": "Nemotron-Research-Reasoning-Qwen-1-5B",
        "base_model": "DeepSeek-R1-Distill-Qwen-1-5B",
        "bathtub": "nowhere",
    },
]

In [None]:
settings[1]

In [None]:
setting = settings[0]

subdf = df.loc[
    (df["base_model"] == setting["base_model"])
    & ((df["model_1"] == setting["model"]) | (df["model_2"] == setting["model"]))
    & ((df["model_1"] == setting["base_model"]) | (df["model_2"] == setting["base_model"]))
    & (df["token_origin"] == setting["token_origin"])
]

sns.relplot(data=subdf, x="layer", y="cka_linear", hue="sample", kind="line", legend=False, alpha=0.3)
# sns.catplot(data=subdf, x="layer", y="cka_linear", kind="box")

In [None]:
settings[0]


In [None]:
avgsim = subdf.groupby("sample")["cka_linear"].mean()

print(
    "top3 and bottom3 avg linear cka\n",
    avgsim[(avgsim > avgsim.quantile(0.95)) | (avgsim < avgsim.quantile(0.05))].index,
)
sns.relplot(
    data=subdf[
        subdf["sample"].isin(avgsim[(avgsim > avgsim.quantile(0.95)) | (avgsim < avgsim.quantile(0.05))].index)
    ],
    x="layer",
    y="cka_linear",
    hue="sample",
    kind="line",
    legend=True,
    alpha=0.3,
)


In [None]:
setting

In [None]:
# toks von sample 14 aus token_origin folder in cache_representation2 laden
import torch

from reasoning.cache import load_tokens
from reasoning.cka import cka, gram_linear
from reasoning.generation import get_hidden_states, get_model_and_tokenizer

# sample_id = 14  # low sim
# sample_id = 19  # low sim
# sample_id = 42  # low sim
# sample_id = 27  # high sim
sample_id = 31  # high sim
# sample_id = 44  # high sim
cache_dir = f"cache_representations_2/{setting['token_origin']}/sample_{sample_id}"
tokens = load_tokens(cache_dir)
print(tokens.size())

In [None]:
# durch model und base_model jagen um hidden states zu kriegen
model_name_descendant = setting["model_hf"]
# model_name_descendant = "Qwen/Qwen2.5-Math-7B-Instruct"
model_name_base = setting["base_model_hf"]

# model_name_base = "meta-llama/Llama-3.1-8B"
# model_name_descendant = "meta-llama/Llama-3.1-8B-Instruct"
# model_name_base = "meta-llama/Llama-3.2-3B"
# model_name_descendant = "meta-llama/Llama-3.2-3B-Instruct"
# model_name_base = "meta-llama/Llama-3.2-3B-Instruct"
# model_name_descendant = "zztheaven/Llama-3.2-3B-Instruct-Open-R1-GRPO-2"
# model_descendant, tokenizer_descendant = get_model_and_tokenizer(
#     model_id=model_name_descendant, device="cuda:0", torch_dtype=torch.float16
# )
# model_base, tokenizer_base = get_model_and_tokenizer(
#     model_id=model_name_base, device="cuda:2", torch_dtype=torch.float16
# )


model_descendant, tokenizer_descendant = get_model_and_tokenizer(model_id=model_name_descendant, device="cuda:5")
model_base, tokenizer_base = get_model_and_tokenizer(model_id=model_name_base, device="cuda:2")
model_base.requires_grad_(False)
model_descendant.requires_grad_(False)


In [None]:
# from transformers import AutoTokenizer

# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-7B")
# input_str = tokenizer.decode(tokens)
# tokens = tokenizer_base(input_str, return_tensors="pt")["input_ids"][0]
# tokens.size()

In [None]:
# (tokenizer_descendant(input_str, return_tensors="pt")["input_ids"][0] == tokens).sum()

In [None]:
model_base.model.embed_tokens.weight.device
model_base.device

In [None]:
h_base = get_hidden_states(model_base, tokens)
h_descendant = get_hidden_states(model_descendant, tokens)


In [None]:
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from tabulate import tabulate
from tqdm import tqdm

CREATE_FIGURE = False


def pca_svd(data):
    """
    Computes principal components and explained variance using the SVD method.

    Args:
        data (torch.Tensor): A tensor of shape (N, D) where N is the number of samples
                             and D is the number of features.

    Returns:
        tuple: A tuple containing:
            - torch.Tensor: The principal components (right singular vectors), shape (D, D).
            - torch.Tensor: The explained variance for each component.
    """
    # 1. Center the data
    mean = torch.mean(data, dim=0)
    centered_data = data - mean

    # 2. Perform SVD
    # torch.linalg.svd returns U, S, and Vh (V transpose)
    # The singular values are returned in descending order. [3]
    _, S, Vh = torch.linalg.svd(centered_data)

    # The principal components are the right singular vectors (V), which is Vh.T
    principal_components = Vh.T

    # 3. Calculate explained variance
    # The eigenvalues of the covariance matrix are the square of the singular values
    # divided by (n_samples - 1).
    eigenvalues = (S**2) / (data.shape[0] - 1)
    total_variance = torch.sum(eigenvalues)
    explained_variance = eigenvalues / total_variance

    return principal_components, explained_variance


@torch.no_grad()
def pc_projection(data: torch.Tensor, principal_components: torch.Tensor, k: int):
    if k < 1:
        raise ValueError(f"Need to project at least one dimension, but {k=}.")
    return data @ principal_components[:, :k]


def add_to_data(grad, kernel, data, layer_idx, model_name):
    data.append(
        {
            "grad_min": grad.min().item(),
            "grad_max": grad.max().item(),
            "grad_mean": grad.mean().item(),
            "grad_median": grad.median().item(),
            "kernel_min": kernel.min().item(),
            "kernel_max": kernel.max().item(),
            "kernel_mean": kernel.mean().item(),
            "kernel_median": kernel.median().item(),
            "layer": layer_idx,
            "model": model_name,
        }
    )


@torch.no_grad()
def add_downsampled_heatmap(grad, ax, bottom_row: bool):
    downsampler = torch.nn.MaxPool2d(16)
    downsampled = downsampler(grad.unsqueeze(0).unsqueeze(0))
    sns.heatmap(downsampled.squeeze(0).squeeze(0), ax=ax, cmap="viridis")
    if bottom_row:
        ax.set_xlabel("Token Index")
        ax.set_ylabel("")  # Remove y-axis label
    else:
        ax.set_title(f"Layer {layer_idx}")
        ax.set_xlabel("")  # Remove x-axis label for the top row
        ax.set_ylabel("")  # Remove y-axis label


# --- Setup for Combined Plot ---
# Determine the number of layers to plot (excluding the embedding layer)
num_layers_to_plot = min(len(h_base), len(h_descendant)) - 1
if CREATE_FIGURE and num_layers_to_plot > 0:
    # Create a figure with subplots: 2 rows, one column per layer
    # Adjust figsize for better readability; width scales with the number of layers.
    fig, axes = plt.subplots(nrows=2, ncols=num_layers_to_plot, figsize=(4 * num_layers_to_plot, 4 * 2 + 1))
    # If there's only one layer, axes will be a 1D array, so we reshape it to 2D
    if num_layers_to_plot == 1:
        axes = axes.reshape(2, 1)

# kernel matrices erstellen
# durch cka score backproppen
data = []
pc_stats = []
projs = []

cka_comparison_table = []
correlation_table = []
top_proj_tokens = []
for layer_idx in tqdm(range(1, min(len(h_base), len(h_descendant)))):  # we start with 1 to skip the embedding layer
    # these tensors have size (batch, tokens, dim). batch is 1 for us
    cka_device = "cuda:4"
    X = h_base[layer_idx].to(cka_device).squeeze(0).to(torch.float32)
    Y = h_descendant[layer_idx].to(cka_device).squeeze(0).to(torch.float32)

    X.requires_grad_(True)
    Y.requires_grad_(True)

    gram_x_linear = gram_linear(X)
    gram_y_linear = gram_linear(Y)
    gram_x_linear.retain_grad()
    gram_y_linear.retain_grad()
    linear_cka_score = cka(gram_x_linear, gram_y_linear)
    linear_cka_score.backward()

    # --- Column index for the plot ---
    col_idx = layer_idx - 1

    # --- Plotting Saliency Map for the Base Model (Row 0) ---
    grad = gram_x_linear.grad.cpu()
    kernel = gram_x_linear.cpu()
    add_to_data(grad, kernel, data, layer_idx, model_name_base)
    if CREATE_FIGURE:
        add_downsampled_heatmap(grad, axes[0, col_idx], False)

    grad = gram_y_linear.grad.cpu()
    kernel = gram_y_linear.cpu()
    add_to_data(grad, kernel, data, layer_idx, model_name_descendant)
    if CREATE_FIGURE:
        add_downsampled_heatmap(grad, axes[1, col_idx], True)

    # --- Taking stats of explained variance by principal components
    pc_x, expvar_x = pca_svd(X)
    pc_y, expvar_y = pca_svd(Y)
    pc_stats.append((expvar_x, expvar_y))
    one_dim_proj_x = pc_projection(X, pc_x, 1)
    one_dim_proj_y = pc_projection(Y, pc_y, 1)
    projs.append((one_dim_proj_x, one_dim_proj_y))

    # --- Testing effect of high-pc1-projecting samples on CKA by removing them.
    k = 10
    topk_x = one_dim_proj_x.squeeze(-1).topk(k)
    topk_y = one_dim_proj_y.squeeze(-1).topk(k)

    top_indices = torch.tensor(list(set(topk_x.indices.tolist()) | set(topk_y.indices.tolist())))
    top_pc_projs_x = one_dim_proj_x[top_indices]
    top_pc_projs_y = one_dim_proj_y[top_indices]

    X_removed = X[~top_indices]
    Y_removed = Y[~top_indices]
    gram_x_linear = gram_linear(X_removed)
    gram_y_linear = gram_linear(Y_removed)
    linear_cka_score_wo_highpc_samples = cka(gram_x_linear, gram_y_linear)
    cka_comparison_table.append((layer_idx, linear_cka_score.item(), linear_cka_score_wo_highpc_samples.item()))

    # --- Testing correlation between pc1 projection magnitude and activation magnitude
    corr_x = pearsonr(one_dim_proj_x.cpu().squeeze(-1), torch.linalg.norm(X, dim=1).detach().cpu())
    corr_y = pearsonr(one_dim_proj_y.cpu().squeeze(-1), torch.linalg.norm(Y, dim=1).detach().cpu())
    correlation_table.append((layer_idx, corr_x, corr_y))

    # --- Inspecting tokens with highest pc1 projection
    top_proj_tokens.append([tokenizer_base.decode(t) for t in tokens[top_indices]])


if CREATE_FIGURE and num_layers_to_plot > 0:
    # Set y-axis labels for the first column to act as row titles
    axes[0, 0].set_ylabel(model_name_base, fontsize=12, weight="bold")
    axes[1, 0].set_ylabel(model_name_descendant, fontsize=12, weight="bold")

    # Add a single, overarching title for the entire figure
    fig.suptitle(
        f"Saliency Maps of Gram Matrices per Layer (16x16 downsample, sample={sample_id})", fontsize=16, weight="bold"
    )

    # Adjust layout to prevent titles and labels from overlapping
    # The rect argument makes space for the suptitle
    fig.tight_layout(rect=[0, 0, 1, 0.96])

    # Save the combined figure
    plt.savefig("combined_saliency_maps.png", dpi=300)
    plt.close(fig)  # Close the figure to free up memory
    print("Combined saliency map saved as 'combined_saliency_maps.png'")
else:
    print("No layers were processed to generate a plot.")
# displot über kernel elemente

print(tabulate(cka_comparison_table, headers=["layer", "old cka", "new cka"]))
print(tabulate(correlation_table, headers=["layer", "base corr", "descendant corr"]))

In [None]:
corrs = [x[1].statistic for x in correlation_table]
plt.plot(torch.arange(len(corrs)), corrs, label="base")
corrs = [x[2].statistic for x in correlation_table]
plt.plot(torch.arange(len(corrs)), corrs, label="descendant")
plt.xlabel("layer")
plt.ylabel("pearsonr")
plt.title(f"Pearsonr PC1 projection with representation norm (sample={sample_id})")

ckas = torch.tensor([t[1] for t in cka_comparison_table])
plt.plot(torch.arange(len(corrs)), ckas, label="cka")

plt.legend()

In [None]:
for toks in top_proj_tokens:
    print(sorted(toks))

In [None]:
cka_comparison_table

In [None]:
expvar_x_1 = torch.tensor([vs[0][0] for vs in pc_stats])
expvar_y_1 = torch.tensor([vs[1][0] for vs in pc_stats])
ckas = torch.tensor([t[1] for t in cka_comparison_table])

plt.figure(figsize=(4 * 1.6, 4))
plt.plot(torch.arange(len(expvar_x_1)), 1 - expvar_x_1.cpu(), label="base_model")
plt.plot(torch.arange(len(expvar_y_1)), 1 - expvar_y_1.cpu(), label="descendant")
plt.plot(torch.arange(len(expvar_y_1)), 1 - expvar_y_1.cpu() - expvar_x_1, label="sum of pc expvars")
plt.plot(
    torch.arange(len(expvar_y_1)),
    ckas,
    label="cka",
    # torch.arange(len(expvar_y_1)), subdf[subdf["sample"] == f"sample_{sample_id}"]["cka_linear"].values, label="cka"
)
plt.plot(
    torch.arange(len(expvar_y_1)),
    torch.tensor([t[2] for t in cka_comparison_table]),
    label="cka w/o max pc samples",
)
plt.title(f"sample {sample_id}")
secax = plt.gca().secondary_yaxis(
    "right",
)
secax.set_ylabel("CKA")
plt.legend()
plt.ylabel("1 - Explained variance by PC1 in representations")
plt.show()

# sns.relplot(
#     data=subdf[subdf["sample"] == f"sample_{sample_id}"],
#     x="layer",
#     y="cka_linear",
#     hue="sample",
#     kind="line",
#     legend=True,
#     alpha=0.3,
#     height=4
# )

# bathtub auch mit qwen math vs math-intruct. mit llama3.1 keine bathtub

In [None]:
for i, (proj_base, proj_desc) in enumerate(projs):
    sns.displot(
        proj_desc.cpu(),
        log_scale=True,
    )
    plt.title(f"layer {i}")

In [None]:
expvar_x_1 = torch.tensor([vs[0][0] for vs in pc_stats])
expvar_y_1 = torch.tensor([vs[1][0] for vs in pc_stats])

plt.plot(torch.arange(len(expvar_x_1)), expvar_x_1.cpu(), label="base_model")
plt.plot(torch.arange(len(expvar_y_1)), expvar_y_1.cpu(), label="descendant")
plt.legend()

#### massive activation-style plots

In [None]:
for layer_idx in range(len(h_base)):
    # layer_idx = 7

    # base model
    # X = h_base[layer_idx].to("cpu").squeeze(0).to(torch.float32)
    # tokenizer = tokenizer_base

    # descendant
    X = h_descendant[layer_idx].to("cpu").squeeze(0).to(torch.float32)
    tokenizer = tokenizer_descendant

    k = 3
    topk = torch.linalg.norm(X, dim=1).topk(k)

    fig = plt.figure(figsize=plt.figaspect(1 / k))
    for i, idx in enumerate(topk.indices):
        ax = fig.add_subplot(1, k, i + 1, projection="3d")
        surrounding_tokens = torch.tensor(
            [idx.item() + i for i in range(-4, 4) if 0 < idx.item() + i < tokens.size(0)]
        )

        acts = X[surrounding_tokens].cpu()
        acts.size()

        dims = torch.arange(acts.size(1))

        surrounding_tokens = surrounding_tokens.unsqueeze(0)
        strs = [f"{tokenizer.decode(tokens[i]).replace('\n', '\\n')} ({i})" for i in surrounding_tokens[0]]
        dims = dims.unsqueeze(-1)
        ax.plot_wireframe(surrounding_tokens, dims, acts.abs().T, rstride=0, color="royalblue", linewidth=2.5)
        ax.set_xticks(surrounding_tokens[0])
        ax.set_xticklabels(strs, rotation=45, ha="right")
        ax.set_title(f"{layer_idx=}, tok_pos={idx.item()}")
    plt.show()


#### stats of kernel matrices

In [None]:
df_cka = pd.DataFrame.from_records(data)

In [None]:
df_cka_long = df_cka.melt(
    ["layer", "model"],
    ["grad_min", "grad_max", "grad_mean", "grad_median", "kernel_min", "kernel_max", "kernel_mean", "kernel_median"],
)
df_cka_long.head(3)

In [None]:
import math


def value_to_magnitude(x: float):
    if x > 0:
        return math.floor(math.log10(x))
    elif x < 0:
        return math.floor(math.log10(-1 * x))
    else:
        return 0


df_cka_long["magnitude"] = df_cka_long["value"].apply(value_to_magnitude)
sns.catplot(
    df_cka_long, x="layer", hue="model", y="magnitude", col="variable", sharey=False, col_wrap=4, kind="point"
)
g = sns.catplot(
    df_cka_long,
    x="layer",
    hue="model",
    y="value",
    col="variable",
    sharey=False,
    col_wrap=4,
    kind="point",
    log_scale=True,
    height=3,
    aspect=1.61,
)

In [None]:
assert gram_x_linear.grad is not None
kernel = gram_x_linear.cpu()
grad = gram_x_linear.grad.data.clone().detach().cpu()

# print(f"{grad.min()=}, {grad.max()=}, {grad.mean()=}, {grad.median()=}")
# print(f"{kernel.min().item()=}, {kernel.max()=}, {kernel.mean()=}, {kernel.median()=}")
# sns.heatmap(gram_x_linear.grad.cpu().abs())

with torch.no_grad():
    downsampler = torch.nn.MaxPool2d(32)
    downsampled = downsampler(grad.unsqueeze(0).unsqueeze(0))
    sns.heatmap(downsampled.cpu().abs().squeeze().squeeze())

#### important token pairs 

In [None]:
from typing import Callable

from tabulate import tabulate


@torch.no_grad()
def get_biggest_magnitude_indices_and_values(kernel: torch.Tensor, k: int = 5, largest: bool = True):
    flat_kernel = kernel.view(-1)
    topk = flat_kernel.abs().topk(k=k, largest=largest)
    rows, cols = kernel.size()

    indices = []
    for idx in topk.indices:
        row = idx // cols
        col = idx % cols
        indices.append([row, col])
    indices = torch.tensor(indices)

    return indices, topk.values


@torch.no_grad()
def decode_token_pairs(
    tokens: torch.Tensor,
    idxs: torch.Tensor,
    tokenizer: Callable,
    values: None | torch.Tensor = None,
    title: str = "",
    ctx_len: int = 5,
):
    table = []
    for i, (idx1, idx2) in enumerate(idxs):
        tok1 = tokenizer.decode(tokens[idx1])
        tok2 = tokenizer.decode(tokens[idx2])
        ctx1 = tokenizer.decode(tokens[idx1 - ctx_len : idx1 + ctx_len])
        ctx2 = tokenizer.decode(tokens[idx2 - ctx_len : idx2 + ctx_len])
        val = values[i] if values is not None else None
        table.append((tok1, tok2, val, ctx1, ctx2))
    if title:
        print(title)
    print(tabulate(table, headers=["tok1", "tok2", "val", "context1", "context2"]))
    print()


k = 10
idxs, vals = get_biggest_magnitude_indices_and_values(gram_x_linear, k, largest=True)
decode_token_pairs(tokens, idxs, tokenizer_base, vals, title="base model qwen-math-base-7b")

idxs, vals = get_biggest_magnitude_indices_and_values(gram_y_linear, k, largest=True)
decode_token_pairs(tokens, idxs, tokenizer_descendant, vals, title="descendent model Qwen2.5-Math-7B-Oat-Zero")
