### Plot the MI trajectories

In [None]:

import seaborn as sns
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import json
import numpy as np
from transformers import AutoTokenizer

sns.set_theme(style="whitegrid", context="talk", palette="muted", font_scale=1.0)


def load_model_data(model_path, dataset, target_layer=31):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model_name = model_path.split('/')[-1]

    data_path = f'results/mi/{dataset}_gtmodel={model_name}_testmodel={model_name}.pth'
    data = torch.load(data_path)

    all_sample_mi_list = []
    all_mi_peak_list = [] 

    for id in data.keys():
        try:
            this_id_mi_list = data[id]['reps'][target_layer]
            
            all_sample_mi_list.append(this_id_mi_list[:])

            top_indices = sorted(range(len(this_id_mi_list)), 
                            key=lambda i: this_id_mi_list[i], reverse=True)[:20]  # approximately take top-20
            all_mi_peak_list.append(top_indices)

        except Exception as e:
            print(f'[id:{id}] Error:', e)

    return all_sample_mi_list, all_mi_peak_list


dataset = 'math_train_12k'
model_path = 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'
model_name = model_path.split('/')[-1]


model_data_dict = {}

mi_list, mi_peak_list = load_model_data(model_path, dataset, target_layer=31)
model_data_dict[model_name] = {
    'mi': mi_list,
    'peaks': mi_peak_list,
}


fig, axes = plt.subplots(2, 5, figsize=(30, 10))
axes = axes.flatten()

for sample_idx in range(10):
    ax = axes[sample_idx]
    mi_values = model_data_dict[model_name]['mi'][sample_idx]
    steps = model_data_dict[model_name]['peaks'][sample_idx]

    ax.plot(mi_values, linewidth=2, alpha=0.8)
    ax.scatter(steps, [mi_values[i] for i in steps],
                s=50, edgecolor="white", linewidth=0.8, zorder=3)

    ax.grid(axis="y", linestyle="--", alpha=0.4)
    ax.set_facecolor("#fafafa")
    sns.despine(ax=ax, top=True, right=True)

    ax.set_title(model_name, fontsize=22, pad=8)  
    ax.set_xlabel("Reasoning Step", fontsize=18)   
    ax.set_ylabel("MI Value", fontsize=20)
    ax.tick_params(axis="x", labelsize=18)  
    ax.tick_params(axis="y", labelsize=18)


plt.subplots_adjust(hspace=0.4, wspace=0.13)

plt.show()

### Projecting the MI-peak representations to token space 

In [None]:
import re
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import json
import numpy as np
from transformers import AutoTokenizer
from collections import Counter
from matplotlib.colors import LinearSegmentedColormap

sns.set(style="whitegrid", context="talk", font_scale=1.3)


def plot_token_freq(model_path, dataset, target_layer=31):

    model_name = model_path.split('/')[-1]
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    data_path = f'results/mi/{dataset}_gtmodel={model_name}_testmodel={model_name}.pth'
    data = torch.load(data_path)

    acts = torch.load(f'acts/reasoning_evolve/{dataset}_{model_name}.pth')

    all_sample_mi_list = [] 
    all_mi_peak_list = []
    all_tokens = []

    fail_id_list = []
    for id in data.keys():
        try:
            this_id_mi_list = data[id]['reps'][target_layer] 
            all_sample_mi_list.append(this_id_mi_list[:]) 
            top_indices = sorted(range(len(this_id_mi_list)), key=lambda i: this_id_mi_list[i], reverse=True)[:20] # approximately take top-20

            all_mi_peak_list.append(top_indices)

            token_list = acts[id]['token_ids'].tolist()
            token_list.append(2)  # [eos] token
            top_prob_token_ids = [token_list[i] for i in top_indices]

            batch_top_n_tokens = tokenizer.batch_decode(top_prob_token_ids, skip_special_tokens=False)
            all_tokens.extend(batch_top_n_tokens)

        except Exception as e:
            fail_id_list.append(id)

    print('fail_id_list:', fail_id_list)


    english_pattern = re.compile(r'^[a-zA-Z]+$')

    processed_all_tokens = []
    for token in all_tokens:
        if english_pattern.match(token.strip()):
            processed_all_tokens.append(token)

    
    token_freq = Counter(processed_all_tokens)

    common_tokens = token_freq.most_common(15)
    print('$'*50)
    print(f'model: {model_name}')
    print("Most common tokens:", common_tokens)

    colors = [
        (0.0, "#3d61aa"),  
        (0.5, "#b1bee9"),   
        (1.0, "#9673c4")    
    ]
    cmap = LinearSegmentedColormap.from_list("blue_purple", colors)


    # ------------------------------ plot --------------------------------------

    token_names, token_counts = zip(*common_tokens)
    token_names_processed = [token.replace('$', '\$').replace('_', '\_').replace('^', '\^') for token in token_names]
    token_names_repr = [repr(token) for token in token_names_processed]

    n_bars = len(token_names_repr)
    palette = [cmap(i / (n_bars - 1)) for i in range(n_bars)]

    plt.figure(figsize=(8, 5))
    sns.barplot(x=list(token_names_repr), y=list(token_counts), palette=palette)
    plt.xticks(rotation=45, ha='right', size=17)
    plt.xlabel('Tokens at MI Peaks')
    plt.ylabel('Frequency')
    plt.title(f'{model_name}')


    plt.show()
    

dataset = 'math_train_12k'
model_path = 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'


plot_token_freq(
    model_path=model_path,
    dataset=dataset,
    target_layer=31,
)

