In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

## Install libraries

```bash
conda create -n edu4 python=3.11 jupyter matplotlib
```

```bash 
! pip install -U -r requirements.txt
```

```bash
! pip install -U numpy
! pip install -U scikit-learn
```

## Update repository

In [None]:
! git pull

## Add import path

In [None]:
import os
import sys
import gc

In [None]:
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
del module_path

## Organize imports

In [None]:
import multiprocessing
from pathlib import Path

In [None]:
from tqdm import tqdm

In [None]:
import seaborn as sns

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import lightning as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning import Trainer
import matplotlib.pyplot as plt

In [None]:
from src.lattmc.fca.utils import *
from src.lattmc.fca.data_utils import *
from src.lattmc.fca.image_utils import *
from src.lattmc.fca.models import *
from src.lattmc.fca.fca_utils import *
from src.lattmc.fca.image_gens import *

#### Number of CPU cores

In [None]:
workers = multiprocessing.cpu_count()
workers

In [None]:
SEED = 2024

## Initialize Path

In [None]:
PATH = Path('data')
checkpoint_dir = PATH / 'checkpoints' / 'sae_kl_1024_mnist'
checkpoint_dir.mkdir(exist_ok=True, parents=True)
checkpoint_path1 = checkpoint_dir / 'best-checkpoint-v1.ckpt'
checkpoint_path2 = checkpoint_dir / 'best-checkpoint.ckpt'

image_dir = PATH / 'images'
image_path = image_dir / '1024.png'

## Initialize simple dataset

In [None]:
# Updated MNIST data loaders with normalization and validation set
def prepare_data(batch_size=128):
    # Normalize to [0, 1] for MNIST
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),  # Mean and std from MNIST
        transforms.Lambda(lambda x: x.view(-1))  # Flatten the image
    ])

    # Training set
    train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Validation set
    val_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_dataset, train_loader, val_dataset, val_loader


In [None]:
train_dataset, train_loader, val_dataset, val_loader = prepare_data()

## Initialize model

In [None]:
class SparseAutoencoder(pl.LightningModule):
    def __init__(self, input_size=784, hidden_size=512, sparsity_target=0.05, sparsity_weight=1e-3):
        super(SparseAutoencoder, self).__init__()
        # Hyperparameters
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.sparsity_target = sparsity_target
        self.sparsity_weight = sparsity_weight
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(True)
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(hidden_size, input_size),
            nn.Sigmoid()
        )
        self.criterion = nn.MSELoss()

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded, encoded

    def training_step(self, batch, batch_idx):
        x, _ = batch

        # Check for NaN or Inf values in the input
        if torch.isnan(x).any() or torch.isinf(x).any():
            print("Data contains NaN or Inf values!")

        x_hat, encoded = self.forward(x)

        # Reconstruction loss
        recon_loss = self.criterion(x_hat, x)

        # Sparsity regularization (KL Divergence)
        rho_hat = torch.mean(encoded, dim=0)
        rho = torch.ones_like(rho_hat) * 0.05
        kl_loss = torch.sum(self.kl_divergence(rho, rho_hat))
        loss = recon_loss + 1e-3 * kl_loss

        # Check for NaN or Inf values in the loss
        if torch.isnan(loss).any() or torch.isinf(loss).any():
            print("Loss contains NaN or Inf values!")

        self.log('train_loss', loss)

        return loss


    def validation_step(self, batch, batch_idx):
        x, _ = batch
        x_hat, encoded = self.forward(x)

        # Reconstruction loss
        recon_loss = self.criterion(x_hat, x)
        
        # Sparsity regularization (KL Divergence)
        rho_hat = torch.mean(encoded, dim=0)
        rho = torch.ones_like(rho_hat) * self.sparsity_target
        kl_loss = torch.sum(self.kl_divergence(rho, rho_hat))
        loss = recon_loss + self.sparsity_weight * kl_loss
        
        self.log('val_loss', loss)
        
        return loss

    def configure_optimizers(self):
        # Adding weight decay for L2 regularization
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-5)
        
        # Adding learning rate scheduler
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
        
        return [optimizer], [scheduler]


    @staticmethod
    def kl_divergence(p, p_hat):
        # Adding small epsilon (1e-10) to prevent division by zero or log(0)
        eps = 1e-10
        p_hat = torch.clamp(p_hat, eps, 1 - eps)  # Clamp to ensure p_hat is between [eps, 1-eps]
        return p * torch.log(p / p_hat + eps) + (1 - p) * torch.log((1 - p) / (1 - p_hat + eps))


In [None]:
class SAE(nn.Module):
    def __init__(self, input_size=784, hidden_size=512):
        super().__init__()
        # Hyperparameters
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(True)
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(hidden_size, input_size),
            nn.Sigmoid()
        )

    @torch.inference_mode()
    def encode(self, x):
        x_in = x if len(x.shape) == 1 else x.view(-1)
        z = self.encoder(x_in)

        return z

    def encode_np(self, x):
        z = self.encode(x)
        z_np = z.cpu().detach().numpy()

        return z_np

    def forward(self, x):
        encoded = self.encode(x)
        decoded = self.decoder(encoded)
        
        return decoded, encoded

In [None]:
def inference_ds(ds):
    with tqdm(ds) as prds:
        zs = np.array(
            [model.encode_np(x) for x, _ in prds]
        )

    return zs

In [None]:
def find_zs(ds):
    z_kl = dict()
    z_ks = dict()
    with tqdm(ds) as prds:
        for x, y in prds:
            z_k = model.encode_np(x)
            z_kl.setdefault(y, list())
            z_kl[y].append(z_k)
    for k, v in z_kl.items():
        z_ks[k] = np.array(v)

    return z_ks

In [None]:
def gr_idx(z, zs):
    with tqdm(zs) as przs:
        gr = np.array(
            [i for i, z_s in enumerate(przs) if (z <= z_s).all()]
        )

    return gr

In [None]:
def map_z(z):
    G_z = find_G_x(val_z, z)
    A = find_v_A(val_z, G_z)
    idx_z = find_G_x(val_z, A)
    val_y[idx_z]

    return val_y[idx_z]

In [None]:
model = SAE(hidden_size=1024)

In [None]:
checkpoint = torch.load(checkpoint_path2, map_location=torch.device('cpu'))

In [None]:
# checkpoint

In [None]:
model.load_state_dict(checkpoint['state_dict'])

In [None]:
model = model.eval()

In [None]:
import matplotlib.pyplot as plt

# Visualize the learned features (filters) by plotting the encoder weights
def visualize_weights(autoencoder):
    weights = autoencoder.encoder[0].weight.data.cpu().numpy()
    fig, axes = plt.subplots(32, 32, figsize=(32, 32))
    for i, ax in enumerate(axes.flat):
        ax.set_title(f'{i}')
        ax.imshow(weights[i].reshape(28, 28), cmap='gray')
        ax.axis('off')
    plt.savefig(image_path)
    plt.show()

# Call visualization functions
visualize_weights(model)

In [None]:
import matplotlib.pyplot as plt

# Visualize the learned features (filters) by plotting the encoder weights
def visualize_activations(autoencoder, z):
    indices = torch.nonzero(z)
    weights = autoencoder.encoder[0].weight.data.cpu().numpy()
    fig, axes = plt.subplots(len(indices), 1, figsize=(128, 128))
    for i, ax in enumerate(axes.flat):
        ax.set_title(f'{indices[i].cpu().detach().numpy()} {z[indices[i]].cpu().detach().numpy()}')
        ax.imshow(weights[indices[i]].reshape(28, 28), cmap='gray')
        ax.axis('off')
    plt.show()

In [None]:
train_y = np.array([y for _, y in train_dataset])
val_y = np.array([y for _, y in val_dataset])

In [None]:
z_train = find_zs(train_dataset)
z_val = find_zs(val_dataset)

In [None]:
train_z = inference_ds(train_dataset)
val_z = inference_ds(val_dataset)

In [None]:
x, y = val_dataset[12]
x.shape, y

In [None]:
x, y = train_dataset[32]
x.shape, y

In [None]:
z = np.min(z_6, axis=1)

In [None]:
np.any(z)

## Run model

In [None]:
len(val_z), len(train_z)

In [None]:
val_z.shape, train_z.shape, z.shape

In [None]:
t = 938
k = 940
q = 937

In [None]:
z_t = np.zeros(val_z[0].shape)
z_t[t] = 0.5

In [None]:
z_k = np.zeros(val_z[0].shape)
z_k[k] = 0.5

In [None]:
z_q = np.zeros(val_z[0].shape)
z_q[q] = 0.5

In [None]:
t_ge = find_G_x(val_z, z_t)
k_ge = find_G_x(val_z, z_k) 
q_ge = find_G_x(val_z, z_q)

In [None]:
val_y[t_ge]

In [None]:
val_y[k_ge]

In [None]:
val_y[q_ge]

In [None]:
z_qk = np.maximum(z_q, z_k)
z_qk[k], z_qk[q]

In [None]:
z_tq = np.maximum(z_q, z_t)
z_tq[t], z_tq[q]

In [None]:
A = find_G_x(val_z, z_tq)

In [None]:
z_A = find_v_A(val_z, A)

In [None]:
idx_z_A = find_G_x(val_z, z_A)

In [None]:
val_y[idx_z_A]

In [None]:
z_qk = np.maximum(z_q, z_k)
z_qk[k], z_qk[q]

In [None]:
z_kqt = np.maximum(z_q, z_k, z_t)
z_kqt

In [None]:
G_z_kqt = find_G_x(val_z, z_kqt)

In [None]:
z_A_kqt = find_v_A(val_z, G_z_kqt)

In [None]:
idx_z_A_kqt = find_G_x(val_z, z_A_kqt)

In [None]:
val_y[idx_z_A_kqt]

In [None]:
val_y[tq_gr]

In [None]:
val_y[qk_gr]

In [None]:
z_458 = np.zeros(val_z[0].shape)
z_458[458] = 0.5

In [None]:
G_458 = find_G_x(val_z, z_458)

In [None]:
z_A_458 = find_v_A(val_z, G_458)

In [None]:
idx_458 = find_G_x(val_z, z_A_458)

In [None]:
val_y[idx_458]

In [None]:
z_587 = np.zeros(val_z[0].shape)
z_587[587] = 1

In [None]:
G_587 = find_G_x(val_z, z_587)
z_A_587 = find_v_A(val_z, G_587)
idx_587 = find_G_x(val_z, z_A_587)
val_y[idx_587]

In [None]:
z_99 = np.zeros(val_z[0].shape)
z_99[99] = 1
G_99 = find_G_x(val_z, z_99)
z_A_99 = find_v_A(val_z, G_99)
idx_99 = find_G_x(val_z, z_A_99)
val_y[idx_99]

In [None]:
z_A_587[z_A_587.argsort()[-4:]]

In [None]:
z = np.zeros(val_z[0].shape)
z[301] = 0.5
map_z(z)

In [None]:
z1 = np.zeros(val_z[0].shape)
z1[301] = 0.5
map_z(z1)

In [None]:
z2 = np.zeros(val_z[0].shape)
z2[366] = 0.5
map_z(z2)

In [None]:
z1 = np.zeros(val_z[0].shape)
z2 = np.zeros(val_z[0].shape)
z1[301] = 0.5
z2[366] = 0.5
z = np.maximum(z1, z2)
map_z(z)

In [None]:
z1 = np.zeros(val_z[0].shape)
z2 = np.zeros(val_z[0].shape)
z3 = np.zeros(val_z[0].shape)
z1[301] = 0.0001
z2[366] = 0.0001
z3[398] = 0.0001
z = np.maximum(z1, z2)
print(z)
print(z[z.argsort()[-4:]])
z = np.maximum(z, z3)
print(z)
print(z[z.argsort()[-4:]])
map_z(z)

In [None]:
id_ge.shape

In [None]:
z_ge.shape

In [None]:
z_min = np.min(z_ge, axis=0)
z_min.shape

In [None]:
z_min[np.argsort(z_min)[-6:]]

In [None]:
id_ge = np.array([i for i, z_r in enumerate(val_z) if (z_min <= z_r).all()])

In [None]:
for i in id_ge:
    print(val_dataset[i][1])

In [None]:
z_ge.shape

In [None]:
z_ge[0].shape

In [None]:
z = model.encode(x)

In [None]:
z.shape, torch.nonzero(z).shape

In [None]:
# Call visualization functions
visualize_activations(model, x)