In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

```bash
conda create -n edu4 python=3.11 jupyter matplotlib
```

```bash 
! pip install -U -r requirements.txt
```

```bash
! pip install -U numpy
! pip install -U scikit-learn
```

## Update repository

In [None]:
! git pull

## Add import path

In [None]:
import os
import sys
import gc

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
del module_path

## Organize imports

In [None]:
import multiprocessing
from pathlib import Path

In [None]:
from collections import OrderedDict

In [None]:
import seaborn as sns

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from tqdm import tqdm

In [None]:
import plotly.express as px

In [None]:
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.discriminant_analysis import (
    LinearDiscriminantAnalysis, 
    QuadraticDiscriminantAnalysis
)
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier 
from sklearn.datasets import (
    load_iris,
    load_wine
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import (
    MaxAbsScaler,
    MinMaxScaler,
    StandardScaler
)
from sklearn.metrics import (
    classification_report,
    confusion_matrix
)

In [None]:
from scipy import stats
from scipy.interpolate import interp1d

In [None]:
import torch

In [None]:
from torch import nn
from torchvision import transforms
import torchvision.transforms.functional as F
from torchvision.utils import make_grid
from torchvision.datasets import MNIST

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2024

## Initialize Path

In [None]:
PATH = Path('data')
model_dir = PATH / 'models'
model_path = model_dir / 'model_chechpoint.ckpt'
images_path = PATH / 'images'
images_path.mkdir(exist_ok=True, parents=True)
pumpkin_path = PATH / 'Pumpkin_Seeds_Dataset.xlsx'

## Model wrapper

In [None]:
class NetWrapper(object):

    def __init__(self, net, transform):
        self.net = net.eval()
        self.transform = transform

    @torch.inference_mode()
    def forward(self, x, k=6):
        t = self.transform(x)
        r = self.net[: k](t)

        return r        
    
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)
    

## Load the model

In [None]:
model = torch.load(model_path, map_location='cpu')

In [None]:
model

In [None]:
def clear_state_dict(state_dict):
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.1.', 'fc1.')] = state_dict.pop(key)
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.4.', 'fc2.')] = state_dict.pop(key)
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.7.', 'fc3.')] = state_dict.pop(key)

    return state_dict

In [None]:
state_dict = clear_state_dict(model['state_dict'])

In [None]:
state_dict

In [None]:
channels = 1
width = 28
height = 28
hidden_size = 16
num_classes = 10
in_features = channels * width * height

In [None]:
net = nn.Sequential(OrderedDict([
    ('flatten', nn.Flatten()),
    ('fc1', nn.Linear(channels * width * height, hidden_size)),
    ('act1', nn.ReLU()),
    ('fc2', nn.Linear(hidden_size, hidden_size)),
    ('act2', nn.ReLU()),
    ('fc3', nn.Linear(hidden_size, num_classes)),
]))

In [None]:
net

In [None]:
net.load_state_dict(state_dict)

In [None]:
net = net.eval()
net

## Helper functions

In [None]:
def layer_V(data, net, k=5):
    V = list()
    X = list()
    with tqdm(data) as ds:
        for i, (x, y) in enumerate(ds):
            v = net(x, k=k).detach().numpy()[0]
            V.append(v)
            X.append(x)

    return np.array(V), X

In [None]:
def loop_maxes(V, func, *args, **kwargs):
    with tqdm(V) as mstml:
        for i, v in enumerate(mstml):
            func(i, v, *args, **kwargs)

In [None]:
def select_top(V, idx, thresh):
    tops = list()
    def add_to_top(i, v):
        if thresh <= v[idx]:
            tops.append(i)
    loop_maxes(V, lambda i, v: add_to_top(i, v))
    
    return tops

In [None]:
def find_v_x(V, mrng, idx):
    mid = np.argmin(np.array(V)[mrng], axis=0)[idx]
    x_id = mrng[mid]
    v_x = V[x_id]

    return v_x, x_id

In [None]:
def find_v_A(V, mrng):
    return np.minimum.reduce(np.array(V)[mrng])

In [None]:
def find_G_x(V, v_x):
    with tqdm(V) as mstm:
        G_x = np.array([i for i, v in enumerate(mstm) if np.all(v_x <= v)])

    return G_x

In [None]:
def show_img(ds, idx):
    plt.imshow(ds[idx][0])

In [None]:
def show(
        imgs, 
        h=12, 
        w=12, 
        img_path=None
):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(
        ncols=len(imgs),
        figsize=(w, h),
        squeeze=False
    )
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    if img_path:
        plt.savefig(img_path)
    plt.show()

In [None]:
def show_grid(
        G_A, 
        data, 
        nrow=8, 
        h=12, 
        w=12, 
        img_path=None
):
    to_tensor = transforms.ToTensor()
    A_gr = [to_tensor(data[i][0]) for i in G_A]
    grid = make_grid(A_gr, nrow=nrow)
    show(grid, h=h, w=w, img_path=img_path)

## Initialize MNIST dataset

In [None]:
transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
)

In [None]:
data = MNIST(images_path, train=False, download=True)

In [None]:
wnet = NetWrapper(net, transform)

In [None]:
V_X, X_V = layer_V(data, wnet, k=5)

In [None]:
np.argmax(V_X, axis=0)

In [None]:
np.max(V_X, axis=0)

In [None]:
V_X[3290, :]

In [None]:
show_img(data, 8140)

## Alanyze maximum stimulus

In [None]:
# th_A = 16
# th_B = 10

th_A = 8
th_B = 8

# th_A = 6
# th_B = 6

In [None]:
G_A_v_A = select_top(V_X, 2, th_A)

In [None]:
show_grid(G_A_v_A, data, nrow=48, h=16, w=32, img_path=images_path / f'G_A_v_A.png')

In [None]:
v_a, a_id = find_v_x(V_X, G_A_v_A, 2)
v_a, a_id

In [None]:
show_img(data, a_id)

In [None]:
v_A = find_v_A(V_X, G_A_v_A)
v_A

In [None]:
G_B_v_B = select_top(V_X, 5, thresh=th_B)

In [None]:
show_grid(G_B_v_B, data, nrow=48, h=16, w=32, img_path=images_path / f'G_B_v_B.png')

In [None]:
v_b, b_id = find_v_x(V_X, G_B_v_B, 5)
v_b, b_id

In [None]:
show_img(data, b_id)

In [None]:
v_B = find_v_A(V_X, G_B_v_B)
v_B

In [None]:
G_A = find_G_x(V_X, v_A)
G_A

In [None]:
show_grid(G_A, data, nrow=48, h=16, w=32, img_path=images_path / f'G_A.png')

In [None]:
G_B = find_G_x(V_X, v_B)
G_B

In [None]:
show_grid(G_B, data, nrow=48, h=16, w=32, img_path=images_path / f'G_B.png')

In [None]:
D = np.intersect1d(G_A, G_B)
D

In [None]:
show_grid(D, data, nrow=48, h=16, w=32, img_path=images_path / f'D.png')

In [None]:
plt.imshow(data[6794][0])

In [None]:
v_D = np.maximum(v_A, v_B)
v_D

In [None]:
G_5_D = find_G_x(V_X, v_D)
G_5_D

In [None]:
show_grid(G_5_D, data, nrow=48, h=16, w=32, img_path=images_path / f'G_5_D.png')

In [None]:
U_X, X_U = layer_V(data, wnet, k=6)

In [None]:
u_D = find_v_A(U_X, D)
u_D

In [None]:
U_X[D]

In [None]:
G_u_D = find_G_x(U_X, u_D)
G_u_D

In [None]:
show_grid(G_u_D, data, nrow=48, h=16, w=32, img_path=images_path / 'G_u_D.png')

In [None]:
plt.imshow(data[9390][0])

In [None]:
G_u_D.shape

In [None]:
y = list()
with tqdm(data) as dt:
    for _, y_l in dt:
        y.append(y_l)
y = np.array(y)

In [None]:
y_3_idx = np.where(y == 3)

In [None]:
len(y_3_idx)

In [None]:
y_3_idx[0].shape, G_u_D.shape

In [None]:
result = y_3_idx[0][~np.isin(y_3_idx[0], G_u_D)]
result.shape

In [None]:
show_grid(result, data, nrow=48, h=16, w=32)

In [None]:
np.count_nonzero(y[G_u_D] == 3), G_u_D.shape