## Imports

In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from collections import OrderedDict 
from sklearn import metrics, model_selection
from torch.optim import Adam

sys.path.append(os.path.abspath(''))

import utils.more_torch_functions as mtf

from compiling_nn.build_odd import compile_nn
from datasets.loan import get_loan_dataset
from utils.custom_activations import StepActivation, StepFunction
from utils.modules import Parallel, MaxLayer, MaxHierarchicalLayer
from utils.custom_loss import AsymBCELoss

# torch.autograd.set_detect_anomaly(True)
pd.options.mode.copy_on_write = True

## Load data

In [2]:
np_x, np_y = get_loan_dataset(balancing=True, discretizing=False, hot_encoding=True, rmv_pct=0.985)
x_data, y_data = torch.Tensor(np_x), torch.Tensor(np_y)
input_size = x_data.size(1)
print(x_data.size())

torch.Size([280, 14])


## Hooks

In [3]:
intermediate_outputs = {}
def get_intermediate_outputs(name):
    def hook(model, input, output):
        if model.training:
            intermediate_outputs.setdefault(name, dict())["train"] = output
        else:
            intermediate_outputs.setdefault(name, dict())["valid"] = output
    return hook

def true_label_for_backward(train, valid):
    def hook(model, input):
        if model.training:
            model.true_labels = train
        else:
            model.true_labels = valid
    return hook

## Metrics and other utils

In [4]:
def cm(y_true, y_pred):
    confusion_matrix = metrics.confusion_matrix(y_true, y_pred)
    cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix, display_labels=[False, True])
    return cm_display

def plot_cm(y_true, y_pred):
    cm_display = cm(y_true, y_pred)
    _, ax = plt.subplots(1, 1, figsize=(4,8))
    cm_display.plot(ax=ax, colorbar=False)

def plot_combine_cm(cms, titles=None):
    n = len(cms)
    fig, axs = plt.subplots(1, n, figsize=(4*n, 8))
    if titles:
        for ax, cm, title in zip(axs, cms, titles):
            cm.plot(ax=ax, colorbar=False)
            ax.set_title(title)
    else:
        for ax, cm in zip(axs, cms):
            cm.plot(ax=ax, colorbar=False)
    fig.tight_layout()

def cov_score(y_true, y_pred):
    labels = np.unique(y_true)
    scores = {}

    for label in labels:
        indices_true = np.where(y_true == label)[0]
        indices_pred = np.where(y_pred == label)[0]
        scores[label] = len(np.intersect1d(indices_true, indices_pred))/len(indices_true)

    return scores

def combine_prompts(prompts, sep):
    plen = len(prompts)//2
    return '\n\t              '.join([f"{vprompt}{sep}{tprompt}" for vprompt, tprompt in zip(prompts[:plen], prompts[plen:])])

def train_model(x, y, model, loss_fn, optimizer, max_epoch):
    for _ in range(max_epoch):
        model.train()
        y_pred = model(x)
        
        loss = loss_fn(y_pred, y)

        model.zero_grad()
        loss.backward()
        optimizer.step()

    return y_pred

def cross_valid(X, Y, model, loss_fn, optimizer, skf, **kw_train):
    for train_index, test_index in skf.split(X, Y):
        x_train, x_test = X[train_index], X[test_index]
        y_train, y_test = Y[train_index], Y[test_index]

        mtf.reset_model(model)
        model.net.or_.register_forward_pre_hook(true_label_for_backward(y_train, y_test))
        y_pred = train_model(x_train, y_train, model, loss_fn, optimizer, **kw_train)
        y_pred_train = y_pred.detach().round()
        model.eval()
        y_pred_eval = model(x_test).detach()
        yield y_pred_train, y_train, y_pred_eval, y_test

## Networks

### Network parts

In [5]:
class ApproxNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        hl1 = 10

        self.nn = nn.Sequential(OrderedDict([
            ('l1', nn.Linear(input_size,hl1)),
            ('a1', StepActivation()),
            ('l2', nn.Linear(hl1,1)),
            ('a2', StepActivation())
        ]))        

    def forward(self, x):
        x = self.nn(x)

        return x

class CentralNet(nn.Module):
    def __init__(self):
        super().__init__()

        hl1 = 50
        hl2 = 25

        self. nn = nn.Sequential(OrderedDict([
            ('l1', nn.Linear(input_size,hl1)),
            ('a1', nn.Sigmoid()),
            ('l2', nn.Linear(hl1,hl2)),
            ('a2', nn.Sigmoid()),
            ('l3', nn.Linear(hl2,1)),
            ('a3', StepActivation()),
        ]))
    
    def forward(self, x):
        x = self.nn(x)

        return x

### Previous Network (and related)

In [6]:
class NetResults():
    def __init__(self, *tensors):
        for tensor in tensors:
            self.register_result(tensor)

    def __getattr__(self, name):
        if hasattr(self.x, name):
            return getattr(self.x, name)
        else:
            raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
    
    def __dir__(self):
        return dir(self.x)

    def __str__(self):
        return '\n'.join([str(t) for t in self.tensors()])

    def tensors(self):
        for v in self.__dict__.values():
            yield v

    def detach(self):
        for t in self.tensors():
            t.detach()
        return self
    
    def round(self, *args):
        for t in self.tensors():
            t.round(*args)
        return self

class Netv1(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.a1 = ApproxNet()
        self.a2 = ApproxNet()
        self.nn = CentralNet()

    def forward(self, x):
        xa1 = self.a1(x)
        xa2 = self.a2(x)
        xnn = self.nn(x)

        res = [xnn, xa1, xa2]

        # /!\ to change for backward propagation /!\
        x = mtf.bitwise_big_or(*[(torch.round(t)).to(bool) for t in res])
        # maximum ???
        # xmax = mtf.maximum(res)
        # x = torch.where(xmax > 0.5, xmax, xnn)

        x = NetResults(x, *res)

        return x

### New Network definition

In [7]:
class Netv2(nn.Module):
    def __init__(self):
        super().__init__()

        self.net = nn.Sequential(OrderedDict([
            ('nets', Parallel(OrderedDict([
                ('nn', CentralNet()),
                ('apx1', ApproxNet()),
                ('apx2', ApproxNet()),
                ('apx3', ApproxNet()),
                ('apx4', ApproxNet()),
                ('apx5', ApproxNet()),
            ]))),
            ('or_', MaxLayer()),
        ]))

    def forward(self, input):
        return self.net(input)

## Network evaluation

In [8]:
model = Netv2()
loss_fn = nn.BCELoss()
optimizer = Adam(model.parameters(), lr=1e-2, weight_decay=1e-6)

model.net.nets.register_forward_hook(get_intermediate_outputs("parallel_out"))

skf = model_selection.StratifiedKFold(n_splits=10, shuffle=True, random_state=104)
bnet_split_res = cross_valid(x_data, y_data, model, loss_fn, optimizer, skf, max_epoch=5000)

nn_children = [name for name, _ in model.net.nets.named_children()]
dict_metrics = {(modelname, metric, key): [] for modelname in ["net"] + nn_children
                for metric in ("f1score", "coverage0", "coverage1") for key in ("valid", "train")}

for i, (train_pred, train_true, valid_pred, valid_true) in enumerate(bnet_split_res):
    out_nns = intermediate_outputs["parallel_out"]
    for d in out_nns.values():
        for k, v in d.items():
            d[k] = v.detach().round()

    f1prompts = []
    covprompts = []
    sep_model = f"{'|':^9}"
    for k, pred, true in [["valid", valid_pred, valid_true], ["train", train_pred, train_true]]:
        net_f1_score = metrics.f1_score(true, pred, average="binary")
        dict_metrics[('net', 'f1score', k)].append(net_f1_score)

        prev_modelname = 'Net'
        prev_f1_score = net_f1_score

        for c, cname in enumerate(nn_children):
            modelname = 'CentralNet'if c==0 else f'Approx {c}'
            model_pred = out_nns[k][cname]

            model_f1_score = metrics.f1_score(true, model_pred, average="binary")
            model_cov_score = cov_score(true, model_pred)

            if c%2==0:
                f1prompts.append(f"{prev_modelname:<15}{prev_f1_score:.3f}{sep_model}{modelname:<15}{model_f1_score:.3f}")
            else:
                prev_modelname = modelname
                prev_f1_score = model_f1_score
            covprompts.append(f"{modelname:<15}{model_cov_score[0]:.3f}{sep_model}{modelname:<15}{model_cov_score[1]:.3f}")

            dict_metrics[(cname, 'f1score', k)].append(model_f1_score)
            dict_metrics[(cname, 'coverage1', k)].append(model_cov_score[1])
            dict_metrics[(cname, 'coverage0', k)].append(model_cov_score[0])

        if c%2:
            f1prompts.append(f"{modelname:<15}{model_f1_score:.3f}{sep_model}{'':<20}")

    sep_tv = f"{'||':^10}"
    print(f"Fold {i+1:3} :            {'Valid':^49}{sep_tv}{'Train':^49}",
          f"\tF1 score      {combine_prompts(f1prompts, sep_tv)}",
          f"\tCoverage      {combine_prompts(covprompts, sep_tv)}",
          sep='\n')

Fold   1 :                                  Valid                          ||                          Train                      
	F1 score      Net            0.667    |    CentralNet     0.667    ||    Net            1.000    |    CentralNet     0.988
	              Approx 1       0.636    |    Approx 2       0.615    ||    Approx 1       0.746    |    Approx 2       0.904
	              Approx 3       0.727    |    Approx 4       0.727    ||    Approx 3       0.714    |    Approx 4       0.681
	              Approx 5       0.571    |                            ||    Approx 5       0.752    |                        
	Coverage      CentralNet     0.571    |    CentralNet     0.714    ||    CentralNet     1.000    |    CentralNet     0.976
	              Approx 1       0.929    |    Approx 1       0.500    ||    Approx 1       1.000    |    Approx 1       0.595
	              Approx 2       0.714    |    Approx 2       0.571    ||    Approx 2       1.000    |    Approx 2       0.825
	

In [9]:
df_metrics = pd.DataFrame.from_dict(dict_metrics, orient='index')
mean_metrics = df_metrics.mean(axis=1)

f1_mean_prompts = []
cov_mean_prompts = []

for k in ["valid", "train"]:
    prev_modelname = 'Net'
    prev_f1_avg = mean_metrics[('net', 'f1score', 'valid')]

    for c, cname in enumerate(nn_children):
        modelname = 'CentralNet'if c==0 else f'Approx {c}'
        model_f1_avg = mean_metrics[(cname, 'f1score', k)]
        model_cov0_avg = mean_metrics[(cname, 'coverage0', k)]
        model_cov1_avg = mean_metrics[(cname, 'coverage1', k)]

        if c%2==0:
            f1_mean_prompts.append(f"{prev_modelname:<15}{prev_f1_avg:.3f}{sep_model}{modelname:<15}{model_f1_avg:.3f}")
        else:
            prev_modelname = modelname
            prev_f1_score = model_f1_score

        cov_mean_prompts.append(f"{modelname:<15}{model_cov0_avg:.3f}{sep_model}{modelname:<15}{model_cov1_avg:.3f}")
    
    if c%2:
        f1_mean_prompts.append(f"{modelname:<15}{model_f1_score:.3f}{sep_model}{'':<20}")

pn = 5
print(f"Average  :            {'Valid':^49}{sep_tv}{'Train':^49}",
          f"\tF1 score      {combine_prompts(f1_mean_prompts, sep_tv)}",
          f"\tCoverage      {combine_prompts(cov_mean_prompts, sep_tv)}",
          sep='\n')

Average  :                                  Valid                          ||                          Train                      
	F1 score      Net            0.724    |    CentralNet     0.578    ||    Net            0.724    |    CentralNet     0.751
	              Approx 1       0.724    |    Approx 2       0.611    ||    Approx 1       0.724    |    Approx 2       0.858
	              Approx 3       0.724    |    Approx 4       0.597    ||    Approx 3       0.724    |    Approx 4       0.831
	              Approx 5       0.928    |                            ||    Approx 5       0.928    |                        
	Coverage      CentralNet     0.786    |    CentralNet     0.543    ||    CentralNet     1.000    |    CentralNet     0.662
	              Approx 1       0.757    |    Approx 1       0.550    ||    Approx 1       0.999    |    Approx 1       0.753
	              Approx 2       0.714    |    Approx 2       0.586    ||    Approx 2       0.999    |    Approx 2       0.764
	