In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 128
tree_depth = 8
batch_size = 512
device = 'cpu'
train_data_path = r'<>/mitbih_train.csv'  # replace <> with the correct path of the dataset
test_data_path = r'<>/mitbih_test.csv'  # replace <> with the correct path of the dataset

In [3]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [4]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [5]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0).to(device)
        except ValueError:
            knn_loss = torch.tensor(0).to(device)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [6]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [None]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.479093551635742 | KNN Loss: 5.855673789978027 | CLS Loss: 1.623420000076294
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 6.210329532623291 | KNN Loss: 5.366710186004639 | CLS Loss: 0.843619167804718
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 5.910048961639404 | KNN Loss: 5.182383060455322 | CLS Loss: 0.7276658415794373
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 5.910216808319092 | KNN Loss: 5.185675621032715 | CLS Loss: 0.7245411276817322
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 5.7453765869140625 | KNN Loss: 5.1567535400390625 | CLS Loss: 0.5886231064796448
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 5.7135725021362305 | KNN Loss: 5.165256500244141 | CLS Loss: 0.5483161807060242
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 5.619647979736328 | KNN Loss: 5.102312088012695 | CLS Loss: 0.5173361301422119
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 5.560379505157471 | KNN Loss: 5.040430068969727 | CLS L

Epoch 4 / 200 | iteration 80 / 171 | Total Loss: 4.985884189605713 | KNN Loss: 4.904341697692871 | CLS Loss: 0.08154263347387314
Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 4.954887390136719 | KNN Loss: 4.829421043395996 | CLS Loss: 0.1254662275314331
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 5.042200565338135 | KNN Loss: 4.883302211761475 | CLS Loss: 0.15889844298362732
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 4.958935737609863 | KNN Loss: 4.862678050994873 | CLS Loss: 0.09625766426324844
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 4.986711025238037 | KNN Loss: 4.874623775482178 | CLS Loss: 0.11208714544773102
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 5.0000319480896 | KNN Loss: 4.879295825958252 | CLS Loss: 0.12073612213134766
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 4.962998390197754 | KNN Loss: 4.871077537536621 | CLS Loss: 0.09192091226577759
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 4.993232727050781 | KNN Loss: 4.8496975898742

Epoch 7 / 200 | iteration 160 / 171 | Total Loss: 4.893457889556885 | KNN Loss: 4.822717666625977 | CLS Loss: 0.07074011117219925
Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 4.861776828765869 | KNN Loss: 4.796504974365234 | CLS Loss: 0.06527180969715118
Epoch: 007, Loss: 4.8884, Train: 0.9793, Valid: 0.9767, Best: 0.9767
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 4.8307671546936035 | KNN Loss: 4.7627339363098145 | CLS Loss: 0.06803344935178757
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 4.84659481048584 | KNN Loss: 4.772560119628906 | CLS Loss: 0.07403487712144852
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 4.882425308227539 | KNN Loss: 4.782869338989258 | CLS Loss: 0.09955620020627975
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 4.863083839416504 | KNN Loss: 4.797335147857666 | CLS Loss: 0.06574861705303192
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 4.869645118713379 | KNN Loss: 4.79896879196167 | CLS Loss: 0.07067614048719406
Epoch 8 / 200 | iteration 5

Epoch 11 / 200 | iteration 60 / 171 | Total Loss: 4.899861812591553 | KNN Loss: 4.8376240730285645 | CLS Loss: 0.06223784387111664
Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 4.78014612197876 | KNN Loss: 4.730955600738525 | CLS Loss: 0.04919062554836273
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 4.873658180236816 | KNN Loss: 4.794013977050781 | CLS Loss: 0.07964427769184113
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 4.814882755279541 | KNN Loss: 4.77776575088501 | CLS Loss: 0.03711717575788498
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 4.84725284576416 | KNN Loss: 4.7870941162109375 | CLS Loss: 0.06015869230031967
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 4.857163906097412 | KNN Loss: 4.77741813659668 | CLS Loss: 0.07974568754434586
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 4.795465469360352 | KNN Loss: 4.742465019226074 | CLS Loss: 0.053000494837760925
Epoch 11 / 200 | iteration 130 / 171 | Total Loss: 4.8629584312438965 | KNN Loss: 4.7930

Epoch 14 / 200 | iteration 140 / 171 | Total Loss: 4.8884992599487305 | KNN Loss: 4.823034763336182 | CLS Loss: 0.06546467542648315
Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 4.751100540161133 | KNN Loss: 4.7270307540893555 | CLS Loss: 0.02406957373023033
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 4.875583171844482 | KNN Loss: 4.772170543670654 | CLS Loss: 0.10341261327266693
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 4.816015720367432 | KNN Loss: 4.763940334320068 | CLS Loss: 0.052075307816267014
Epoch: 014, Loss: 4.8055, Train: 0.9863, Valid: 0.9821, Best: 0.9823
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 4.829922199249268 | KNN Loss: 4.766785144805908 | CLS Loss: 0.0631370022892952
Epoch 15 / 200 | iteration 10 / 171 | Total Loss: 4.83040189743042 | KNN Loss: 4.762335777282715 | CLS Loss: 0.06806630641222
Epoch 15 / 200 | iteration 20 / 171 | Total Loss: 4.758620738983154 | KNN Loss: 4.724384307861328 | CLS Loss: 0.03423648327589035
Epoch 15 / 200 | ite

Epoch 18 / 200 | iteration 30 / 171 | Total Loss: 4.717538833618164 | KNN Loss: 4.7014594078063965 | CLS Loss: 0.016079504042863846
Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 4.762868404388428 | KNN Loss: 4.719110012054443 | CLS Loss: 0.04375819116830826
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 4.764823913574219 | KNN Loss: 4.723001956939697 | CLS Loss: 0.041822049766778946
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 4.77712345123291 | KNN Loss: 4.742673873901367 | CLS Loss: 0.034449636936187744
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 4.717177867889404 | KNN Loss: 4.703795433044434 | CLS Loss: 0.013382225297391415
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 4.867327690124512 | KNN Loss: 4.793591499328613 | CLS Loss: 0.07373606413602829
Epoch 18 / 200 | iteration 90 / 171 | Total Loss: 4.80062198638916 | KNN Loss: 4.72520112991333 | CLS Loss: 0.07542107254266739
Epoch 18 / 200 | iteration 100 / 171 | Total Loss: 4.839138031005859 | KNN Loss: 4.75936

Epoch 21 / 200 | iteration 110 / 171 | Total Loss: 4.7034010887146 | KNN Loss: 4.688318729400635 | CLS Loss: 0.015082130208611488
Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 4.746294021606445 | KNN Loss: 4.710805892944336 | CLS Loss: 0.03548790141940117
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 4.78123664855957 | KNN Loss: 4.729268550872803 | CLS Loss: 0.051968008279800415
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 4.771313667297363 | KNN Loss: 4.721961975097656 | CLS Loss: 0.04935183748602867
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 4.804829120635986 | KNN Loss: 4.735812187194824 | CLS Loss: 0.06901676952838898
Epoch 21 / 200 | iteration 160 / 171 | Total Loss: 4.715946197509766 | KNN Loss: 4.691524982452393 | CLS Loss: 0.02442116290330887
Epoch 21 / 200 | iteration 170 / 171 | Total Loss: 4.790764331817627 | KNN Loss: 4.752243518829346 | CLS Loss: 0.03852062672376633
Epoch: 021, Loss: 4.7665, Train: 0.9877, Valid: 0.9832, Best: 0.9847
Epoch 22 / 200 

Epoch 25 / 200 | iteration 0 / 171 | Total Loss: 4.754667282104492 | KNN Loss: 4.715234279632568 | CLS Loss: 0.03943293169140816
Epoch 25 / 200 | iteration 10 / 171 | Total Loss: 4.731079578399658 | KNN Loss: 4.697120666503906 | CLS Loss: 0.03395884856581688
Epoch 25 / 200 | iteration 20 / 171 | Total Loss: 4.701995849609375 | KNN Loss: 4.673275947570801 | CLS Loss: 0.028720084577798843
Epoch 25 / 200 | iteration 30 / 171 | Total Loss: 4.757462501525879 | KNN Loss: 4.729284286499023 | CLS Loss: 0.0281781367957592
Epoch 25 / 200 | iteration 40 / 171 | Total Loss: 4.768881797790527 | KNN Loss: 4.715585708618164 | CLS Loss: 0.05329617112874985
Epoch 25 / 200 | iteration 50 / 171 | Total Loss: 4.754007816314697 | KNN Loss: 4.720300674438477 | CLS Loss: 0.033707145601511
Epoch 25 / 200 | iteration 60 / 171 | Total Loss: 4.763588905334473 | KNN Loss: 4.693790912628174 | CLS Loss: 0.06979823112487793
Epoch 25 / 200 | iteration 70 / 171 | Total Loss: 4.706191539764404 | KNN Loss: 4.67548942565

Epoch 28 / 200 | iteration 80 / 171 | Total Loss: 4.859772205352783 | KNN Loss: 4.802390098571777 | CLS Loss: 0.05738217011094093
Epoch 28 / 200 | iteration 90 / 171 | Total Loss: 4.7455878257751465 | KNN Loss: 4.7022881507873535 | CLS Loss: 0.04329954832792282
Epoch 28 / 200 | iteration 100 / 171 | Total Loss: 4.741067409515381 | KNN Loss: 4.72560453414917 | CLS Loss: 0.015462839975953102
Epoch 28 / 200 | iteration 110 / 171 | Total Loss: 4.756267070770264 | KNN Loss: 4.700916290283203 | CLS Loss: 0.05535084381699562
Epoch 28 / 200 | iteration 120 / 171 | Total Loss: 4.73722505569458 | KNN Loss: 4.718879699707031 | CLS Loss: 0.018345534801483154
Epoch 28 / 200 | iteration 130 / 171 | Total Loss: 4.838357448577881 | KNN Loss: 4.746140480041504 | CLS Loss: 0.09221718460321426
Epoch 28 / 200 | iteration 140 / 171 | Total Loss: 4.710324287414551 | KNN Loss: 4.68350887298584 | CLS Loss: 0.026815634220838547
Epoch 28 / 200 | iteration 150 / 171 | Total Loss: 4.743510723114014 | KNN Loss: 4.

Epoch 31 / 200 | iteration 160 / 171 | Total Loss: 4.731513977050781 | KNN Loss: 4.692133903503418 | CLS Loss: 0.03937986493110657
Epoch 31 / 200 | iteration 170 / 171 | Total Loss: 4.703509330749512 | KNN Loss: 4.666195392608643 | CLS Loss: 0.03731381520628929
Epoch: 031, Loss: 4.7367, Train: 0.9928, Valid: 0.9853, Best: 0.9856
Epoch 32 / 200 | iteration 0 / 171 | Total Loss: 4.7314252853393555 | KNN Loss: 4.713092803955078 | CLS Loss: 0.01833263225853443
Epoch 32 / 200 | iteration 10 / 171 | Total Loss: 4.675447463989258 | KNN Loss: 4.643362045288086 | CLS Loss: 0.032085493206977844
Epoch 32 / 200 | iteration 20 / 171 | Total Loss: 4.740655899047852 | KNN Loss: 4.719870567321777 | CLS Loss: 0.02078537829220295
Epoch 32 / 200 | iteration 30 / 171 | Total Loss: 4.726932525634766 | KNN Loss: 4.716751575469971 | CLS Loss: 0.010180879384279251
Epoch 32 / 200 | iteration 40 / 171 | Total Loss: 4.781632423400879 | KNN Loss: 4.747108459472656 | CLS Loss: 0.034523993730545044
Epoch 32 / 200 |

Epoch 35 / 200 | iteration 50 / 171 | Total Loss: 4.781567096710205 | KNN Loss: 4.727852821350098 | CLS Loss: 0.05371406674385071
Epoch 35 / 200 | iteration 60 / 171 | Total Loss: 4.851928234100342 | KNN Loss: 4.803271293640137 | CLS Loss: 0.048656996339559555
Epoch 35 / 200 | iteration 70 / 171 | Total Loss: 4.7483696937561035 | KNN Loss: 4.708577632904053 | CLS Loss: 0.0397920198738575
Epoch 35 / 200 | iteration 80 / 171 | Total Loss: 4.737277030944824 | KNN Loss: 4.700249671936035 | CLS Loss: 0.03702755644917488
Epoch 35 / 200 | iteration 90 / 171 | Total Loss: 4.757837295532227 | KNN Loss: 4.730257987976074 | CLS Loss: 0.027579303830862045
Epoch 35 / 200 | iteration 100 / 171 | Total Loss: 4.810493469238281 | KNN Loss: 4.756618976593018 | CLS Loss: 0.05387434735894203
Epoch 35 / 200 | iteration 110 / 171 | Total Loss: 4.75885009765625 | KNN Loss: 4.732407093048096 | CLS Loss: 0.026443004608154297
Epoch 35 / 200 | iteration 120 / 171 | Total Loss: 4.802587985992432 | KNN Loss: 4.770

Epoch 38 / 200 | iteration 130 / 171 | Total Loss: 4.75571870803833 | KNN Loss: 4.711143970489502 | CLS Loss: 0.04457465559244156
Epoch 38 / 200 | iteration 140 / 171 | Total Loss: 4.776109218597412 | KNN Loss: 4.75214958190918 | CLS Loss: 0.023959530517458916
Epoch 38 / 200 | iteration 150 / 171 | Total Loss: 4.726465225219727 | KNN Loss: 4.706407070159912 | CLS Loss: 0.020058222115039825
Epoch 38 / 200 | iteration 160 / 171 | Total Loss: 4.7585344314575195 | KNN Loss: 4.743497371673584 | CLS Loss: 0.01503690704703331
Epoch 38 / 200 | iteration 170 / 171 | Total Loss: 4.750986099243164 | KNN Loss: 4.710392951965332 | CLS Loss: 0.04059309884905815
Epoch: 038, Loss: 4.7457, Train: 0.9920, Valid: 0.9856, Best: 0.9863
Epoch 39 / 200 | iteration 0 / 171 | Total Loss: 4.717318058013916 | KNN Loss: 4.703307151794434 | CLS Loss: 0.014010962098836899
Epoch 39 / 200 | iteration 10 / 171 | Total Loss: 4.743936061859131 | KNN Loss: 4.710650444030762 | CLS Loss: 0.03328558802604675
Epoch 39 / 200 

Epoch 42 / 200 | iteration 20 / 171 | Total Loss: 4.724581241607666 | KNN Loss: 4.7038373947143555 | CLS Loss: 0.020743614062666893
Epoch 42 / 200 | iteration 30 / 171 | Total Loss: 4.723586082458496 | KNN Loss: 4.714700698852539 | CLS Loss: 0.008885176852345467
Epoch 42 / 200 | iteration 40 / 171 | Total Loss: 4.765615940093994 | KNN Loss: 4.73288631439209 | CLS Loss: 0.03272943198680878
Epoch 42 / 200 | iteration 50 / 171 | Total Loss: 4.805850028991699 | KNN Loss: 4.785758018493652 | CLS Loss: 0.020091840997338295
Epoch 42 / 200 | iteration 60 / 171 | Total Loss: 4.717689037322998 | KNN Loss: 4.691269874572754 | CLS Loss: 0.026418928056955338
Epoch 42 / 200 | iteration 70 / 171 | Total Loss: 4.703210830688477 | KNN Loss: 4.6879119873046875 | CLS Loss: 0.015299065969884396
Epoch 42 / 200 | iteration 80 / 171 | Total Loss: 4.703726291656494 | KNN Loss: 4.696319580078125 | CLS Loss: 0.007406489923596382
Epoch 42 / 200 | iteration 90 / 171 | Total Loss: 4.725312232971191 | KNN Loss: 4.7

Epoch 45 / 200 | iteration 90 / 171 | Total Loss: 4.71621561050415 | KNN Loss: 4.695929527282715 | CLS Loss: 0.02028621733188629
Epoch 45 / 200 | iteration 100 / 171 | Total Loss: 4.752372741699219 | KNN Loss: 4.7263641357421875 | CLS Loss: 0.02600843645632267
Epoch 45 / 200 | iteration 110 / 171 | Total Loss: 4.715632915496826 | KNN Loss: 4.686986923217773 | CLS Loss: 0.0286461990326643
Epoch 45 / 200 | iteration 120 / 171 | Total Loss: 4.710268497467041 | KNN Loss: 4.689510822296143 | CLS Loss: 0.020757833495736122
Epoch 45 / 200 | iteration 130 / 171 | Total Loss: 4.788549900054932 | KNN Loss: 4.731328010559082 | CLS Loss: 0.057221926748752594
Epoch 45 / 200 | iteration 140 / 171 | Total Loss: 4.838242530822754 | KNN Loss: 4.782989501953125 | CLS Loss: 0.05525294691324234
Epoch 45 / 200 | iteration 150 / 171 | Total Loss: 4.719775676727295 | KNN Loss: 4.714876651763916 | CLS Loss: 0.00489898631349206
Epoch 45 / 200 | iteration 160 / 171 | Total Loss: 4.69456672668457 | KNN Loss: 4.6

Epoch 48 / 200 | iteration 160 / 171 | Total Loss: 4.685040473937988 | KNN Loss: 4.676989555358887 | CLS Loss: 0.008050753735005856
Epoch 48 / 200 | iteration 170 / 171 | Total Loss: 4.730463981628418 | KNN Loss: 4.701570510864258 | CLS Loss: 0.028893589973449707
Epoch: 048, Loss: 4.7216, Train: 0.9951, Valid: 0.9862, Best: 0.9865
Epoch 49 / 200 | iteration 0 / 171 | Total Loss: 4.714487075805664 | KNN Loss: 4.700851917266846 | CLS Loss: 0.013635317794978619
Epoch 49 / 200 | iteration 10 / 171 | Total Loss: 4.7296833992004395 | KNN Loss: 4.719689846038818 | CLS Loss: 0.009993338957428932
Epoch 49 / 200 | iteration 20 / 171 | Total Loss: 4.691746234893799 | KNN Loss: 4.684889316558838 | CLS Loss: 0.006856848485767841
Epoch 49 / 200 | iteration 30 / 171 | Total Loss: 4.774263381958008 | KNN Loss: 4.735970497131348 | CLS Loss: 0.038292814046144485
Epoch 49 / 200 | iteration 40 / 171 | Total Loss: 4.696545124053955 | KNN Loss: 4.687640190124512 | CLS Loss: 0.008905141614377499
Epoch 49 / 2

Epoch 52 / 200 | iteration 50 / 171 | Total Loss: 4.692766189575195 | KNN Loss: 4.681344509124756 | CLS Loss: 0.011421863920986652
Epoch 52 / 200 | iteration 60 / 171 | Total Loss: 4.729835510253906 | KNN Loss: 4.671848773956299 | CLS Loss: 0.057986583560705185
Epoch 52 / 200 | iteration 70 / 171 | Total Loss: 4.748166084289551 | KNN Loss: 4.743180274963379 | CLS Loss: 0.00498565100133419
Epoch 52 / 200 | iteration 80 / 171 | Total Loss: 4.693648338317871 | KNN Loss: 4.673943042755127 | CLS Loss: 0.019705500453710556
Epoch 52 / 200 | iteration 90 / 171 | Total Loss: 4.69476318359375 | KNN Loss: 4.68264102935791 | CLS Loss: 0.012122022919356823
Epoch 52 / 200 | iteration 100 / 171 | Total Loss: 4.694419860839844 | KNN Loss: 4.678998947143555 | CLS Loss: 0.015420942567288876
Epoch 52 / 200 | iteration 110 / 171 | Total Loss: 4.744701385498047 | KNN Loss: 4.728053092956543 | CLS Loss: 0.016648251563310623
Epoch 52 / 200 | iteration 120 / 171 | Total Loss: 4.684564590454102 | KNN Loss: 4.6

Epoch 55 / 200 | iteration 120 / 171 | Total Loss: 4.728991508483887 | KNN Loss: 4.7077202796936035 | CLS Loss: 0.021271103993058205
Epoch 55 / 200 | iteration 130 / 171 | Total Loss: 4.728135585784912 | KNN Loss: 4.71952486038208 | CLS Loss: 0.008610825054347515
Epoch 55 / 200 | iteration 140 / 171 | Total Loss: 4.692447185516357 | KNN Loss: 4.672152996063232 | CLS Loss: 0.020294370129704475
Epoch 55 / 200 | iteration 150 / 171 | Total Loss: 4.69741678237915 | KNN Loss: 4.685597896575928 | CLS Loss: 0.01181864831596613
Epoch 55 / 200 | iteration 160 / 171 | Total Loss: 4.734898567199707 | KNN Loss: 4.700370788574219 | CLS Loss: 0.0345279797911644
Epoch 55 / 200 | iteration 170 / 171 | Total Loss: 4.729795455932617 | KNN Loss: 4.71120023727417 | CLS Loss: 0.018595360219478607
Epoch: 055, Loss: 4.7154, Train: 0.9948, Valid: 0.9858, Best: 0.9866
Epoch 56 / 200 | iteration 0 / 171 | Total Loss: 4.685880184173584 | KNN Loss: 4.657973766326904 | CLS Loss: 0.027906406670808792
Epoch 56 / 200

Epoch 59 / 200 | iteration 10 / 171 | Total Loss: 4.755058288574219 | KNN Loss: 4.67966890335083 | CLS Loss: 0.07538935542106628
Epoch 59 / 200 | iteration 20 / 171 | Total Loss: 4.689681529998779 | KNN Loss: 4.65852165222168 | CLS Loss: 0.031159790232777596
Epoch 59 / 200 | iteration 30 / 171 | Total Loss: 4.6968536376953125 | KNN Loss: 4.684027194976807 | CLS Loss: 0.01282651536166668
Epoch 59 / 200 | iteration 40 / 171 | Total Loss: 4.700145721435547 | KNN Loss: 4.679486274719238 | CLS Loss: 0.02065940946340561
Epoch 59 / 200 | iteration 50 / 171 | Total Loss: 4.745163917541504 | KNN Loss: 4.726982593536377 | CLS Loss: 0.018181195482611656
Epoch 59 / 200 | iteration 60 / 171 | Total Loss: 4.687380790710449 | KNN Loss: 4.681849479675293 | CLS Loss: 0.005531215574592352
Epoch 59 / 200 | iteration 70 / 171 | Total Loss: 4.718723297119141 | KNN Loss: 4.699902057647705 | CLS Loss: 0.018821345642209053
Epoch 59 / 200 | iteration 80 / 171 | Total Loss: 4.683139801025391 | KNN Loss: 4.68101

Epoch 62 / 200 | iteration 80 / 171 | Total Loss: 4.678022861480713 | KNN Loss: 4.660522937774658 | CLS Loss: 0.01750010997056961
Epoch 62 / 200 | iteration 90 / 171 | Total Loss: 4.8311920166015625 | KNN Loss: 4.788859844207764 | CLS Loss: 0.042332377284765244
Epoch 62 / 200 | iteration 100 / 171 | Total Loss: 4.725604057312012 | KNN Loss: 4.722499370574951 | CLS Loss: 0.003104632953181863
Epoch 62 / 200 | iteration 110 / 171 | Total Loss: 4.7709197998046875 | KNN Loss: 4.731648921966553 | CLS Loss: 0.03927066549658775
Epoch 62 / 200 | iteration 120 / 171 | Total Loss: 4.7206854820251465 | KNN Loss: 4.710465908050537 | CLS Loss: 0.010219573974609375
Epoch 62 / 200 | iteration 130 / 171 | Total Loss: 4.700364589691162 | KNN Loss: 4.696147441864014 | CLS Loss: 0.004216935019940138
Epoch 62 / 200 | iteration 140 / 171 | Total Loss: 4.746033191680908 | KNN Loss: 4.738822937011719 | CLS Loss: 0.007210229989141226
Epoch 62 / 200 | iteration 150 / 171 | Total Loss: 4.696201801300049 | KNN Lo

Epoch 65 / 200 | iteration 150 / 171 | Total Loss: 4.731400966644287 | KNN Loss: 4.710884094238281 | CLS Loss: 0.020516861230134964
Epoch 65 / 200 | iteration 160 / 171 | Total Loss: 4.696679592132568 | KNN Loss: 4.670450687408447 | CLS Loss: 0.02622891589999199
Epoch 65 / 200 | iteration 170 / 171 | Total Loss: 4.736091613769531 | KNN Loss: 4.703755855560303 | CLS Loss: 0.03233577683568001
Epoch: 065, Loss: 4.7219, Train: 0.9928, Valid: 0.9853, Best: 0.9867
Epoch 66 / 200 | iteration 0 / 171 | Total Loss: 4.71272087097168 | KNN Loss: 4.6901326179504395 | CLS Loss: 0.022588247433304787
Epoch 66 / 200 | iteration 10 / 171 | Total Loss: 4.691913604736328 | KNN Loss: 4.67496919631958 | CLS Loss: 0.016944216564297676
Epoch 66 / 200 | iteration 20 / 171 | Total Loss: 4.72657585144043 | KNN Loss: 4.702298641204834 | CLS Loss: 0.024277029559016228
Epoch 66 / 200 | iteration 30 / 171 | Total Loss: 4.691000938415527 | KNN Loss: 4.661459445953369 | CLS Loss: 0.029541458934545517
Epoch 66 / 200 |

Epoch 69 / 200 | iteration 40 / 171 | Total Loss: 4.741040229797363 | KNN Loss: 4.739066123962402 | CLS Loss: 0.0019740709103643894
Epoch 69 / 200 | iteration 50 / 171 | Total Loss: 4.739317893981934 | KNN Loss: 4.719746112823486 | CLS Loss: 0.019571971148252487
Epoch 69 / 200 | iteration 60 / 171 | Total Loss: 4.705991268157959 | KNN Loss: 4.699800491333008 | CLS Loss: 0.006190769840031862
Epoch 69 / 200 | iteration 70 / 171 | Total Loss: 4.6821208000183105 | KNN Loss: 4.680339813232422 | CLS Loss: 0.001781064784154296
Epoch 69 / 200 | iteration 80 / 171 | Total Loss: 4.733116149902344 | KNN Loss: 4.712962627410889 | CLS Loss: 0.020153380930423737
Epoch 69 / 200 | iteration 90 / 171 | Total Loss: 4.693075180053711 | KNN Loss: 4.666686058044434 | CLS Loss: 0.026389243081212044
Epoch 69 / 200 | iteration 100 / 171 | Total Loss: 4.669792652130127 | KNN Loss: 4.6653828620910645 | CLS Loss: 0.004409815184772015
Epoch 69 / 200 | iteration 110 / 171 | Total Loss: 4.72755241394043 | KNN Loss:

Epoch 72 / 200 | iteration 110 / 171 | Total Loss: 4.736121654510498 | KNN Loss: 4.7034382820129395 | CLS Loss: 0.0326836034655571
Epoch 72 / 200 | iteration 120 / 171 | Total Loss: 4.663413047790527 | KNN Loss: 4.6573662757873535 | CLS Loss: 0.00604686513543129
Epoch 72 / 200 | iteration 130 / 171 | Total Loss: 4.715249061584473 | KNN Loss: 4.684045314788818 | CLS Loss: 0.03120359778404236
Epoch 72 / 200 | iteration 140 / 171 | Total Loss: 4.707891941070557 | KNN Loss: 4.6946940422058105 | CLS Loss: 0.013198046013712883
Epoch 72 / 200 | iteration 150 / 171 | Total Loss: 4.687875270843506 | KNN Loss: 4.678534030914307 | CLS Loss: 0.009341307915747166
Epoch 72 / 200 | iteration 160 / 171 | Total Loss: 4.689694881439209 | KNN Loss: 4.676888942718506 | CLS Loss: 0.01280574407428503
Epoch 72 / 200 | iteration 170 / 171 | Total Loss: 4.681415557861328 | KNN Loss: 4.66070032119751 | CLS Loss: 0.02071525529026985
Epoch: 072, Loss: 4.7052, Train: 0.9959, Valid: 0.9869, Best: 0.9872
Epoch 73 / 

Epoch: 075, Loss: 4.7168, Train: 0.9942, Valid: 0.9839, Best: 0.9872
Epoch 76 / 200 | iteration 0 / 171 | Total Loss: 4.727001190185547 | KNN Loss: 4.7065629959106445 | CLS Loss: 0.02043803222477436
Epoch 76 / 200 | iteration 10 / 171 | Total Loss: 4.701597213745117 | KNN Loss: 4.690100193023682 | CLS Loss: 0.011497098952531815
Epoch 76 / 200 | iteration 20 / 171 | Total Loss: 4.68812370300293 | KNN Loss: 4.663998126983643 | CLS Loss: 0.02412549778819084
Epoch 76 / 200 | iteration 30 / 171 | Total Loss: 4.715468406677246 | KNN Loss: 4.688493728637695 | CLS Loss: 0.026974670588970184
Epoch 76 / 200 | iteration 40 / 171 | Total Loss: 4.761530876159668 | KNN Loss: 4.755425453186035 | CLS Loss: 0.006105518899857998
Epoch 76 / 200 | iteration 50 / 171 | Total Loss: 4.7031989097595215 | KNN Loss: 4.699009895324707 | CLS Loss: 0.004188897088170052
Epoch 76 / 200 | iteration 60 / 171 | Total Loss: 4.687836647033691 | KNN Loss: 4.677367687225342 | CLS Loss: 0.010468875989317894
Epoch 76 / 200 |

Epoch 79 / 200 | iteration 70 / 171 | Total Loss: 4.728180885314941 | KNN Loss: 4.691195011138916 | CLS Loss: 0.03698594123125076
Epoch 79 / 200 | iteration 80 / 171 | Total Loss: 4.744700908660889 | KNN Loss: 4.724522113800049 | CLS Loss: 0.0201788991689682
Epoch 79 / 200 | iteration 90 / 171 | Total Loss: 4.811827182769775 | KNN Loss: 4.787870407104492 | CLS Loss: 0.023956604301929474
Epoch 79 / 200 | iteration 100 / 171 | Total Loss: 4.7140302658081055 | KNN Loss: 4.7020039558410645 | CLS Loss: 0.012026101350784302
Epoch 79 / 200 | iteration 110 / 171 | Total Loss: 4.77559232711792 | KNN Loss: 4.73651647567749 | CLS Loss: 0.039075929671525955
Epoch 79 / 200 | iteration 120 / 171 | Total Loss: 4.679646968841553 | KNN Loss: 4.671733856201172 | CLS Loss: 0.007913208566606045
Epoch 79 / 200 | iteration 130 / 171 | Total Loss: 4.733114242553711 | KNN Loss: 4.700106620788574 | CLS Loss: 0.0330076664686203
Epoch 79 / 200 | iteration 140 / 171 | Total Loss: 4.653071880340576 | KNN Loss: 4.6

Epoch 82 / 200 | iteration 140 / 171 | Total Loss: 4.69216775894165 | KNN Loss: 4.683448791503906 | CLS Loss: 0.00871878769248724
Epoch 82 / 200 | iteration 150 / 171 | Total Loss: 4.6821112632751465 | KNN Loss: 4.658242702484131 | CLS Loss: 0.023868370801210403
Epoch 82 / 200 | iteration 160 / 171 | Total Loss: 4.659106731414795 | KNN Loss: 4.65635871887207 | CLS Loss: 0.002748072613030672
Epoch 82 / 200 | iteration 170 / 171 | Total Loss: 4.7101945877075195 | KNN Loss: 4.683586120605469 | CLS Loss: 0.026608334854245186
Epoch: 082, Loss: 4.7005, Train: 0.9969, Valid: 0.9864, Best: 0.9872
Epoch 83 / 200 | iteration 0 / 171 | Total Loss: 4.698898792266846 | KNN Loss: 4.681894302368164 | CLS Loss: 0.01700431853532791
Epoch 83 / 200 | iteration 10 / 171 | Total Loss: 4.673853397369385 | KNN Loss: 4.646125316619873 | CLS Loss: 0.027728192508220673
Epoch 83 / 200 | iteration 20 / 171 | Total Loss: 4.693404197692871 | KNN Loss: 4.680067539215088 | CLS Loss: 0.013336624018847942
Epoch 83 / 20

Epoch 86 / 200 | iteration 30 / 171 | Total Loss: 4.678218364715576 | KNN Loss: 4.668631553649902 | CLS Loss: 0.009586694650352001
Epoch 86 / 200 | iteration 40 / 171 | Total Loss: 4.734948635101318 | KNN Loss: 4.722637176513672 | CLS Loss: 0.012311398051679134
Epoch 86 / 200 | iteration 50 / 171 | Total Loss: 4.696827411651611 | KNN Loss: 4.6686506271362305 | CLS Loss: 0.028176721185445786
Epoch 86 / 200 | iteration 60 / 171 | Total Loss: 4.713253498077393 | KNN Loss: 4.702340602874756 | CLS Loss: 0.010913092643022537
Epoch 86 / 200 | iteration 70 / 171 | Total Loss: 4.738593578338623 | KNN Loss: 4.718464374542236 | CLS Loss: 0.020129339769482613
Epoch 86 / 200 | iteration 80 / 171 | Total Loss: 4.6677680015563965 | KNN Loss: 4.6575751304626465 | CLS Loss: 0.010192830115556717
Epoch 86 / 200 | iteration 90 / 171 | Total Loss: 4.698948860168457 | KNN Loss: 4.6800432205200195 | CLS Loss: 0.018905552104115486
Epoch 86 / 200 | iteration 100 / 171 | Total Loss: 4.6785173416137695 | KNN Los

Epoch 89 / 200 | iteration 100 / 171 | Total Loss: 4.696671962738037 | KNN Loss: 4.683212757110596 | CLS Loss: 0.013459126465022564
Epoch 89 / 200 | iteration 110 / 171 | Total Loss: 4.685176849365234 | KNN Loss: 4.669690132141113 | CLS Loss: 0.01548655703663826
Epoch 89 / 200 | iteration 120 / 171 | Total Loss: 4.692789077758789 | KNN Loss: 4.681221961975098 | CLS Loss: 0.011567292734980583
Epoch 89 / 200 | iteration 130 / 171 | Total Loss: 4.678859233856201 | KNN Loss: 4.669167518615723 | CLS Loss: 0.009691828861832619
Epoch 89 / 200 | iteration 140 / 171 | Total Loss: 4.731479167938232 | KNN Loss: 4.70622444152832 | CLS Loss: 0.025254879146814346
Epoch 89 / 200 | iteration 150 / 171 | Total Loss: 4.6768083572387695 | KNN Loss: 4.667814254760742 | CLS Loss: 0.008994298987090588
Epoch 89 / 200 | iteration 160 / 171 | Total Loss: 4.726377964019775 | KNN Loss: 4.714968681335449 | CLS Loss: 0.011409463360905647
Epoch 89 / 200 | iteration 170 / 171 | Total Loss: 4.695120334625244 | KNN Lo

Epoch 92 / 200 | iteration 170 / 171 | Total Loss: 4.696181297302246 | KNN Loss: 4.677591800689697 | CLS Loss: 0.01858963444828987
Epoch: 092, Loss: 4.7105, Train: 0.9964, Valid: 0.9859, Best: 0.9872
Epoch 93 / 200 | iteration 0 / 171 | Total Loss: 4.6847028732299805 | KNN Loss: 4.680771827697754 | CLS Loss: 0.003931147512048483
Epoch 93 / 200 | iteration 10 / 171 | Total Loss: 4.655590534210205 | KNN Loss: 4.651773452758789 | CLS Loss: 0.0038171089254319668
Epoch 93 / 200 | iteration 20 / 171 | Total Loss: 4.789667129516602 | KNN Loss: 4.769301414489746 | CLS Loss: 0.020365657284855843
Epoch 93 / 200 | iteration 30 / 171 | Total Loss: 4.728115081787109 | KNN Loss: 4.700418472290039 | CLS Loss: 0.02769654430449009
Epoch 93 / 200 | iteration 40 / 171 | Total Loss: 4.697596073150635 | KNN Loss: 4.6905694007873535 | CLS Loss: 0.007026592269539833
Epoch 93 / 200 | iteration 50 / 171 | Total Loss: 4.691486835479736 | KNN Loss: 4.673923015594482 | CLS Loss: 0.017563797533512115
Epoch 93 / 20

Epoch 96 / 200 | iteration 60 / 171 | Total Loss: 4.715567588806152 | KNN Loss: 4.693706035614014 | CLS Loss: 0.021861374378204346
Epoch 96 / 200 | iteration 70 / 171 | Total Loss: 4.68407678604126 | KNN Loss: 4.6768059730529785 | CLS Loss: 0.007270941976457834
Epoch 96 / 200 | iteration 80 / 171 | Total Loss: 4.700981616973877 | KNN Loss: 4.687576770782471 | CLS Loss: 0.01340483222156763
Epoch 96 / 200 | iteration 90 / 171 | Total Loss: 4.694089889526367 | KNN Loss: 4.668557167053223 | CLS Loss: 0.02553291618824005
Epoch 96 / 200 | iteration 100 / 171 | Total Loss: 4.676794052124023 | KNN Loss: 4.664745807647705 | CLS Loss: 0.01204825658351183
Epoch 96 / 200 | iteration 110 / 171 | Total Loss: 4.658322811126709 | KNN Loss: 4.649949550628662 | CLS Loss: 0.008373203687369823
Epoch 96 / 200 | iteration 120 / 171 | Total Loss: 4.710494518280029 | KNN Loss: 4.7052693367004395 | CLS Loss: 0.005225078668445349
Epoch 96 / 200 | iteration 130 / 171 | Total Loss: 4.706723213195801 | KNN Loss: 4

Epoch 99 / 200 | iteration 130 / 171 | Total Loss: 4.736800670623779 | KNN Loss: 4.686923980712891 | CLS Loss: 0.04987671226263046
Epoch 99 / 200 | iteration 140 / 171 | Total Loss: 4.702376365661621 | KNN Loss: 4.692188739776611 | CLS Loss: 0.010187593288719654
Epoch 99 / 200 | iteration 150 / 171 | Total Loss: 4.709314346313477 | KNN Loss: 4.698884963989258 | CLS Loss: 0.01042957417666912
Epoch 99 / 200 | iteration 160 / 171 | Total Loss: 4.7336812019348145 | KNN Loss: 4.706918716430664 | CLS Loss: 0.02676249109208584
Epoch 99 / 200 | iteration 170 / 171 | Total Loss: 4.706286907196045 | KNN Loss: 4.699099063873291 | CLS Loss: 0.007187775336205959
Epoch: 099, Loss: 4.7095, Train: 0.9960, Valid: 0.9858, Best: 0.9872
Epoch 100 / 200 | iteration 0 / 171 | Total Loss: 4.689737796783447 | KNN Loss: 4.677109241485596 | CLS Loss: 0.012628618627786636
Epoch 100 / 200 | iteration 10 / 171 | Total Loss: 4.661083698272705 | KNN Loss: 4.643423557281494 | CLS Loss: 0.01766018196940422
Epoch 100 /

Epoch 103 / 200 | iteration 10 / 171 | Total Loss: 4.678947925567627 | KNN Loss: 4.6758270263671875 | CLS Loss: 0.0031210239976644516
Epoch 103 / 200 | iteration 20 / 171 | Total Loss: 4.65442419052124 | KNN Loss: 4.6470770835876465 | CLS Loss: 0.007346881553530693
Epoch 103 / 200 | iteration 30 / 171 | Total Loss: 4.6808648109436035 | KNN Loss: 4.672391891479492 | CLS Loss: 0.008472945541143417
Epoch 103 / 200 | iteration 40 / 171 | Total Loss: 4.759029388427734 | KNN Loss: 4.7434797286987305 | CLS Loss: 0.015549637377262115
Epoch 103 / 200 | iteration 50 / 171 | Total Loss: 4.728278160095215 | KNN Loss: 4.719176292419434 | CLS Loss: 0.009101921692490578
Epoch 103 / 200 | iteration 60 / 171 | Total Loss: 4.678065776824951 | KNN Loss: 4.6706109046936035 | CLS Loss: 0.007454941049218178
Epoch 103 / 200 | iteration 70 / 171 | Total Loss: 4.651780605316162 | KNN Loss: 4.64718770980835 | CLS Loss: 0.0045930808410048485
Epoch 103 / 200 | iteration 80 / 171 | Total Loss: 4.657245635986328 | 

In [None]:
test(model, test_data_iter, device)

In [None]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

In [None]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

In [None]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

In [None]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

In [None]:
clusters = DBSCAN(eps=2, min_samples=10).fit_predict(projections)

In [None]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

In [None]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [None]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [None]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [None]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [None]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 100
use_cuda = device != 'cpu'

In [None]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [None]:
losses = []
accs = []
sparsity = []

In [None]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

In [None]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

In [None]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

# Tree Visualization

In [None]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

# Extract Rules

# Accumulate samples in the leaves

In [None]:
print(f"Number of patterns: {len(root.get_leaves())}")

In [None]:
method = 'greedy'

In [None]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)

# Tighten boundaries

In [None]:
attr_names = [f"T_{i}" for i in range(test_samples.shape[2])]
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")