In [2]:
%pip install circuitsvis
import torch as th
from circuitsvis.activations import text_neuron_activations
from transformer_lens import HookedTransformer

device = "cuda" if th.cuda.is_available() else "cpu"

model_1l = HookedTransformer.from_pretrained(
    "solu-1l",
    device = device
)
model_2l = HookedTransformer.from_pretrained(
    "solu-2l",
    device = device
)
model_3l = HookedTransformer.from_pretrained(
    "solu-3l",
    device = device
)

def display_text_probability(text_list, model_list):
    # Displays the probability of each token for each model in model_list 
    assert(isinstance(text_list, list))
    display_text_list = []
    display_target_probs_list = []
    for t in text_list:
        text_split = model_1l.to_str_tokens(t, prepend_bos=False)[1:] #Ignore the first token cause not predicting
        token = model_1l.to_tokens(t, prepend_bos=False)
        for model in model_list:
            logits = model(token)
            probs = logits.log_softmax(-1).exp()
            target_probs = list(probs[0,:-1].gather(-1, token[0,1:].unsqueeze(-1)))
            display_text_list += [x.replace('\n', '\\newline') for x in text_split] + ["\n"]
            display_target_probs_list += target_probs + [0.0]
    display_target_probs_list = th.round(th.tensor(display_target_probs_list).reshape(-1,1,1), decimals=10)
    return text_neuron_activations(tokens=display_text_list, activations=display_target_probs_list)
model_list = [model_1l, model_2l, model_3l]
texts = [
    " 1 2 3 4 5 6 7 8 9",
    " a a a a a a a a a a a",
    "ria Chronicles III Valkyria",
]
display_text_probability(texts, model_list)

  from .autonotebook import tqdm as notebook_tqdm


Loaded pretrained model solu-1l into HookedTransformer


Downloading (…)lve/main/config.json: 100%|██████████| 1.27k/1.27k [00:00<00:00, 150kB/s]
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Downloading (…)"model_final.pth";: 100%|██████████| 227M/227M [00:07<00:00, 32.1MB/s] 


Loaded pretrained model solu-2l into HookedTransformer


Downloading (…)lve/main/config.json: 100%|██████████| 1.27k/1.27k [00:00<?, ?B/s]
Downloading (…)"model_final.pth";: 100%|██████████| 241M/241M [00:06<00:00, 35.9MB/s] 


Loaded pretrained model solu-3l into HookedTransformer


In [94]:
display_target_probs_list.round(decimals=2)

tensor([[[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0600]],

        [[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0100]],

        [[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0300]],

        [[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0000]],

        [[0.0600]],

        [[0.0700]],

        [[0.0500]],

        [[0.0400]],

        [[0.0300]],

        [[0.0300]],

        [[0.0200]],

        [[0.0000]],

        [[0.6000]],

        [[0.6600]],

        [[0.7100]],

        [[0.7100]],

        [[0.7100]],

        [[0.7200]],

        [[0.7200]],

        [[0.7100]],

        [[0.0000]],

        [[0.7100]],

        [[0.5800]],

        [[0.6

In [62]:
print(model_1l.to_string(probs[0,:,:].argmax(-1)))
probs[0,:8,:].max(-1)

 2 2 2 2 2 2 2 2 2 2


torch.return_types.max(
values=tensor([0.1121, 0.1325, 0.5537, 0.2538, 0.6102, 0.2994, 0.6110, 0.3072],
       grad_fn=<MaxBackward0>),
indices=tensor([374, 374, 374, 374, 374, 374, 374, 374]))