In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

In [None]:
from IPython.display import display, Javascript

display(Javascript('''
(function() {
    var interval = setInterval(function() {
        if (typeof google !== 'undefined' && google.translate && google.translate.TranslateElement) {
            clearInterval(interval);
            google.translate.TranslateElement = function() {};
            document.getElementById('google_translate_element')?.remove();
        }
    }, 1000);
})();
'''))

```bash
conda create -n edu4 python=3.11 jupyter matplotlib
```

```bash 
! pip install -U -r requirements.txt
```

```bash
! pip install -U numpy
! pip install -U scikit-learn
```

## Update repository

In [None]:
! git pull

## Add import path

In [None]:
import os
import sys
import gc

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
del module_path

## Organize imports

In [None]:
import multiprocessing
from pathlib import Path

In [None]:
import json

In [None]:
from collections import OrderedDict

In [None]:
from functools import reduce

In [None]:
import seaborn as sns

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

In [None]:
from tqdm import tqdm

In [None]:
import plotly.express as px

In [None]:
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.discriminant_analysis import (
    LinearDiscriminantAnalysis, 
    QuadraticDiscriminantAnalysis
)
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier 
from sklearn.datasets import (
    load_iris,
    load_wine
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import (
    MaxAbsScaler,
    MinMaxScaler,
    StandardScaler
)
from sklearn.metrics import (
    classification_report,
    confusion_matrix
)

In [None]:
from scipy import stats
from scipy.interpolate import interp1d

In [None]:
import torch

In [None]:
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.transforms.functional as F
from torchvision.utils import make_grid
from torchvision.datasets import FashionMNIST

In [None]:
from captum.attr import (
    IntegratedGradients, 
    LayerIntegratedGradients,
    NeuronGradient,
    NeuronIntegratedGradients,
    NeuronGuidedBackprop,
    NeuronDeepLift,
    NeuronDeepLiftShap,
    NeuronGradientShap,
)
from captum.attr import visualization as viz

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2024

## Initialize Path

In [None]:
PATH = Path('data')
model_dir = PATH / 'models'
model_path = model_dir / 'simple_cnn_fasion_mnist_model.ckpt'
config_dir = PATH / 'config'
config_dir.mkdir(exist_ok=True, parents=True)
config_1_layer_path = config_dir / 'neurons_1_layer.json'
config_cnn_layer_1 = config_dir / 'neurons_cnn_1_layer.json'
images_dir = PATH / 'images'
images_dir.mkdir(exist_ok=True, parents=True)
pumpkin_path = PATH / 'Pumpkin_Seeds_Dataset.xlsx'

## Model wrapper

In [None]:
def find_device():
    if torch.cuda.is_available():
        device = torch.device('cuda')
    elif torch.backends.mps.is_available():
        device = torch.device('mps')
    else:
        device = torch.device('cpu')

    return device

In [None]:
class Device(object):

    def __init__(self, x, device):
        self.x = x
        self.device = device
        self.cpu = torch.device('cpu')

    def __enter__(self):
        self.x = self.x.to(self.device)

        return self

    def __exit__(self):
        self.xx = self.x.to(self.cpu)

        return self

In [None]:
class NetWrapper(object):

    def __init__(self, net, transform, device=None):
        self._net = net.eval()
        self.transform = transform
        self._device = device if device else find_device()
        self._net.to(self._device)
        self.cpu = torch.device('cpu')

    @property
    def net(self):
        return self._net

    @property
    def device(self):
        return self._device
    
    def __getitem__(self, i):
        return self.net[i]

    def __len__(self):
        return len(self.net)

    @torch.inference_mode()
    def forward(self, *xs, k=6):
        ts = torch.stack(
            [self.transform(x) for x in xs], 
            dim=0
        )
        ts = ts.to(self.device)
        rs = self[: k](ts) if k else self.net(ts)
        rs = rs.to(self.cpu).detach().numpy()

        return rs

    def to(self, device=None):
        dvc = device if device else find_device()
        self._device = dvc
        self.net.to(self._device)
        
    
    def __call__(self, *xs, k=6):
        return self.forward(*xs, k=k)

## Load the model

In [None]:
model = torch.load(model_path, map_location='cpu')

In [None]:
def clear_state_dict(state_dict):
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.0.', 'conv1.')] = state_dict.pop(key)
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.3.', 'conv2.')] = state_dict.pop(key)
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.8.', 'fc1.')] = state_dict.pop(key)
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.11.', 'fc2.')] = state_dict.pop(key)

    return state_dict

In [None]:
state_dict = clear_state_dict(model['state_dict'])

In [None]:
channels = 1
width = 28
height = 28
hidden_size = 16
num_classes = 10
in_features = channels * width * height

In [None]:
net = nn.Sequential(OrderedDict([
    ('conv1', nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)),
    ('act1', nn.ReLU()),
    ('mxp1', nn.MaxPool2d(kernel_size=2, stride=2)),
    ('conv2', nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)),
    ('act2', nn.ReLU()),
    ('mxp2', nn.MaxPool2d(kernel_size=2, stride=2)),
    ('flatten', nn.Flatten()),
    ('fc1', nn.Linear(64 * 7 * 7, 128)),
    ('act3', nn.ReLU()),
    ('fc2', nn.Linear(128, 10)),
]))

In [None]:
net.load_state_dict(state_dict)

In [None]:
net = net.eval()

## Helper functions

In [None]:
def argmax_kd(v):
    return np.unravel_index(np.argmax(v), v.shape)

In [None]:
def argmax_kd_val(v):
    max_idxs = argmax_kd(v)
    max_vals = v[max_idxs]

    return max_idxs, max_vals

In [None]:
def to_numpy(v):
    return v.to('cpu').detach().numpy() if isinstance(
        v, 
        (torch.Tensor,)
    ) else v

In [None]:
def intersect(*arrs):
    return reduce(np.intersect1d, (arrs))

In [None]:
def intersect_xd(*arrs):
    return np.minimum.reduce(arrs)

In [None]:
def le(v1, v2):
    v1_np = to_numpy(v1)
    v2_np = to_numpy(v2)
    return np.all(v1_np <= v2_np)

In [None]:
def layer_V(data, net, k=5, bs=1):
    V = list()
    X = list()
    with tqdm(list(range(0, len(data), bs))) as ds:
        for bi in ds:
            xs = [data[batch][0] for batch in range(bi, bi + bs)]
            vs = net(*xs, k=k)
            V.append(vs)
            X.extend(xs)

    return np.vstack(V), X

In [None]:
def loop_maxes(V, func, *args, **kwargs):
    with tqdm(V) as mstml:
        for i, v in enumerate(mstml):
            func(i, v, *args, **kwargs)

In [None]:
def select_top(V, idx, thresh):
    tops = list()
    def add_to_top(i, v):
        if thresh <= v[idx]:
            tops.append(i)
    loop_maxes(V, lambda i, v: add_to_top(i, v))
    
    return tops

In [None]:
def find_v_x(V, mrng, idx):
    mid = np.argmin(np.array(V)[mrng], axis=0)[idx]
    x_id = mrng[mid]
    v_x = V[x_id]

    return v_x, x_id

In [None]:
def find_v_A(V, mrng):
    return np.minimum.reduce(np.array(V)[mrng])

In [None]:
def find_G_x(V, v_x):
    with tqdm(V) as mstm:
        G_x = np.array([i for i, v in enumerate(mstm) if np.all(v_x <= v)])

    return G_x

In [None]:
def show_img(ds, idx):
    plt.imshow(ds[idx][0])

In [None]:
def show(imgs, h=12, w=12):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(
        ncols=len(imgs),
        figsize=(w, h),
        squeeze=False
    )
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
def show_grid(G_A, data, nrow=8, h=12, w=12, my=None):
    G_A_F = G_A.ravel()
    to_tensor = transforms.ToTensor()
    A_gr = [
        to_tensor(data[i][0]) for i in G_A_F
    ] if my is None else [
        to_tensor(data[i][0]) for i in G_A_F if data[i][1] != my
    ]
    grid = make_grid(A_gr, nrow=nrow)
    show(grid, h=h, w=w)

#### Data processing

In [None]:
def layer_hist(X_y, V_X, y=None):
    if y is None:
        V_X_y = V_X
    else:
        with tqdm(V_X) as p_V_X:
            V_X_y = np.array(
                [v_x for x_y, v_x in zip(X_y, p_V_X) if x_y[1] == y]
            )

    return V_X_y

In [None]:
def get_digits(data):
    digits = dict()
    with tqdm(data) as pdata:
        for x, y in pdata:
            digits.setdefault(y, list())
            digits[y].append(x)

    return digits

#### Feature visualizations

In [None]:
def visualize_slices(activations, filters=32):
    for k in range(0, filters, 16):
        fig, axes = plt.subplots(4, 4, figsize=(12, 12))
        for i, ax in enumerate(axes.flat):
            if i < 16:
                activation = activations[k + i]
                activation = activation.cpu() if isinstance(activation, torch.Tensor) else activation
                im = ax.imshow(activation, cmap='viridis')
                ax.set_title(f'Conv1 - Filter {k + i}')
                ax.axis('off')
                # fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        plt.show()    

In [None]:
# Function to visualize the activations
def visualize_activations(model, image, layers=[2, 4], hist=False):
    # Pass the image through the network
    activations = list()
    with torch.no_grad():
        for k in layers:
            output = model(image, k=k)
            activations.append(output.detach())

    # Plot the activations
    for k in range(0, 32, 16):
        fig, axes = plt.subplots(4, 4, figsize=(12, 12))
        for i, ax in enumerate(axes.flat):
            if i < 16:
                activation = activations[0][k + i].cpu()
                if hist:
                    ax = fig.add_subplot(4, 4, i + 1, projection='3d')
                    activation = activation.numpy()
                    x, y = np.meshgrid(np.arange(activation.shape[1]), np.arange(activation.shape[0]))
                    ax.bar3d(x.ravel(), y.ravel(), np.zeros_like(x.ravel()), 1, 1, activation.ravel(), shade=True)
                else:
                    im = ax.imshow(activation, cmap='viridis')
                    ax.set_title(f'Conv1 - Filter {k + i}')
                    ax.axis('off')
                # fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        plt.show()
    print('======================')
    for k in range(0, 64, 16):
        fig, axes = plt.subplots(4, 4, figsize=(12, 12))
        for i, ax in enumerate(axes.flat):
            if i < 16:
                im = ax.imshow(activations[1][k + i].cpu(), cmap='viridis')
                ax.set_title(f'Conv2 - Filter {k + i}')
                ax.axis('off')
                fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        plt.show()

    return activations

In [None]:
def show_activation(activation, layer_name='', filter_index=0):
    # Plot the activation as a grid of numbers
    fig, ax = plt.subplots(figsize=(16, 16))
    cax = ax.matshow(activation, cmap='viridis')

    # Customize the ticks to show each step
    ax.set_xticks(np.arange(activation.shape[1]))
    ax.set_yticks(np.arange(activation.shape[0]))

    # Show grid lines
    ax.grid(color='black', linestyle='-', linewidth=1, which='both')
    ax.xaxis.set_ticks_position('bottom')

    # Annotate the grid with the activation values
    for (i, j), val in np.ndenumerate(activation):
        ax.text(j, i, f'{val:.2f}', ha='center', va='center', color='white')

    fig.colorbar(cax)
    plt.title(f'Activation of {layer_name} - Filter {filter_index}')
    plt.show()

In [None]:
def visualize_weights(layer, num_filters=32):
    weights = layer.weight.data.cpu().numpy()
    for k in range(0, num_filters, 16):
        fig, axes = plt.subplots(4, 4, figsize=(12, 12))
        for i, ax in enumerate(axes.flat):
            if i < num_filters:
                img = weights[i, 0, :, :]  # Select the i-th filter and the first input channel
                im = ax.imshow(img, cmap='viridis')
                ax.axis('off')
                ax.set_title(f'Filter {i}')
                fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        plt.show()

## Feature FCA analysis methods

In [None]:
def find_V_X_digits(V_X):
    return [
        layer_hist(data_train, V_X, y=k) for k in range(10)
    ]

In [None]:
def sort_V(*V_Xs):
    with tqdm(V_Xs) as pV_Xs:
        V_X_sr = [np.sort(V_X_d, axis=0) for V_X_d in V_Xs]

    return V_X_sr

In [None]:
def sort_V_X(V_X):
    V_X_ds = find_V_X_digits(V_X)
    V_X_sr = sort_V(*V_X_ds)

    return V_X_ds, V_X_sr

In [None]:
def features_hist(*n_Fs, V=np.zeros((1, 16))):
    rows = len(n_Fs)
    vs_ls = list()
    with tqdm(n_Fs) as pn_Fs:
        for n_F in pn_Fs:
            vs = [v[n_F] for v in V]
            vs_ls.append(vs)
    fig, axs = plt.subplots(rows, 1, sharey=True, tight_layout=True, figsize=(8 * rows, 32))
    for r in range(rows):
        vs_h = vs_ls[r]
        axs[r].hist(vs_h)
        axs[r].set_xlim(0, 32)  # Set the X-axis limit if you want a specific range
        axs[r].set_xticks(np.arange(0, 32, 0.5))  # Ensure the ticks match the new range
        # axs[r].set_title(str(vs_h))
    

In [None]:
class LayerFCA(object):

    def __init__(self, V_X, U_X, data):
        self.V_X = V_X
        self.U_X = U_X
        self.data = data
        self.G_As = list()
        self.v_As = list()
        self.D = None
        self.v_D = None
        self.U_D = None
        self.G_U_D = None
        self.find_G_x = find_G_x
        self.find_v_A = find_v_A
        self.find_G_x = find_G_x

    def fca_v(self, ns, ths):
        for n_A, th_A in zip(ns, ths):
            G_A_v_A = select_top(self.V_X, n_A, th_A)
            v_A = find_v_A(self.V_X, G_A_v_A)
            self.v_As.append(v_A)
            G_A = find_G_x(self.V_X, v_A)
            self.G_As.append(G_A)
        self.D = intersect(*self.G_As) if self.G_As else []
        self.v_D = np.maximum.reduce(self.v_As)
        
        return self.D, self.v_D

    def fca_u(self, ns, ths):
        D, _ = self.fca_v(ns, ths)
        self.u_D = find_v_A(
            self.U_X, D
        ) if np.any(D) else np.zeros(
            (16,), dtype=float
        )
        self.G_u_D = find_G_x(self.U_X, self.u_D)

        return self.G_u_D

    @staticmethod
    def count_ys(ys):
        un, cn = np.unique(ys, return_counts=True)
        uncn = np.array([un, cn])

        return uncn

    def _report_u(self, G_u_D, data=None):
        data_ls = self.data if data is None else data
        ys = np.array([data_ls[idx][1] for idx in G_u_D])
        uncn = self.count_ys(ys)
        if data is None:
            self.uncn = uncn

        return uncn

    def report(self, G_u_D, data):
        return self._report_u(G_u_D, data=data)

    def fca_u_arr(self, ns_arr):
        ns = [nr[0] for nr in neurons[neur_idx]]
        ts = [nr[1] for nr in neurons[neur_idx]]
        G_U_D = self.fca_u(ns, ts)
        self._report_u(G_U_D)

        return self.G_u_D

    @staticmethod
    def G_U(U_X, u_D):
        return find_G_x(U_X, u_D)

    def find_u_G_u(self, v):
        G_v = find_G_x(self.V_X, v)
        u_D = find_v_A(self.U_X, G_v)
        G_u = find_G_x(self.U_X, u_D)
        self._report_u(G_u)

        return G_v, u_D, G_u


    def find_G_u(self, v, U, X):
        G_v, u_D, G_u = self.find_u_G_u(v)
        G_rest = find_G_x(U, u_D)

        return G_rest

#### Data processing

In [None]:
def compute_mean_std(dataset):
    loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=workers)
    mean = 0.
    std = 0.
    for images, _ in loader:
        batch_samples = images.size(0)  # batch size (the last batch can have smaller size)
        images = images.view(batch_samples, images.size(1), -1)
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)
    
    mean /= len(loader.dataset)
    std /= len(loader.dataset)
    return mean, std

## Initialize FashionMNIST dataset

In [None]:
mean, std = compute_mean_std(
    FashionMNIST(
        images_dir, 
        train=True, 
        download=True, 
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
              ]
            )
        )
    )

In [None]:
mean, std

In [None]:
transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((mean,), (std,)),
            ]
)

In [None]:
data_train = FashionMNIST(images_dir, train=True, download=True)
data_test = FashionMNIST(images_dir, train=False, download=True)

In [None]:
next(net.parameters()).device

In [None]:
device = find_device()
device

In [None]:
wnet = NetWrapper(net, transform)

In [None]:
wnet.net

In [None]:
wnet.net[:6]

In [None]:
layer_V_n = 3
layer_U_n = 6

In [None]:
wnet.device

In [None]:
r1 = wnet(V_X_train[0][0], V_X_train[1][0], k=layer_V_n)
r2 = wnet(V_X_train[2][0], V_X_train[3][0], k=layer_V_n)

In [None]:
np.vstack([r1, r2]).shape

In [None]:
V_X_train, X_V_train = layer_V(data_train, wnet, k=layer_V_n, bs=8)

In [None]:
V_X_test, X_V_test = layer_V(data_test, wnet, k=layer_V_n, bs=8)

In [None]:
U_X_train, X_U_train = layer_V(data_train, wnet, k=layer_U_n, bs=8)

In [None]:
U_X_test, X_U_test = layer_V(data_test, wnet, k=layer_U_n, bs=8)

In [None]:
V_X_train.shape, V_X_test.shape, U_X_train.shape, U_X_test.shape

In [None]:
arg_max = np.argmax(V_X_train, axis=0)
arg_max.shape

In [None]:
arg_top = np.argsort(V_X_train, axis=0)
arg_top.shape

In [None]:
show_grid(arg_top[-16:,1, 9, 2], data_train, nrow=32)

In [None]:
# np.max(V_X_train, axis=0)

In [None]:
show_grid(arg_max[1,:], data_train, nrow=14)

In [None]:
# show_grid(arg_max, data_train, nrow=14)

## Sorting vectors

In [None]:
V_X_digits, V_X_sorteds = sort_V_X(V_X_train)

## Alanyze maximum stimulus

In [None]:
u_Ds = dict()
G_v_tests = dict()
G_u_tests = dict()

In [None]:
i = 0
ths = [
    328, #0
    280, #1
    320, #2
    384, #3
    300, #4
    300, #5
    400, #6
    200, #7
    380, #8
    180  #9
]
v = np.copy(V_X_sorteds[i][ths[i]])

In [None]:
for i in range(10):
    v = np.copy(V_X_sorteds[i][ths[i]])
    layer_fca = LayerFCA(V_X_train, U_X_train, data_train)
    G_v, u_D, G_u = layer_fca.find_u_G_u(v)
    G_v_test = layer_fca.find_G_x(V_X_test, v)
    G_u_test = layer_fca.find_G_x(U_X_test, u_D)
    u_Ds[i] = u_D
    G_v_tests[i] = G_v_test
    G_u_tests[i] = G_u_test

In [None]:
uncn_test = layer_fca.report(G_u_tests[i], data_test)
uncn_test, np.round(uncn_test[1] / np.sum(uncn_test[1]), decimals=4)

In [None]:
show_grid(G_v_tests[i], data_test, nrow=32)

In [None]:
show_grid(G_u_tests[i], data_test, nrow=32)

In [None]:
layer_fca = LayerFCA(V_X_train, U_X_train, data_train)
G_v, u_D, G_u = layer_fca.find_u_G_u(v)
G_v_test = layer_fca.find_G_x(V_X_test, v)
G_u_test = layer_fca.find_G_x(U_X_test, u_D)
u_Ds[i] = u_D

In [None]:
G_u.shape, G_u_test.shape

In [None]:
[G_v.shape, G_u.shape], np.max(u_D)

In [None]:
uncn_test = layer_fca.report(G_u_test, data_test)

In [None]:
uncn_test, np.round(uncn_test[1] / np.sum(uncn_test[1]), decimals=4)

In [None]:
layer_fca.uncn, V_X_sorteds[i].shape[0]

In [None]:
show_grid(G_v_tests[i], data_test, nrow=32)

In [None]:
show_grid(G_u_tests[i], data_test, nrow=32)

In [None]:
y_hs = [np.argmax(wnet(data_test[idx][0])) for idx in G_u_test]

In [None]:
uncn_hat = layer_fca.count_ys(y_hs)

In [None]:
uncn_hat

In [None]:
net[3].weight[:, 0], net[3].bias

In [None]:
layer_fca.uncn

In [None]:
show_grid(G_u_test, data_test, nrow=48, h=64, w=64)

In [None]:
show_grid(G_u_test, data_test, nrow=48, h=64, w=64, my=i)

## Visualization of distribution

In [None]:
idx = 9

In [None]:
visualize_slices(V_X_sorteds[idx][ths[idx]])

In [None]:
digits_train = get_digits(data_train)

In [None]:
digits_train[0]

In [None]:
# Load an example image
example_image, _ = data_test[i]
# example_image = example_image.unsqueeze(0)  # Add batch dimension

# Visualize the activations
acts = visualize_activations(wnet, digits_train[0][8], layers=[3, 6], hist=False)

In [None]:
digits_train[0][:4]

In [None]:
res = [wnet(x, k=3) for x in digits_train[0]]

In [None]:
res_k = list()
int_k = list()
with tqdm(list(range(32))) as pange:
    for k in pange:
        r_k = [r[k] for r in res]
        res_k.append(r_k)
        int_k.append(intersect_xd(*r_k))

In [None]:
visualize_slices(int_k)

In [None]:
int_30.shape

In [None]:
show_activation(int_k[28])

In [None]:
show_activation(acts[0][1])

In [None]:
indices1 = np.where(acts[0][1] >= 1.2)
indices1, acts[0][1][indices1]

In [None]:
show_activation(acts[0][8])

In [None]:
indices2 = np.where(acts[0][8] >= 1.4)

In [None]:
acts[0][8][indices2]

In [None]:
acts[0][1].shape

In [None]:
acts[0][1][idcs].shape

In [None]:
wnet.net[3]

In [None]:
# Visualize weights of the first convolutional layer
visualize_weights(wnet.net[0], num_filters=32)

# Visualize weights of the second convolutional layer
visualize_weights(wnet.net[3], num_filters=64)

In [None]:
V_X_ds[0].shape

In [None]:
cprs = np.array([np.all(V_X_ds[0][0] <= V_X_d) or np.all(V_X_d < V_X_ds[0][0]) for V_X_d in V_X_ds[0]])

In [None]:
np.where(cprs)