In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

In [None]:
from IPython.display import display, Javascript

display(Javascript('''
(function() {
    var interval = setInterval(function() {
        if (typeof google !== 'undefined' && google.translate && google.translate.TranslateElement) {
            clearInterval(interval);
            google.translate.TranslateElement = function() {};
            document.getElementById('google_translate_element')?.remove();
        }
    }, 1000);
})();
'''))

## Update repository

In [None]:
! git pull

## Add import path

In [None]:
import os
import sys
import gc

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
del module_path

## Organize imports

In [None]:
import multiprocessing
from pathlib import Path

In [None]:
from collections import OrderedDict

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

In [None]:
from tqdm import tqdm

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

In [None]:
from src.lattmc.fca.utils import *
from src.lattmc.fca.data_utils import *
from src.lattmc.fca.image_utils import *
from src.lattmc.fca.models import *
from src.lattmc.fca.fca_utils import *

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2024

## Initialize Path

In [None]:
PATH = Path('data')
model_dir = PATH / 'models'
model_path = model_dir / 'simple_cnn_mnist_model.ckpt'
config_dir = PATH / 'config'
config_dir.mkdir(exist_ok=True, parents=True)
config_1_layer_path = config_dir / 'neurons_1_layer.json'
config_cnn_layer_1 = config_dir / 'neurons_cnn_1_layer.json'
images_dir = PATH / 'images'
images_dir.mkdir(exist_ok=True, parents=True)
pumpkin_path = PATH / 'Pumpkin_Seeds_Dataset.xlsx'

## Load the model

In [None]:
model = torch.load(model_path, map_location='cpu')

In [None]:
def clear_state_dict(state_dict):
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.0.', 'conv1.')] = state_dict.pop(key)
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.3.', 'conv2.')] = state_dict.pop(key)
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.8.', 'fc1.')] = state_dict.pop(key)
    for key in list(state_dict.keys()):
        state_dict[key.replace('model.11.', 'fc2.')] = state_dict.pop(key)

    return state_dict

In [None]:
state_dict = clear_state_dict(model['state_dict'])

In [None]:
net = nn.Sequential(OrderedDict([
    ('conv1', nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)),
    ('act1', nn.ReLU()),
    ('mxp1', nn.MaxPool2d(kernel_size=2, stride=2)),
    ('conv2', nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)),
    ('act2', nn.ReLU()),
    ('mxp2', nn.MaxPool2d(kernel_size=2, stride=2)),
    ('flatten', nn.Flatten()),
    ('fc1', nn.Linear(64 * 7 * 7, 128)),
    ('act3', nn.ReLU()),
    ('fc2', nn.Linear(128, 10)),
]))

In [None]:
net.load_state_dict(state_dict)

In [None]:
net = net.eval()

## Initialize MNIST dataset

In [None]:
transform = transforms.Compose(
            [
                ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
)

In [None]:
data_train = MNIST(images_dir, train=True, download=True)
data_test = MNIST(images_dir, train=False, download=True)

In [None]:
next(net.parameters()).device

In [None]:
device = find_device()
device

In [None]:
wnet = NetWrapper(net, transform)

In [None]:
wnet.net

In [None]:
wnet.net[:6]

In [None]:
layer_V_n = 3
layer_U_n = 6

In [None]:
wnet.device

In [None]:
bs = 8

In [None]:
V_X_train, X_V_train = layer_V(data_train, wnet, k=layer_V_n, bs=bs)

In [None]:
V_X_test, X_V_test = layer_V(data_test, wnet, k=layer_V_n, bs=bs)

In [None]:
U_X_train, X_U_train = layer_V(data_train, wnet, k=layer_U_n, bs=bs)

In [None]:
U_X_test, X_U_test = layer_V(data_test, wnet, k=layer_U_n, bs=bs)

In [None]:
arg_max = np.argmax(V_X_train, axis=0)
arg_max.shape

In [None]:
arg_top = np.argsort(V_X_train, axis=0)
arg_top.shape

In [None]:
i = 0

In [None]:
arg_top[-10:].shape

In [None]:
arg_top[-10:,:, :].shape

In [None]:
# V_X_train[arg_top[-16:,1, 9, 2]][:,1, 9, 2]

In [None]:
show_grid(arg_top[-16:,1, 9, 2], data_train, nrow=32)

In [None]:
# np.max(V_X_train, axis=0)

In [None]:
# V_X_train[arg_max[13], :]

In [None]:
show_grid(arg_max[1,:], data_train, nrow=14)

In [None]:
# show_grid(arg_max[0], data_train, nrow=14)

## Sorting vectors

In [None]:
V_X_digits, V_X_sorteds = sort_V_X(V_X_train, data_train)

## Alanyze maximum stimulus

In [None]:
v_Ds = dict()
u_Ds = dict()
G_v_tests = dict()
G_u_tests = dict()

In [None]:
i = 9
ths = [
    568, #0
    330, #1
    672, #2
    580, #3
    470, #4
    590, #5
    484, #6
    544, #7
    640, #8
    584  #9
]
v = np.copy(V_X_sorteds[i][ths[i]])
# v[13] = 0.0

In [None]:
for i in range(10):
    layer_fca = LayerFCA(V_X_train, U_X_train, data_train)
    G_v, v_D, u_D, G_u, G_v_test, G_u_test, uncn_reps = layer_fca.find_G_v_us(
        v, 
        V_X_test, 
        U_X_test, 
        data_test
    )
    v_Ds[i] = v_D
    u_Ds[i] = u_D
    G_v_tests[i] = G_v_test
    G_u_tests[i] = G_u_test

In [None]:
uncn_reps

In [None]:
show_grid(G_v_tests[i], data_test, nrow=32)

In [None]:
show_grid(G_u_tests[i], data_test, nrow=32)

In [None]:
y_hs = np.argmax(wnet(*[data_test[idx][0] for idx in G_u_test]))

In [None]:
uncn_hat = layer_fca.count_ys(y_hs)

In [None]:
uncn_hat

In [None]:
# net[3].weight[:, 0], net[3].bias

In [None]:
layer_fca.uncn

In [None]:
show_grid(G_u_test, data_test, nrow=48, h=64, w=64)

In [None]:
show_grid(G_u_test, data_test, nrow=48, h=64, w=64, my=i)

## Visualization of distribution

In [None]:
argmax_kd_val(V_X_sorteds[idx][ths[idx]][2])

In [None]:
V_X_sorteds[idx][ths[idx]][2][7, 9]

In [None]:
idx = 7
v_test = V_X_sorteds[idx][ths[idx]]
# v_test = V_X_test[0]

In [None]:
visualize_slices(v_test)

In [None]:
digits_train = get_digits(data_train)

In [None]:
# digits_train[0]

In [None]:
# Load an example image
example_image, _ = data_test[i]
# example_image = example_image.unsqueeze(0)  # Add batch dimension

# Visualize the activations
acts = visualize_activations(wnet, digits_train[7][8], layers=[3, 6], hist=False)

In [None]:
digits_train[0][:4]

In [None]:
res = [wnet(x, k=3) for x in digits_train[0]]

In [None]:
res_k = list()
int_k = list()
with tqdm(list(range(32))) as pange:
    for k in pange:
        r_k = [r[k] for r in res]
        res_k.append(r_k)
        int_k.append(intersect_xd(*r_k))

In [None]:
visualize_slices(int_k)

In [None]:
int_30.shape

In [None]:
show_activation(int_k[28])

In [None]:
show_activation(acts[0][1])

In [None]:
indices1 = np.where(acts[0][1] >= 1.2)
indices1, acts[0][1][indices1]

In [None]:
show_activation(acts[0][8])

In [None]:
indices2 = np.where(acts[0][8] >= 1.4)

In [None]:
acts[0][8][indices2]

In [None]:
acts[0][1].shape

In [None]:
acts[0][1][idcs].shape

In [None]:
wnet.net[3]

In [None]:
# Visualize weights of the first convolutional layer
visualize_weights(wnet.net[0], num_filters=32)

# Visualize weights of the second convolutional layer
visualize_weights(wnet.net[3], num_filters=64)

In [None]:
V_X_ds[0].shape

In [None]:
cprs = np.array([np.all(V_X_ds[0][0] <= V_X_d) or np.all(V_X_d < V_X_ds[0][0]) for V_X_d in V_X_ds[0]])

In [None]:
np.where(cprs)

## Inference with FCA

In [None]:
class InferFCA(object):

    def __init__(self, model, layer, us):
        self.model = model
        self.layer = layer
        self.us = us

    def forward(self, x):
        u_x = self.model(x, k=self.layer)
        y_f = list()
        ds = list()
        for n, u in self.us.items():
            if le(u, u_x):
                y_f.append(n)
                d = dist(u, u_x)
                ds.append(d)
        if y_f:
            md= np.argmax(np.array(ds))
            y_mx = y_f[md]
        else:
            y_mx = -1

        return y_f, y_mx
    
    def forward_all(self, data):
        y_fs = list()
        y_fc = list()
        count_fc = 0
        with tqdm(data) as pdta:
            for x, y in pdta:
                y_f, y_mx = self.forward(x)
                y_fs.append(y_f)
                y_fc.append(y_mx)
                if y_f:
                    count_fc += 1

        return y_fs, y_fc, count_fc

In [None]:
inferFCA = InferFCA(wnet, layer_U_n, u_Ds)
y_fcs, y_fds, cn = inferFCA.forward_all(data_test)

In [None]:
for i, y_f in enumerate(y_fcs):
    if len(y_f) > 1:
        print(i, len(y_f), y_f, data_test[i][1], y_fds[i], np.argmax(wnet(data_test[i][0]).to('cpu').detach().numpy()))

In [None]:
count_fc

In [None]:
le(u_Ds[3], u_Ds[5]) or le(u_Ds[5], u_Ds[3])

In [None]:
np.sum(u_Ds[3] * U_X_test[64]), np.sum(u_Ds[5] * U_X_test[64])

In [None]:
u_Ds[3].shape

In [None]:
visualize_slices(u_Ds[3], filters=64)

In [None]:
visualize_slices(u_Ds[5], filters=64)

## Experiments with shapes

In [None]:
import numpy as np
from PIL import Image, ImageDraw
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# Functions to generate images using PIL with variations
def generate_vertical_line_image(height, width, line_length=14, line_thickness=2, shift=0, intensity=255):
    image = Image.new('L', (width, height), 0)
    draw = ImageDraw.Draw(image)
    x = width // 2 + shift
    start_y = (height - line_length) // 2
    end_y = start_y + line_length
    draw.line((x, start_y, x, end_y), fill=intensity, width=line_thickness)
    return np.array(image)

def generate_horizontal_line_image(height, width, line_length=14, line_thickness=2, shift=0, intensity=255):
    image = Image.new('L', (width, height), 0)
    draw = ImageDraw.Draw(image)
    y = height // 2 + shift
    start_x = (width - line_length) // 2
    end_x = start_x + line_length
    draw.line((start_x, y, end_x, y), fill=intensity, width=line_thickness)
    return np.array(image)

def generate_stretched_ring_image(height, width, radius_x=None, radius_y=None, thickness=2, intensity=255):
    image = Image.new('L', (width, height), 0)
    draw = ImageDraw.Draw(image)
    if radius_x is None:
        radius_x = width // 4
    if radius_y is None:
        radius_y = height // 8
    center = (width // 2, height // 2)
    draw.ellipse((center[0] - radius_x, center[1] - radius_y, center[0] + radius_x, center[1] + radius_y), outline=intensity, width=thickness)
    return np.array(image)

# Custom PyTorch Dataset
class CustomShapeDataset(Dataset):
    def __init__(self, num_samples, height=28, width=28):
        self.num_samples = num_samples
        self.height = height
        self.width = width
        self.shapes = ['vertical_line', 'horizontal_line', 'ring']

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        shape_type = np.random.choice(self.shapes)
        shift = np.random.randint(-5, 6)  # Shift lines by up to ±5 pixels
        intensity = np.random.randint(16, 256)  # Random intensity between 50 and 255
        if shape_type == 'vertical_line':
            image = generate_vertical_line_image(self.height, self.width, shift=shift, intensity=intensity)
        elif shape_type == 'horizontal_line':
            image = generate_horizontal_line_image(self.height, self.width, shift=shift, intensity=intensity)
        elif shape_type == 'ring':
            image = generate_stretched_ring_image(self.height, self.width, intensity=intensity)
        
        # Convert image to PyTorch tensor and normalize to [0, 1]
        image = torch.tensor(image, dtype=torch.float32).unsqueeze(0) / 255.0
        
        return image, shape_type

# Example usage
dataset = CustomShapeDataset(num_samples=1000)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Display some examples
def show_images(images, titles, ncols=4):
    nrows = len(images) // ncols
    fig, axs = plt.subplots(nrows, ncols, figsize=(10, 10))
    for i, (img, title) in enumerate(zip(images, titles)):
        ax = axs[i // ncols, i % ncols]
        ax.imshow(img.squeeze(), cmap='gray')
        ax.set_title(title)
        ax.axis('off')
    plt.show()

# Get a batch of images
images, labels = next(iter(dataloader))

# Show a batch of images
show_images(images[:16], labels[:16])

In [None]:
shapes_shp = dict()
for images, labels in dataloader:
    for im, lb in zip(images, labels):
        shapes_shp.setdefault(lb, list())
        shapes_shp[lb].append(im)

In [None]:
v_X_shapes_v = np.array(wnet(*shapes_shp['vertical_line'], k=layer_V_n))
v_X_shapes_v.shape

In [None]:
v_X_shapes_h = np.array(wnet(*shapes_shp['horizontal_line'], k=layer_V_n))
v_X_shapes_h.shape

In [None]:
i = 3

In [None]:
show_img(shapes_shp['vertical_line'], i)

In [None]:
visualize_slices(v_X_shapes_h[i])

In [None]:
v_X_shapes_v_inter = intersect_xd(*[v_X_shapes_v[r] for r in range(v_X_shapes_v.shape[0])])#np.min(v_X_shapes_v, axis=0)

In [None]:
v_X_shapes_v_inter = np.min(v_X_shapes_v, axis=0)

In [None]:
v_X_shapes_v.shape, v_X_shapes_v_inter.shape

In [None]:
diag = torch.zeros((28, 28), dtype=torch.float32)
for i in range(12, 18):
    diag[28 - i, i] = 255
diag /= 255
# diag = diag.t()

In [None]:
plt.imshow(diag)

In [None]:
v_X_shapes_d = wnet(diag.unsqueeze(0), k=layer_V_n)

In [None]:
visualize_slices(v_X_shapes_d[0])

In [None]:
visualize_slices(v_X_shapes_v[0])

In [None]:
v_idx = 0
# fl_v = 16
fl_v = 2
neurons_v = np.zeros(v_X_shapes_v[v_idx].shape)
th = np.max(v_X_shapes_v[v_idx][fl_v]) - np.max(v_X_shapes_v[v_idx][fl_v]) / 4.88
idxs_v = np.where(v_X_shapes_v[v_idx][fl_v] >= th)
idx_v_c = idxs_v[1]
# idx_v_c -= 1
neurons_v[fl_v][idxs_v] = 0.2
# neurons_v[fl_v][2:8, idx_v_c[0]] = 0
# neurons_v[fl_v][9:14, idx_v_c[0]] = 0
v_v = neurons_v

In [None]:
show_activation(neurons_v[fl_v])

In [None]:
d_idx = 0
# fl_d = 16
fl_d = 28
neurons_d = np.zeros(v_X_shapes_d[v_idx].shape)
th = np.max(v_X_shapes_d[v_idx][fl_v]) - np.max(v_X_shapes_d[v_idx][fl_d]) / 4.88
idxs_d = np.where(v_X_shapes_d[d_idx][fl_d] >= th)
idx_d_c = idxs_d[1]
# idx_d_c -= 1
neurons_d[fl_d][idxs_d] = 0.1
# neurons_d[fl_d][2:8, idx_d_c[0]] = 0
# neurons_d[fl_d][9:14, idx_d_c[0]] = 0
v_d = neurons_d

In [None]:
show_activation(neurons_d[fl_d])

In [None]:
layer_fca = LayerFCA(V_X_train, U_X_train, data_train)
G_v, v_D, u_D, G_u, G_v_test, G_u_test, uncn_reps = layer_fca.find_G_v_us(
    v_v, 
    V_X_test, 
    U_X_test, 
    data_test
)
uncn_reps

In [None]:
show_grid(G_v_test, data_test, nrow=32)

In [None]:
layer_fca = LayerFCA(V_X_train, U_X_train, data_train)
G_v, v_D, u_D, G_u, G_v_test, G_u_test, uncn_reps = layer_fca.find_G_v_us(
    v_d, 
    V_X_test, 
    U_X_test, 
    data_test
)
uncn_reps

In [None]:
show_grid(G_v_test, data_test, nrow=32)

In [None]:
# show_grid(G_u_test, data_test, nrow=32)

In [None]:
h_idx = 0
fl_h = 8
neurons_h = np.zeros(v_X_shapes_h[h_idx].shape)
th = np.max(v_X_shapes_h[h_idx][fl_h]) - (np.max(v_X_shapes_h[h_idx][fl_h]) / 4.88)
idx_h = np.where(v_X_shapes_h[h_idx][fl_h] >= th)
idx_h_r = idx_h[0]
idx_h_r_idx = idx_h_r[0]
# idx_h_r += 2
neurons_h[fl_h][idx_h] = 0.2
neurons_h[fl_h][idx_h_r_idx, :4] = 0
neurons_h[fl_h][idx_h_r_idx, 10:12] = 0
v_h = neurons_h
th

In [None]:
show_activation(neurons_h[fl_h])

In [None]:
layer_fca = LayerFCA(V_X_train, U_X_train, data_train)
G_v, v_D, u_D, G_u, G_v_test, G_u_test, uncn_reps = layer_fca.find_G_v_us(
    v_h, 
    V_X_test, 
    U_X_test, 
    data_test
)
uncn_reps

In [None]:
show_grid(G_v_test, data_test, nrow=32)

In [None]:
# show_grid(G_u_test, data_test, nrow=32)

In [None]:
# v_l = [v_v, v_h]
v_l = [v_d, v_h]
v = np.max(np.array(v_l), axis=0)
v.shape

In [None]:
layer_fca = LayerFCA(V_X_train, U_X_train, data_train)
G_v, v_D, u_D, G_u, G_v_test, G_u_test, uncn_reps = layer_fca.find_G_v_us(
    v, 
    V_X_test, 
    U_X_test, 
    data_test
)
uncn_reps

In [None]:
show_grid(G_v_test, data_test, nrow=32)

In [None]:
show_grid(G_u_test, data_test, nrow=64)