In [2]:

import numpy as np
import torch
import tqdm
import einops
import re
from jaxtyping import Int, Float
from typing import List, Optional, Any
from torch import Tensor
import json
import os
from torch.utils.data import Dataset, DataLoader
import random
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
import random
from transformer_lens.utils import get_act_name
from IPython.display import display, HTML
import plotly.express as px

In [5]:

generation_dict = torch.load("generation_dicts/gemma2_generation_temps_dict.pt", map_location="cpu")


hypen_tok_id = 235290
break_tok_id = 108
eot_tok_id = 107
blanck_tok_id = 235248

In [6]:

def get_stats(dict_toks):
    """
    For each topic create a dictionary of stats
    For each generated list get:
    - Number of tokens
    - Number of items in the list
    - Average number of tokens per item
    - Item positions in which blank tokens are foung
    """
    stats_dict = {}
    for topic, temp_dict in dict_toks.items():
        stats_dict[topic] = {}
        for temp,tok_list in temp_dict.items():
            stats_dict[topic][temp] = []
            for toks in tok_list:
                toks = toks.squeeze()

                hypen_positions = torch.where(toks == hypen_tok_id)[0].to("cpu")
                break_positions = torch.where(toks == break_tok_id)[0].to("cpu")
                eot_positions = torch.where(toks == eot_tok_id)[0].to("cpu")
                filter_break_pos = [pos.item() for pos in break_positions if pos+1 in hypen_positions]
                topic_spans = [(hypen_positions[i].item(),hypen_positions[i+1].item()) for i in range(len(hypen_positions)-1)] +[(hypen_positions[-1].item(),eot_positions[-1].item())]
                token_spans = []
                for span in topic_spans:
                    token_spans.append(toks[span[0]:span[1]].tolist())
                num_items = len(token_spans)
                number_of_tokens_per_item = torch.tensor([len(span) for span in token_spans])
                white_space_tok = torch.tensor([235248 in tok_span for tok_span in token_spans])
                white_spaces_tok_pos = torch.where(white_space_tok)[0].to("cpu")

                stats_dict[topic][temp].append({"num_tokens": number_of_tokens_per_item, "num_items": num_items, "avg_tokens_per_item": number_of_tokens_per_item, "blank_positions": white_spaces_tok_pos})
    return stats_dict




stats_dict = get_stats(generation_dict)

In [7]:
import plotly.graph_objects as go

def plot_stats(stats_dict):
    """
    Plot the contents of the stats_dict generated from get_stats in Plotly interactive plots.
    """
    # Create separate lists for each stat
    topics = []
    temperatures = []
    num_tokens_per_topic = []
    avg_tokens_per_topic = []
    num_items_per_topic = []
    
    num_blank_tokens_per_topic = []
    blank_token_positions = []

    for topic, temp_stats_dict in stats_dict.items():
        # Iterate through the temperature stats for each topic
        for temp, stats_list in temp_stats_dict.items():
            for stats in stats_list:
                topics.append(topic)
                temperatures.append(temp)
                num_tokens_per_topic.append(stats['num_tokens'].sum().item() / len(stats["num_tokens"]))  # Total number of tokens
                avg_tokens_per_topic.append(stats['avg_tokens_per_item'].float().mean().item())  # Average tokens per item
                num_items_per_topic.append(stats['num_items'])  # Number of items
                num_blank_tokens_per_topic.append(len(stats['blank_positions']))  # Number of blank tokens
                
                # Blank token positions (we need to record topic, temperature, and position)
                blank_token_positions.extend([(topic, temp, pos.item()) for pos in stats['blank_positions']])
    
    # Create the interactive plot with Plotly
    fig = go.Figure()

    # Adding a bar chart for the number of tokens per topic and temperature
    fig.add_trace(go.Bar(
        x=[f"{topic} (Temp: {temp})" for topic, temp in zip(topics, temperatures)],
        y=num_tokens_per_topic,
        name='Number of Tokens',
        marker_color='blue'
    ))

    # Adding a line chart for average tokens per item
    fig.add_trace(go.Scatter(
        x=[f"{topic} (Temp: {temp})" for topic, temp in zip(topics, temperatures)],
        y=avg_tokens_per_topic,
        name='Average Tokens per Item',
        mode='lines+markers',
        marker_color='green'
    ))

    # Adding a bar chart for the number of items per topic and temperature
    fig.add_trace(go.Bar(
        x=[f"{topic} (Temp: {temp})" for topic, temp in zip(topics, temperatures)],
        y=num_items_per_topic,
        name='Number of Items',
        marker_color='orange',
        opacity=0.6
    ))

    # Update the layout for better presentation
    fig.update_layout(
        title="Token Statistics per Topic and Temperature",
        xaxis_title="Topics and Temperatures",
        yaxis_title="Values",
        barmode='group',  # Bars will be shown side by side
        legend=dict(
            x=0.1,
            y=1.1,
            orientation="h"
        ),
        template='plotly_dark'
    )

    # Plot 2: Average tokens per item across topics and temperatures
    fig1 = go.Figure()
    
    fig1.add_trace(go.Scatter(
        x=[f"{topic} (Temp: {temp})" for topic, temp in zip(topics, temperatures)],
        y=avg_tokens_per_topic,
        mode='lines+markers',
        name='Avg Tokens per Item',
        marker=dict(color='green', size=8),
        line=dict(color='green', width=2)
    ))

    fig1.update_layout(
        title="Average Tokens per Item across Topics and Temperatures",
        xaxis_title="Topics and Temperatures",
        yaxis_title="Average Number of Tokens",
        template="plotly_dark"
    )

    # Plot 3: Bar plot for number of blank tokens per topic and temperature
    fig2 = go.Figure()

    fig2.add_trace(go.Bar(
        x=[f"{topic} (Temp: {temp})" for topic, temp in zip(topics, temperatures)],
        y=num_blank_tokens_per_topic,
        name='Number of Blank Tokens',
        marker_color='orange',
        opacity=0.7
    ))

    # Adding a scatter plot for blank token positions
    if blank_token_positions:
        topics_temp, positions = zip(*[(f"{topic} (Temp: {temp})", pos) for topic, temp, pos in blank_token_positions])  # Unzipping topic, temperature, and position data
        fig2.add_trace(go.Scatter(
            x=topics_temp,
            y=positions,
            mode='markers',
            name='Blank Token Positions',
            marker=dict(color='red', size=8, symbol='x')
        ))

    fig2.update_layout(
        title="Blank Token Information across Topics and Temperatures",
        xaxis_title="Topics and Temperatures",
        yaxis_title="Blank Token Count / Position",
        template="plotly_dark",
        showlegend=True
    )
    
    # Show the plots
    fig.show()
    fig1.show()
    fig2.show()

plot_stats(stats_dict)


In [None]:
import plotly.graph_objects as go

def plot_stats(stats_dict):
    """
    Plot the contents of the stats_dict generated from get_stats in Plotly interactive plots.
    """
    # Create separate lists for each stat
    topics = []
    temperatures = []
    num_tokens_per_topic = []
    avg_tokens_per_topic = []
    num_items_per_topic = []
    num_blank_tokens_per_topic = []
    blank_token_positions = []

    for topic, temp_stats_dict in stats_dict.items():
        # Iterate through the temperature stats for each topic
        for temp, stats_list in temp_stats_dict.items():
            for stats in stats_list:
                topics.append(topic)
                temperatures.append(temp)
                num_tokens_per_topic.append(stats['num_tokens'].sum().item() / len(stats["num_tokens"]))  # Total number of tokens
                avg_tokens_per_topic.append(stats['avg_tokens_per_item'].float().mean().item())  # Average tokens per item
                num_items_per_topic.append(stats['num_items'])  # Number of items
                num_blank_tokens_per_topic.append(len(stats['blank_positions']))  # Number of blank tokens

                # Blank token positions (we need to record topic, temperature, and position)
                blank_token_positions.extend([(topic, temp, pos.item()) for pos in stats['blank_positions']])

    # Create the interactive plot with Plotly
    fig = go.Figure()

    # Adding a box plot for the number of tokens per topic and temperature
    fig.add_trace(go.Box(
        y=num_tokens_per_topic,
        x=[f"{topic} (Temp: {temp})" for topic, temp in zip(topics, temperatures)],
        name='Number of Tokens',
        marker_color='blue'
    ))

    # Adding a box plot for average tokens per item
    fig.add_trace(go.Box(
        y=avg_tokens_per_topic,
        x=[f"{topic} (Temp: {temp})" for topic, temp in zip(topics, temperatures)],
        name='Average Tokens per Item',
        marker_color='green'
    ))

    # Adding a box plot for the number of items per topic and temperature
    fig.add_trace(go.Box(
        y=num_items_per_topic,
        x=[f"{topic} (Temp: {temp})" for topic, temp in zip(topics, temperatures)],
        name='Number of Items',
        marker_color='orange'
    ))

    # Update the layout for better presentation
    fig.update_layout(
        title="Token Statistics per Topic and Temperature",
        xaxis_title="Topics and Temperatures",
        yaxis_title="Values",
        boxmode='group',  # Box plots will be shown side by side
        legend=dict(
            x=0.1,
            y=1.1,
            orientation="h"
        ),
        template='plotly_dark'
    )

    # Plot 2: Box plot for average tokens per item across topics and temperatures
    fig1 = go.Figure()

    fig1.add_trace(go.Box(
        y=avg_tokens_per_topic,
        x=[f"{topic} (Temp: {temp})" for topic, temp in zip(topics, temperatures)],
        name='Avg Tokens per Item',
        marker=dict(color='green')
    ))

    fig1.update_layout(
        title="Average Tokens per Item across Topics and Temperatures",
        xaxis_title="Topics and Temperatures",
        yaxis_title="Average Number of Tokens",
        template="plotly_dark"
    )

    # Plot 3: Box plot for number of blank tokens per topic and temperature
    fig2 = go.Figure()

    fig2.add_trace(go.Box(
        y=num_blank_tokens_per_topic,
        x=[f"{topic} (Temp: {temp})" for topic, temp in zip(topics, temperatures)],
        name='Number of Blank Tokens',
        marker_color='orange'
    ))

    # Adding a scatter plot for blank token positions
    if blank_token_positions:
        topics_temp, positions = zip(*[(f"{topic} (Temp: {temp})", pos) for topic, temp, pos in blank_token_positions])  # Unzipping topic, temperature, and position data
        fig2.add_trace(go.Scatter(
            x=topics_temp,
            y=positions,
            mode='markers',
            name='Blank Token Positions',
            marker=dict(color='red', size=8, symbol='x')
        ))

    fig2.update_layout(
        title="Blank Token Information across Topics and Temperatures",
        xaxis_title="Topics and Temperatures",
        yaxis_title="Blank Token Count / Position",
        template="plotly_dark",
        showlegend=True
    )

    # Show the plots
    fig.show()
    fig1.show()
    fig2.show()

# Call the function with your stats_dict
plot_stats(stats_dict)


In [None]:

model = HookedSAETransformer.from_pretrained("gemma-2-2b-it")



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b-it into HookedTransformer
