In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

```bash
conda create -n edu4 python=3.11 jupyter matplotlib
```

```bash 
! pip install -U -r requirements.txt
```

```bash
! pip install -U numpy
! pip install -U scikit-learn
```

## Update repository

In [None]:
! git pull

## Add import path

In [None]:
import os
import sys
import gc

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
del module_path

## Organize imports

In [None]:
import multiprocessing
from pathlib import Path

In [None]:
import json

In [None]:
from collections import OrderedDict

In [None]:
from functools import reduce

In [None]:
import seaborn as sns

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

In [None]:
from tqdm import tqdm

In [None]:
import plotly.express as px

In [None]:
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.discriminant_analysis import (
    LinearDiscriminantAnalysis, 
    QuadraticDiscriminantAnalysis
)
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier 
from sklearn.datasets import (
    load_iris,
    load_wine
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import (
    MaxAbsScaler,
    MinMaxScaler,
    StandardScaler
)
from sklearn.metrics import (
    classification_report,
    confusion_matrix
)

In [None]:
from scipy import stats
from scipy.interpolate import interp1d

In [None]:
import torch

In [None]:
from torch import nn
from torchvision import transforms
import torchvision.transforms.functional as F
from torchvision.utils import make_grid
from torchvision.datasets import MNIST

In [None]:
from captum.attr import (
    IntegratedGradients, 
    LayerIntegratedGradients,
    NeuronGradient,
    NeuronIntegratedGradients,
    NeuronGuidedBackprop,
    NeuronDeepLift,
    NeuronDeepLiftShap,
    NeuronGradientShap,
)
from captum.attr import visualization as viz

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2024

## Initialize Path

In [None]:
PATH = Path('data')
model_dir = PATH / 'models'
model_path = model_dir / 'simple_fcn_mnist_model.ckpt'
config_dir = PATH / 'config'
config_dir.mkdir(exist_ok=True, parents=True)
config_1_layer_path = config_dir / 'neurons_1_layer.json'
images_path = PATH / 'images'
images_path.mkdir(exist_ok=True, parents=True)
pumpkin_path = PATH / 'Pumpkin_Seeds_Dataset.xlsx'

## Model wrapper

In [None]:
class NetWrapper(object):

    def __init__(self, net, transform):
        self._net = net.eval()
        self.transform = transform

    @property
    def net(self):
        return self._net
    
    def __getitem__(self, i):
        return self.net[i]

    def __len__(self):
        return len(self.net)

    @torch.inference_mode()
    def forward(self, x, k=6):
        t = self.transform(x)
        r = self[: k](t)

        return r        
    
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

## Load the model

In [None]:
model = torch.load(model_path, map_location='cpu')

In [None]:
def clear_state_dict(state_dict):
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.1.', 'fc1.')] = state_dict.pop(key)
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.4.', 'fc2.')] = state_dict.pop(key)
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.7.', 'fc3.')] = state_dict.pop(key)

    return state_dict

In [None]:
state_dict = clear_state_dict(model['state_dict'])

In [None]:
channels = 1
width = 28
height = 28
hidden_size = 16
num_classes = 10
in_features = channels * width * height

In [None]:
net = nn.Sequential(OrderedDict([
    ('flatten', nn.Flatten()),
    ('fc1', nn.Linear(channels * width * height, hidden_size)),
    ('act1', nn.ReLU()),
    ('fc2', nn.Linear(hidden_size, hidden_size)),
    ('act2', nn.ReLU()),
    ('fc3', nn.Linear(hidden_size, num_classes)),
]))

In [None]:
net.load_state_dict(state_dict)

In [None]:
net = net.eval()

## Helper functions

In [None]:
def intersect(*arrs):
    return reduce(np.intersect1d, (arrs))

In [None]:
def layer_V(data, net, k=5):
    V = list()
    X = list()
    with tqdm(data) as ds:
        for i, (x, y) in enumerate(ds):
            v = net(x, k=k).detach().numpy()[0]
            V.append(v)
            X.append(x)

    return np.array(V), X

In [None]:
def loop_maxes(V, func, *args, **kwargs):
    with tqdm(V) as mstml:
        for i, v in enumerate(mstml):
            func(i, v, *args, **kwargs)

In [None]:
def select_top(V, idx, thresh):
    tops = list()
    def add_to_top(i, v):
        if thresh <= v[idx]:
            tops.append(i)
    loop_maxes(V, lambda i, v: add_to_top(i, v))
    
    return tops

In [None]:
def find_v_x(V, mrng, idx):
    mid = np.argmin(np.array(V)[mrng], axis=0)[idx]
    x_id = mrng[mid]
    v_x = V[x_id]

    return v_x, x_id

In [None]:
def find_v_A(V, mrng):
    return np.minimum.reduce(np.array(V)[mrng])

In [None]:
def find_G_x(V, v_x):
    with tqdm(V) as mstm:
        G_x = np.array([i for i, v in enumerate(mstm) if np.all(v_x <= v)])

    return G_x

In [None]:
def show_img(ds, idx):
    plt.imshow(ds[idx][0])

In [None]:
def show(imgs, h=12, w=12):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(
        ncols=len(imgs),
        figsize=(w, h),
        squeeze=False
    )
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
def show_grid(G_A, data, nrow=8, h=12, w=12, my=None):
    to_tensor = transforms.ToTensor()
    A_gr = [
        to_tensor(data[i][0]) for i in G_A
    ] if my is None else [
        to_tensor(data[i][0]) for i in G_A if data[i][1] != my
    ]
    grid = make_grid(A_gr, nrow=nrow)
    show(grid, h=h, w=w)

In [None]:
def vis_features(*n_idxs, data=None, layer=3, func=NeuronIntegratedGradients, use_plot=False, **kwargs):
    # Assuming model is your neural network and input_img is the input image tensor
    # Define the target layer and neuron index
    target_layer = wnet[:layer]
    # Initialize Integrated Gradients
    neuron_ig = func(wnet.net, wnet.net[layer])

    attr_igs = list()
    input_imgs = list()
    for n_idx in n_idxs:
        # Compute the attributions for the target neuron
        input_img = data[arg_max[n_idx]][0]
        input_tns = transform(input_img)
        target_neuron_index = n_idx
        attributions_ig = neuron_ig.attribute(
            input_tns, 
            neuron_selector=target_neuron_index,
            **kwargs,
        )
        attr_ig = attributions_ig.squeeze().cpu().detach().numpy()
        attr_igs.append(attr_ig)
        input_imgs.append(input_img)

    # fig, ax = plt.subplots(len(n_idxs), 2, figsize=(10, 5))

    if use_plot:
        for r, (attr_ig, input_img) in enumerate(zip(attr_igs, input_imgs)):
            # Display the first image
            fig, ax = plt.subplots(1, 2, figsize=(10, 5))
            ax[0].imshow(input_img, cmap='gray')
            ax[0].axis('off')  # Hide the axes
            
            # Display the second image
            ax[1].imshow(attr_ig, cmap='gray')
            ax[1].axis('off')  # Hide the axes
        
        # Show the plot
        plt.show()

    return attr_igs

In [None]:
def display_features(attrs, data):
    rows = len(attrs)
    # fig, axes = plt.subplots(rows, 3, figsize=(15, 5 * rows))
    for idx in range(rows):
        attr = np.expand_dims(attrs[idx], axis=2)
        orig = np.expand_dims(data[arg_max[idx]][0], axis=2)
        titl = str(idx)
        fig, axs = viz.visualize_image_attr_multiple(
            attr, 
            orig, 
            [
                'original_image', 
                'heat_map', 
                'masked_image',
            ], 
            [
                'all', 
                'positive', 
                'positive'
            ],
            use_pyplot=True,
            titles=[titl, titl, titl],
            show_colorbar=True,
        )
        
    #     axes[idx, 0] = axs[0]
        
    #     axes[idx, 1] = axs[1]
        
    #     axes[idx, 2] = axs[2]

    # plt.tight_layout()
    # plt.show()


## Feature FCA analysis methods

In [None]:
def features_hist(*n_Fs, V=np.zeros((1, 16))):
    rows = len(n_Fs)
    vs_ls = list()
    with tqdm(n_Fs) as pn_Fs:
        for n_F in pn_Fs:
            vs = [v[n_F] for v in V]
            vs_ls.append(vs)
    fig, axs = plt.subplots(rows, 1, sharey=True, tight_layout=True, figsize=(8 * rows, 32))
    for r in range(rows):
        vs_h = vs_ls[r]
        axs[r].hist(vs_h)
        axs[r].set_xlim(0, 32)  # Set the X-axis limit if you want a specific range
        axs[r].set_xticks(np.arange(0, 32, 0.5))  # Ensure the ticks match the new range
        # axs[r].set_title(str(vs_h))
    

In [None]:
class LayerFCA(object):

    def __init__(self, V_X, U_X, data):
        self.V_X = V_X
        self.U_X = U_X
        self.data = data
        self.G_As = list()
        self.v_As = list()
        self.D = None
        self.v_D = None
        self.U_D = None
        self.G_U_D = None

    def fca_v(self, ns, ths):
        for n_A, th_A in zip(ns, ths):
            G_A_v_A = select_top(self.V_X, n_A, th_A)
            v_A = find_v_A(self.V_X, G_A_v_A)
            self.v_As.append(v_A)
            G_A = find_G_x(self.V_X, v_A)
            self.G_As.append(G_A)
        self.D = intersect(*self.G_As) if self.G_As else []
        self.v_D = np.maximum.reduce(self.v_As)
        
        return self.D, self.v_D

    def fca_u(self, ns, ths):
        D, _ = self.fca_v(ns, ths)
        self.u_D = find_v_A(
            self.U_X, D
        ) if np.any(D) else np.zeros(
            (16,), dtype=float
        )
        self.G_u_D = find_G_x(self.U_X, self.u_D)

        return self.G_u_D

    @staticmethod
    def count_ys(ys):
        un, cn = np.unique(ys, return_counts=True)
        uncn = np.array([un, cn])

        return uncn

    def _report_u(self, G_u_D, data=None):
        data_ls = self.data if data is None else data
        ys = np.array([data_ls[idx][1] for idx in G_u_D])
        uncn = self.count_ys(ys)
        if data is None:
            self.uncn = uncn

        return uncn

    def report(self, G_u_D, data):
        return self._report_u(G_u_D, data=data)

    def fca_u_arr(self, ns_arr):
        ns = [nr[0] for nr in neurons[neur_idx]]
        ts = [nr[1] for nr in neurons[neur_idx]]
        G_U_D = self.fca_u(ns, ts)
        self._report_u(G_U_D)

        return self.G_u_D

    def G_U(self, U_X):
        return find_G_x(U_X, self.u_D)

## Initialize MNIST dataset

In [None]:
transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
)

In [None]:
data_train = MNIST(images_path, train=True, download=True)
data_test = MNIST(images_path, train=False, download=True)

In [None]:
wnet = NetWrapper(net, transform)

In [None]:
wnet.net

In [None]:
wnet.net[:3]

In [None]:
V_X_train, X_V_train = layer_V(data_train, wnet, k=3)

In [None]:
V_X_test, X_V_test = layer_V(data_test, wnet, k=3)

In [None]:
U_X_train, X_U_train = layer_V(data_train, wnet, k=5)

In [None]:
U_X_test, X_U_test = layer_V(data_test, wnet, k=5)

In [None]:
arg_max = np.argmax(V_X_train, axis=0)
arg_max

In [None]:
np.max(V_X_train, axis=0)

In [None]:
V_X_train[arg_max[13], :]

In [None]:
show_grid(arg_max, data_train)

In [None]:
attr_all = vis_features(
    *list(range(16)),
    data=data_train,
    # func=NeuronGradientShap,
    func=NeuronGuidedBackprop,
    # n_steps=n_steps,
)

In [None]:
display_features(attr_all, data_train)

In [None]:
len(attr_all)

In [None]:
with config_1_layer_path.open('r') as js:
    neuron_js = json.load(js)
neurons = neuron_js['neurons']
len(neurons)

## Alanyze maximum stimulus

In [None]:
# neur_idx = -5 #4
# neur_idx = -2 #0
neur_idx = -1 #1
neurons[neur_idx]

In [None]:
neur_idx = -1
with config_1_layer_path.open('r') as js:
    neuron_js = json.load(js)
neurons = neuron_js['neurons']
layer_fca = LayerFCA(V_X_train, U_X_train, data_train)
G_u_D = layer_fca.fca_u_arr(neurons[neur_idx])

In [None]:
[layer_fca.D.shape, G_u_D.shape], layer_fca.u_D

In [None]:
layer_fca.uncn

In [None]:
G_U_test = layer_fca.G_U(U_X_test)
uncn_test = layer_fca.report(G_U_test, data_test)

In [None]:
G_U_test.shape, uncn_test

In [None]:
y_hs = [np.argmax(wnet(data_test[idx][0])) for idx in G_U_test]

In [None]:
uncn_hat = layer_fca.count_ys(y_hs)

In [None]:
uncn_hat

In [None]:
layer_fca.v_As[1]

In [None]:
net[3].weight[:, 0], net[3].bias

In [None]:
G_u_D.shape, layer_fca.u_D

In [None]:
layer_fca.uncn

In [None]:
show_grid(G_U_test, data_test, nrow=48, h=64, w=64)

In [None]:
show_grid(G_U_test, data_test, nrow=48, h=64, w=64, my=2)

## Visualization of distribution

In [None]:
features_hist(*list(zip(*neurons[neur_idx]))[0], V=V_X)

In [None]:
def layer_hist(X_y, V_X, y=None):
    if y is None:
        V_X_y = V_X
    else:
        with tqdm(V_X) as p_V_X:
            V_X_y = np.array(
                [v_x for x_y, v_x in zip(X_y, p_V_X) if x_y[1] == y]
            )

    return V_X_y

In [None]:
V_X_4 = layer_hist(data_train, V_X_train, y=4)

In [None]:
V_X_4.shape

In [None]:
for i in range(V_X_4.shape[1]):
    ax = sns.displot(V_X_4[:, i], binwidth=0.2)
    ax.ax.set_xlabel(f'{i}')
plt.show()

In [None]:
V_X_0 = layer_hist(data_train, V_X_train, y=0)

In [None]:
V_X_3 = layer_hist(data_train, V_X_train, y=3)

In [None]:
V_X_0.shape, V_X_3.shape

In [None]:
for i in range(V_X_0.shape[1]):
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    sns.histplot(V_X_0[:, i], binwidth=0.2, ax=axes[0])
    axes[0].set_xlabel(f'0 - {i}')
    sns.histplot(V_X_3[:, i], binwidth=0.2, ax=axes[1])
    axes[1].set_xlabel(f'3 - {i}')
plt.show()

In [None]:
V_X_2 = layer_hist(data_train, V_X_train, y=2)

In [None]:
V_X_2.shape

In [None]:
for i in range(V_X_2.shape[1]):
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    sns.histplot(V_X_2[:, i], binwidth=0.2, ax=axes[0])
    axes[0].set_xlabel(f'2 - {i}')
    sns.histplot(V_X_3[:, i], binwidth=0.2, ax=axes[1])
    axes[1].set_xlabel(f'3 - {i}')
plt.show()