In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

```bash
conda create -n edu4 python=3.11 jupyter matplotlib
```

```bash 
! pip install -U -r requirements.txt
```

```bash
! pip install -U numpy
! pip install -U scikit-learn
```

## Update repository

In [None]:
! git pull

## Add import path

In [None]:
import os
import sys
import gc

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
del module_path

## Organize imports

In [None]:
import multiprocessing
from pathlib import Path

In [None]:
from collections import OrderedDict

In [None]:
from functools import reduce

In [None]:
import seaborn as sns

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from tqdm import tqdm

In [None]:
import plotly.express as px

In [None]:
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.discriminant_analysis import (
    LinearDiscriminantAnalysis, 
    QuadraticDiscriminantAnalysis
)
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier 
from sklearn.datasets import (
    load_iris,
    load_wine
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import (
    MaxAbsScaler,
    MinMaxScaler,
    StandardScaler
)
from sklearn.metrics import (
    classification_report,
    confusion_matrix
)

In [None]:
from scipy import stats
from scipy.interpolate import interp1d

In [None]:
import torch

In [None]:
from torch import nn
from torchvision import transforms
import torchvision.transforms.functional as F
from torchvision.utils import make_grid
from torchvision.datasets import MNIST

In [None]:
from captum.attr import (
    IntegratedGradients, 
    LayerIntegratedGradients,
    NeuronGradient,
    NeuronIntegratedGradients,
    NeuronGuidedBackprop,
    NeuronDeepLift,
    NeuronDeepLiftShap,
    NeuronGradientShap,
)
from captum.attr import visualization as viz

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2024

## Initialize Path

In [None]:
PATH = Path('data')
model_dir = PATH / 'models'
model_path = model_dir / 'model_chechpoint.ckpt'
images_path = PATH / 'images'
images_path.mkdir(exist_ok=True, parents=True)
pumpkin_path = PATH / 'Pumpkin_Seeds_Dataset.xlsx'

## Model wrapper

In [None]:
class NetWrapper(object):

    def __init__(self, net, transform):
        self._net = net.eval()
        self.transform = transform

    @property
    def net(self):
        return self._net
    
    def __getitem__(self, i):
        return self.net[i]

    def __len__(self):
        return len(self.net)

    @torch.inference_mode()
    def forward(self, x, k=6):
        t = self.transform(x)
        r = self[: k](t)

        return r        
    
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

## Load the model

In [None]:
model = torch.load(model_path, map_location='cpu')

In [None]:
def clear_state_dict(state_dict):
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.1.', 'fc1.')] = state_dict.pop(key)
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.4.', 'fc2.')] = state_dict.pop(key)
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.7.', 'fc3.')] = state_dict.pop(key)

    return state_dict

In [None]:
state_dict = clear_state_dict(model['state_dict'])

In [None]:
channels = 1
width = 28
height = 28
hidden_size = 16
num_classes = 10
in_features = channels * width * height

In [None]:
net = nn.Sequential(OrderedDict([
    ('flatten', nn.Flatten()),
    ('fc1', nn.Linear(channels * width * height, hidden_size)),
    ('act1', nn.ReLU()),
    ('fc2', nn.Linear(hidden_size, hidden_size)),
    ('act2', nn.ReLU()),
    ('fc3', nn.Linear(hidden_size, num_classes)),
]))

In [None]:
net.load_state_dict(state_dict)

In [None]:
net = net.eval()

## Helper functions

In [None]:
def intersect(*arrs):
    return reduce(np.intersect1d, (arrs))

In [None]:
def layer_V(data, net, k=5):
    V = list()
    X = list()
    with tqdm(data) as ds:
        for i, (x, y) in enumerate(ds):
            v = net(x, k=k).detach().numpy()[0]
            V.append(v)
            X.append(x)

    return np.array(V), X

In [None]:
def loop_maxes(V, func, *args, **kwargs):
    with tqdm(V) as mstml:
        for i, v in enumerate(mstml):
            func(i, v, *args, **kwargs)

In [None]:
def select_top(V, idx, thresh):
    tops = list()
    def add_to_top(i, v):
        if thresh <= v[idx]:
            tops.append(i)
    loop_maxes(V, lambda i, v: add_to_top(i, v))
    
    return tops

In [None]:
def find_v_x(V, mrng, idx):
    mid = np.argmin(np.array(V)[mrng], axis=0)[idx]
    x_id = mrng[mid]
    v_x = V[x_id]

    return v_x, x_id

In [None]:
def find_v_A(V, mrng):
    return np.minimum.reduce(np.array(V)[mrng])

In [None]:
def find_G_x(V, v_x):
    with tqdm(V) as mstm:
        G_x = np.array([i for i, v in enumerate(mstm) if np.all(v_x <= v)])

    return G_x

In [None]:
def show_img(ds, idx):
    plt.imshow(ds[idx][0])

In [None]:
def show(imgs, h=12, w=12):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(
        ncols=len(imgs),
        figsize=(w, h),
        squeeze=False
    )
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
def show_grid(G_A, data, nrow=8, h=12, w=12):
    to_tensor = transforms.ToTensor()
    A_gr = [to_tensor(data[i][0]) for i in G_A]
    grid = make_grid(A_gr, nrow=nrow)
    show(grid, h=h, w=w)

In [None]:
def vis_features(*n_idxs, layer=3, func=NeuronIntegratedGradients, use_plot=False, **kwargs):
    # Assuming model is your neural network and input_img is the input image tensor
    # Define the target layer and neuron index
    target_layer = wnet[:layer]
    # Initialize Integrated Gradients
    neuron_ig = func(wnet.net, wnet.net[layer])

    attr_igs = list()
    input_imgs = list()
    for n_idx in n_idxs:
        # Compute the attributions for the target neuron
        input_img = data[arg_max[n_idx]][0]
        input_tns = transform(input_img)
        target_neuron_index = n_idx
        attributions_ig = neuron_ig.attribute(
            input_tns, 
            neuron_selector=target_neuron_index,
            **kwargs,
        )
        attr_ig = attributions_ig.squeeze().cpu().detach().numpy()
        attr_igs.append(attr_ig)
        input_imgs.append(input_img)

    # fig, ax = plt.subplots(len(n_idxs), 2, figsize=(10, 5))

    if use_plot:
        for r, (attr_ig, input_img) in enumerate(zip(attr_igs, input_imgs)):
            # Display the first image
            fig, ax = plt.subplots(1, 2, figsize=(10, 5))
            ax[0].imshow(input_img, cmap='gray')
            ax[0].axis('off')  # Hide the axes
            
            # Display the second image
            ax[1].imshow(attr_ig, cmap='gray')
            ax[1].axis('off')  # Hide the axes
        
        # Show the plot
        plt.show()

    return attr_igs

In [None]:
def display_features(attrs):
    rows = len(attrs)
    # fig, axes = plt.subplots(rows, 3, figsize=(15, 5 * rows))
    for idx in range(rows):
        attr = np.expand_dims(attrs[idx], axis=2)
        orig = np.expand_dims(data[arg_max[idx]][0], axis=2)
        titl = str(idx)
        fig, axs = viz.visualize_image_attr_multiple(
            attr, 
            orig, 
            [
                'original_image', 
                'heat_map', 
                'masked_image',
            ], 
            [
                'all', 
                'positive', 
                'positive'
            ],
            use_pyplot=True,
            titles=[titl, titl, titl],
            show_colorbar=True,
        )
        
    #     axes[idx, 0] = axs[0]
        
    #     axes[idx, 1] = axs[1]
        
    #     axes[idx, 2] = axs[2]

    # plt.tight_layout()
    # plt.show()


## Initialize MNIST dataset

In [None]:
transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
)

In [None]:
data = MNIST(images_path, train=False, download=True)

In [None]:
wnet = NetWrapper(net, transform)

In [None]:
wnet.net

In [None]:
wnet.net[:3]

In [None]:
V_X, X_V = layer_V(data, wnet, k=3)

In [None]:
U_X, X_U = layer_V(data, wnet, k=5)

In [None]:
arg_max = np.argmax(V_X, axis=0)
arg_max

In [None]:
np.max(V_X, axis=0)

In [None]:
V_X[arg_max[13], :]

In [None]:
show_grid(arg_max, data)

In [None]:
attr_all = vis_features(
    *list(range(16)),
    # func=NeuronGradientShap,
    func=NeuronGuidedBackprop,
    # n_steps=n_steps,
)

In [None]:
display_features(attr_all)

In [None]:
len(attr_all)

In [None]:
axs, fig

In [None]:
fig.

In [None]:
n_A = 0
n_B = 9
n_C = 15

## Visualization of distribution

In [None]:
def features_hist(*n_Fs):
    rows = len(n_Fs)
    vs_ls = list()
    with tqdm(n_Fs) as pn_Fs:
        for n_F in pn_Fs:
            vs = [v[n_F] for v in pv]
            vs_ls.append(vs)
    fig, axs = plt.subplots(rows, 1, sharey=True, tight_layout=True)
    for r in range(rows):
        vs_h = vs_ls[r]
        axs[r].hist(vs_h)
        # axs[r].set_title(str(vs_h))
    

In [None]:
features_hist(n_A, n_B, n_C)

## Alanyze maximum stimulus

In [None]:
# th_A = 16
# th_B = 10

# th_A = 8
# th_B = 8

th_A = 1.6
th_B = 1.2
th_C = 1.2

In [None]:
G_A_v_A = select_top(V_X, n_A, th_A)

In [None]:
show_grid(G_A_v_A, data, nrow=48, h=16, w=32)

In [None]:
v_A = find_v_A(V_X, G_A_v_A)
v_A

In [None]:
G_B_v_B = select_top(V_X, n_B, thresh=th_B)

In [None]:
show_grid(G_B_v_B, data, nrow=48, h=32, w=32)

In [None]:
v_B = find_v_A(V_X, G_B_v_B)
v_B

In [None]:
G_C_v_C = select_top(V_X, n_C, thresh=th_C)

In [None]:
show_grid(G_C_v_C, data, nrow=48, h=32, w=32)

In [None]:
v_C = find_v_A(V_X, G_C_v_C)
v_C

In [None]:
G_A = find_G_x(V_X, v_A)
G_A

In [None]:
show_grid(G_A, data, nrow=48, h=16, w=32)

In [None]:
G_B = find_G_x(V_X, v_B)
G_B

In [None]:
show_grid(G_B, data, nrow=48, h=32, w=32)

In [None]:
G_C = find_G_x(V_X, v_C)
G_C

In [None]:
show_grid(G_C, data, nrow=48, h=32, w=32)

In [None]:
D = intersect(G_A, G_B, G_C)
D

In [None]:
show_grid(D, data, nrow=48, h=16, w=32)

In [None]:
plt.imshow(data[D[0]][0])

In [None]:
v_D = np.maximum(v_A, v_B, v_C)
v_D

In [None]:
G_v_D = find_G_x(V_X, v_D)
G_v_D

In [None]:
show_grid(G_v_D, data, nrow=48, h=16, w=32)

In [None]:
u_D = find_v_A(U_X, D)
u_D

In [None]:
U_X[D]

In [None]:
G_u_D = find_G_x(U_X, u_D)
G_u_D

In [None]:
G_u_D.shape

In [None]:
show_grid(G_u_D, data, nrow=48, h=64, w=64)

In [None]:
ys = np.array([data[idx][1] for idx in G_u_D])
un, cn = np.unique(ys, return_counts=True)
uncn = np.array([un, cn])
uncn

In [None]:
plt.imshow(data[D[1]][0])

In [None]:
G_u_D.shape

In [None]:
y = list()
with tqdm(data) as dt:
    for _, y_l in dt:
        y.append(y_l)
y = np.array(y)

In [None]:
y_3_idx = np.where(y == 3)

In [None]:
len(y_3_idx)

In [None]:
y_3_idx[0].shape, G_u_D.shape

In [None]:
result = y_3_idx[0][~np.isin(y_3_idx[0], G_u_D)]
result.shape

In [None]:
show_grid(result, data, nrow=48, h=16, w=32)

In [None]:
np.count_nonzero(y[G_u_D] == 3), G_u_D.shape