In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 128
tree_depth = 6
batch_size = 512
device = 'cpu'
train_data_path = r'<>/mitbih_train.csv'  # replace <> with the correct path of the dataset
test_data_path = r'<>/mitbih_test.csv'  # replace <> with the correct path of the dataset

In [3]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [4]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [5]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0).to(device)
        except ValueError:
            knn_loss = torch.tensor(0).to(device)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [6]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [None]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.629314422607422 | KNN Loss: 5.94550085067749 | CLS Loss: 1.6838135719299316
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 6.246893882751465 | KNN Loss: 5.410148620605469 | CLS Loss: 0.8367450833320618
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 5.863901138305664 | KNN Loss: 5.224198341369629 | CLS Loss: 0.6397028565406799
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 5.8153977394104 | KNN Loss: 5.172204494476318 | CLS Loss: 0.6431933641433716
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 5.770508289337158 | KNN Loss: 5.207658290863037 | CLS Loss: 0.5628498792648315
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 5.6319355964660645 | KNN Loss: 5.1199951171875 | CLS Loss: 0.5119404792785645
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 5.606156826019287 | KNN Loss: 5.101445198059082 | CLS Loss: 0.5047116279602051
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 5.439274787902832 | KNN Loss: 5.02860164642334 | CLS Loss: 0

Epoch 4 / 200 | iteration 80 / 171 | Total Loss: 5.045559406280518 | KNN Loss: 4.834207057952881 | CLS Loss: 0.21135224401950836
Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 4.956061840057373 | KNN Loss: 4.874513149261475 | CLS Loss: 0.08154855668544769
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 5.0474443435668945 | KNN Loss: 4.875492095947266 | CLS Loss: 0.17195230722427368
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 4.977890968322754 | KNN Loss: 4.8427324295043945 | CLS Loss: 0.13515834510326385
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 5.003618240356445 | KNN Loss: 4.8549323081970215 | CLS Loss: 0.14868596196174622
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 4.912004470825195 | KNN Loss: 4.81768798828125 | CLS Loss: 0.09431671351194382
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 4.937765598297119 | KNN Loss: 4.827059268951416 | CLS Loss: 0.11070647090673447
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 4.955239772796631 | KNN Loss: 4.84469461

Epoch 7 / 200 | iteration 160 / 171 | Total Loss: 4.912887096405029 | KNN Loss: 4.825107097625732 | CLS Loss: 0.08777984231710434
Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 4.956603527069092 | KNN Loss: 4.823631286621094 | CLS Loss: 0.13297216594219208
Epoch: 007, Loss: 4.8936, Train: 0.9796, Valid: 0.9783, Best: 0.9783
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 4.85853910446167 | KNN Loss: 4.749813556671143 | CLS Loss: 0.10872567445039749
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 4.895758152008057 | KNN Loss: 4.828134059906006 | CLS Loss: 0.06762416660785675
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 4.900043487548828 | KNN Loss: 4.7993669509887695 | CLS Loss: 0.10067657381296158
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 4.841984748840332 | KNN Loss: 4.771867752075195 | CLS Loss: 0.07011718302965164
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 4.883410453796387 | KNN Loss: 4.786993503570557 | CLS Loss: 0.0964171290397644
Epoch 8 / 200 | iteration 50

Epoch 11 / 200 | iteration 60 / 171 | Total Loss: 4.8051252365112305 | KNN Loss: 4.732942581176758 | CLS Loss: 0.07218284159898758
Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 4.853011608123779 | KNN Loss: 4.76618766784668 | CLS Loss: 0.0868237093091011
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 4.826859474182129 | KNN Loss: 4.7362589836120605 | CLS Loss: 0.09060072153806686
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 4.8144049644470215 | KNN Loss: 4.770278453826904 | CLS Loss: 0.04412638023495674
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 4.824174880981445 | KNN Loss: 4.756758213043213 | CLS Loss: 0.06741679459810257
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 4.837057590484619 | KNN Loss: 4.7804694175720215 | CLS Loss: 0.05658822879195213
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 4.781062126159668 | KNN Loss: 4.74584436416626 | CLS Loss: 0.03521760553121567
Epoch 11 / 200 | iteration 130 / 171 | Total Loss: 4.897092342376709 | KNN Loss: 4.798

Epoch 14 / 200 | iteration 140 / 171 | Total Loss: 4.756664752960205 | KNN Loss: 4.704980850219727 | CLS Loss: 0.051684118807315826
Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 4.789742946624756 | KNN Loss: 4.739541053771973 | CLS Loss: 0.05020182952284813
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 4.834897994995117 | KNN Loss: 4.767545700073242 | CLS Loss: 0.06735212355852127
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 4.801946640014648 | KNN Loss: 4.755081653594971 | CLS Loss: 0.046865154057741165
Epoch: 014, Loss: 4.8062, Train: 0.9858, Valid: 0.9823, Best: 0.9823
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 4.761838436126709 | KNN Loss: 4.723233699798584 | CLS Loss: 0.038604702800512314
Epoch 15 / 200 | iteration 10 / 171 | Total Loss: 4.817720890045166 | KNN Loss: 4.738092422485352 | CLS Loss: 0.07962839305400848
Epoch 15 / 200 | iteration 20 / 171 | Total Loss: 4.823766231536865 | KNN Loss: 4.777785301208496 | CLS Loss: 0.045981090515851974
Epoch 15 / 200

Epoch 18 / 200 | iteration 30 / 171 | Total Loss: 4.791929721832275 | KNN Loss: 4.718076229095459 | CLS Loss: 0.07385341823101044
Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 4.797919750213623 | KNN Loss: 4.743964672088623 | CLS Loss: 0.05395513027906418
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 4.834443092346191 | KNN Loss: 4.772032737731934 | CLS Loss: 0.06241016462445259
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 4.888416767120361 | KNN Loss: 4.829061508178711 | CLS Loss: 0.059355124831199646
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 4.80781364440918 | KNN Loss: 4.748607635498047 | CLS Loss: 0.059205882251262665
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 4.81538200378418 | KNN Loss: 4.757635116577148 | CLS Loss: 0.05774672329425812
Epoch 18 / 200 | iteration 90 / 171 | Total Loss: 4.850193977355957 | KNN Loss: 4.7681565284729 | CLS Loss: 0.0820375308394432
Epoch 18 / 200 | iteration 100 / 171 | Total Loss: 4.762050151824951 | KNN Loss: 4.7272033691

Epoch 21 / 200 | iteration 110 / 171 | Total Loss: 4.7503743171691895 | KNN Loss: 4.728378772735596 | CLS Loss: 0.02199532277882099
Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 4.761358261108398 | KNN Loss: 4.707172393798828 | CLS Loss: 0.054185718297958374
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 4.7785186767578125 | KNN Loss: 4.711872577667236 | CLS Loss: 0.0666460320353508
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 4.8240156173706055 | KNN Loss: 4.780041694641113 | CLS Loss: 0.04397384822368622
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 4.765025615692139 | KNN Loss: 4.719274044036865 | CLS Loss: 0.04575154557824135
Epoch 21 / 200 | iteration 160 / 171 | Total Loss: 4.8462042808532715 | KNN Loss: 4.790517330169678 | CLS Loss: 0.05568686127662659
Epoch 21 / 200 | iteration 170 / 171 | Total Loss: 4.716973304748535 | KNN Loss: 4.686045169830322 | CLS Loss: 0.03092792071402073
Epoch: 021, Loss: 4.7774, Train: 0.9884, Valid: 0.9837, Best: 0.9837
Epoch 22 /

Epoch 25 / 200 | iteration 0 / 171 | Total Loss: 4.755165100097656 | KNN Loss: 4.738037586212158 | CLS Loss: 0.017127374187111855
Epoch 25 / 200 | iteration 10 / 171 | Total Loss: 4.724353313446045 | KNN Loss: 4.699821472167969 | CLS Loss: 0.024531763046979904
Epoch 25 / 200 | iteration 20 / 171 | Total Loss: 4.791072368621826 | KNN Loss: 4.752983570098877 | CLS Loss: 0.03808900713920593
Epoch 25 / 200 | iteration 30 / 171 | Total Loss: 4.869054794311523 | KNN Loss: 4.810556411743164 | CLS Loss: 0.0584983229637146
Epoch 25 / 200 | iteration 40 / 171 | Total Loss: 4.77292013168335 | KNN Loss: 4.741631031036377 | CLS Loss: 0.03128919005393982
Epoch 25 / 200 | iteration 50 / 171 | Total Loss: 4.759739875793457 | KNN Loss: 4.718317985534668 | CLS Loss: 0.04142165929079056
Epoch 25 / 200 | iteration 60 / 171 | Total Loss: 4.727336406707764 | KNN Loss: 4.711234092712402 | CLS Loss: 0.016102243214845657
Epoch 25 / 200 | iteration 70 / 171 | Total Loss: 4.849032878875732 | KNN Loss: 4.79315757

Epoch 28 / 200 | iteration 80 / 171 | Total Loss: 4.730710506439209 | KNN Loss: 4.704972743988037 | CLS Loss: 0.025737665593624115
Epoch 28 / 200 | iteration 90 / 171 | Total Loss: 4.779216289520264 | KNN Loss: 4.721531867980957 | CLS Loss: 0.05768420174717903
Epoch 28 / 200 | iteration 100 / 171 | Total Loss: 4.7538347244262695 | KNN Loss: 4.7370758056640625 | CLS Loss: 0.01675872877240181
Epoch 28 / 200 | iteration 110 / 171 | Total Loss: 4.785388469696045 | KNN Loss: 4.743839740753174 | CLS Loss: 0.04154874011874199
Epoch 28 / 200 | iteration 120 / 171 | Total Loss: 4.747068405151367 | KNN Loss: 4.722784042358398 | CLS Loss: 0.02428451180458069
Epoch 28 / 200 | iteration 130 / 171 | Total Loss: 4.732659816741943 | KNN Loss: 4.702388286590576 | CLS Loss: 0.03027172014117241
Epoch 28 / 200 | iteration 140 / 171 | Total Loss: 4.778486251831055 | KNN Loss: 4.72112512588501 | CLS Loss: 0.05736112967133522
Epoch 28 / 200 | iteration 150 / 171 | Total Loss: 4.752243518829346 | KNN Loss: 4.

Epoch 31 / 200 | iteration 160 / 171 | Total Loss: 4.840951442718506 | KNN Loss: 4.79369592666626 | CLS Loss: 0.047255393117666245
Epoch 31 / 200 | iteration 170 / 171 | Total Loss: 4.743808269500732 | KNN Loss: 4.709822177886963 | CLS Loss: 0.03398594260215759
Epoch: 031, Loss: 4.7682, Train: 0.9895, Valid: 0.9825, Best: 0.9853
Epoch 32 / 200 | iteration 0 / 171 | Total Loss: 4.721872806549072 | KNN Loss: 4.705934524536133 | CLS Loss: 0.01593817211687565
Epoch 32 / 200 | iteration 10 / 171 | Total Loss: 4.720330715179443 | KNN Loss: 4.706740379333496 | CLS Loss: 0.01359035074710846
Epoch 32 / 200 | iteration 20 / 171 | Total Loss: 4.738834381103516 | KNN Loss: 4.711851596832275 | CLS Loss: 0.026982568204402924
Epoch 32 / 200 | iteration 30 / 171 | Total Loss: 4.750956058502197 | KNN Loss: 4.713137149810791 | CLS Loss: 0.03781897947192192
Epoch 32 / 200 | iteration 40 / 171 | Total Loss: 4.785808563232422 | KNN Loss: 4.730553150177002 | CLS Loss: 0.05525543913245201
Epoch 32 / 200 | it

Epoch 35 / 200 | iteration 50 / 171 | Total Loss: 4.700842380523682 | KNN Loss: 4.680301666259766 | CLS Loss: 0.02054055780172348
Epoch 35 / 200 | iteration 60 / 171 | Total Loss: 4.737159729003906 | KNN Loss: 4.709205627441406 | CLS Loss: 0.02795402705669403
Epoch 35 / 200 | iteration 70 / 171 | Total Loss: 4.711860656738281 | KNN Loss: 4.683170795440674 | CLS Loss: 0.028689643368124962
Epoch 35 / 200 | iteration 80 / 171 | Total Loss: 4.7458038330078125 | KNN Loss: 4.719736576080322 | CLS Loss: 0.026067135855555534
Epoch 35 / 200 | iteration 90 / 171 | Total Loss: 4.756103038787842 | KNN Loss: 4.717801094055176 | CLS Loss: 0.03830195218324661
Epoch 35 / 200 | iteration 100 / 171 | Total Loss: 4.7588791847229 | KNN Loss: 4.689006328582764 | CLS Loss: 0.06987293064594269
Epoch 35 / 200 | iteration 110 / 171 | Total Loss: 4.735707759857178 | KNN Loss: 4.706620216369629 | CLS Loss: 0.029087338596582413
Epoch 35 / 200 | iteration 120 / 171 | Total Loss: 4.830999851226807 | KNN Loss: 4.750

Epoch 38 / 200 | iteration 120 / 171 | Total Loss: 4.713037967681885 | KNN Loss: 4.689632892608643 | CLS Loss: 0.023405276238918304
Epoch 38 / 200 | iteration 130 / 171 | Total Loss: 4.751391887664795 | KNN Loss: 4.716283798217773 | CLS Loss: 0.03510820120573044
Epoch 38 / 200 | iteration 140 / 171 | Total Loss: 4.694932460784912 | KNN Loss: 4.67671537399292 | CLS Loss: 0.018217025324702263
Epoch 38 / 200 | iteration 150 / 171 | Total Loss: 4.746928691864014 | KNN Loss: 4.712491035461426 | CLS Loss: 0.034437861293554306
Epoch 38 / 200 | iteration 160 / 171 | Total Loss: 4.736726760864258 | KNN Loss: 4.712124347686768 | CLS Loss: 0.02460247464478016
Epoch 38 / 200 | iteration 170 / 171 | Total Loss: 4.77501106262207 | KNN Loss: 4.7244062423706055 | CLS Loss: 0.050604984164237976
Epoch: 038, Loss: 4.7470, Train: 0.9926, Valid: 0.9859, Best: 0.9867
Epoch 39 / 200 | iteration 0 / 171 | Total Loss: 4.799773693084717 | KNN Loss: 4.779982089996338 | CLS Loss: 0.019791459664702415
Epoch 39 / 2

Epoch 42 / 200 | iteration 10 / 171 | Total Loss: 4.7694091796875 | KNN Loss: 4.7399725914001465 | CLS Loss: 0.02943648211658001
Epoch 42 / 200 | iteration 20 / 171 | Total Loss: 4.7939910888671875 | KNN Loss: 4.761270523071289 | CLS Loss: 0.03272048756480217
Epoch 42 / 200 | iteration 30 / 171 | Total Loss: 4.7398552894592285 | KNN Loss: 4.7179951667785645 | CLS Loss: 0.021860260516405106
Epoch 42 / 200 | iteration 40 / 171 | Total Loss: 4.777414798736572 | KNN Loss: 4.743305206298828 | CLS Loss: 0.03410957753658295
Epoch 42 / 200 | iteration 50 / 171 | Total Loss: 4.741176605224609 | KNN Loss: 4.699720859527588 | CLS Loss: 0.04145575314760208
Epoch 42 / 200 | iteration 60 / 171 | Total Loss: 4.72554349899292 | KNN Loss: 4.714349746704102 | CLS Loss: 0.01119384914636612
Epoch 42 / 200 | iteration 70 / 171 | Total Loss: 4.761524200439453 | KNN Loss: 4.734308242797852 | CLS Loss: 0.027216162532567978
Epoch 42 / 200 | iteration 80 / 171 | Total Loss: 4.751012325286865 | KNN Loss: 4.72414

Epoch 45 / 200 | iteration 90 / 171 | Total Loss: 4.71081018447876 | KNN Loss: 4.675663471221924 | CLS Loss: 0.035146936774253845
Epoch 45 / 200 | iteration 100 / 171 | Total Loss: 4.6789398193359375 | KNN Loss: 4.656825542449951 | CLS Loss: 0.022114157676696777
Epoch 45 / 200 | iteration 110 / 171 | Total Loss: 4.75437068939209 | KNN Loss: 4.711879730224609 | CLS Loss: 0.042490795254707336
Epoch 45 / 200 | iteration 120 / 171 | Total Loss: 4.693910121917725 | KNN Loss: 4.659311771392822 | CLS Loss: 0.034598369151353836
Epoch 45 / 200 | iteration 130 / 171 | Total Loss: 4.746001243591309 | KNN Loss: 4.6951422691345215 | CLS Loss: 0.05085901916027069
Epoch 45 / 200 | iteration 140 / 171 | Total Loss: 4.746581077575684 | KNN Loss: 4.719237804412842 | CLS Loss: 0.027343420311808586
Epoch 45 / 200 | iteration 150 / 171 | Total Loss: 4.788656711578369 | KNN Loss: 4.731900215148926 | CLS Loss: 0.05675647035241127
Epoch 45 / 200 | iteration 160 / 171 | Total Loss: 4.721305847167969 | KNN Loss

Epoch 48 / 200 | iteration 160 / 171 | Total Loss: 4.759403228759766 | KNN Loss: 4.7245988845825195 | CLS Loss: 0.03480418026447296
Epoch 48 / 200 | iteration 170 / 171 | Total Loss: 4.73067569732666 | KNN Loss: 4.710829257965088 | CLS Loss: 0.019846387207508087
Epoch: 048, Loss: 4.7346, Train: 0.9928, Valid: 0.9859, Best: 0.9867
Epoch 49 / 200 | iteration 0 / 171 | Total Loss: 4.797367572784424 | KNN Loss: 4.741666316986084 | CLS Loss: 0.0557011179625988
Epoch 49 / 200 | iteration 10 / 171 | Total Loss: 4.736055850982666 | KNN Loss: 4.71439266204834 | CLS Loss: 0.021662971004843712
Epoch 49 / 200 | iteration 20 / 171 | Total Loss: 4.74731969833374 | KNN Loss: 4.733874320983887 | CLS Loss: 0.013445189222693443
Epoch 49 / 200 | iteration 30 / 171 | Total Loss: 4.721357345581055 | KNN Loss: 4.709593296051025 | CLS Loss: 0.011764280498027802
Epoch 49 / 200 | iteration 40 / 171 | Total Loss: 4.727178573608398 | KNN Loss: 4.701723575592041 | CLS Loss: 0.025454789400100708
Epoch 49 / 200 | i

Epoch 52 / 200 | iteration 50 / 171 | Total Loss: 4.663251876831055 | KNN Loss: 4.649634838104248 | CLS Loss: 0.01361693162471056
Epoch 52 / 200 | iteration 60 / 171 | Total Loss: 4.717588424682617 | KNN Loss: 4.658873558044434 | CLS Loss: 0.05871477350592613
Epoch 52 / 200 | iteration 70 / 171 | Total Loss: 4.707067489624023 | KNN Loss: 4.6723713874816895 | CLS Loss: 0.03469623997807503
Epoch 52 / 200 | iteration 80 / 171 | Total Loss: 4.768733501434326 | KNN Loss: 4.719743251800537 | CLS Loss: 0.04899027198553085
Epoch 52 / 200 | iteration 90 / 171 | Total Loss: 4.7491326332092285 | KNN Loss: 4.731719493865967 | CLS Loss: 0.017413312569260597
Epoch 52 / 200 | iteration 100 / 171 | Total Loss: 4.737704753875732 | KNN Loss: 4.721890926361084 | CLS Loss: 0.015813684090971947
Epoch 52 / 200 | iteration 110 / 171 | Total Loss: 4.7184672355651855 | KNN Loss: 4.700959205627441 | CLS Loss: 0.01750796101987362
Epoch 52 / 200 | iteration 120 / 171 | Total Loss: 4.67574405670166 | KNN Loss: 4.6

Epoch 55 / 200 | iteration 130 / 171 | Total Loss: 4.784759044647217 | KNN Loss: 4.745648384094238 | CLS Loss: 0.03911081328988075
Epoch 55 / 200 | iteration 140 / 171 | Total Loss: 4.7013325691223145 | KNN Loss: 4.689765453338623 | CLS Loss: 0.011566933244466782
Epoch 55 / 200 | iteration 150 / 171 | Total Loss: 4.717602252960205 | KNN Loss: 4.695037841796875 | CLS Loss: 0.02256433293223381
Epoch 55 / 200 | iteration 160 / 171 | Total Loss: 4.724377632141113 | KNN Loss: 4.694237232208252 | CLS Loss: 0.030140286311507225
Epoch 55 / 200 | iteration 170 / 171 | Total Loss: 4.778329372406006 | KNN Loss: 4.714925289154053 | CLS Loss: 0.06340423226356506
Epoch: 055, Loss: 4.7325, Train: 0.9941, Valid: 0.9864, Best: 0.9867
Epoch 56 / 200 | iteration 0 / 171 | Total Loss: 4.685586452484131 | KNN Loss: 4.66071081161499 | CLS Loss: 0.02487550675868988
Epoch 56 / 200 | iteration 10 / 171 | Total Loss: 4.751164436340332 | KNN Loss: 4.712457180023193 | CLS Loss: 0.0387071892619133
Epoch 56 / 200 |

Epoch 59 / 200 | iteration 20 / 171 | Total Loss: 4.706699848175049 | KNN Loss: 4.698912143707275 | CLS Loss: 0.007787507027387619
Epoch 59 / 200 | iteration 30 / 171 | Total Loss: 4.676464557647705 | KNN Loss: 4.656978130340576 | CLS Loss: 0.019486529752612114
Epoch 59 / 200 | iteration 40 / 171 | Total Loss: 4.7313761711120605 | KNN Loss: 4.700752258300781 | CLS Loss: 0.03062368556857109
Epoch 59 / 200 | iteration 50 / 171 | Total Loss: 4.67796516418457 | KNN Loss: 4.662048816680908 | CLS Loss: 0.015916133299469948
Epoch 59 / 200 | iteration 60 / 171 | Total Loss: 4.711739540100098 | KNN Loss: 4.691953659057617 | CLS Loss: 0.019785994663834572
Epoch 59 / 200 | iteration 70 / 171 | Total Loss: 4.7136969566345215 | KNN Loss: 4.684292316436768 | CLS Loss: 0.02940448187291622
Epoch 59 / 200 | iteration 80 / 171 | Total Loss: 4.717371463775635 | KNN Loss: 4.69194221496582 | CLS Loss: 0.025429442524909973
Epoch 59 / 200 | iteration 90 / 171 | Total Loss: 4.692537307739258 | KNN Loss: 4.674

Epoch 62 / 200 | iteration 100 / 171 | Total Loss: 4.652350902557373 | KNN Loss: 4.6443305015563965 | CLS Loss: 0.008020630106329918
Epoch 62 / 200 | iteration 110 / 171 | Total Loss: 4.700518608093262 | KNN Loss: 4.686192035675049 | CLS Loss: 0.014326769858598709
Epoch 62 / 200 | iteration 120 / 171 | Total Loss: 4.693160057067871 | KNN Loss: 4.69080924987793 | CLS Loss: 0.002350822789594531
Epoch 62 / 200 | iteration 130 / 171 | Total Loss: 4.712261199951172 | KNN Loss: 4.684502601623535 | CLS Loss: 0.02775839902460575
Epoch 62 / 200 | iteration 140 / 171 | Total Loss: 4.844231605529785 | KNN Loss: 4.797901153564453 | CLS Loss: 0.046330563724040985
Epoch 62 / 200 | iteration 150 / 171 | Total Loss: 4.764029026031494 | KNN Loss: 4.737517833709717 | CLS Loss: 0.026511378586292267
Epoch 62 / 200 | iteration 160 / 171 | Total Loss: 4.782322883605957 | KNN Loss: 4.753468036651611 | CLS Loss: 0.028855066746473312
Epoch 62 / 200 | iteration 170 / 171 | Total Loss: 4.7098798751831055 | KNN L

Epoch 65 / 200 | iteration 170 / 171 | Total Loss: 4.695548057556152 | KNN Loss: 4.674094200134277 | CLS Loss: 0.02145376242697239
Epoch: 065, Loss: 4.7259, Train: 0.9950, Valid: 0.9865, Best: 0.9867
Epoch 66 / 200 | iteration 0 / 171 | Total Loss: 4.687714099884033 | KNN Loss: 4.66969633102417 | CLS Loss: 0.018017776310443878
Epoch 66 / 200 | iteration 10 / 171 | Total Loss: 4.797497749328613 | KNN Loss: 4.770165920257568 | CLS Loss: 0.027331799268722534
Epoch 66 / 200 | iteration 20 / 171 | Total Loss: 4.690598011016846 | KNN Loss: 4.671249866485596 | CLS Loss: 0.01934833452105522
Epoch 66 / 200 | iteration 30 / 171 | Total Loss: 4.665422439575195 | KNN Loss: 4.658097743988037 | CLS Loss: 0.007324494421482086
Epoch 66 / 200 | iteration 40 / 171 | Total Loss: 4.6960673332214355 | KNN Loss: 4.686069965362549 | CLS Loss: 0.009997517801821232
Epoch 66 / 200 | iteration 50 / 171 | Total Loss: 4.771283149719238 | KNN Loss: 4.7330827713012695 | CLS Loss: 0.03820024058222771
Epoch 66 / 200 |

Epoch 69 / 200 | iteration 60 / 171 | Total Loss: 4.684815883636475 | KNN Loss: 4.65388298034668 | CLS Loss: 0.030932817608118057
Epoch 69 / 200 | iteration 70 / 171 | Total Loss: 4.78237771987915 | KNN Loss: 4.7424421310424805 | CLS Loss: 0.039935722947120667
Epoch 69 / 200 | iteration 80 / 171 | Total Loss: 4.728329658508301 | KNN Loss: 4.705081939697266 | CLS Loss: 0.023247847333550453
Epoch 69 / 200 | iteration 90 / 171 | Total Loss: 4.690245628356934 | KNN Loss: 4.670650482177734 | CLS Loss: 0.019595179706811905
Epoch 69 / 200 | iteration 100 / 171 | Total Loss: 4.715877532958984 | KNN Loss: 4.700520038604736 | CLS Loss: 0.015357361175119877
Epoch 69 / 200 | iteration 110 / 171 | Total Loss: 4.7068352699279785 | KNN Loss: 4.674800395965576 | CLS Loss: 0.03203496336936951
Epoch 69 / 200 | iteration 120 / 171 | Total Loss: 4.6674041748046875 | KNN Loss: 4.643582820892334 | CLS Loss: 0.023821575567126274
Epoch 69 / 200 | iteration 130 / 171 | Total Loss: 4.791815280914307 | KNN Loss:

Epoch 72 / 200 | iteration 130 / 171 | Total Loss: 4.696310043334961 | KNN Loss: 4.674035549163818 | CLS Loss: 0.02227451279759407
Epoch 72 / 200 | iteration 140 / 171 | Total Loss: 4.690064430236816 | KNN Loss: 4.660336494445801 | CLS Loss: 0.029728049412369728
Epoch 72 / 200 | iteration 150 / 171 | Total Loss: 4.71966552734375 | KNN Loss: 4.708710670471191 | CLS Loss: 0.01095469668507576
Epoch 72 / 200 | iteration 160 / 171 | Total Loss: 4.832865238189697 | KNN Loss: 4.779080867767334 | CLS Loss: 0.05378442257642746
Epoch 72 / 200 | iteration 170 / 171 | Total Loss: 4.688311576843262 | KNN Loss: 4.677219390869141 | CLS Loss: 0.01109207421541214
Epoch: 072, Loss: 4.7303, Train: 0.9933, Valid: 0.9860, Best: 0.9870
Epoch 73 / 200 | iteration 0 / 171 | Total Loss: 4.676386833190918 | KNN Loss: 4.666842937469482 | CLS Loss: 0.009543945081532001
Epoch 73 / 200 | iteration 10 / 171 | Total Loss: 4.650279998779297 | KNN Loss: 4.643248558044434 | CLS Loss: 0.007031308487057686
Epoch 73 / 200 

Epoch 76 / 200 | iteration 20 / 171 | Total Loss: 4.688770294189453 | KNN Loss: 4.668248176574707 | CLS Loss: 0.02052224613726139
Epoch 76 / 200 | iteration 30 / 171 | Total Loss: 4.696654319763184 | KNN Loss: 4.667148113250732 | CLS Loss: 0.029506415128707886
Epoch 76 / 200 | iteration 40 / 171 | Total Loss: 4.690762519836426 | KNN Loss: 4.67844295501709 | CLS Loss: 0.012319492176175117
Epoch 76 / 200 | iteration 50 / 171 | Total Loss: 4.739765167236328 | KNN Loss: 4.727096080780029 | CLS Loss: 0.012669023126363754
Epoch 76 / 200 | iteration 60 / 171 | Total Loss: 4.729842185974121 | KNN Loss: 4.674856662750244 | CLS Loss: 0.05498557910323143
Epoch 76 / 200 | iteration 70 / 171 | Total Loss: 4.720070838928223 | KNN Loss: 4.707390785217285 | CLS Loss: 0.012680169194936752
Epoch 76 / 200 | iteration 80 / 171 | Total Loss: 4.736769676208496 | KNN Loss: 4.720137119293213 | CLS Loss: 0.016632476821541786
Epoch 76 / 200 | iteration 90 / 171 | Total Loss: 4.712795734405518 | KNN Loss: 4.6929

Epoch 79 / 200 | iteration 90 / 171 | Total Loss: 4.749201774597168 | KNN Loss: 4.71890926361084 | CLS Loss: 0.03029240481555462
Epoch 79 / 200 | iteration 100 / 171 | Total Loss: 4.674299716949463 | KNN Loss: 4.666773796081543 | CLS Loss: 0.00752576207742095
Epoch 79 / 200 | iteration 110 / 171 | Total Loss: 4.729270935058594 | KNN Loss: 4.701777458190918 | CLS Loss: 0.027493465691804886
Epoch 79 / 200 | iteration 120 / 171 | Total Loss: 4.7278151512146 | KNN Loss: 4.6826558113098145 | CLS Loss: 0.04515913128852844
Epoch 79 / 200 | iteration 130 / 171 | Total Loss: 4.664077281951904 | KNN Loss: 4.647403240203857 | CLS Loss: 0.016674185171723366
Epoch 79 / 200 | iteration 140 / 171 | Total Loss: 4.693663597106934 | KNN Loss: 4.684532642364502 | CLS Loss: 0.009130863472819328
Epoch 79 / 200 | iteration 150 / 171 | Total Loss: 4.692898750305176 | KNN Loss: 4.676730632781982 | CLS Loss: 0.016168085858225822
Epoch 79 / 200 | iteration 160 / 171 | Total Loss: 4.729942321777344 | KNN Loss: 4

Epoch 82 / 200 | iteration 160 / 171 | Total Loss: 4.694540023803711 | KNN Loss: 4.6625590324401855 | CLS Loss: 0.03198079764842987
Epoch 82 / 200 | iteration 170 / 171 | Total Loss: 4.677211284637451 | KNN Loss: 4.643377780914307 | CLS Loss: 0.03383341059088707
Epoch: 082, Loss: 4.7175, Train: 0.9944, Valid: 0.9872, Best: 0.9872
Epoch 83 / 200 | iteration 0 / 171 | Total Loss: 4.676074981689453 | KNN Loss: 4.665339469909668 | CLS Loss: 0.010735508985817432
Epoch 83 / 200 | iteration 10 / 171 | Total Loss: 4.698059558868408 | KNN Loss: 4.6681389808654785 | CLS Loss: 0.029920395463705063
Epoch 83 / 200 | iteration 20 / 171 | Total Loss: 4.665213108062744 | KNN Loss: 4.6573405265808105 | CLS Loss: 0.007872720248997211
Epoch 83 / 200 | iteration 30 / 171 | Total Loss: 4.653903007507324 | KNN Loss: 4.6487250328063965 | CLS Loss: 0.005178028251975775
Epoch 83 / 200 | iteration 40 / 171 | Total Loss: 4.6849517822265625 | KNN Loss: 4.680814266204834 | CLS Loss: 0.00413739075884223
Epoch 83 / 

Epoch 86 / 200 | iteration 50 / 171 | Total Loss: 4.665544033050537 | KNN Loss: 4.661722660064697 | CLS Loss: 0.0038213187362998724
Epoch 86 / 200 | iteration 60 / 171 | Total Loss: 4.6744303703308105 | KNN Loss: 4.663593292236328 | CLS Loss: 0.01083712000399828
Epoch 86 / 200 | iteration 70 / 171 | Total Loss: 4.649409770965576 | KNN Loss: 4.638027191162109 | CLS Loss: 0.011382469907402992
Epoch 86 / 200 | iteration 80 / 171 | Total Loss: 4.736215591430664 | KNN Loss: 4.73093318939209 | CLS Loss: 0.005282334052026272
Epoch 86 / 200 | iteration 90 / 171 | Total Loss: 4.667987823486328 | KNN Loss: 4.641734600067139 | CLS Loss: 0.02625301666557789
Epoch 86 / 200 | iteration 100 / 171 | Total Loss: 4.731301307678223 | KNN Loss: 4.70539665222168 | CLS Loss: 0.025904793292284012
Epoch 86 / 200 | iteration 110 / 171 | Total Loss: 4.7266340255737305 | KNN Loss: 4.708229064941406 | CLS Loss: 0.018404977396130562
Epoch 86 / 200 | iteration 120 / 171 | Total Loss: 4.7539567947387695 | KNN Loss: 

Epoch 89 / 200 | iteration 120 / 171 | Total Loss: 4.710534572601318 | KNN Loss: 4.706701755523682 | CLS Loss: 0.0038326869253069162
Epoch 89 / 200 | iteration 130 / 171 | Total Loss: 4.721573352813721 | KNN Loss: 4.687057971954346 | CLS Loss: 0.034515202045440674
Epoch 89 / 200 | iteration 140 / 171 | Total Loss: 4.664742946624756 | KNN Loss: 4.656628608703613 | CLS Loss: 0.008114155381917953
Epoch 89 / 200 | iteration 150 / 171 | Total Loss: 4.746030330657959 | KNN Loss: 4.723965167999268 | CLS Loss: 0.022065024822950363
Epoch 89 / 200 | iteration 160 / 171 | Total Loss: 4.761634826660156 | KNN Loss: 4.739597320556641 | CLS Loss: 0.022037314251065254
Epoch 89 / 200 | iteration 170 / 171 | Total Loss: 4.790550708770752 | KNN Loss: 4.7768659591674805 | CLS Loss: 0.01368456706404686
Epoch: 089, Loss: 4.7135, Train: 0.9946, Valid: 0.9858, Best: 0.9872
Epoch 90 / 200 | iteration 0 / 171 | Total Loss: 4.6970906257629395 | KNN Loss: 4.680536270141602 | CLS Loss: 0.01655414327979088
Epoch 90

Epoch 93 / 200 | iteration 10 / 171 | Total Loss: 4.691937446594238 | KNN Loss: 4.677088737487793 | CLS Loss: 0.01484874077141285
Epoch 93 / 200 | iteration 20 / 171 | Total Loss: 4.669137001037598 | KNN Loss: 4.64680290222168 | CLS Loss: 0.022334007546305656
Epoch 93 / 200 | iteration 30 / 171 | Total Loss: 4.840211868286133 | KNN Loss: 4.810022354125977 | CLS Loss: 0.03018931858241558
Epoch 93 / 200 | iteration 40 / 171 | Total Loss: 4.7588887214660645 | KNN Loss: 4.724839687347412 | CLS Loss: 0.03404882550239563
Epoch 93 / 200 | iteration 50 / 171 | Total Loss: 4.670864105224609 | KNN Loss: 4.652909278869629 | CLS Loss: 0.017954975366592407
Epoch 93 / 200 | iteration 60 / 171 | Total Loss: 4.672861099243164 | KNN Loss: 4.663883686065674 | CLS Loss: 0.008977356366813183
Epoch 93 / 200 | iteration 70 / 171 | Total Loss: 4.652175426483154 | KNN Loss: 4.64774751663208 | CLS Loss: 0.004427762236446142
Epoch 93 / 200 | iteration 80 / 171 | Total Loss: 4.735976696014404 | KNN Loss: 4.71878

Epoch 96 / 200 | iteration 80 / 171 | Total Loss: 4.7123565673828125 | KNN Loss: 4.669720649719238 | CLS Loss: 0.042636122554540634
Epoch 96 / 200 | iteration 90 / 171 | Total Loss: 4.744257926940918 | KNN Loss: 4.701762676239014 | CLS Loss: 0.042495302855968475
Epoch 96 / 200 | iteration 100 / 171 | Total Loss: 4.743129253387451 | KNN Loss: 4.708944797515869 | CLS Loss: 0.03418467566370964
Epoch 96 / 200 | iteration 110 / 171 | Total Loss: 4.729891300201416 | KNN Loss: 4.713491439819336 | CLS Loss: 0.016399793326854706
Epoch 96 / 200 | iteration 120 / 171 | Total Loss: 4.69455623626709 | KNN Loss: 4.666153430938721 | CLS Loss: 0.02840261161327362
Epoch 96 / 200 | iteration 130 / 171 | Total Loss: 4.66073751449585 | KNN Loss: 4.648351669311523 | CLS Loss: 0.012385695241391659
Epoch 96 / 200 | iteration 140 / 171 | Total Loss: 4.741384506225586 | KNN Loss: 4.713133335113525 | CLS Loss: 0.028251243755221367
Epoch 96 / 200 | iteration 150 / 171 | Total Loss: 4.713798522949219 | KNN Loss: 

Epoch 99 / 200 | iteration 150 / 171 | Total Loss: 4.685487270355225 | KNN Loss: 4.668654441833496 | CLS Loss: 0.016832629218697548
Epoch 99 / 200 | iteration 160 / 171 | Total Loss: 4.662221908569336 | KNN Loss: 4.657052040100098 | CLS Loss: 0.005169677548110485
Epoch 99 / 200 | iteration 170 / 171 | Total Loss: 4.832718849182129 | KNN Loss: 4.808174133300781 | CLS Loss: 0.02454492822289467
Epoch: 099, Loss: 4.7115, Train: 0.9949, Valid: 0.9858, Best: 0.9875
Epoch 100 / 200 | iteration 0 / 171 | Total Loss: 4.812679767608643 | KNN Loss: 4.8079423904418945 | CLS Loss: 0.004737257957458496
Epoch 100 / 200 | iteration 10 / 171 | Total Loss: 4.801422119140625 | KNN Loss: 4.781116008758545 | CLS Loss: 0.020306140184402466
Epoch 100 / 200 | iteration 20 / 171 | Total Loss: 4.685364246368408 | KNN Loss: 4.6790995597839355 | CLS Loss: 0.0062647550366818905
Epoch 100 / 200 | iteration 30 / 171 | Total Loss: 4.678440093994141 | KNN Loss: 4.655169486999512 | CLS Loss: 0.02327064424753189
Epoch 1

Epoch 103 / 200 | iteration 30 / 171 | Total Loss: 4.692645072937012 | KNN Loss: 4.679912567138672 | CLS Loss: 0.01273231953382492
Epoch 103 / 200 | iteration 40 / 171 | Total Loss: 4.82916259765625 | KNN Loss: 4.769058704376221 | CLS Loss: 0.06010374054312706
Epoch 103 / 200 | iteration 50 / 171 | Total Loss: 4.667762756347656 | KNN Loss: 4.663081169128418 | CLS Loss: 0.004681488964706659
Epoch 103 / 200 | iteration 60 / 171 | Total Loss: 4.683344841003418 | KNN Loss: 4.674220561981201 | CLS Loss: 0.00912412442266941
Epoch 103 / 200 | iteration 70 / 171 | Total Loss: 4.748209476470947 | KNN Loss: 4.734986305236816 | CLS Loss: 0.013223184272646904
Epoch 103 / 200 | iteration 80 / 171 | Total Loss: 4.668698310852051 | KNN Loss: 4.646714687347412 | CLS Loss: 0.02198338508605957
Epoch 103 / 200 | iteration 90 / 171 | Total Loss: 4.70866584777832 | KNN Loss: 4.679040431976318 | CLS Loss: 0.02962525002658367
Epoch 103 / 200 | iteration 100 / 171 | Total Loss: 4.72878360748291 | KNN Loss: 4.

In [None]:
test(model, test_data_iter, device)

In [None]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

In [None]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

In [None]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

In [None]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

In [None]:
clusters = DBSCAN(eps=2, min_samples=10).fit_predict(projections)

In [None]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

In [None]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [None]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [None]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [None]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [None]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 100
use_cuda = device != 'cpu'

In [None]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [None]:
losses = []
accs = []
sparsity = []

In [None]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

In [None]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

In [None]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

# Tree Visualization

In [None]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

# Extract Rules

# Accumulate samples in the leaves

In [None]:
print(f"Number of patterns: {len(root.get_leaves())}")

In [None]:
method = 'greedy'

In [None]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)

# Tighten boundaries

In [None]:
attr_names = [f"T_{i}" for i in range(test_samples.shape[2])]
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")