In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

```bash
conda create -n edu4 python=3.11 jupyter matplotlib
```

```bash 
! pip install -U -r requirements.txt
```

```bash
! pip install -U numpy
! pip install -U scikit-learn
```

## Update repository

In [None]:
# ! git pull

## Add import path

In [None]:
import os
import sys
import gc

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
module_path = os.path.abspath(os.path.join('../../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
del module_path

## Organize imports

In [None]:
from datasets import load_dataset

In [None]:
import multiprocessing
from pathlib import Path

In [None]:
from tqdm import tqdm

In [None]:
import seaborn as sns

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from scipy.sparse import csr_matrix

In [None]:
import plotly.express as px

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

In [None]:
import os

In [None]:
import PIL
from clipscope import ConfiguredViT, TopKSAE

In [None]:
from src.lattmc.fca.utils import *
from src.lattmc.fca.data_utils import *
from src.lattmc.fca.image_utils import *
from src.lattmc.fca.models import *
from src.lattmc.fca.fca_utils import *
from src.lattmc.fca.image_gens import *

In [None]:
from src.lattmc.sae.nlp_sae_utils import *

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2024

In [None]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

## Initialize Path

In [None]:
PATH = Path('data')
checkpoint_dir = PATH / 'saes'
checkpoint_dir.mkdir(exist_ok=True, parents=True)
checkpoint_path1 = checkpoint_dir / 'best-checkpoint-v1.ckpt'
checkpoint_path2 = checkpoint_dir / 'best-checkpoint.ckpt'
fca_path = checkpoint_dir / 'vit_scope_res22_725159424_imagenette_val.joblib'

image_dir = PATH / 'images'
image_path = image_dir / '1024.png'

## Initialize simple dataset

In [None]:
val_dataset = datasets.Imagenette(
    root="./data", 
    split='val',
    size='160px',
    # transform=transform, 
    download=True
)
# Use a subset (e.g., 1000 images) to keep dataset small (<10,000 images)
# subset_size = 1000
# subset = Subset(dataset, range(subset_size))
# dataloader = DataLoader(subset, batch_size=32, shuffle=False)

## Initialize model

In [None]:
layer = 22
resid = 'resid'

In [None]:
# filename_in_hf_repo = '700092672.pt'

In [None]:
# sae = TopKSAE.from_pretrained(
#     checkpoint=filename_in_hf_repo,
#     repo_id='lewington/CLIP-ViT-L-scope', 
#     device=device
# )

In [None]:
filename_in_hf_repo = f'{layer}_{resid}/700092672.pt'
sae = TopKSAE.from_pretrained(checkpoint=filename_in_hf_repo, device=device)

locations = [(layer, resid)]
transformer_name='laion/CLIP-ViT-L-14-laion2B-s32B-b82K'
# transformer = ConfiguredViT(locations, device=device)
transformer = ConfiguredViT(locations, transformer_name, device=device)

In [None]:
transformer

In [None]:
sae

In [None]:
gc.collect()

In [None]:
def topK(a, k):
    return np.argsort(a)[-k:][::-1]

In [None]:
def vec_i_j(v_sm, indices, values):
    v_m = np.zeros_like(v_sm)
    v_m[indices] = values

    return v_m

In [None]:
def show_images(indices, ds):
    for idx in indices:
        plt.imshow(ds[idx][0])
        plt.show()

In [None]:
class Net(object):

    def __init__(
        self, 
        transformer,
        locations
    ):
        self.transformer = transformer
        self.locations = locations

    def encode(self, image):
        activations = self.transformer.all_activations(image)[self.locations[0]]  # shape: (B, token_count, hidden_dim)
        cls_activations = activations[:, 0]  # (B, hidden_dim)

        return cls_activations
        

In [None]:
if any(Path(fca_path).iterdir()):
        logger.info(f'{matrix_dir} is not empty')
    else:
        V_sparse = []
        with tqdm(dataset) as pdata:
            for idx, d in enumerate(pdata):
                t = d['text']
                v = net.encode(t)
                v_sparse = csr_matrix(v.to('cpu').detach().numpy()[0])
                V_sparse.append(v_sparse)
joblib.dump(
    v_sparse,
    matrix_dir / f'{idx}.joblib'
)

In [None]:
if fca_path.exists():
    fca = FCA.load(fca_path)
    V = fca.V
else:
    V = []
    # Process images in batches
    with(tqdm(val_dataset)) as datap:
        for batch_idx, (image, label) in enumerate(datap):
            activations = transformer.all_activations(image)[locations[0]]  # shape: (B, token_count, hidden_dim)
            cls_activations = activations[:, 0]  # (B, hidden_dim)
            
            # Forward pass through the sparse autoencoder
            output = sae.forward_verbose(cls_activations)
            v = output['latent'][0].to('cpu').detach().numpy()
            V.append(v)
    V = np.array(V)
    V_sparse = csr_matrix(V)
    fca = FCA(V_sparse)
    fca.save(fca_path)

In [None]:
V[0].shape, V[0]

In [None]:
v_x = V[-2]

In [None]:
val_dataset[-2][1]

In [None]:
v_y = V[-1000]

In [None]:
val_dataset[-1000][1]

In [None]:
val_dataset.classes

In [None]:
np.nonzero(v_x), np.max(v_x), np.argmax(v_x), v_x[np.argmax(v_x)]

In [None]:
V_sim1 = [(idx, v) for idx, v in enumerate(V) if np.intersect1d(np.nonzero(v), np.nonzero(v_x)).shape[0] > 16]

In [None]:
V_sim2 = [(idx, v) for idx, v in enumerate(V) if np.intersect1d(np.nonzero(v), np.nonzero(v_y)).shape[0] > 16]

In [None]:
v_A = find_v_A(V, np.array([v_s[0] for v_s in V_sim1[2:48]]))

In [None]:
idx_cm = np.nonzero(v_x)
for v_s in V_sim1:
    print(f'{topK(v_s[1], 4)} {np.argmax(v_s[1])} {v_s[1][topK(v_s[1], 4)]} {np.max(v_s[1])}')
    print(f'{np.intersect1d(np.nonzero(v_s[1]), np.nonzero(v_x))}\n')
    idx_cm = np.intersect1d(idx_cm, np.nonzero(v_s[1]))

In [None]:
idx_cm

In [None]:
k = 12
idx_cm = np.nonzero(v_x)
for v_s in V_sim2:
    print(f'{topK(v_s[1], k)} {np.argmax(v_s[1])} {v_s[1][topK(v_s[1], k)]} {np.max(v_s[1])}')
    print(f'{np.intersect1d(np.nonzero(v_s[1]), np.nonzero(v_x))}\n')
    idx_cm = np.intersect1d(idx_cm, np.nonzero(v_s[1]))

In [None]:
idx_cm

In [None]:
len(V_sim1), len(V_sim2)

In [None]:
plt.imshow(val_dataset[V_sim1[2][0]][0])
plt.show()

In [None]:
v_i = vec_i_j(v_x, [64916, 59768], [10, 4])
v_j = vec_i_j(v_x, [64916, 59768], [10, 4])

In [None]:
v_i[53645] = 2.2
v_j[281] = 2.2

In [None]:
v_p = vec_i_j(v_x, [64916, 59768], [10, 3])
v_q = vec_i_j(v_x, [64916, 59768], [10, 3])

In [None]:
v_p[27760] = 2.4
v_q[50826] = 2.6

In [None]:
v_t = vec_i_j(v_x, [64916, 59768], [10, 3])
v_s = vec_i_j(v_x, [64916, 59768], [10, 3])

In [None]:
v_t[64573] = 2.2
v_s[60547] = 2.2

In [None]:
concept_i = fca.G_FG(v_i)
concept_i

In [None]:
show_images(concept_i.A, val_dataset)

In [None]:
concept_j = fca.G_FG(v_j)
concept_j

In [None]:
show_images(concept_j.A, val_dataset)

In [None]:
concept_i_j = concept_i & concept_j
concept_i_j

In [None]:
show_images(concept_i_j.A, val_dataset)

In [None]:
concept_p = fca.G_FG(v_p)
concept_p

In [None]:
concept_q = fca.G_FG(v_q)
concept_q

In [None]:
concept_p_q = concept_p & concept_q
concept_p_q

In [None]:
show_images(concept_p_q.A, val_dataset)

In [None]:
concept_t = fca.G_FG(v_t)
concept_t

In [None]:
concept_s = fca.G_FG(v_s)
concept_s

In [None]:
concept_t_s = concept_t & concept_s
concept_t_s

In [None]:
show_images(concept_t_s.A, val_dataset)

In [None]:
v_a = vec_i_j(v_y, [64916, 50826], [10, 3])
v_b = vec_i_j(v_y, [64916, 50826], [10, 3])

In [None]:
v_a[17005] = 2.2
v_b[15707] = 2.2

In [None]:
concept_a = fca.G_FG(v_a)
concept_a

In [None]:
show_images(concept_a.A, val_dataset)

In [None]:
concept_b = fca.G_FG(v_b)
concept_b

In [None]:
show_images(concept_b.A, val_dataset)

In [None]:
concept_a_b = concept_a & concept_b
concept_a_b

In [None]:
show_images(concept_a_b.A, val_dataset)

In [None]:
v_c = vec_i_j(v_y, [64916, 50826], [10, 3])
v_d = vec_i_j(v_y, [64916, 50826], [10, 3])
v_e = vec_i_j(v_y, [64916, 50826], [10, 3])
v_f = vec_i_j(v_y, [64916, 50826], [10, 3])

In [None]:
v_c[26494] = 1.2
v_d[8588] = 2.2
v_e[36155] = 2.2
v_f[19359] = 1.2

In [None]:
concept_c = fca.G_FG(v_c)
concept_c

In [None]:
concept_d = fca.G_FG(v_d)
concept_d

In [None]:
concept_e = fca.G_FG(v_e)
concept_e

In [None]:
concept_f = fca.G_FG(v_f)
concept_f

In [None]:
concept_c_d = concept_a & concept_b
concept_c_d

In [None]:
concept_c_d_e = concept_a & concept_b & concept_e
concept_c_d_e

In [None]:
concept_c_d_e_f = concept_a & concept_b & concept_e & concept_f
concept_c_d_e_f

In [None]:
show_images(concept_c_d.A, val_dataset)

In [None]:
show_images(concept_c_d_e.A, val_dataset)

In [None]:
show_images(concept_c_d_e_f.A, val_dataset)