In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

```bash
conda create -n edu4 python=3.11 jupyter matplotlib
```

```bash 
! pip install -U -r requirements.txt
```

```bash
! pip install -U numpy
! pip install -U scikit-learn
```

## Update repository

In [None]:
# ! git pull

## Add import path

In [None]:
import os
import sys
import gc

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
module_path = os.path.abspath(os.path.join('../../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
del module_path

## Organize imports

In [None]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch
from torch import nn

In [None]:
from sae_lens import SAE, HookedSAETransformer
from transformer_lens.utils import tokenize_and_concatenate

In [None]:
from transformer_lens import HookedTransformer

In [None]:
from datasets import load_dataset

In [None]:
import multiprocessing
from pathlib import Path

In [None]:
from tqdm import tqdm

In [None]:
import seaborn as sns

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from scipy.sparse import csr_matrix

In [None]:
import plotly.express as px

In [None]:
from src.lattmc.fca.utils import *
from src.lattmc.fca.data_utils import *
from src.lattmc.fca.image_utils import *
from src.lattmc.fca.models import *
from src.lattmc.fca.fca_utils import *
from src.lattmc.fca.image_gens import *

In [None]:
from src.lattmc.sae.nlp_sae_utils import *

In [None]:
import logging

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2024

In [None]:
logging.basicConfig(level=logging.INFO)

In [None]:
device = init_device()

In [None]:
torch.__version__

In [None]:
np.__version__

## Initialize Path

In [None]:
PATH = Path('../data')
dataset_suffix = 'pile'
vectors_name = f'gpt2_small_10_{dataset_suffix}'
checkpoint_dir = PATH / 'saes'
checkpoint_dir.mkdir(exist_ok=True, parents=True)
vectors_dir = checkpoint_dir / f'{vectors_name}_vecs'
vectors_dir.mkdir(exist_ok=True, parents=True)
matrix_dir = checkpoint_dir / f'{vectors_name}_mats'
matrix_dir.mkdir(exist_ok=True, parents=True)
vectors_path = checkpoint_dir / f'{vectors_name}_vecs.joblib'

image_dir = PATH / 'images'
image_path = image_dir / '1024.png'

## Initialize simple dataset

In [None]:
dataset = load_dataset(
    path='NeelNanda/pile-10k',
    split='train',
    streaming=False,
)

## Initialize model

In [None]:
layer = 10

In [None]:
def gen_concept(idx, val, shape=24576):
    v_idx = np.zeros((shape,), dtype=float)
    v_idx[idx] = val
    concept = fca.G_FG(v_idx)
    
    return v_idx, concept

In [None]:
model_name = 'gpt2-small'
release = 'gpt2-small-res-jb'
sae_id = f'blocks.{layer}.hook_resid_pre'
net = Text2Sae(
    model_name,
    release,
    sae_id,
    device
)

## Generate V Lattice

In [None]:
gc.collect()

In [None]:
init_matrices(matrix_dir, dataset, net)

In [None]:
gc.collect()

In [None]:
V = init_vectors(vectors_path, matrix_dir, segment=False)

## Generate Context and Analyze

In [None]:
gc.collect()

In [None]:
fca = FCA(V)

In [None]:
text1 = " The Golden Gate Bridge"
z = net.encode(text1)
zs = z.to('cpu').detach().numpy()[0]
v = np.maximum.reduce(zs)

In [None]:
zs[1:].shape

In [None]:
zs.shape

In [None]:
v[19837]

## Golden Feature from Neuronscope

In [None]:
v_golden, concept = gen_concept(19837, 90)
v_golden.shape, concept

In [None]:
dataset[concept.A[0].item()]

## Cat Feature from Neuronscope

In [None]:
v_cat, concept = gen_concept(16899, 80)
v_cat.shape, concept

In [None]:
dataset[concept.A[0].item()]

## Apple Feature from Neuronscope

In [None]:
v_apple, concept = gen_concept(4269, 70)
v_apple.shape, concept

In [None]:
dataset[concept.A[0].item()]

## Thunder and Lightning Feature from Neuronscope

In [None]:
v_thunder, concept = gen_concept(23123, 40)
v_thunder.shape, concept

In [None]:
dataset[concept.A[0].item()]

## School and Lightning Feature from Neuronscope

In [None]:
v_school, concept = gen_concept(20781, 84)
v_school.shape, concept

In [None]:
dataset[concept.A[0].item()]

## King and Lightning Feature from Neuronscope

In [None]:
v_king, concept = gen_concept(17624, 60)
v_king.shape, concept

In [None]:
dataset[concept.A[1].item()]

## Orgs and Lightning Feature from Neuronscope

In [None]:
v_orgs, concept = gen_concept(16660, 60)
v_orgs.shape, concept

In [None]:
dataset[concept.A[0].item()]

## States and Lightning Feature from Neuronscope

In [None]:
v_state, concept = gen_concept(22, 10)
v_state.shape, concept

In [None]:
dataset[concept.A[0].item()]

## Joint concepts

In [None]:
# topK(v_orgs, 4), topK(v_school, 4)

In [None]:
v_orgs = np.zeros_like(to_numpy(v))
v_school = np.zeros_like(to_numpy(v))
v_having = np.zeros_like(to_numpy(v))

In [None]:
v_orgs[16660] = 20
v_school[17624] = 20
v_having[17935] = 6

In [None]:
orgs = fca.G_FG(v_orgs)
schools = fca.G_FG(v_school)
havings = fca.G_FG(v_having)
orgs, schools, havings

In [None]:
org_school = orgs & schools
org_school

In [None]:
dataset[org_school.A[0].item()]['text']

In [None]:
org_school_havings = orgs & schools & havings 
org_school_havings

In [None]:
dataset[org_school_havings.A[0].item()]['text']

## Recipie Feature from Neuronscope

In [None]:
v_recipie, concept = gen_concept(7, 28)
v_recipie.shape, concept

In [None]:
dataset[concept.A[0].item()]

## Election Feature from Neuronscope

In [None]:
v_elect, concept = gen_concept(29, 28)
v_elect.shape, concept

In [None]:
dataset[concept.A[0].item()]

## Mixture of Election and Party

In [None]:
v_elect, concept_elects = gen_concept(29, 2)
v_elect.shape, concept

## Food Features

In [None]:
texts = [
    'food recipie',
    'love',
    'admire',
    'sex',
]

In [None]:
v = net.encode(texts[2])[0][1]
concept = fca.G_FG(v / 4)
concept

In [None]:
dataset[concept.A[3].item()]

In [None]:
topK(v, 10)

In [None]:
v_love, recipie = gen_concept(14654, 20)
v_love.shape, concept

In [None]:
v_admire, recipie = gen_concept(14990, 20)
v_admire.shape, concept

In [None]:
dataset[concept.A[0].item()]

## Verbal Feature from Neuronscope

In [None]:
v_verbal, concept = gen_concept(33, 40)
v_verbal.shape, concept

In [None]:
dataset[concept.A[0].item()]

## Border Feature from Neuronscope

In [None]:
v_border, concept = gen_concept(35, 20)
v_border.shape, concept

In [None]:
dataset[concept.A[0].item()]

## Cross Concept

In [None]:
v_verbal, concept_verbal = gen_concept(33, 16)
v_border, concept_border = gen_concept(35, 16)

In [None]:
cross = concept_verbal & concept_border
cross

In [None]:
dataset[concept.A[1].item()]

## Cross Concept Probability

In [None]:
v_posib, concept_posib = gen_concept(1061, 20)
v_posit, concept_posit = gen_concept(809, 20)
v_liklh, concept_liklh = gen_concept(418, 20)
v_qualt, concept_qualt = gen_concept(129, 20)

In [None]:
cross_pr_ps = concept_posib & concept_posit
cross_pr_ps

In [None]:
dataset[cross_pr_ps.A[0].item()]

In [None]:
cross_lk_ql = concept_liklh & concept_qualt
cross_lk_ql

In [None]:
dataset[cross_lk_ql.A[0].item()]

## Context Features

In [None]:
v_peval, concept_peval = gen_concept(3109, 20)
v_evstt, concept_evstt = gen_concept(3784, 20)
v_sosis, concept_sosis = gen_concept(3702, 20)
v_agevs, concept_agevs = gen_concept(3388, 20)

In [None]:
concept_pe_ev = concept_peval & concept_evstt
concept_pe_ev

In [None]:
dataset[concept_pe_ev.A[0].item()]

In [None]:
concept_so_ag = concept_sosis & concept_agevs
concept_so_ag

In [None]:
dataset[concept_so_ag.A[0].item()]

In [None]:
concept_all = concept_peval & concept_evstt & concept_sosis & concept_agevs
concept_all

In [None]:
dataset[concept_all.A[0].item()]

## Computer and Technilogies Concepts

In [None]:
v_comps, concept_comps = gen_concept(20542, 14)
v_comps.shape, concept_comps

In [None]:
dataset[concept_comps.A[0].item()]

## Innovation Features

In [None]:
v_innov, concept_innov = gen_concept(2503, 10)
v_innov.shape, concept_innov

In [None]:
dataset[concept_innov.A[0].item()]

In [None]:
concept_co_in = concept_comps & concept_innov
concept_co_in

In [None]:
dataset[concept_co_in.A[0].item()]

## Add Examples

In [None]:
text1 = " The Golden Gate Bridge"
z = net.encode(text1)
tokens = net.tokenize(text1)
print(torch.topk(z, 12))
print()
print(z)

In [None]:
r = net.encode('Golden')
torch.topk(r[0][1], 50)

In [None]:
torch.nonzero(r[0][1]).shape

In [None]:
texts = [
    "Golden Gate Bridge",
    "New York City",
    "Silicon Valley",
    "The White House",
    "Apple Inc."
]

In [None]:
texts = [
    "School of AI",
    "School of Economics",
    "School of Medicine",
    "School of Arts",
    "School of Technologies"
]

In [None]:
texts = "Golden Retriever"
tokens = net.tokenize(texts)
tokens, net.to_string(tokens)

In [None]:
texts = "Gold"
tokens = net.tokenize(texts)
vs = net.encode(texts)
tokens, net.to_string(tokens)

In [None]:
net.to_string(32378)

In [None]:
vs = net.encode(texts)

In [None]:
vs.shape, vs[0].shape

In [None]:
v_t = [v[1] for v in vs[:3]]
v_t[0].shape

In [None]:
v_A = v_t[0]
for v in v_t:
    v_A = torch.minimum(v_A, v)

In [None]:
vs[0][0].shape, 19837

In [None]:
torch.topk(vs[0][1], 20)

In [None]:
torch.topk(vs[0][1], 20)

In [None]:
torch.topk(vs[0][1], 20)

In [None]:
torch.topk(v_A, 30)

## Golden Second Index

In [None]:
v_1, concept_1 = gen_concept(21286, 16)
v_1.shape, concept_1

In [None]:
dataset[concept_1.A[0].item()]

## Gold Second Index

In [None]:
v_2, concept_2 = gen_concept(9572, 30)
v_2.shape, concept_2

In [None]:
dataset[concept_2.A[0].item()]

In [None]:
concept_m = concept_1 & concept_2
concept_m

In [None]:
dataset[concept_m.A[1].item()]

In [None]:
tk = topK(V[0], 30)
tk

In [None]:
v_A = np.zeros_like(v_t[0].to('cpu').detach().numpy())
v_A[17943] = 62
v_A[11811] = 6
v_A[15823] = 4
v_A[4507] = 4
v_A[20161] = 4
concept = fca.G_FG(v_A)
concept

In [None]:
concept.A.shape, V.shape, V.shape[0] - concept.A.shape[0]

In [None]:
concept.A[0]

In [None]:
dataset[concept.A[2].item()]

## Analyze Tokens

In [None]:
len(dataset)

In [None]:
checks = []
with tqdm(list(range(len(dataset)))) as pdata:
    for idx in pdata:
        tokens = net.tokenize(dataset[idx]['text'])[0]
        vs = joblib.load(matrix_dir / f'{idx}.joblib')
        checks.append(tokens.shape[0] == vs.shape[0])

In [None]:
np.all(checks)

In [None]:
from IPython.display import clear_output
v_paths = list(matrix_dir.glob('*.joblib'))
T_dict = {}
W_dict = {}
with tqdm(concept.A) as v_ppaths:
    for idx in v_ppaths:
        vs = joblib.load(matrix_dir / f'{idx}.joblib').toarray()
        G_x = find_G_x(vs, v_A)
        if G_x.shape[0] > 0:
            T_dict[idx.item()] = G_x
        clear_output(wait=True)

In [None]:
len(T_dict)

In [None]:
T_dict[5]

In [None]:
tokens = net.tokenize(dataset[5]['text'])

In [None]:
net.to_string(tokens[0][100])

In [None]:
T_dict[5]

In [None]:
net.to_string(tokens[0][0]), net.to_string(tokens[0][T_dict[5]])

In [None]:
W_dict = {}
for k, v in T_dict.items():
    tokens = net.tokenize(dataset[k]['text'])[0]
    W_dict[k] = net.to_string(tokens[v])

In [None]:
W_dict