## Imports

In [1]:
import sys
import os
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from collections import OrderedDict 
from sklearn import metrics, model_selection
from torch.autograd import Function
from torch.optim import Adam
from graphviz import Source
from pyeda.boolalg.bdd import bdd2expr

filepath = os.path.abspath('')
sys.path.append(os.path.join(filepath, "..", "..", "compiling_nn"))
from build_odd import compile_nn

pd.options.mode.copy_on_write = True

## Metrics, Activations, Loss and Networks definitions

In [2]:
def cm(y_true, y_pred):
    confusion_matrix = metrics.confusion_matrix(y_true, y_pred)
    cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix, display_labels=[False, True])
    return cm_display

def plot_cm(y_true, y_pred):
    cm_display = cm(y_true, y_pred)
    fig, ax = plt.subplots(1, 1, figsize=(4,8))
    cm_display.plot(ax=ax, colorbar=False)

def plot_combine_cm(cms, titles=None):
    n = len(cms)
    fig, axs = plt.subplots(1, n, figsize=(4*n, 8))
    if titles:
        for ax, cm, title in zip(axs, cms, titles):
            cm.plot(ax=ax, colorbar=False)
            ax.set_title(title)
    else:
        for ax, cm in zip(axs, cms):
            cm.plot(ax=ax, colorbar=False)
    fig.tight_layout()

def cov_score(y_true, y_pred):
    labels = np.unique(y_true)
    scores = {}

    for label in labels:
        indices_true = np.where(y_true == label)[0]
        indices_pred = np.where(y_pred == label)[0]
        scores[label] = len(np.intersect1d(indices_true, indices_pred))/len(indices_true)

    return scores

def cross_valid(X, Y, train_func, skf, **kw_train):
    for train_index, test_index in skf.split(X, Y):
        x_train, x_test = X[train_index], X[test_index]
        y_train, y_test = Y[train_index], Y[test_index]

        model, y_pred = train_func(x_train, y_train, **kw_train)
        model.eval()
        yield y_pred.detach().round(), y_train, model(x_test).detach(), y_test

def tnot(a): return torch.logical_not(a)
def tor(a,b): return torch.logical_or(a,b)
def tand(a,b): return torch.logical_and(a,b)
def txor(a,b): return torch.logical_xor(a,b)

In [3]:
class StepFunction(Function):
    @staticmethod
    def forward(ctx, input):
        output = torch.where(input>=0, torch.tensor(1.0), torch.tensor(0.0))
        ctx.save_for_backward(input)
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = torch.zeros_like(input)
        return grad_input
    
class StepActivation(nn.Module):
    def forward(self, input):
        if self.training:
            return torch.sigmoid(input)
        else:
            return StepFunction.apply(input)
    
class AsymMSELoss(nn.Module): # https://www.desmos.com/calculator/zmxcluqhkt
    def __init__(self, p=2):
        super(AsymMSELoss, self).__init__()
        self.p = p

    def forward(self, input, label):
        dif = label - input
        a = torch.square(dif)
        b = a*self.p
        loss = torch.where(dif < 0, b, a)
        loss = torch.mean(loss)
        return loss
    
class AsymBCELoss(nn.Module):
    def __init__(self, p=2):
        super(AsymBCELoss, self).__init__()
        self.p = p

    def forward(self, input, label):
        loss = -torch.maximum(label*torch.log(input) + self.p*(1-label)*torch.log(1-input), torch.full(input.size(), -100))
        loss = torch.mean(loss)
        return loss

In [4]:
class ApproxNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        hl1 = 10

        self.nn = nn.Sequential(OrderedDict([
            ('l1', nn.Linear(25,hl1)),
            ('a1', StepActivation()),
            ('l2', nn.Linear(hl1,1)),
            ('a2', StepActivation())
        ]))        

    def forward(self, x):
        x = self.nn(x)

        return x

class CentralNet(nn.Module):
    def __init__(self):
        super().__init__()

        hl1 = 50
        hl2 = 50

        self. nn = nn.Sequential(
            nn.Linear(25,hl1),
            nn.Sigmoid(),
            nn.Linear(hl1,hl2),
            nn.Sigmoid(),
            nn.Linear(hl2,1),
            StepActivation(),
        )
    
    def forward(self, x):
        x = self.nn(x)

        return x

class BoundNetResults():
    def __init__(self, x, xhi, xnn, xlo):
        self.x = x
        self.hi = xhi
        self.nn = xnn
        self.lo = xlo

    def __getattr__(self, name):
        if hasattr(self.x, name):
            return getattr(self.x, name)
        else:
            raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
    
    def __dir__(self):
        return dir(self.x)
    
    def detach(self):
        self.x = self.x.detach()
        self.hi = self.hi.detach()
        self.nn = self.nn.detach()
        self.lo = self.lo.detach()
        return self
    
    def round(self, *args):
        self.x = self.x.round(*args)
        self.hi = self.hi.round(*args)
        self.nn = self.nn.round(*args)
        self.lo = self.lo.round(*args)
        return self

class BoundNetLoss():
    def __init__(self, hi, nn, lo):
        self.hi_loss_fn = hi
        self.nn_loss_fn = nn
        self.lo_loss_fn = lo

    def __call__(self, pred, true):
        self.nn_loss = self.nn_loss_fn(pred.nn, true)

        target = pred.nn.detach()
        self.hi_loss = self.hi_loss_fn(pred.hi, target)
        self.lo_loss = self.lo_loss_fn(pred.lo, target)

        return self

    def backward(self):
        self.nn_loss.backward()
        self.hi_loss.backward()
        self.lo_loss.backward()

class BoundNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        # high approx nn (bigger = very long to compute into ODD)
        self.hi = ApproxNet()

        # low approx nn
        self.lo = ApproxNet()

        # nn to approximate (can make it bigger easily)
        self.nn = CentralNet()

    def forward(self, x):
        xhi = self.hi(x)
        xnn = self.nn(x)
        xlo = self.lo(x)
        x = torch.where(xhi>0.5, xhi, torch.where(xlo<0.5, xlo, xnn))
        
        x = BoundNetResults(x, xhi, xnn, xlo)

        return x

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()

        hl1 = 50
        hl2 = 50

        self.nn = nn.Sequential(
            nn.Linear(25,hl1),
            nn.Sigmoid(),
            nn.Linear(hl1,hl2),
            nn.Sigmoid(),
            nn.Linear(hl2,1),
            StepActivation(),
        )

    def forward(self, x):
        x = self.nn(x)
        
        return x

## Data processing

In [5]:
df = pd.read_csv("loan_data_set.csv", sep=",")
df = df.drop(columns=["Loan_ID"])

# Remove above 98.5th percentile for 'ApplicantIncome' and 'CoapplicantIncome'
df_rank = df[["ApplicantIncome", "CoapplicantIncome"]]
df_rank["rankA"] = df_rank[["ApplicantIncome"]].rank(pct=True)
df_rank["rankCo"] = df_rank["CoapplicantIncome"].rank(pct=True)

df = df.loc[(df_rank["rankA"]<=0.985) & (df_rank["rankCo"]<=0.985)]
df.index = range(len(df))

# Transform using hot encoding
df_y = pd.get_dummies(df[["Loan_Status"]], drop_first=True)
df_x = df.drop(columns=["Loan_Status"])

nunique = df_x.nunique(axis=0)
df_x_mean = df_x.mean(axis=0, numeric_only=True)

for col, n in nunique.items():
    if n > 4:
        df_x[col] = df_x[col].apply(lambda x : min(4, x//(.5*df_x_mean[col])))

df_x = pd.get_dummies(df_x, columns=df_x.columns, drop_first=True)

# Balance dataset
itrue = df_y.index[df_y["Loan_Status_Y"]==1].tolist()
ifalse = df_y.index[df_y["Loan_Status_Y"]==0].tolist()

swap = len(itrue) > len(ifalse)
if swap:
    itrue,ifalse=ifalse,itrue

ifalse = random.choices(ifalse, k=len(itrue))

if swap:
    itrue,ifalse=ifalse,itrue

print(df_y.iloc[itrue+ifalse].value_counts())

x_data=torch.Tensor(df_x.iloc[itrue+ifalse].to_numpy(dtype=int))
y_data=torch.Tensor(df_y.iloc[itrue+ifalse].to_numpy(dtype=int))

print(x_data.shape, y_data.shape)

Loan_Status_Y
0                184
1                184
dtype: int64
torch.Size([368, 25]) torch.Size([368, 1])


## Training functions

In [6]:
def train_model(x, y, model, loss_fn, optimizer, max_epoch):
    for _ in range(max_epoch):
        model.train()
        y_pred = model(x)
        
        # loss with true y values
        loss = loss_fn(y_pred, y)

        model.zero_grad()
        loss.backward()
        optimizer.step()

    return model, y_pred

def train_boundnet(x, y, max_epoch=5000, learning_rate=1e-2):
    model = BoundNet()
    loss_fn = BoundNetLoss(AsymBCELoss(100), nn.BCELoss(), AsymBCELoss(.001))
    optimizer = Adam(model.parameters(), lr=learning_rate)

    return train_model(x, y, model, loss_fn, optimizer, max_epoch=max_epoch)

def train_simplenet(x, y, max_epoch=5000, learning_rate=1e-2):
    model = SimpleNet()
    loss_fn = nn.BCELoss()
    optimizer = Adam(model.parameters(), lr=learning_rate)

    return train_model(x, y, model, loss_fn, optimizer, max_epoch=max_epoch)

## Network evaluation

In [7]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

def show_activation(act, output):
    # Mean activation per output
    act_ones  = torch.where(output==1, act, torch.zeros(act.size()))
    act_zeros = torch.where(output==0, act, torch.zeros(act.size()))

    mean_ones  = torch.mean(act_ones, dim=0)
    mean_zeros = torch.mean(act_zeros, dim=0)

    # Figure initialization
    fig, ax = plt.subplots(2, 1)
    tick_kw = {'left': False, 'bottom': False, 'labelleft': False}

    # Normalize cmap accross both images
    min_act = min(mean_ones.min().item(), mean_zeros.min().item())
    max_act = max(mean_ones.max().item(), mean_zeros.max().item())

    color_map = 'PRGn'

    ax[0].imshow(mean_zeros.unsqueeze(0), cmap=color_map, vmin=min_act, vmax=max_act)
    ax[0].tick_params(**tick_kw)
    ax[0].set_title("activation moyenne de la couche cachée avec 0 en sortie")

    ax[1].imshow(mean_ones.unsqueeze(0), cmap=color_map, vmin=min_act, vmax=max_act)
    ax[1].tick_params(**tick_kw)
    ax[1].set_title("activation moyenne de la couche cachée avec 1 en sortie")

    # Show text on cells
    for i, (v0, v1) in enumerate(zip(mean_zeros, mean_ones)):
        ax[0].text(i, 0, f"{v0.item():.2f}", ha="center", va="center")
        ax[1].text(i, 0, f"{v1.item():.2f}", ha="center", va="center")
    
    fig.tight_layout()
    plt.show()

def compute_activation(net, layer, data):
    net.eval()
    getattr(net, layer).register_forward_hook(get_activation('__net__'))
    output = net(data).detach()
    act = activation.pop('__net__').squeeze()
    show_activation(act, output)

In [8]:
skf = model_selection.StratifiedKFold(n_splits=10, shuffle=True, random_state=104)
bnet_split_res = cross_valid(x_data, y_data, train_boundnet, skf)
snet_split_res = cross_valid(x_data, y_data, train_simplenet, skf)

for i, ((bt, tt, bv, tv), (st, _, sv, _)) in enumerate(zip(bnet_split_res, snet_split_res)):
    prompts = []
    for b, s, t in [[bt, st, tt],[bv, sv, tv]]:
        b_f1_score = metrics.f1_score(t, b, average="binary")
        s_f1_score = metrics.f1_score(t, s, average="binary")
        prompts.append(f"{'BoundNet':<15}{b_f1_score:.3f}{'|':^9}{'SimpleNet':<15}{s_f1_score:.3f}")

        hi_f1_score = metrics.f1_score(t, b.hi, average="binary")
        lo_f1_score = metrics.f1_score(t, b.lo, average="binary")
        prompts.append(f"{'High':<15}{hi_f1_score:.3f}{'|':^9}{'Low':<15}{lo_f1_score:.3f}")

        hi_cov_score = cov_score(t, b.hi)[1]
        lo_cov_score = cov_score(t, b.lo)[0]
        prompts.append(f"{'High':<15}{hi_cov_score:.3f}{'|':^9}{'Low':<15}{lo_cov_score:.3f}")

    sep_tv = f"{'||':^10}"
    print(f"Fold {i+1:3} :            {'Valid':^49}{sep_tv}{'Train':^49}",
          f"\tF1 score      {prompts[3]}{sep_tv}{prompts[0]}",
          f"\t              {prompts[4]}{sep_tv}{prompts[1]}",
          f"\tCoverage      {prompts[5]}{sep_tv}{prompts[2]}",
          sep='\n')

Fold   1 :                                  Valid                          ||                          Train                      
	F1 score      BoundNet       0.667    |    SimpleNet      0.703    ||    BoundNet       0.982    |    SimpleNet      0.982
	              High           0.483    |    Low            0.679    ||    High           0.804    |    Low            0.665
	Coverage      High           0.368    |    Low            0.000    ||    High           0.673    |    Low            0.000
Fold   2 :                                  Valid                          ||                          Train                      
	F1 score      BoundNet       0.649    |    SimpleNet      0.529    ||    BoundNet       0.979    |    SimpleNet      0.979
	              High           0.516    |    Low            0.679    ||    High           0.904    |    Low            0.665
	Coverage      High           0.421    |    Low            0.000    ||    High           0.824    |    Low            