In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

```bash
conda create -n edu4 python=3.11 jupyter matplotlib
```

```bash 
! pip install -U -r requirements.txt
```

```bash
! pip install -U numpy
! pip install -U scikit-learn
```

## Update repository

In [None]:
# ! git pull

## Add import path

In [None]:
import gc
import os
import sys

In [None]:
def add_library_level(level=4):
    suf_path = ['..']
    path = '..'
    for i in range(0, level):
        join_path = suf_path * i
        path = '/'.join(join_path)
        module_path = os.path.abspath(os.path.join(path))
        if module_path not in sys.path:
            sys.path.append(module_path)
            print(f'Appendeding {path}')

In [None]:
add_library_level(level=4)

## Organize imports

In [None]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch
from torch import nn

In [None]:
from sae_lens import SAE, HookedSAETransformer
from transformer_lens.utils import tokenize_and_concatenate

In [None]:
from transformer_lens import HookedTransformer

In [None]:
from datasets import load_dataset

In [None]:
import multiprocessing
from pathlib import Path

In [None]:
from tqdm import tqdm

In [None]:
import seaborn as sns

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from scipy.sparse import csr_matrix

In [None]:
import plotly.express as px

In [None]:
from src.lattmc.fca.utils import *
from src.lattmc.fca.data_utils import *
from src.lattmc.fca.image_utils import *
from src.lattmc.fca.models import *
from src.lattmc.fca.fca_utils import *
from src.lattmc.fca.image_gens import *

In [None]:
from src.lattmc.sae.nlp_sae_utils import *

In [None]:
import logging

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2024

In [None]:
logging.basicConfig(level=logging.INFO)

In [None]:
device = init_device()

In [None]:
torch.__version__

In [None]:
np.__version__

## Initialize Layer

In [None]:
layer_2 = 2
layer_5 = 5
layer_10 = 10
dataset_suffix = 'pile'

## Initialize Path

In [None]:
def init_data(model_name, layer, dataset_suffix='pile'):
    PATH = Path('../data')
    dataset_suffix = 'pile'
    vectors_name = f'{model_name.replace('-', '_')}_{layer}_{dataset_suffix}'
    checkpoint_dir = PATH / 'saes'
    vectors_dir = checkpoint_dir / f'{vectors_name}_vecs'
    matrix_dir = checkpoint_dir / f'{vectors_name}_mats'
    vectors_path = checkpoint_dir / f'{vectors_name}_vecs.npz'
    
    mkdirs(
        checkpoint_dir,
        vectors_dir,
        matrix_dir
    )
    logger.info(f'{matrix_dir = } {vectors_path = }')

    return matrix_dir, vectors_path

## Initialize simple dataset

In [None]:
dataset = load_dataset(
    path='NeelNanda/pile-10k',
    split='train',
    streaming=False,
)

## Initialize model

In [None]:
model_name = 'gpt2-small'
release = 'gpt2-small-res-jb'
sae_id_2 = f'blocks.{layer_2}.hook_resid_pre'
sae_id_5 = f'blocks.{layer_5}.hook_resid_pre'
sae_id_10 = f'blocks.{layer_10}.hook_resid_pre'

In [None]:
matrix_dir_2, vector_path_2 = init_data(model_name, layer_2, dataset_suffix=dataset_suffix)
matrix_dir_5, vector_path_5 = init_data(model_name, layer_5, dataset_suffix=dataset_suffix)
matrix_dir_10, vector_path_10 = init_data(model_name, layer_10, dataset_suffix=dataset_suffix)

In [None]:
net_2 = Text2Sae(
    model_name,
    release,
    sae_id_2,
    device
)

net_5 = Text2Sae(
    model_name,
    release,
    sae_id_5,
    device
)

net_10 = Text2Sae(
    model_name,
    release,
    sae_id_10,
    device
)

In [None]:
v = net_5.embed('President of USA')[0]

In [None]:
v.shape

In [None]:
topK(v[0], 30)

## Generate V Lattice

In [None]:
gc.collect()

In [None]:
matrix_dir_2, matrix_dir_5, matrix_dir_10

In [None]:
len(dataset)

In [None]:
init_matrices(matrix_dir_2, dataset, net_2)
init_matrices(matrix_dir_5, dataset, net_5)
init_matrices(matrix_dir_10, dataset, net_10)

In [None]:
gc.collect()

In [None]:
V_2 = init_vectors(vector_path_2, matrix_dir_2, segment=False)
V_5 = init_vectors(vector_path_5, matrix_dir_5, segment=False)
V_10 = init_vectors(vector_path_10, matrix_dir_10, segment=False)

In [None]:
fca_2 = FCA(V_2)
fca_5 = FCA(V_5)
fca_10 = FCA(V_10)

In [None]:
word_mapper_2 = WordMapper(net_2, dataset, matrix_dir_2)
word_mapper_5 = WordMapper(net_5, dataset, matrix_dir_5)
word_mapper_10 = WordMapper(net_10, dataset, matrix_dir_10)

In [None]:
cu_2 = ConceptUtils(fca_2, word_mapper_2)
cu_5 = ConceptUtils(fca_5, word_mapper_5)
cu_10 = ConceptUtils(fca_10, word_mapper_10)

## Analysis of Particular Word

In [None]:
def search_words(dataset, net, word:str):
    word_indices = dict()
    with tqdm(dataset, desc='Searching words') as pdata:
        for k, d in enumerate(pdata):
            idx = d['text'].rfind(word)
            if idx > 0:
                word_indices[k] = [idx]
                pdata.set_postfix_str(f'Word: {word} found in {idx}')

    return word_indices

In [None]:
dataset

In [None]:
text = 'Test.   '
text.lower().strip()

In [None]:
'understand' in dataset[0]['text']

In [None]:
dataset[0]['text'].rfind('understand')

In [None]:
dataset[0]['text'][11771:11771 + len('understand')]

## Random Text

In [None]:
rtext = 'The quick brown fox jumps over the lazy dog.'

In [None]:
v_rtext = net_10.embed(rtext)[0][1:]
t_rtext = net_10.tokenize(rtext)[0][1:]

In [None]:
cu_10.print_tokens(t_rtext)

## New York City

In [None]:
stext = 'New York City'

In [None]:
word_indices = search_words(dataset, net_10, stext)

In [None]:
word_indices

In [None]:
dataset[8]['text'][121 - len(stext): 121 + 2 * len(text)]

In [None]:
dataset[9551]['text'][9910-10: 9910 + 20]

In [None]:
len(word_indices)

In [None]:
p_keys = np.array(list(word_indices.keys()))

In [None]:
V_p = V_10[p_keys]

In [None]:
v_p = V_10[9551]

In [None]:
v_p = meet_all(V_p)

In [None]:
printTopK(v_p)

In [None]:
printTopK(V_p[0])

In [None]:
v_stext = net_10.embed(stext)[0][1:]
t_stext = net_10.tokenize(stext)[0][1:]

In [None]:
for idx, t in enumerate(t_stext):
    print(f'{idx} {net_10.to_string(t)} {t}')

In [None]:
printTopK(v_stext[0])

In [None]:
printTopK(v_stext[1])

In [None]:
printTopK(v_stext[2])

In [None]:
v_mtext = meet_all(v_stext[1:3])

In [None]:
printTopK(v_mtext)

In [None]:
printTopK(V_p[:, 17754], k=100)

In [None]:
V_p[92][17754]

In [None]:
p_vp, c_cp, w_cp = cu_10.gen_print([14717], [24], context=[10, 10], top_k=400, indices_only=False)

In [None]:
for w in w_cp.words:
    tw = w[0][1]
    cw = w[0][2].replace('<|endoftext|>', '')
    print(f'Token: {repr(tw)},  Context window: {repr(cw)}')

In [None]:
printTopK(w_cp.v)

In [None]:
printTopK(w_cp.vecs[1453])

In [None]:
v_w = meet_all(list(w_cp.vecs.values()))

In [None]:
printTopK(v_w)

In [None]:
v1, c1 = cu_10.gen_concept([11982, 14717], [20, 30])
v1, c2 = cu_10.gen_concept([23026, 14717], [20, 24])

In [None]:
c1, c2

In [None]:
printTopK(c1.v)

In [None]:
printTopK(c2.v)

In [None]:
c_or = c1 | c2

In [None]:
c_or

In [None]:
printTopK(c_or.v)

In [None]:
cu_10.mapper.search_words(c_or, v=c_or.v, context=[10, 10])

In [None]:
w_cp.words

In [None]:
dataset[92]['text'].rfind('president')

## Golden Gate Bridge

In [None]:
gtext = 'Golden Gate Bridge'

In [None]:
v_gb = net_10.embed(gtext)[0][1:]
t_gb = net_10.tokenize(gtext)[0][1:]

In [None]:
cu_10.print_tokens(t_gb)

In [None]:
v_pl = meet(v_gb[1], v_gb[2])

In [None]:
printTopK(v_pl)

In [None]:
p_gb, c_gb, w_gb = cu_10.gen_print([19694], [3], context=[10, 10], top_k=400, indices_only=False)

In [None]:
w_gb.words

In [None]:
p_gb2, c_gb2, w_gb2 = cu_10.gen_print([19694, 14466, 22253], [3, 3, 3], context=[10, 10], top_k=600, indices_only=False)

In [None]:
printTopK(w_gb2.v)

In [None]:
w_gb2.words

## The Booklyn Bridge

In [None]:
btext = "The Brooklyn Bridge"

In [None]:
v_bb = net_10.embed(btext)[0][1:]
t_bb = net_10.tokenize(btext)[0][1:]

In [None]:
cu_10.print_tokens(t_bb)

In [None]:
v_mb = meet(v_bb[1], v_bb[2])

In [None]:
printTopK(v_mb)

In [None]:
p_bb2, c_bb2, w_bb2 = cu_10.gen_print([23638,  4809,  7574,  6183], [3.9, 1.9, 1.5, 1.4], context=[10, 10], top_k=600, indices_only=False)

In [None]:
printTopK(w_bb2.v)

In [None]:
w_bb2.words

## Grand Canyon Park

In [None]:
ctext = 'Grand Canyon Park'

In [None]:
v_gc = net_10.embed(ctext)[0][1:]
t_gc = net_10.tokenize(ctext)[0][1:]

In [None]:
cu_10.print_tokens(t_gc)

In [None]:
v_mc = meet(v_gc[1], v_gc[2])

In [None]:
printTopK(v_mc)

In [None]:
p_gc2, c_gc2, w_gc2 = cu_10.gen_print([17616, 19694], [14.151013  , 12.207326], context=[10, 10], top_k=600, indices_only=False)

In [None]:
printTopK(w_gc2.v)

In [None]:
w_gc2.words

In [None]:
p_gc3, c_gc3, w_gc3 = cu_10.gen_print([12842, 13311,  8423], [7.78711033,  4.21317959,  4.02698946], context=[10, 10], top_k=600, indices_only=False)

In [None]:
printTopK(w_gc3.v)

In [None]:
w_gc3.words

## Card Game Bridge

In [None]:
pbtxt = 'How is the card game bridge played'

In [None]:
v_pb = net_10.embed(pbtxt)[0][1:]
t_pb = net_10.tokenize(pbtxt)[0][1:]

In [None]:
cu_10.print_tokens(t_pb)

In [None]:
v_mp = meet(v_pb[2], v_pb[3])

In [None]:
printTopK(v_mp)

In [None]:
p_pb2, c_pb2, w_pb2 = cu_10.gen_print([4009, 19398, 17011, 23638], [7.505138 , 2.9314897, 1.9517093, 1.3742858], context=[10, 10], top_k=600, indices_only=False)

In [None]:
printTopK(w_pb2.v)

In [None]:
w_pb2.words

## Person Status

In [None]:
sttext = 'As professor of mathematics and computer science said'

In [None]:
v_st = net_10.embed(sttext)[0][1:]
t_st = net_10.tokenize(sttext)[0][1:]

In [None]:
cu_10.print_tokens(t_st)

In [None]:
printTopK(v_st[1])

In [None]:
p_st1, c_st1, w_st1 = cu_10.gen_print([18549], [49], context=[10, 10], top_k=600, indices_only=False)

In [None]:
printTopK(w_st1.v)

In [None]:
printTopK(w_st1.vecs[816])

In [None]:
printTopK(w_st1.vecs[293])

In [None]:
printTopK(meet(w_st1.vecs[816], w_st1.vecs[293]))

In [None]:
printTopK(meet(w_st1.vecs[791], w_st1.vecs[2153]))

In [None]:
p_stv1, c_stv1, w_stv1 = cu_10.gen_print([18549, 13704], [53.454197, 10.211535], context=[10, 10], top_k=600, indices_only=False)

In [None]:
p_stv2, c_stv2, w_stv2 = cu_10.gen_print([18549, 19968], [52.81996, 3.5897946,], context=[10, 10], top_k=600, indices_only=False)

In [None]:
v_mst = meet(v_st[1], v_st[3])

In [None]:
printTopK(v_mst)

In [None]:
p_st2, c_st2, w_st2 = cu_10.gen_print([23638,   363, 18599, 19543], [6.949088  , 2.560136  , 2.4140446 , 2.3404593], context=[10, 10], top_k=600, indices_only=False)

In [None]:
printTopK(w_st2.v)

In [None]:
w_st2.words

In [None]:
p_st3, c_st3, w_st3 = cu_10.gen_print(
    [2502, 23638, 13029, 3957, 13408, 24512, 17993, 23394,], 
    [13.28715515, 8.6833334, 6.49952269, 5.09930706, 4.34795856, 3.87788129, 3.86469817, 3.58735943,], 
    context=[10, 10], top_k=800, indices_only=False)

In [None]:
printTopK(w_st3.v)

In [None]:
w_st3.words

In [None]:
v_mst2 = meet(v_st[4], v_st[6])

In [None]:
printTopK(v_mst2)

## Generate Context and Analyze

In [None]:
gc.collect()

In [None]:
text1 = "The Golden Gate Bridge"
text2 = "The Brooklyn Bridge"
text3 = "The card game Bridge"

In [None]:
sent1 = 'New York City'
sent2 = 'Golden Gate Bridge'
sent3 = 'Grand Canyon Park'

## New York Ciry

In [None]:
v_ns = net_2.embed(sent1)[0][1:]
v_gs = net_2.embed(sent2)[0][1:]

In [None]:
t_ns = net_2.tokenize(sent1)[0][1:]
t_gs = net_2.tokenize(sent2)[0][1:]

In [None]:
t_ns.shape, t_gs.shape

In [None]:
t_ns, t_gs

In [None]:
topK(v_ns[0], 10)

In [None]:
topK(v_gs[0], 10)

In [None]:
m_bg = meet(v_ns[2], v_gs[2])
topK(m_bg, 20)

In [None]:
print(f'Intersection of {net_2.to_string(t_ns[2])} and {net_2.to_string(t_gs[2])}')

In [None]:
top_k = 28

In [None]:
m_cg, c_bg, w_bg = cu_2.gen_print(8448, 18, top_k=top_k)

In [None]:
w_bg

#### Neighboor Analysis

In [None]:
topK(c_bg.v, 20)

In [None]:
n_cg, c_nbg, w_nbg = cu_2.gen_print(1605, 44, context=[8, 8], top_k=top_k)

In [None]:
w_nbg

In [None]:
c_mbg = c_bg & c_nbg
c_mbg

In [None]:
topK(c_mbg.v, 20)

In [None]:
m_vg, c_vmg, w_vgm = cu_2.gen_print(5277, 49, context=[8, 8], top_k=top_k)

In [None]:
w_vgm

## Experiment with Tokenizer

In [None]:
?? net.model.tokenizer.encode

#### Neighboor Analysis for Golden

In [None]:
def gen_neughbors(idxs, vals, context=[8, 8]):
    v_idx, c_idx = gen_concept(idxs, vals)
    words = word_mapper(c_idx, v=v_idx, context=context)

    return v_idx, c_idx, words   

In [None]:
def gen_print(idxs, vals, context=[8, 8]):
    v_idx, c_idx, words = gen_neughbors(idxs, vals, context=context)
    print(f'{c_idx = }')
    print('\n'.join(f'{wd}' for wd in enumerate(words)))

    return v_idx, c_idx, words

In [None]:
printTopK(v_gs[0], k=20)

In [None]:
n_gg, c_gg, w_gg = cu_2.gen_print(19837, 79, context=[8, 8])

In [None]:
v_2gg, c_2gg, ws_gg = cu.gen_print([19837, 21286], [8, 2], context=[8, 8])

In [None]:
v_2gg, c_2gg, ws_gg = cu.gen_print([21286, 4507], [15, 12], context=[8, 8])

In [None]:
v_2gg, c_2gg, ws_gg = cu.gen_print([4507, 14717], [12, 5], context=[8, 8])

In [None]:
v_2gg, c_2gg, ws_gg = cu.gen_print([4507, 14717, 6183], [12, 5, 5], context=[8, 8])

In [None]:
v_2gg, c_2gg, ws_gg = cu.gen_print([14717, 6183], [5, 5], context=[8, 8])

In [None]:
v_2gg, c_2gg, ws_gg = cu.gen_print([6183, 332, 2407], [5, 3, 2], context=[8, 8])

In [None]:
v_2gg, c_2gg, ws_gg = cu.gen_print([332, 2407, 10363], [3, 2, 2], context=[8, 8])

In [None]:
printTopK(c_gg.v, 20)

In [None]:
n_gn, c_gn, w_gn = cu.gen_print(6284, 12, context=[8, 8])

In [None]:
c_mbg = c_bg & c_nbg
c_mbg

In [None]:
printTopK(c_mbg.v, 20)

In [None]:
m_vg, c_vmg, w_vgm = cu.gen_print(23992, 29, context=[8, 8])

## Apple Case

In [None]:
# Example inputs: same word "apple" in two different contexts
text1 = "I ate an apple for breakfast."
text2 = "Apple Inc. unveiled its latest product."

In [None]:
v_a1 = net.embed(text1)[0][1:]
v_a2 = net.embed(text2)[0][1:]

In [None]:
t_a1 = net.tokenize(text1)[0][1:]
t_a2 = net.tokenize(text2)[0][1:]

In [None]:
v_a1.shape, t_a1.shape, v_a2.shape, t_a2.shape, 

In [None]:
for i in range(t_a1.shape[0]): 
    print(f'{i} {net.to_string(t_a1[i])}, {net.to_string(t_a2[i])}\n')

In [None]:
printTopK(v_a1[3], k=20)

In [None]:
v_am = meet(v_a1[3], v_a2[0])
printTopK(v_am)

In [None]:
v_2gg, c_2gg, ws_gg = cu.gen_print([4269, 4809, 23638], [30, 1, 1], context=[8, 8])

In [None]:
printTopK(v_a1[1])

In [None]:
v_ate, c_ate, ws_ate = cu.gen_print([15767], [51], context=[8, 8])

In [None]:
v_ate, c_ate, ws_ate = cu.gen_print([9493, 22952], [18, 8], context=[8, 8])

In [None]:
printTopK(v_a2[1])

In [None]:
v_ate, c_ate, ws_ate = cu.gen_print([23563], [65], context=[8, 8])

In [None]:
v_ate, c_ate, ws_ate = cu.gen_print([9768, 18294], [16 , 11], context=[8, 8])

In [None]:
v_apple_inc = meet(v_a2[0], v_a2[1])
printTopK(v_apple_inc)

In [None]:
v_aplle_inc, c_aplle_inc, ws_aplle_inc = cu.gen_print([17725, 7574], [9 , 6], context=[8, 8])

In [None]:
v_apple_ate = meet(v_a1[1], v_a1[3])
printTopK(v_apple_ate)

In [None]:
val_apple_inc, idx_apple_ink = topK(v_apple_inc, 20)
val_apple_ate, idx_apple_ate = topK(v_apple_ate, 20)

In [None]:
val_apple_inc == val_apple_ate

In [None]:
val_apple_inc, val_apple_ate

In [None]:
idx_apple_ink == idx_apple_ate

In [None]:
idx_apple_ink, idx_apple_ate

In [None]:
v_aplle_ate, c_aplle_ate, ws_aplle_ate = cu.gen_print([11930, 19398], [6 , 2], context=[8, 8])

In [None]:
v_aplle_ate, c_aplle_ate, ws_aplle_ate = cu.gen_print([19398, 22952], [2 , 2], context=[8, 8])

## Sentence generation

In [None]:
text_president = 'President of USA Abraham Lincoln'
text_thanks = 'Thanks for checking out our product'
text_jedi = 'Return of the Jedi from 1983'

In [None]:
v_pr = net_2.embed(text_president)[0][1:]
t_pr = net_2.tokenize(text_president)[0][1:]

In [None]:
for k, t in enumerate(t_pr):
    print(f'{k} {t} {net_2.to_string(t)}')

In [None]:
net_2.to_string(t_pr)

In [None]:
printTopK(v_pr[0], k=20)

In [None]:
v_ipr, c_pres, ws_pres = cu_2.gen_print([3420], [16], context=[8, 8], top_k=top_k)

In [None]:
printTopK(c_pres.v)

In [None]:
v_inp, c_npres, ws_npres = cu_2.gen_print([5277], [15], context=[8, 8], top_k=top_k)

In [None]:
ws_npres

In [None]:
printTopK(v_pr[3], k=20)

In [None]:
v_iab, c_abrm, ws_abrm = cu_2.gen_print([6545], [10], context=[8, 8], top_k=top_k)

In [None]:
v_cabr, c_cabrm, ws_cabrm = cu_2.gen_print([23578, 11719], [7, 4], context=[8, 8])

In [None]:
printTopK(v_pr[4], k=20)

In [None]:
v_ilc, c_lincn, ws_lincn = cu_2.gen_print([4869], [39], context=[8, 8], top_k=top_k)

In [None]:
c_abrm_lincn = c_abrm & c_lincn
c_abrm_lincn

In [None]:
printTopK(c_abrm_lincn.v)

In [None]:
c_pres_abrm = c_pres & c_abrm
c_pres_abrm

In [None]:
printTopK(c_pres_abrm.v)

In [None]:
v_ipa, c_pa, ws_pa = cu_2.gen_print([12756, 14845], [60, 40], context=[8, 8], top_k=top_k)

In [None]:
ws_pa

## Analyze 10 Layer

In [None]:
v_10p = net_10.embed(text_president)[0][1:]
t_10p = net_10.tokenize(text_president)[0][1:]

In [None]:
for k, t in enumerate(t_10p):
    print(f'{k} {t} {net_10.to_string(t)}')

In [None]:
printTopK(v_10p[0])

In [None]:
m_p10, c_p10, words_p10 = cu_10.gen_print([14245], [12], context=[8, 8], top_k=top_k)

In [None]:
words_p10.words

In [None]:
word_count_17754 = 0
for _, v in words_p10.word_indices.items():
    word_count_17754 += len(v)

In [None]:
word_count_17754

In [None]:
words_p10.words

In [None]:
m_pc10, c_pc10, words_pc10 = cu_10.gen_print([17754, 14245], [67, 12], context=[8, 8], top_k=top_k)

In [None]:
words_pc10.word_indices

In [None]:
word_count_17754_14245 = 0
for _, v in words_pc10.word_indices.items():
    word_count_17754_14245 += len(v)

In [None]:
word_count_17754_14245

In [None]:
tokens_count = 0
with tqdm(dataset) as pdata:
    for d in pdata:
        tkns = net_10.tokenize(d['text'])[0]
        tokens_count += len(tkns)

In [None]:
tokens_count_1 = 5465620

In [None]:
tokens_count_0 = 5475620

In [None]:
tokens_count_0 - tokens_count_1

In [None]:
p_17754 = word_count_17754 / tokens_count_0

In [None]:
p_17754_14245 = word_count_17754_14245 / tokens_count_0

In [None]:
p_17754

In [None]:
p_17754_14245

In [None]:
p_17754_17754_14245 = p_17754_14245 / p_17754

In [None]:
p_17754_17754_14245

In [None]:
print(f'P(17754 | 17754, 14245) = {p_17754_17754_14245:2f}')