In [65]:

# %%
from attribution_utils import calculate_feature_attribution
from collections import defaultdict
from torch.nn.functional import log_softmax
from gemma_utils import get_all_string_min_l0_resid_gemma
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils
from functools import partial
import tqdm
from sae_lens import HookedSAETransformer, SAE, SAEConfig
from gemma_utils import get_gemma_2_config, gemma_2_sae_loader
import numpy as np
import torch
import tqdm
import pandas as pd

import torch.nn.functional as F
import plotly.express as px
import seaborn as sns
import matplotlib.pyplot as plt

In [66]:


model = HookedSAETransformer.from_pretrained("google/gemma-2-2b-it")




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b-it into HookedTransformer


In [67]:

generation_dict = torch.load("gemma2_generation_dict.pt")

In [68]:

def logits_to_entropy(logits):
    log_probs = log_softmax(logits, dim = -1)
    probs = log_probs.exp()
    entropy = -(log_probs*probs).sum(dim = -1)
    return entropy


def logits_to_varentropy(logits):
    log_probs = log_softmax(logits, dim = -1)
    probs = log_probs.exp()
    entropy = -(log_probs*probs).sum(dim = -1)
    elem = (probs*(-log_probs)**2).sum(dim = -1)
    return elem - entropy



In [69]:

all_entropy_hyphen = []
tokens = []
for key in list(generation_dict.keys()):
    val = generation_dict[key]
    for toks in val:
        tokens.append(toks)
        with torch.no_grad():
            hyphen_pos = torch.where(toks[0] == 235290)[0]
            break_pos = torch.where(toks[0] == 108)[0]
            positions = (hyphen_pos[1:]-1).tolist() + [break_pos[-1].item()-2]
            logits = model(toks)
            entropy = logits_to_entropy(logits)
            all_entropy_hyphen.append(entropy[:,positions])



max_size = max(tensor.size(1) for tensor in all_entropy_hyphen)
padded_tensors = []
for tensor in all_entropy_hyphen:
    pad_amount = max_size - tensor.size(1)
    padded_tensor = F.pad(tensor, (pad_amount, 0), "constant", 0)  # Left padding
    padded_tensors.append(padded_tensor)
stacked_entropy_hyphen = torch.cat(padded_tensors,dim = 0)

torch.cuda.empty_cache()




In [70]:

px.imshow(stacked_entropy_hyphen.cpu().numpy(), aspect = 'auto')

In [None]:

all_entropy_hyphen = []
tokens = []
for key in list(generation_dict.keys()):
    val = generation_dict[key]
    for toks in val:
        toks[0,8] = 1497
        tokens.append(toks)
        with torch.no_grad():
            hyphen_pos = torch.where(toks[0] == 235290)[0]
            break_pos = torch.where(toks[0] == 108)[0]
            positions = (hyphen_pos[1:]-1).tolist() + [break_pos[-1].item()-2]
            logits = model(toks)
            entropy = logits_to_entropy(logits)
            all_entropy_hyphen.append(entropy[:,positions])



max_size = max(tensor.size(1) for tensor in all_entropy_hyphen)
padded_tensors = []
for tensor in all_entropy_hyphen:
    pad_amount = max_size - tensor.size(1)
    padded_tensor = F.pad(tensor, (pad_amount, 0), "constant", 0)  # Left padding
    padded_tensors.append(padded_tensor)
stacked_entropy_hyphen = torch.cat(padded_tensors,dim = 0)


In [None]:

px.imshow(stacked_entropy_hyphen.cpu().numpy(), aspect = 'auto')

## Entropy of the long sequence

In [None]:

generation_dict = torch.load("gemma2_generation_long_dict.pt")
all_entropy_hyphen = []
tokens = []
for key in list(generation_dict.keys()):
    val = generation_dict[key]
    for toks in val:
        toks[0,8] = 1497
        tokens.append(toks)
        with torch.no_grad():
            hyphen_pos = torch.where(toks[0] == 235290)[0]
            break_pos = torch.where(toks[0] == 108)[0]
            positions = (hyphen_pos[1:]-1).tolist() + [break_pos[-1].item()-2]
            logits = model(toks)
            entropy = logits_to_entropy(logits)
            all_entropy_hyphen.append(entropy[:,positions])



max_size = max(tensor.size(1) for tensor in all_entropy_hyphen)
padded_tensors = []
for tensor in all_entropy_hyphen:
    pad_amount = max_size - tensor.size(1)
    padded_tensor = F.pad(tensor, (pad_amount, 0), "constant", 0)  # Left padding
    padded_tensors.append(padded_tensor)
stacked_entropy_hyphen = torch.cat(padded_tensors,dim = 0)
torch.cuda.empty_cache()

In [None]:

px.imshow(stacked_entropy_hyphen.cpu().numpy(), aspect = 'auto')

In [None]:

generation_dict = torch.load("gemma2_generation_long_dict.pt")
all_entropy_hyphen = []
tokens = []
for key in list(generation_dict.keys()):
    val = generation_dict[key]
    for toks in val:
        toks[0,8] = 3309
        tokens.append(toks)
        with torch.no_grad():
            hyphen_pos = torch.where(toks[0] == 235290)[0]
            break_pos = torch.where(toks[0] == 108)[0]
            positions = (hyphen_pos[1:]-1).tolist() + [break_pos[-1].item()-2]
            logits = model(toks)
            entropy = logits_to_entropy(logits)
            all_entropy_hyphen.append(entropy[:,positions])



max_size = max(tensor.size(1) for tensor in all_entropy_hyphen)
padded_tensors = []
for tensor in all_entropy_hyphen:
    pad_amount = max_size - tensor.size(1)
    padded_tensor = F.pad(tensor, (pad_amount, 0), "constant", 0)  # Left padding
    padded_tensors.append(padded_tensor)
stacked_entropy_hyphen = torch.cat(padded_tensors,dim = 0)
torch.cuda.empty_cache()

In [None]:

px.imshow(stacked_entropy_hyphen.cpu().numpy(), aspect = 'auto')

## Probabilities

In [None]:

generation_dict = torch.load("gemma2_generation_dict.pt")

In [None]:

def logits_to_prob(logits,pos,tok_id1,tok_id2):
    log_probs = log_softmax(logits, dim = -1)
    probs = log_probs.exp()
    tup = [(probs[0,p,tok_id1].item(),probs[0,p,tok_id2].item()) for p in pos]
    return tup


all_probs = []
all_probs_temp = []
tokens = []
for key in list(generation_dict.keys()):
    val = generation_dict[key]
    for toks in val:
        tokens.append(toks)
        with torch.no_grad():
            hyphen_pos = torch.where(toks[0] == 235290)[0]
            break_pos = torch.where(toks[0] == 108)[0]
            positions = (hyphen_pos[1:]-2).tolist() + [break_pos[-1].item()-2]
            logits = model(toks)
            probs = logits_to_prob(logits, positions, 235248, 108)
            all_probs.append(probs)
            probs = logits_to_prob(logits/1.3, positions, 235248, 108)
            all_probs_temp.append(probs)


In [None]:
logits_to_entropy(logits)#-logits_to_entropy(logits/1.3)