In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

```bash
conda create -n edu4 python=3.11 jupyter matplotlib
```

```bash 
! pip install -U -r requirements.txt
```

```bash
! pip install -U numpy
! pip install -U scikit-learn
```

## Update repository

In [None]:
# ! git pull

## Add import path

In [None]:
import os
import sys
import gc

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
del module_path

## Organize imports

In [None]:
from datasets import load_dataset

In [None]:
import multiprocessing
from pathlib import Path

In [None]:
from tqdm import tqdm

In [None]:
import seaborn as sns

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
import plotly.express as px

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

In [None]:
import os

In [None]:
import PIL
from clipscope import ConfiguredViT, TopKSAE

In [None]:
from src.lattmc.fca.utils import *
from src.lattmc.fca.data_utils import *
from src.lattmc.fca.image_utils import *
from src.lattmc.fca.models import *
from src.lattmc.fca.fca_utils import *
from src.lattmc.fca.image_gens import *

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2024

In [None]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

## Initialize Path

In [None]:
PATH = Path('data')
checkpoint_dir = PATH / 'saes'
checkpoint_dir.mkdir(exist_ok=True, parents=True)
checkpoint_path1 = checkpoint_dir / 'best-checkpoint-v1.ckpt'
checkpoint_path2 = checkpoint_dir / 'best-checkpoint.ckpt'

image_dir = PATH / 'images'
image_path = image_dir / '1024.png'

## Initialize simple dataset

In [None]:
# Define transforms to resize and normalize images as required by the model
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
#                          std=(0.26862954, 0.26130258, 0.27577711)),
# ])

transform = None

# Load CIFAR-10 test dataset (download if needed)
train_dataset = datasets.CIFAR10(
    root="./data", 
    train=True, 
    transform=transform, 
    download=True
)
val_dataset = datasets.CIFAR10(
    root="./data", 
    train=False, 
    transform=transform, 
    download=True
)
# Use a subset (e.g., 1000 images) to keep dataset small (<10,000 images)
# subset_size = 1000
# subset = Subset(dataset, range(subset_size))
# dataloader = DataLoader(subset, batch_size=32, shuffle=False)

## Initialize model

In [None]:
layer = 22
resid = 'resid'

In [None]:
filename_in_hf_repo = f'{layer}_{resid}/1200013184.pt'
sae = TopKSAE.from_pretrained(checkpoint=filename_in_hf_repo, device=device)

locations = [(layer, resid)]
transformer = ConfiguredViT(locations, device=device)

In [None]:
transformer

In [None]:
sae

In [None]:
gc.collect()

In [None]:
V = []
images = []
labels = []
# Process images in batches
with(tqdm(dataset)) as datap:
    for batch_idx, (image, label) in enumerate(datap):
        # images = images.to(device)
        # Get activations from the transformer
        # The method returns a dictionary keyed by the layer tuple; we use the CLS token (index 0)
        activations = transformer.all_activations(image)[locations[0]]  # shape: (B, token_count, hidden_dim)
        cls_activations = activations[:, 0]  # (B, hidden_dim)
        
        # Forward pass through the sparse autoencoder
        output = sae.forward_verbose(cls_activations)
        v = output['latent'][0].to('cpu').detach().numpy().shape
        V.append(v)
        # images.append(image)
        # labels.append(labels)
        
        # # Print shapes for latent and reconstruction outputs
        # print(f"Batch {batch_idx+1}:")
        # print("  Latent shape:", output['latent'].shape)
        # print("  Reconstruction shape:", output['reconstruction'].shape)
        
        # # (Optional) Here you could visualize or further analyze the latent activations.

In [None]:
output

In [None]:
output['latent'][0].to('cpu').detach().numpy().shape

In [None]:
output['active_latents'].nonzero()

In [None]:
output['latent'].nonzero()

In [None]:
labels

In [None]:
plt.imshow(images)
plt.show()

In [None]:
output['active_latents'].nonzero()[:, 0]

In [None]:
output['latent'].nonzero()[:, 1]

In [None]:
output['latent'].nonzero()[:, 1] == output['active_latents'].nonzero()[:, 0]

In [None]:
plt.imshow(images)
plt.show()

In [None]:
torch.max(output['latent'])