In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

```bash
conda create -n edu4 python=3.11 jupyter matplotlib
```

```bash 
! pip install -U -r requirements.txt
```

```bash
! pip install -U numpy
! pip install -U scikit-learn
```

## Update repository

In [None]:
# ! git pull

## Add import path

In [None]:
import gc
import os
import sys

In [None]:
def add_library_level(level=4):
    suf_path = ['..']
    path = '..'
    for i in range(0, level):
        join_path = suf_path * i
        path = '/'.join(join_path)
        module_path = os.path.abspath(os.path.join(path))
        if module_path not in sys.path:
            sys.path.append(module_path)
            print(f'Appendeding {path}')

In [None]:
add_library_level(level=5)

## Organize imports

In [None]:
import multiprocessing

In [None]:
from src.lattmc.fca.utils import *
from src.lattmc.fca.fca_utils import *

In [None]:
from src.lattmc.tc.transcoder_analyzers import ConceptAnalysis, init_analyzer

In [None]:
from src.lattmc.sae.nlp_sae_utils import init_device, gen_concept

In [None]:
import logging

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2025

In [None]:
logging.basicConfig(level=logging.INFO)

In [None]:
device = init_device()
device

In [None]:
torch.__version__

In [None]:
np.__version__

In [None]:
# np.set_printoptions(precision=4, suppress=True)

## Initialize Paths

In [None]:
PATH = Path('data')
GPT2 = PATH / 'transcoders' / 'gpt2'
OWT_TOKENS_DIR = GPT2 / 'owt_tokens'
TOKENS_PATH = OWT_TOKENS_DIR / 'owt_tokens_torch.pt'
OWT_TOKENS_DIR.mkdir(exist_ok=True, parents=True)

## Load trancoders

In [None]:
layers = list(range(12))
layers = [0, 4, 6, 8, 10, 11]

In [None]:
tr_analyzer = init_analyzer(
    layers,
    TOKENS_PATH,
    GPT2,
    device=device
)

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

## Analyze Neuron

In [None]:
c0 = tr_analyzer.gen_and_print_all(9854, 2, 0, with_text=True)
c0

In [None]:
c0.token_idcs

In [None]:
len(c0.token_idcs)

In [None]:
vs = tr_analyzer.transcoder(tr_analyzer.tokens[c0.c.A], 10)

In [None]:
vs.shape

In [None]:
v_indices = []
v_list = []
for token_ids, v_i in zip(c0.token_idcs, vs):
    v_list_i = [v_i[token_id] for token_id in token_ids]
    v_dets_i = meet_all(np.array([v_i[token_id] for token_id in token_ids]))
    v_indices.append(v_dets_i)
    v_list.append(v_list_i)

In [None]:
v_dets = meet_all(np.array(v_indices))

In [None]:
v_list[0]

In [None]:
topK(v_list[1][0], 10)

In [None]:
topK(v_list[10][0], 10)

In [None]:
topK(v_dets, 10)

In [None]:
topK(c0.v_FG, 10)

In [None]:
c0_a = tr_analyzer.gen_and_print_all(4237, 0.09982419, 0, with_text=True, limit=200)

In [None]:
topK(c0_a.v_FG, 10)

In [None]:
c6 = tr_analyzer.gen_and_print_all(11831, 12, 6, with_text=True)
c6

In [None]:
topK(c6.v_FG, 10)

In [None]:
c6_a = tr_analyzer.gen_and_print_all(3084, 1.655886, 6, with_text=True, limit=200)

In [None]:
topK(c6_a.v_FG, 10)

In [None]:
c8 = tr_analyzer.gen_and_print_all(355, 12, 8, with_text=True)
c8

In [None]:
topK(c8.v_FG, 2)

In [None]:
c8_a = tr_analyzer.gen_and_print_all(8919, 1.1454895e-03, 8, with_text=True, limit=200)

In [None]:
topK(c8_a.v_FG, 10)

## Detect colors

In [None]:
concept_an = ConceptAnalysis(
    ' Interesting fact about color purple, which itself is a beautiful colour, there is a band named Deep Purple in 70s', 
    tr_analyzer
)

In [None]:
concept_an.analyze_concepts()

In [None]:
t_idcs = [5, 20]
layer = 0

In [None]:
concept_an.gen_text(t_idcs, layer)

In [None]:
concept_an.c_is[0]

In [None]:
v0_5 = concept_an.v_FG[layer][5]
vals, idcs = topK(v0_5, v0_5.shape[0])
vals, idcs

In [None]:
concept_an.detected_vs[layer][236]

In [None]:
concept_an.detected_vs[layer][926]

In [None]:
v_5 = concept_an.detected_vs[layer][236][5][64]
v_20 = concept_an.detected_vs[layer][236][20][84]

In [None]:
topK(v_5, 10)

In [None]:
topK(v0_5, 10)

In [None]:
concept_an.to_string(concept_an.corpus[236])

In [None]:
concept_an.gen_and_print(idcs, vals, layer, with_text=True, limit=100)

In [None]:
concept_an.v_FG[layer]

In [None]:
layer = 8

In [None]:
concept_an.gen_text(t_idcs, layer)

In [None]:
v8_5 = concept_an.v_FG[layer][5]
vals, idcs = topK(v8_5, 10)
vals, idcs

In [None]:
topK(concept_an.v_is[layer][5], 10)

In [None]:
layer = 11

In [None]:
concept_an.gen_text(t_idcs, layer, limit=100, red_val=2)

In [None]:
v11_5 = concept_an.v_FG[layer][5]
vals, idcs = topK(v11_5, v11_5.shape[0])
vals, idcs

In [None]:
concept_an.gen_and_print(idcs, vals, layer, with_text=True, limit=100)

## Experiments Pos and Negs Black

In [None]:
tok_indx = 143

In [None]:
text_detoken = tr_analyzer.to_clean(tr_analyzer.tokens[tok_indx])
text_detoken

In [None]:
concept_an = ConceptAnalysis(text_detoken, tr_analyzer)

In [None]:
concept_an.analyze_concepts()

In [None]:
i, j = 27, 28
t_idcs = [28, 127]

In [None]:
layer_0 = 0

In [None]:
concept_an.tr_utils.transcoder.background_dets = None

In [None]:
concept_an.gen_text(t_idcs, layer_0)

In [None]:
det_22 = concept_an.detected_vs[layer_0][22]
det_22

In [None]:
v_b1 = det_22[28][32]
v_b2 = det_22[127][32]

In [None]:
v_b3 = det_22[127][96]
v_b4 = det_22[127][119]

In [None]:
v_b = join(v_b1, v_b2)

In [None]:
topK(v_b1, 20)

In [None]:
topK(v_b3, 20)

In [None]:
topK(v_b4, 20)

In [None]:
np.all(v_b1 == v_b2)

In [None]:
layer_8 = 8

In [None]:
concept_an.gen_text(t_idcs, layer_8)

In [None]:
concept_an.detected_vs[layer_8]

In [None]:
layer_11 = 11

In [None]:
concept_an.gen_text(t_idcs, layer_11)

In [None]:
concept_an.detected_vs[layer_11][3]

In [None]:
concept_an.detected_vs[layer_11][8]

In [None]:
v_white = concept_an.detected_vs[layer_11][8][28][62]
v_black = concept_an.detected_vs[layer_11][8][127][34]

In [None]:
topK(v_white, 10)

In [None]:
topK(v_black, 10)

In [None]:
v_meet = meet(v_white, v_black)

In [None]:
vals, indcs = topK(v_meet, 10)
vals, indcs

In [None]:
c_meet = concept_an.gen_concept(indcs[1], vals[1], layer_11)
c_meet

In [None]:
concept_an.gen_and_print(indcs[1:4], vals[1:4], layer_11, with_text=True, limit=20)

In [None]:
concept_an.gen_and_print([21836], [17], layer_11, with_text=True, limit=20)