In [4]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [5]:
k = 16
tree_depth = 8
batch_size = 512
device = 'cuda'
train_data_path = r'/mnt/qnap/ekosman/mitbih_train.csv'
test_data_path = r'/mnt/qnap/ekosman/mitbih_test.csv'

In [6]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [7]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [8]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0).to(device)
        except ValueError:
            knn_loss = torch.tensor(0).to(device)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [9]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [None]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.484518527984619 | KNN Loss: 5.702592849731445 | CLS Loss: 1.7819255590438843
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 4.918007850646973 | KNN Loss: 3.6306533813476562 | CLS Loss: 1.2873547077178955
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 4.060115814208984 | KNN Loss: 3.4252774715423584 | CLS Loss: 0.6348381042480469
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 4.008340835571289 | KNN Loss: 3.4107308387756348 | CLS Loss: 0.59760981798172
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 3.8298227787017822 | KNN Loss: 3.2591328620910645 | CLS Loss: 0.570689857006073
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 3.826385498046875 | KNN Loss: 3.267730951309204 | CLS Loss: 0.5586546659469604
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 3.7447385787963867 | KNN Loss: 3.226707696914673 | CLS Loss: 0.5180308222770691
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 3.7083966732025146 | KNN Loss: 3.2126309871673584 | C

Epoch 4 / 200 | iteration 80 / 171 | Total Loss: 3.2201507091522217 | KNN Loss: 3.085808277130127 | CLS Loss: 0.1343424767255783
Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 3.213165521621704 | KNN Loss: 3.0549569129943848 | CLS Loss: 0.15820854902267456
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 3.247746229171753 | KNN Loss: 3.077493667602539 | CLS Loss: 0.17025253176689148
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 3.2177517414093018 | KNN Loss: 3.061121702194214 | CLS Loss: 0.15662993490695953
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 3.2320189476013184 | KNN Loss: 3.052795648574829 | CLS Loss: 0.17922340333461761
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 3.2135558128356934 | KNN Loss: 3.049027442932129 | CLS Loss: 0.16452841460704803
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 3.2569992542266846 | KNN Loss: 3.077732801437378 | CLS Loss: 0.17926636338233948
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 3.2157771587371826 | KNN Loss: 3.0915

Epoch 7 / 200 | iteration 160 / 171 | Total Loss: 3.1755993366241455 | KNN Loss: 3.0636112689971924 | CLS Loss: 0.11198803782463074
Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 3.123162031173706 | KNN Loss: 3.0568337440490723 | CLS Loss: 0.06632837653160095
Epoch: 007, Loss: 3.1496, Train: 0.9759, Valid: 0.9734, Best: 0.9734
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 3.136845350265503 | KNN Loss: 3.0463509559631348 | CLS Loss: 0.09049447625875473
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 3.0937507152557373 | KNN Loss: 3.0225329399108887 | CLS Loss: 0.07121770083904266
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 3.1500046253204346 | KNN Loss: 3.0523664951324463 | CLS Loss: 0.09763810783624649
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 3.1348865032196045 | KNN Loss: 3.016669750213623 | CLS Loss: 0.1182168573141098
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 3.156845808029175 | KNN Loss: 3.062957286834717 | CLS Loss: 0.0938885509967804
Epoch 8 / 200 | iter

Epoch 11 / 200 | iteration 50 / 171 | Total Loss: 3.0776071548461914 | KNN Loss: 3.000746726989746 | CLS Loss: 0.07686050236225128
Epoch 11 / 200 | iteration 60 / 171 | Total Loss: 3.085125684738159 | KNN Loss: 3.013312578201294 | CLS Loss: 0.07181314378976822
Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 3.138761520385742 | KNN Loss: 3.053946018218994 | CLS Loss: 0.08481541275978088
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 3.1633753776550293 | KNN Loss: 3.0122034549713135 | CLS Loss: 0.15117180347442627
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 3.08711838722229 | KNN Loss: 3.022554874420166 | CLS Loss: 0.06456354260444641
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 3.127065896987915 | KNN Loss: 3.0565388202667236 | CLS Loss: 0.070527084171772
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 3.0992629528045654 | KNN Loss: 3.0368895530700684 | CLS Loss: 0.06237330287694931
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 3.086613655090332 | KNN Loss: 3.04

Epoch 14 / 200 | iteration 120 / 171 | Total Loss: 3.096914768218994 | KNN Loss: 2.998443126678467 | CLS Loss: 0.09847155958414078
Epoch 14 / 200 | iteration 130 / 171 | Total Loss: 3.080664873123169 | KNN Loss: 3.0051023960113525 | CLS Loss: 0.07556237280368805
Epoch 14 / 200 | iteration 140 / 171 | Total Loss: 3.1227219104766846 | KNN Loss: 3.0508949756622314 | CLS Loss: 0.07182703167200089
Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 3.1062474250793457 | KNN Loss: 2.973442554473877 | CLS Loss: 0.13280494511127472
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 3.072504758834839 | KNN Loss: 2.9906082153320312 | CLS Loss: 0.08189655095338821
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 3.080322027206421 | KNN Loss: 3.018167734146118 | CLS Loss: 0.06215440481901169
Epoch: 014, Loss: 3.0915, Train: 0.9825, Valid: 0.9788, Best: 0.9795
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 3.092512845993042 | KNN Loss: 3.0173840522766113 | CLS Loss: 0.07512886822223663
Epoch 15 /

Epoch 18 / 200 | iteration 10 / 171 | Total Loss: 3.1174850463867188 | KNN Loss: 3.0808298587799072 | CLS Loss: 0.03665507212281227
Epoch 18 / 200 | iteration 20 / 171 | Total Loss: 3.071300506591797 | KNN Loss: 3.0253753662109375 | CLS Loss: 0.0459250807762146
Epoch 18 / 200 | iteration 30 / 171 | Total Loss: 3.07586669921875 | KNN Loss: 3.048640489578247 | CLS Loss: 0.027226246893405914
Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 3.13127064704895 | KNN Loss: 3.0248804092407227 | CLS Loss: 0.10639022290706635
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 3.0837440490722656 | KNN Loss: 3.005286455154419 | CLS Loss: 0.07845764607191086
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 3.0469810962677 | KNN Loss: 3.0123417377471924 | CLS Loss: 0.034639276564121246
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 3.058379888534546 | KNN Loss: 3.0104284286499023 | CLS Loss: 0.047951411455869675
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 3.136033058166504 | KNN Loss: 3.055

Epoch 21 / 200 | iteration 80 / 171 | Total Loss: 3.072303533554077 | KNN Loss: 3.0304126739501953 | CLS Loss: 0.041890885680913925
Epoch 21 / 200 | iteration 90 / 171 | Total Loss: 3.106809139251709 | KNN Loss: 3.0574450492858887 | CLS Loss: 0.04936406761407852
Epoch 21 / 200 | iteration 100 / 171 | Total Loss: 3.0857889652252197 | KNN Loss: 2.992156505584717 | CLS Loss: 0.09363243728876114
Epoch 21 / 200 | iteration 110 / 171 | Total Loss: 3.0656545162200928 | KNN Loss: 3.009784460067749 | CLS Loss: 0.0558699406683445
Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 3.1175971031188965 | KNN Loss: 3.027939558029175 | CLS Loss: 0.08965752273797989
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 3.1061408519744873 | KNN Loss: 3.034520387649536 | CLS Loss: 0.07162037491798401
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 3.078258752822876 | KNN Loss: 3.0179920196533203 | CLS Loss: 0.06026671081781387
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 3.096146821975708 | KNN Los

Epoch 24 / 200 | iteration 150 / 171 | Total Loss: 3.0880792140960693 | KNN Loss: 3.023932933807373 | CLS Loss: 0.06414623558521271
Epoch 24 / 200 | iteration 160 / 171 | Total Loss: 3.1041336059570312 | KNN Loss: 3.0562996864318848 | CLS Loss: 0.04783383384346962
Epoch 24 / 200 | iteration 170 / 171 | Total Loss: 3.0745058059692383 | KNN Loss: 3.0173163414001465 | CLS Loss: 0.05718955770134926
Epoch: 024, Loss: 3.0675, Train: 0.9862, Valid: 0.9816, Best: 0.9820
Epoch 25 / 200 | iteration 0 / 171 | Total Loss: 3.1071107387542725 | KNN Loss: 3.054157257080078 | CLS Loss: 0.052953582257032394
Epoch 25 / 200 | iteration 10 / 171 | Total Loss: 3.078192710876465 | KNN Loss: 2.9919002056121826 | CLS Loss: 0.08629250526428223
Epoch 25 / 200 | iteration 20 / 171 | Total Loss: 3.0160980224609375 | KNN Loss: 2.994378089904785 | CLS Loss: 0.02171992138028145
Epoch 25 / 200 | iteration 30 / 171 | Total Loss: 3.053192138671875 | KNN Loss: 2.9841277599334717 | CLS Loss: 0.0690644159913063
Epoch 25 /

Epoch 28 / 200 | iteration 40 / 171 | Total Loss: 3.0965096950531006 | KNN Loss: 3.0548324584960938 | CLS Loss: 0.04167718440294266
Epoch 28 / 200 | iteration 50 / 171 | Total Loss: 3.076524496078491 | KNN Loss: 3.0198299884796143 | CLS Loss: 0.05669456720352173
Epoch 28 / 200 | iteration 60 / 171 | Total Loss: 3.093820095062256 | KNN Loss: 3.0308022499084473 | CLS Loss: 0.0630178153514862
Epoch 28 / 200 | iteration 70 / 171 | Total Loss: 3.044827461242676 | KNN Loss: 2.990917444229126 | CLS Loss: 0.05391012132167816
Epoch 28 / 200 | iteration 80 / 171 | Total Loss: 3.0495588779449463 | KNN Loss: 2.9970035552978516 | CLS Loss: 0.052555374801158905
Epoch 28 / 200 | iteration 90 / 171 | Total Loss: 3.11285400390625 | KNN Loss: 3.0553159713745117 | CLS Loss: 0.057538073509931564
Epoch 28 / 200 | iteration 100 / 171 | Total Loss: 3.0618841648101807 | KNN Loss: 2.994277238845825 | CLS Loss: 0.06760697811841965
Epoch 28 / 200 | iteration 110 / 171 | Total Loss: 3.0767714977264404 | KNN Loss:

Epoch 31 / 200 | iteration 110 / 171 | Total Loss: 3.030496120452881 | KNN Loss: 3.0028512477874756 | CLS Loss: 0.027644770219922066
Epoch 31 / 200 | iteration 120 / 171 | Total Loss: 3.0511116981506348 | KNN Loss: 3.0235025882720947 | CLS Loss: 0.027609193697571754
Epoch 31 / 200 | iteration 130 / 171 | Total Loss: 3.0730271339416504 | KNN Loss: 2.9977052211761475 | CLS Loss: 0.07532180100679398
Epoch 31 / 200 | iteration 140 / 171 | Total Loss: 3.1093688011169434 | KNN Loss: 3.0171923637390137 | CLS Loss: 0.09217637777328491
Epoch 31 / 200 | iteration 150 / 171 | Total Loss: 3.046893358230591 | KNN Loss: 3.0313942432403564 | CLS Loss: 0.015499157831072807
Epoch 31 / 200 | iteration 160 / 171 | Total Loss: 3.046731472015381 | KNN Loss: 3.0217323303222656 | CLS Loss: 0.02499902807176113
Epoch 31 / 200 | iteration 170 / 171 | Total Loss: 3.037355899810791 | KNN Loss: 2.9954833984375 | CLS Loss: 0.0418725460767746
Epoch: 031, Loss: 3.0548, Train: 0.9901, Valid: 0.9840, Best: 0.9840
Epoch

Epoch: 034, Loss: 3.0544, Train: 0.9902, Valid: 0.9843, Best: 0.9847
Epoch 35 / 200 | iteration 0 / 171 | Total Loss: 3.002249002456665 | KNN Loss: 2.979762315750122 | CLS Loss: 0.022486748173832893
Epoch 35 / 200 | iteration 10 / 171 | Total Loss: 3.005972385406494 | KNN Loss: 2.9862120151519775 | CLS Loss: 0.01976034790277481
Epoch 35 / 200 | iteration 20 / 171 | Total Loss: 3.07405424118042 | KNN Loss: 3.065553903579712 | CLS Loss: 0.008500357158482075
Epoch 35 / 200 | iteration 30 / 171 | Total Loss: 3.0086770057678223 | KNN Loss: 2.958332061767578 | CLS Loss: 0.05034501478075981
Epoch 35 / 200 | iteration 40 / 171 | Total Loss: 3.0729494094848633 | KNN Loss: 3.030625820159912 | CLS Loss: 0.042323511093854904
Epoch 35 / 200 | iteration 50 / 171 | Total Loss: 3.0724589824676514 | KNN Loss: 2.9882192611694336 | CLS Loss: 0.08423981070518494
Epoch 35 / 200 | iteration 60 / 171 | Total Loss: 3.079284191131592 | KNN Loss: 3.0226094722747803 | CLS Loss: 0.05667470768094063
Epoch 35 / 200

Epoch 38 / 200 | iteration 70 / 171 | Total Loss: 3.0692670345306396 | KNN Loss: 3.0084216594696045 | CLS Loss: 0.06084538996219635
Epoch 38 / 200 | iteration 80 / 171 | Total Loss: 3.060638904571533 | KNN Loss: 3.014662027359009 | CLS Loss: 0.04597682133316994
Epoch 38 / 200 | iteration 90 / 171 | Total Loss: 3.050187587738037 | KNN Loss: 3.0145936012268066 | CLS Loss: 0.03559395670890808
Epoch 38 / 200 | iteration 100 / 171 | Total Loss: 3.0925381183624268 | KNN Loss: 3.0518741607666016 | CLS Loss: 0.040664032101631165
Epoch 38 / 200 | iteration 110 / 171 | Total Loss: 3.1179025173187256 | KNN Loss: 3.065650463104248 | CLS Loss: 0.052252147346735
Epoch 38 / 200 | iteration 120 / 171 | Total Loss: 3.0689852237701416 | KNN Loss: 3.0128448009490967 | CLS Loss: 0.0561404675245285
Epoch 38 / 200 | iteration 130 / 171 | Total Loss: 3.0371639728546143 | KNN Loss: 2.994481325149536 | CLS Loss: 0.042682696133852005
Epoch 38 / 200 | iteration 140 / 171 | Total Loss: 3.058029890060425 | KNN Los

Epoch 41 / 200 | iteration 140 / 171 | Total Loss: 3.04655122756958 | KNN Loss: 2.9815945625305176 | CLS Loss: 0.06495656818151474
Epoch 41 / 200 | iteration 150 / 171 | Total Loss: 3.021655321121216 | KNN Loss: 2.9990034103393555 | CLS Loss: 0.022651933133602142
Epoch 41 / 200 | iteration 160 / 171 | Total Loss: 3.0579395294189453 | KNN Loss: 3.010359525680542 | CLS Loss: 0.047580018639564514
Epoch 41 / 200 | iteration 170 / 171 | Total Loss: 3.0303218364715576 | KNN Loss: 3.0110747814178467 | CLS Loss: 0.019247055053710938
Epoch: 041, Loss: 3.0479, Train: 0.9913, Valid: 0.9860, Best: 0.9860
Epoch 42 / 200 | iteration 0 / 171 | Total Loss: 3.0330517292022705 | KNN Loss: 2.976486921310425 | CLS Loss: 0.05656488612294197
Epoch 42 / 200 | iteration 10 / 171 | Total Loss: 3.0319833755493164 | KNN Loss: 2.99629807472229 | CLS Loss: 0.03568527102470398
Epoch 42 / 200 | iteration 20 / 171 | Total Loss: 3.0837011337280273 | KNN Loss: 3.0506725311279297 | CLS Loss: 0.03302852064371109
Epoch 42

Epoch 45 / 200 | iteration 30 / 171 | Total Loss: 3.0256075859069824 | KNN Loss: 3.0115015506744385 | CLS Loss: 0.014105948619544506
Epoch 45 / 200 | iteration 40 / 171 | Total Loss: 3.079688310623169 | KNN Loss: 3.066108465194702 | CLS Loss: 0.013579923659563065
Epoch 45 / 200 | iteration 50 / 171 | Total Loss: 3.118647336959839 | KNN Loss: 3.092388868331909 | CLS Loss: 0.026258500292897224
Epoch 45 / 200 | iteration 60 / 171 | Total Loss: 3.034674644470215 | KNN Loss: 3.0099868774414062 | CLS Loss: 0.024687718600034714
Epoch 45 / 200 | iteration 70 / 171 | Total Loss: 3.0953006744384766 | KNN Loss: 3.0430116653442383 | CLS Loss: 0.052288953214883804
Epoch 45 / 200 | iteration 80 / 171 | Total Loss: 3.0714797973632812 | KNN Loss: 3.0464415550231934 | CLS Loss: 0.02503821812570095
Epoch 45 / 200 | iteration 90 / 171 | Total Loss: 3.0897979736328125 | KNN Loss: 3.0511820316314697 | CLS Loss: 0.038615882396698
Epoch 45 / 200 | iteration 100 / 171 | Total Loss: 3.0729310512542725 | KNN Lo

Epoch 48 / 200 | iteration 100 / 171 | Total Loss: 3.1204707622528076 | KNN Loss: 3.106139659881592 | CLS Loss: 0.014330994337797165
Epoch 48 / 200 | iteration 110 / 171 | Total Loss: 3.1133551597595215 | KNN Loss: 3.0940017700195312 | CLS Loss: 0.019353380426764488
Epoch 48 / 200 | iteration 120 / 171 | Total Loss: 3.0735158920288086 | KNN Loss: 3.055389404296875 | CLS Loss: 0.018126413226127625
Epoch 48 / 200 | iteration 130 / 171 | Total Loss: 3.065485954284668 | KNN Loss: 3.0124359130859375 | CLS Loss: 0.053050052374601364
Epoch 48 / 200 | iteration 140 / 171 | Total Loss: 3.011496067047119 | KNN Loss: 2.9810543060302734 | CLS Loss: 0.03044179454445839
Epoch 48 / 200 | iteration 150 / 171 | Total Loss: 3.025641679763794 | KNN Loss: 2.999887704849243 | CLS Loss: 0.025753965601325035
Epoch 48 / 200 | iteration 160 / 171 | Total Loss: 3.072685718536377 | KNN Loss: 3.0468103885650635 | CLS Loss: 0.025875402614474297
Epoch 48 / 200 | iteration 170 / 171 | Total Loss: 3.0596275329589844 

In [None]:
test(model, test_data_iter, device)

In [None]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

In [None]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

In [None]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

In [None]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

In [None]:
clusters = DBSCAN(eps=2, min_samples=10).fit_predict(projections)

In [None]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

In [None]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [None]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [None]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [None]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [None]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 100
use_cuda = device != 'cpu'

In [None]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [None]:
losses = []
accs = []
sparsity = []

In [None]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

In [None]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

In [None]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

# Tree Visualization

In [None]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

# Extract Rules

# Accumulate samples in the leaves

In [None]:
print(f"Number of patterns: {len(root.get_leaves())}")

In [None]:
method = 'greedy'

In [None]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)

# Tighten boundaries

In [None]:
attr_names = [f"T_{i}" for i in range(test_samples.shape[2])]
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")