## Imports

In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os

import torch
import torch.nn as nn
import pandas as pd

from collections import OrderedDict 
from sklearn import metrics, model_selection
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

sys.path.append(os.path.abspath(''))

import datasets
import utils.more_torch_functions as mtf

from utils.custom_activations import StepActivation
from utils.modules import Parallel, MaxLayer, MaxHierarchicalLayer
from utils.misc import cross_valid, combine_prompts, cov_score, train_model

# torch.autograd.set_detect_anomaly(True)

## Load data

In [2]:
np_x, np_y = datasets.MnistDataset.get_dataset(balancing=True, keep_label=['0', '1'])
x_data, y_data = torch.Tensor(np_x), torch.Tensor(np_y)
input_size = x_data.size(1)
print(x_data.size())

torch.Size([11846, 28, 28])


## Hooks

In [3]:
intermediate_outputs = {}
def get_intermediate_outputs(name):
    def hook(model, input, output):
        if model.training:
            intermediate_outputs.setdefault(name, dict())["train"] = output
        else:
            intermediate_outputs.setdefault(name, dict())["valid"] = output
    return hook

def true_label_for_backward(train, valid):
    def hook(model, input):
        if model.training:
            model.true_labels = train
        else:
            model.true_labels = valid
    return hook

# créer hook fonction de perte pour meilleur backward ? (comparer individuellement les sorties des réseaux ???)

## Networks

### Network parts

In [4]:
class ApproxConvNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.cnn = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(1,1,3,1)),
            ('a1', StepActivation()),
            ('conv2', nn.Conv2d(1,1,3,1)),
            ('a2', StepActivation()),
            ('flatten', nn.Flatten()),
        ]))

        self.fc = nn.Sequential(OrderedDict([
            ('fc', nn.Linear(24*24,1)),
            ('afc', StepActivation()),
        ]))

    def forward(self, x):
        x = self.cnn(x)
        x = self.fc(x)

        return x

class CentralConvNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.cnn = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(1,1,3,1)),
            ('a1', nn.ReLU()),
            ('conv2', nn.Conv2d(1,1,3,1)),
            ('a2', nn.ReLU()),
            ('conv3', nn.Conv2d(1,1,3,1)),
            ('a3', nn.ReLU()),
            ('flatten', nn.Flatten()),
        ]))

        self.fc = nn.Sequential(OrderedDict([
            ('fc', nn.Linear(22*22,1)),
            ('afc', StepActivation()),
        ]))
    
    def forward(self, x):
        x = self.cnn(x)
        x = self.fc(x)

        return x

### New Network definition

In [5]:
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.net = nn.Sequential(OrderedDict([
            ('nets', Parallel(OrderedDict([
                ('cnn', CentralConvNet()),
                ('apx1', ApproxConvNet()),
            ]))),
            ('or_', MaxLayer()),
        ]))

    def forward(self, input):
        return self.net(input)

## Network evaluation

In [6]:
model = ConvNet()
criterion = nn.BCELoss()
optimizer = Adam(model.parameters(), lr=1e-2, weight_decay=1e-6)

model.net.nets.register_forward_hook(get_intermediate_outputs("parallel_out"))

lr = 0.001
num_epochs = 10
batch_size = 64
num_splits = 10

skf = model_selection.StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=76)

for fold, (train_index, val_index) in enumerate(skf.split(x_data, y_data)):
    # Split data into train and validation sets
    train_images, train_labels = x_data[train_index], y_data[train_index]
    val_images, val_labels = x_data[val_index], y_data[val_index]

    # Convert data to PyTorch tensors
    train_images_tensor = torch.tensor(train_images, dtype=torch.float32).unsqueeze(1)  # Add channel dimension
    train_labels_tensor = torch.tensor(train_labels, dtype=torch.float32)
    val_images_tensor = torch.tensor(val_images, dtype=torch.float32).unsqueeze(1)  # Add channel dimension
    val_labels_tensor = torch.tensor(val_labels, dtype=torch.float32)

    # Create DataLoader for training and validation sets
    train_dataset = TensorDataset(train_images_tensor, train_labels_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataset = TensorDataset(val_images_tensor, val_labels_tensor)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    mtf.reset_model(model)
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation loop
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model(inputs)
                total += labels.size(0)
                correct += (outputs == labels).sum().item()
                val_loss += criterion(outputs, labels).item()

        # Print statistics
        print(f'Fold [{fold + 1}/{num_splits}], Epoch [{epoch + 1}/{num_epochs}], '
              f'Train Loss: {running_loss / len(train_loader):.4f}, '
              f'Val Loss: {val_loss / len(val_loader):.4f}, '
              f'Val Acc: {(correct / total) * 100:.2f}%')

  train_images_tensor = torch.tensor(train_images, dtype=torch.float32).unsqueeze(1)  # Add channel dimension
  train_labels_tensor = torch.tensor(train_labels, dtype=torch.float32)
  val_images_tensor = torch.tensor(val_images, dtype=torch.float32).unsqueeze(1)  # Add channel dimension
  val_labels_tensor = torch.tensor(val_labels, dtype=torch.float32)


Fold [1/10], Epoch [1/10], Train Loss: 0.2373, Val Loss: 0.4112, Val Acc: 99.58%
Fold [1/10], Epoch [2/10], Train Loss: 0.2342, Val Loss: 0.0822, Val Acc: 99.92%
Fold [1/10], Epoch [3/10], Train Loss: 0.2952, Val Loss: 0.4112, Val Acc: 99.58%
Fold [1/10], Epoch [4/10], Train Loss: 0.2757, Val Loss: 0.1645, Val Acc: 99.83%
Fold [1/10], Epoch [5/10], Train Loss: 0.1388, Val Loss: 0.1645, Val Acc: 99.83%
Fold [1/10], Epoch [6/10], Train Loss: 0.1367, Val Loss: 0.1645, Val Acc: 99.83%
Fold [1/10], Epoch [7/10], Train Loss: 0.1866, Val Loss: 0.0822, Val Acc: 99.92%
Fold [1/10], Epoch [8/10], Train Loss: 0.2161, Val Loss: 0.0822, Val Acc: 99.92%
Fold [1/10], Epoch [9/10], Train Loss: 0.0941, Val Loss: 0.0822, Val Acc: 99.92%
Fold [1/10], Epoch [10/10], Train Loss: 0.0945, Val Loss: 0.0822, Val Acc: 99.92%
Fold [2/10], Epoch [1/10], Train Loss: 0.0593, Val Loss: 0.3289, Val Acc: 99.66%
Fold [2/10], Epoch [2/10], Train Loss: 0.0135, Val Loss: 0.4112, Val Acc: 99.58%
Fold [2/10], Epoch [3/10], 