In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

```bash
conda create -n edu4 python=3.11 jupyter matplotlib
```

```bash 
! pip install -U -r requirements.txt
```

```bash
! pip install -U numpy
! pip install -U scikit-learn
```

## Update repository

In [None]:
# ! git pull

## Add import path

In [None]:
import gc
import os
import sys

In [None]:
def add_library_level(level=4):
    suf_path = ['..']
    path = '..'
    for i in range(0, level):
        join_path = suf_path * i
        path = '/'.join(join_path)
        module_path = os.path.abspath(os.path.join(path))
        if module_path not in sys.path:
            sys.path.append(module_path)
            print(f'Appendeding {path}')

In [None]:
add_library_level(level=4)

## Organize imports

In [None]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
from huggingface_hub import hf_hub_download, notebook_login
import numpy as np
import torch
from torch import nn

In [None]:
from sae_lens import SAE, HookedSAETransformer
from transformer_lens.utils import tokenize_and_concatenate

In [None]:
from transformer_lens import HookedTransformer

In [None]:
from datasets import load_dataset

In [None]:
import multiprocessing
from pathlib import Path

In [None]:
from tqdm import tqdm

In [None]:
import seaborn as sns

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from scipy.sparse import csr_matrix

In [None]:
import plotly.express as px

In [None]:
from src.lattmc.fca.utils import *
from src.lattmc.fca.data_utils import *
from src.lattmc.fca.image_utils import *
from src.lattmc.fca.models import *
from src.lattmc.fca.fca_utils import *
from src.lattmc.fca.image_gens import *

In [None]:
from src.lattmc.sae.nlp_sae_utils import *

In [None]:
import logging

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2024

In [None]:
logging.basicConfig(level=logging.INFO)

In [None]:
device = init_device()

In [None]:
torch.__version__

In [None]:
np.__version__

## Initialize Layer

In [None]:
layer = 4
dataset_suffix = 'pile'

## Initialize Path

In [None]:
def init_data(model_name, layer, dataset_suffix='pile'):
    PATH = Path('../data')
    dataset_suffix = 'pile'
    vectors_name = f'{model_name.replace('-', '_')}_{layer}_{dataset_suffix}'
    checkpoint_dir = PATH / 'saes'
    vectors_dir = checkpoint_dir / f'{vectors_name}_vecs'
    matrix_dir = checkpoint_dir / f'{vectors_name}_mats'
    vectors_path = checkpoint_dir / f'{vectors_name}_vecs.joblib'
    
    mkdirs(
        checkpoint_dir,
        vectors_dir,
        matrix_dir
    )
    logger.info(f'{matrix_dir = } {vectors_path = }')

    return matrix_dir, vectors_path

## Initialize simple dataset

In [None]:
dataset = load_dataset(
    path='NeelNanda/pile-10k',
    split='train',
    streaming=False,
)

## Initialize model

In [None]:
def gen_concept(idx, val, shape=24576):
    v_idx = np.zeros((shape,), dtype=float)
    v_idx[idx] = val
    concept = fca.G_FG(v_idx)
    
    return v_idx, concept

In [None]:
model_name = 'gpt2-small'
release = 'gpt2-small-res-jb'
sae_id = f'blocks.{layer}.hook_resid_pre'

In [None]:
matrix_dir, vector_path = init_data(model_name, layer, dataset_suffix=dataset_suffix)

In [None]:
net = Text2Sae(
    model_name,
    release,
    sae_id,
    device
)

## Generate V Lattice

In [None]:
gc.collect()

In [None]:
matrix_dir

In [None]:
init_matrices(matrix_dir, dataset, net)

In [None]:
gc.collect()

In [None]:
V = init_vectors(vector_path, matrix_dir, segment=False)

## Generate Context and Analyze

In [None]:
gc.collect()

In [None]:
fca = FCA(V)

In [None]:
text1 = "The Golden Gate Bridge"
text2 = "The Brooklyn Bridge"
text3 = "The card game Bridge"

In [None]:
sent1 = 'New York City'
sent2 = 'Golden Gate Bridge'
sent3 = 'Grand Canyon Park'

In [None]:
word_mapper = WordMapper(net, dataset, matrix_dir)

## New York Ciry

In [None]:
v_ns = net.embed(sent1)[0][1:]
v_gs = net.embed(sent2)[0][1:]

In [None]:
t_ns = net.tokenize(sent1)[0][1:]
t_gs = net.tokenize(sent2)[0][1:]

In [None]:
t_ns.shape, t_gs.shape

In [None]:
t_ns, t_gs

In [None]:
topK(v_ns[0], 10)

In [None]:
topK(v_gs[0], 10)

In [None]:
m_bg = meet(v_ns[2], v_gs[2])
topK(m_bg, 20)

In [None]:
m_cg, c_bg = gen_concept(9805, 6)
c_bg

In [None]:
ws = word_mapper(c_bg, v=m_cg)

In [None]:
ws

## Sentence generation

In [None]:
c_golden = fca.G_FG(v_gs[0] / 10)
c_gate = fca.G_FG(v_gs[1] / 10)
c_bridge = fca.G_FG(v_gs[2] / 10)

In [None]:
c_golden, c_gate, c_bridge

In [None]:
c_golden_gate = c_golden & c_gate
c_golden_gate

In [None]:
dataset[c_golden_gate.A[0].item()]['text']

In [None]:
v_nyc = net.embed(sent1)[0][1:]
v_ggb = net.embed(sent2)[0][1:]
v_gcp = net.embed(sent3)[0][1:]
v_nyc.shape, sent1, v_ggb.shape, sent2, v_gcp.shape, sent3

In [None]:
join_all(v_ggb).shape

In [None]:
c_ggb = fca.G_FG(v_ggb[1] / 10)
c_ggb

In [None]:
dataset[c_ggb.A[2].item()]['text']

In [None]:
ws = word_mapper(c_ggb, v=v_ggb[0] / 100)

In [None]:
ws

In [None]:
topK(v_gcp[0], 20)

In [None]:
v_gcm = meet(v_gcp[0], v_gcp[1])
topK(v_gcm, 10)

In [None]:
v_ggm = meet(v_ggb[0], v_ggb[1])
topK(v_ggm, 10)

In [None]:
c_ggm = fca.G_FG(v_ggm)

In [None]:
w_ggm = word_mapper(c_ggm)
w_ggm

In [None]:
items = dataset[c_ggm.A.tolist()]

In [None]:
items

In [None]:
v_i, cn_i = gen_concept(6374, 6)
cn_i

In [None]:
ws = word_mapper(cn_i, v=v_i)

In [None]:
ws