In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

In [None]:
from IPython.display import display, Javascript

display(Javascript('''
(function() {
    var interval = setInterval(function() {
        if (typeof google !== 'undefined' && google.translate && google.translate.TranslateElement) {
            clearInterval(interval);
            google.translate.TranslateElement = function() {};
            document.getElementById('google_translate_element')?.remove();
        }
    }, 1000);
})();
'''))

## Update repository

In [None]:
! git pull

## Add import path

In [None]:
import os
import sys
import gc

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
del module_path

## Organize imports

In [None]:
import multiprocessing
from pathlib import Path

In [None]:
from collections import OrderedDict

In [None]:
from functools import reduce

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

In [None]:
from tqdm import tqdm

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import FashionMNIST

In [None]:
from src.lattmc.fca.utils import *
from src.lattmc.fca.data_utils import *
from src.lattmc.fca.image_utils import *
from src.lattmc.fca.models import *
from src.lattmc.fca.fca_utils import *

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2024

## Initialize Path

In [None]:
PATH = Path('data')
model_dir = PATH / 'models'
model_path = model_dir / 'simple_cnn_fasion_mnist_model.ckpt'
config_dir = PATH / 'config'
config_dir.mkdir(exist_ok=True, parents=True)
config_1_layer_path = config_dir / 'neurons_1_layer.json'
config_cnn_layer_1 = config_dir / 'neurons_cnn_1_layer.json'
images_dir = PATH / 'images'
images_dir.mkdir(exist_ok=True, parents=True)
pumpkin_path = PATH / 'Pumpkin_Seeds_Dataset.xlsx'

## Load the model

In [None]:
model = torch.load(model_path, map_location='cpu')

In [None]:
def clear_state_dict(state_dict):
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.0.', 'conv1.')] = state_dict.pop(key)
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.3.', 'conv2.')] = state_dict.pop(key)
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.8.', 'fc1.')] = state_dict.pop(key)
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.11.', 'fc2.')] = state_dict.pop(key)

    return state_dict

In [None]:
state_dict = clear_state_dict(model['state_dict'])

In [None]:
net = nn.Sequential(OrderedDict([
    ('conv1', nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)),
    ('act1', nn.ReLU()),
    ('mxp1', nn.MaxPool2d(kernel_size=2, stride=2)),
    ('conv2', nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)),
    ('act2', nn.ReLU()),
    ('mxp2', nn.MaxPool2d(kernel_size=2, stride=2)),
    ('flatten', nn.Flatten()),
    ('fc1', nn.Linear(64 * 7 * 7, 128)),
    ('act3', nn.ReLU()),
    ('fc2', nn.Linear(128, 10)),
]))

In [None]:
net.load_state_dict(state_dict)

In [None]:
net = net.eval()

## Initialize FashionMNIST dataset

In [None]:
mean, std = compute_mean_std(
    FashionMNIST(
        images_dir, 
        train=True, 
        download=True, 
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
              ]
            )
        ),
    workers=workers
    )

In [None]:
mean, std

In [None]:
transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((mean,), (std,)),
            ]
)

In [None]:
data_train = FashionMNIST(images_dir, train=True, download=True)
data_test = FashionMNIST(images_dir, train=False, download=True)

In [None]:
next(net.parameters()).device

In [None]:
device = find_device()
device

In [None]:
wnet = NetWrapper(net, transform)

In [None]:
wnet.net

In [None]:
wnet.net[:6]

In [None]:
layer_V_n = 3
layer_U_n = 6

In [None]:
wnet.device

In [None]:
bs = 8

In [None]:
V_X_train, X_V_train = layer_V(data_train, wnet, k=layer_V_n, bs=bs)

In [None]:
V_X_test, X_V_test = layer_V(data_test, wnet, k=layer_V_n, bs=bs)

In [None]:
U_X_train, X_U_train = layer_V(data_train, wnet, k=layer_U_n, bs=bs)

In [None]:
U_X_test, X_U_test = layer_V(data_test, wnet, k=layer_U_n, bs=bs)

In [None]:
V_X_train.shape, V_X_test.shape, U_X_train.shape, U_X_test.shape

In [None]:
arg_max = np.argmax(V_X_train, axis=0)
arg_max.shape

In [None]:
arg_top = np.argsort(V_X_train, axis=0)
arg_top.shape

In [None]:
show_grid(arg_top[-16:,1, 9, 2], data_train, nrow=32)

In [None]:
# np.max(V_X_train, axis=0)

In [None]:
show_grid(arg_max[1,:], data_train, nrow=14)

In [None]:
# show_grid(arg_max, data_train, nrow=14)

## Sorting vectors

In [None]:
V_X_digits, V_X_sorteds = sort_V_X(V_X_train, data_train)

## Alanyze maximum stimulus

In [None]:
v_Ds = dict()
u_Ds = dict()
G_v_tests = dict()
G_u_tests = dict()

In [None]:
i = 0
ths = [
    328, #0
    280, #1
    320, #2
    384, #3
    300, #4
    300, #5
    400, #6
    200, #7
    380, #8
    180  #9
]
v = np.copy(V_X_sorteds[i][ths[i]])

In [None]:
for i in range(10):
    layer_fca = LayerFCA(V_X_train, U_X_train, data_train)
    G_v, v_D, u_D, G_u, G_v_test, G_u_test, uncn_reps = layer_fca.find_G_v_us(
        v, 
        V_X_test, 
        U_X_test, 
        data_test
    )
    v_Ds[i] = v_D
    u_Ds[i] = u_D
    G_v_tests[i] = G_v_test
    G_u_tests[i] = G_u_test

In [None]:
uncn_reps

In [None]:
data_test[G_v_tests[i][0]][0]

In [None]:
show_grid(G_v_tests[i], data_test, nrow=32)

In [None]:
show_grid(G_u_tests[i], data_test, nrow=32)

In [None]:
y_hs = [np.argmax(wnet(data_test[idx][0])) for idx in G_u_test]

In [None]:
uncn_hat = layer_fca.count_ys(y_hs)

In [None]:
uncn_hat

In [None]:
net[3].weight[:, 0], net[3].bias

In [None]:
layer_fca.uncn

In [None]:
show_grid(G_u_test, data_test, nrow=48, h=64, w=64)

In [None]:
show_grid(G_u_test, data_test, nrow=48, h=64, w=64, my=i)

## Visualization of distribution

In [None]:
idx = 9

In [None]:
visualize_slices(V_X_sorteds[idx][ths[idx]])

In [None]:
digits_train = get_digits(data_train)

In [None]:
digits_train[0]

In [None]:
# Load an example image
example_image, _ = data_test[i]
# example_image = example_image.unsqueeze(0)  # Add batch dimension

# Visualize the activations
acts = visualize_activations(wnet, digits_train[0][8], layers=[3, 6], hist=False)

In [None]:
digits_train[0][:4]

In [None]:
res = wnet(*[x for x in digits_train[0]], k=3)

In [None]:
res_k = list()
int_k = list()
with tqdm(list(range(32))) as pange:
    for k in pange:
        r_k = [r[k] for r in res]
        res_k.append(r_k)
        int_k.append(intersect_xd(*r_k))

In [None]:
visualize_slices(int_k)

In [None]:
show_activation(int_k[28])

In [None]:
show_activation(acts[0][1])

In [None]:
indices1 = np.where(acts[0][1] >= 1.2)
indices1, acts[0][1][indices1]

In [None]:
show_activation(acts[0][8])

In [None]:
indices2 = np.where(acts[0][8] >= 1.4)

In [None]:
acts[0][8][indices2]

In [None]:
acts[0][1].shape

In [None]:
acts[0][1][idcs].shape

In [None]:
wnet.net[3]

In [None]:
# Visualize weights of the first convolutional layer
visualize_weights(wnet.net[0], num_filters=32)

# Visualize weights of the second convolutional layer
visualize_weights(wnet.net[3], num_filters=64)