## Imports

In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os

import torch
import torch.nn as nn
import pandas as pd

from collections import OrderedDict 
from sklearn import metrics, model_selection
from torch.optim import Adam

sys.path.append(os.path.abspath(''))

import datasets
import utils.more_torch_functions as mtf

from utils.custom_activations import StepActivation
from utils.modules import Parallel, MaxLayer, MaxHierarchicalLayer
from utils.misc import cross_valid, combine_prompts, cov_score, train_model

# torch.autograd.set_detect_anomaly(True)

## Load data

In [2]:
np_x, np_y = datasets.MnistDataset.get_dataset(balancing=True, keep_label=['0', '1'])
x_data, y_data = torch.Tensor(np_x), torch.Tensor(np_y)
input_size = x_data.size(1)
print(x_data.size())

torch.Size([11846, 28, 28])


## Hooks

In [3]:
intermediate_outputs = {}
def get_intermediate_outputs(name):
    def hook(model, input, output):
        if model.training:
            intermediate_outputs.setdefault(name, dict())["train"] = output
        else:
            intermediate_outputs.setdefault(name, dict())["valid"] = output
    return hook

def true_label_for_backward(train, valid):
    def hook(model, input):
        if model.training:
            model.true_labels = train
        else:
            model.true_labels = valid
    return hook

# créer hook fonction de perte pour meilleur backward ? (comparer individuellement les sorties des réseaux ???)

## Networks

### Network parts

In [4]:
class ApproxConvNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.cnn = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(1,1,3,1)),
            ('a1', StepActivation()),
            ('conv2', nn.Conv2d(1,1,3,1)),
            ('a2', StepActivation()),
            ('flatten', nn.Flatten()),
        ]))

        self.fc = nn.Sequential(OrderedDict([
            ('fc', nn.Linear(24*24,1)),
            ('afc', StepActivation()),
        ]))

    def forward(self, x):
        x = self.cnn(x)
        x = self.fc(x)

        return x

class CentralConvNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.cnn = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(1,1,3,1)),
            ('a1', nn.ReLU()),
            ('conv2', nn.Conv2d(1,1,3,1)),
            ('a2', nn.ReLU()),
            ('conv3', nn.Conv2d(1,1,3,1)),
            ('a3', nn.ReLU()),
            ('flatten', nn.Flatten()),
        ]))

        self.fc = nn.Sequential(OrderedDict([
            ('fc', nn.Linear(22*22,1)),
            ('afc', StepActivation()),
        ]))
    
    def forward(self, x):
        x = self.cnn(x)
        x = self.fc(x)

        return x

### New Network definition

In [5]:
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.net = nn.Sequential(OrderedDict([
            ('nets', Parallel(OrderedDict([
                ('cnn', CentralConvNet()),
                ('apx1', ApproxConvNet()),
            ]))),
            ('or_', MaxLayer()),
        ]))

    def forward(self, input):
        return self.net(input)

## Network evaluation

In [6]:
model = ConvNet()
criterion = nn.BCELoss()
optimizer = Adam(model.parameters(), lr=1e-2, weight_decay=1e-6)

model.net.nets.register_forward_hook(get_intermediate_outputs("parallel_out"))

lr = 0.001
num_epochs = 10
batch_size = 64
num_splits = 10

skf = model_selection.StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=76)
convnet_split_res = cross_valid(x_data, y_data, model, criterion, optimizer, skf, batch_size=batch_size, max_epoch=10)

for i, (train_pred, train_true, valid_pred, valid_true) in enumerate(convnet_split_res):
    print(i)

0
1
2
3
4
5
6
7
8
9
