In [3]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [4]:
k = 16
tree_depth = 6
batch_size = 512
device = 'cuda'
train_data_path = r'<>/mitbih_train.csv'  # replace <> with the correct path of the dataset
test_data_path = r'<>/mitbih_test.csv'  # replace <> with the correct path of the dataset

In [5]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [6]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [7]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0).to(device)
        except ValueError:
            knn_loss = torch.tensor(0).to(device)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [8]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [None]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.353687286376953 | KNN Loss: 5.5526628494262695 | CLS Loss: 1.8010241985321045
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 4.729332447052002 | KNN Loss: 3.8576250076293945 | CLS Loss: 0.871707558631897
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 4.0940260887146 | KNN Loss: 3.309173583984375 | CLS Loss: 0.7848523259162903
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 3.8650264739990234 | KNN Loss: 3.2776787281036377 | CLS Loss: 0.5873478651046753
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 3.8695287704467773 | KNN Loss: 3.2394747734069824 | CLS Loss: 0.6300539970397949
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 3.7488129138946533 | KNN Loss: 3.197767734527588 | CLS Loss: 0.5510451793670654
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 3.756903886795044 | KNN Loss: 3.245512008666992 | CLS Loss: 0.5113918781280518
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 3.7646336555480957 | KNN Loss: 3.2590463161468506 | 

Epoch 4 / 200 | iteration 80 / 171 | Total Loss: 3.2076327800750732 | KNN Loss: 3.086923122406006 | CLS Loss: 0.12070959806442261
Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 3.250946521759033 | KNN Loss: 3.111034870147705 | CLS Loss: 0.1399116963148117
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 3.2041304111480713 | KNN Loss: 3.0966131687164307 | CLS Loss: 0.10751722007989883
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 3.210449457168579 | KNN Loss: 3.0962560176849365 | CLS Loss: 0.11419354379177094
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 3.218536376953125 | KNN Loss: 3.087359666824341 | CLS Loss: 0.13117672502994537
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 3.2078328132629395 | KNN Loss: 3.118234872817993 | CLS Loss: 0.08959782868623734
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 3.2393791675567627 | KNN Loss: 3.111973762512207 | CLS Loss: 0.12740540504455566
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 3.246518611907959 | KNN Loss: 3.10823

Epoch 7 / 200 | iteration 160 / 171 | Total Loss: 3.1732869148254395 | KNN Loss: 3.0635108947753906 | CLS Loss: 0.10977602750062943
Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 3.163022994995117 | KNN Loss: 3.0696487426757812 | CLS Loss: 0.09337416291236877
Epoch: 007, Loss: 3.1678, Train: 0.9715, Valid: 0.9696, Best: 0.9729
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 3.1693410873413086 | KNN Loss: 3.1077771186828613 | CLS Loss: 0.061564043164253235
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 3.1543514728546143 | KNN Loss: 3.0508298873901367 | CLS Loss: 0.1035216823220253
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 3.1989665031433105 | KNN Loss: 3.0559260845184326 | CLS Loss: 0.14304044842720032
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 3.1524837017059326 | KNN Loss: 3.0671067237854004 | CLS Loss: 0.08537698537111282
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 3.168806552886963 | KNN Loss: 3.038391351699829 | CLS Loss: 0.13041509687900543
Epoch 8 / 200 | 

Epoch 11 / 200 | iteration 50 / 171 | Total Loss: 3.130671977996826 | KNN Loss: 3.0627493858337402 | CLS Loss: 0.06792263686656952
Epoch 11 / 200 | iteration 60 / 171 | Total Loss: 3.185950517654419 | KNN Loss: 3.0498311519622803 | CLS Loss: 0.13611933588981628
Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 3.0878138542175293 | KNN Loss: 3.0351450443267822 | CLS Loss: 0.05266888067126274
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 3.1192784309387207 | KNN Loss: 3.0429582595825195 | CLS Loss: 0.07632014155387878
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 3.143902540206909 | KNN Loss: 3.044109582901001 | CLS Loss: 0.09979291260242462
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 3.086247205734253 | KNN Loss: 3.0325021743774414 | CLS Loss: 0.053745001554489136
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 3.1076831817626953 | KNN Loss: 3.050908327102661 | CLS Loss: 0.05677477642893791
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 3.09683895111084 | KNN Loss:

Epoch 14 / 200 | iteration 120 / 171 | Total Loss: 3.0749526023864746 | KNN Loss: 2.9884390830993652 | CLS Loss: 0.08651348948478699
Epoch 14 / 200 | iteration 130 / 171 | Total Loss: 3.064164638519287 | KNN Loss: 3.0101206302642822 | CLS Loss: 0.0540439635515213
Epoch 14 / 200 | iteration 140 / 171 | Total Loss: 3.1232993602752686 | KNN Loss: 3.0588908195495605 | CLS Loss: 0.06440845131874084
Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 3.2020256519317627 | KNN Loss: 3.119720458984375 | CLS Loss: 0.08230522274971008
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 3.095851182937622 | KNN Loss: 3.0164060592651367 | CLS Loss: 0.07944509387016296
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 3.0927488803863525 | KNN Loss: 3.040519952774048 | CLS Loss: 0.05222897604107857
Epoch: 014, Loss: 3.1010, Train: 0.9832, Valid: 0.9806, Best: 0.9806
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 3.084994077682495 | KNN Loss: 3.068641185760498 | CLS Loss: 0.016352981328964233
Epoch 15

Epoch 18 / 200 | iteration 10 / 171 | Total Loss: 3.0767881870269775 | KNN Loss: 3.0437984466552734 | CLS Loss: 0.032989781349897385
Epoch 18 / 200 | iteration 20 / 171 | Total Loss: 3.096275806427002 | KNN Loss: 3.0341508388519287 | CLS Loss: 0.062124885618686676
Epoch 18 / 200 | iteration 30 / 171 | Total Loss: 3.0932607650756836 | KNN Loss: 3.053999185562134 | CLS Loss: 0.03926166892051697
Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 3.065483808517456 | KNN Loss: 3.0170743465423584 | CLS Loss: 0.048409461975097656
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 3.0908892154693604 | KNN Loss: 3.0193703174591064 | CLS Loss: 0.07151887565851212
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 3.0691111087799072 | KNN Loss: 3.0328140258789062 | CLS Loss: 0.03629716858267784
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 3.093935251235962 | KNN Loss: 3.0395166873931885 | CLS Loss: 0.054418548941612244
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 3.05326247215271 | KNN Los

Epoch 21 / 200 | iteration 80 / 171 | Total Loss: 3.122095823287964 | KNN Loss: 3.0830843448638916 | CLS Loss: 0.03901158273220062
Epoch 21 / 200 | iteration 90 / 171 | Total Loss: 3.1017069816589355 | KNN Loss: 3.062173366546631 | CLS Loss: 0.03953363746404648
Epoch 21 / 200 | iteration 100 / 171 | Total Loss: 3.0794382095336914 | KNN Loss: 3.034637451171875 | CLS Loss: 0.044800758361816406
Epoch 21 / 200 | iteration 110 / 171 | Total Loss: 3.1042065620422363 | KNN Loss: 3.0651865005493164 | CLS Loss: 0.03902009502053261
Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 3.08304762840271 | KNN Loss: 3.026240110397339 | CLS Loss: 0.056807611137628555
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 3.0790364742279053 | KNN Loss: 3.0486650466918945 | CLS Loss: 0.0303714070469141
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 3.102686882019043 | KNN Loss: 3.0604677200317383 | CLS Loss: 0.04221905395388603
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 3.101654052734375 | KNN Lo

Epoch 24 / 200 | iteration 150 / 171 | Total Loss: 3.1193597316741943 | KNN Loss: 3.0528275966644287 | CLS Loss: 0.0665321946144104
Epoch 24 / 200 | iteration 160 / 171 | Total Loss: 3.0882351398468018 | KNN Loss: 3.049344301223755 | CLS Loss: 0.038890860974788666
Epoch 24 / 200 | iteration 170 / 171 | Total Loss: 3.0932209491729736 | KNN Loss: 3.0585403442382812 | CLS Loss: 0.03468066081404686
Epoch: 024, Loss: 3.0838, Train: 0.9872, Valid: 0.9813, Best: 0.9834
Epoch 25 / 200 | iteration 0 / 171 | Total Loss: 3.0578863620758057 | KNN Loss: 3.0292866230010986 | CLS Loss: 0.028599683195352554
Epoch 25 / 200 | iteration 10 / 171 | Total Loss: 3.1232593059539795 | KNN Loss: 3.0471243858337402 | CLS Loss: 0.0761348083615303
Epoch 25 / 200 | iteration 20 / 171 | Total Loss: 3.058932065963745 | KNN Loss: 3.0318615436553955 | CLS Loss: 0.02707044780254364
Epoch 25 / 200 | iteration 30 / 171 | Total Loss: 3.0941250324249268 | KNN Loss: 3.0421202182769775 | CLS Loss: 0.052004843950271606
Epoch 

Epoch 28 / 200 | iteration 40 / 171 | Total Loss: 3.0628809928894043 | KNN Loss: 3.0260109901428223 | CLS Loss: 0.036869898438453674
Epoch 28 / 200 | iteration 50 / 171 | Total Loss: 3.0731210708618164 | KNN Loss: 3.0232818126678467 | CLS Loss: 0.049839284271001816
Epoch 28 / 200 | iteration 60 / 171 | Total Loss: 3.1361403465270996 | KNN Loss: 3.0624918937683105 | CLS Loss: 0.07364853471517563
Epoch 28 / 200 | iteration 70 / 171 | Total Loss: 3.044420003890991 | KNN Loss: 3.0236783027648926 | CLS Loss: 0.020741663873195648
Epoch 28 / 200 | iteration 80 / 171 | Total Loss: 3.0636415481567383 | KNN Loss: 3.033003330230713 | CLS Loss: 0.030638156458735466
Epoch 28 / 200 | iteration 90 / 171 | Total Loss: 3.0811164379119873 | KNN Loss: 3.052544116973877 | CLS Loss: 0.028572436422109604
Epoch 28 / 200 | iteration 100 / 171 | Total Loss: 3.0675675868988037 | KNN Loss: 3.023331880569458 | CLS Loss: 0.04423559457063675
Epoch 28 / 200 | iteration 110 / 171 | Total Loss: 3.0429530143737793 | KN

Epoch 31 / 200 | iteration 110 / 171 | Total Loss: 3.077326536178589 | KNN Loss: 3.043952703475952 | CLS Loss: 0.03337394818663597
Epoch 31 / 200 | iteration 120 / 171 | Total Loss: 3.0239179134368896 | KNN Loss: 3.0152881145477295 | CLS Loss: 0.008629754185676575
Epoch 31 / 200 | iteration 130 / 171 | Total Loss: 3.0556178092956543 | KNN Loss: 3.0303359031677246 | CLS Loss: 0.025281798094511032
Epoch 31 / 200 | iteration 140 / 171 | Total Loss: 3.118743658065796 | KNN Loss: 3.040374994277954 | CLS Loss: 0.0783686414361
Epoch 31 / 200 | iteration 150 / 171 | Total Loss: 3.041203498840332 | KNN Loss: 3.009995222091675 | CLS Loss: 0.031208310276269913
Epoch 31 / 200 | iteration 160 / 171 | Total Loss: 3.073951244354248 | KNN Loss: 3.0235228538513184 | CLS Loss: 0.05042850971221924
Epoch 31 / 200 | iteration 170 / 171 | Total Loss: 3.0446395874023438 | KNN Loss: 3.0161476135253906 | CLS Loss: 0.0284919124096632
Epoch: 031, Loss: 3.0694, Train: 0.9906, Valid: 0.9837, Best: 0.9855
Epoch 32 

Epoch: 034, Loss: 3.0667, Train: 0.9914, Valid: 0.9848, Best: 0.9855
Epoch 35 / 200 | iteration 0 / 171 | Total Loss: 3.0662121772766113 | KNN Loss: 3.0217669010162354 | CLS Loss: 0.04444516450166702
Epoch 35 / 200 | iteration 10 / 171 | Total Loss: 3.057434320449829 | KNN Loss: 3.0277421474456787 | CLS Loss: 0.02969209849834442
Epoch 35 / 200 | iteration 20 / 171 | Total Loss: 3.105956792831421 | KNN Loss: 3.0287537574768066 | CLS Loss: 0.07720306515693665
Epoch 35 / 200 | iteration 30 / 171 | Total Loss: 3.0673296451568604 | KNN Loss: 3.013756036758423 | CLS Loss: 0.0535736046731472
Epoch 35 / 200 | iteration 40 / 171 | Total Loss: 3.0966765880584717 | KNN Loss: 3.0470833778381348 | CLS Loss: 0.049593228846788406
Epoch 35 / 200 | iteration 50 / 171 | Total Loss: 3.0387303829193115 | KNN Loss: 3.026477098464966 | CLS Loss: 0.01225337851792574
Epoch 35 / 200 | iteration 60 / 171 | Total Loss: 3.0685648918151855 | KNN Loss: 3.0463814735412598 | CLS Loss: 0.022183511406183243
Epoch 35 / 

Epoch 38 / 200 | iteration 70 / 171 | Total Loss: 3.07731032371521 | KNN Loss: 3.0510125160217285 | CLS Loss: 0.02629787102341652
Epoch 38 / 200 | iteration 80 / 171 | Total Loss: 3.0772719383239746 | KNN Loss: 3.038449287414551 | CLS Loss: 0.03882255405187607
Epoch 38 / 200 | iteration 90 / 171 | Total Loss: 3.067284107208252 | KNN Loss: 3.0135114192962646 | CLS Loss: 0.05377264320850372
Epoch 38 / 200 | iteration 100 / 171 | Total Loss: 3.0744316577911377 | KNN Loss: 3.0536653995513916 | CLS Loss: 0.0207663644105196
Epoch 38 / 200 | iteration 110 / 171 | Total Loss: 3.0269172191619873 | KNN Loss: 3.016000747680664 | CLS Loss: 0.010916439816355705
Epoch 38 / 200 | iteration 120 / 171 | Total Loss: 3.076125144958496 | KNN Loss: 3.053071975708008 | CLS Loss: 0.02305307239294052
Epoch 38 / 200 | iteration 130 / 171 | Total Loss: 3.0688095092773438 | KNN Loss: 3.0349771976470947 | CLS Loss: 0.03383226692676544
Epoch 38 / 200 | iteration 140 / 171 | Total Loss: 3.09177827835083 | KNN Loss:

Epoch 41 / 200 | iteration 140 / 171 | Total Loss: 3.0545618534088135 | KNN Loss: 3.036682605743408 | CLS Loss: 0.017879270017147064
Epoch 41 / 200 | iteration 150 / 171 | Total Loss: 3.0807571411132812 | KNN Loss: 3.055466413497925 | CLS Loss: 0.025290826335549355
Epoch 41 / 200 | iteration 160 / 171 | Total Loss: 3.0504586696624756 | KNN Loss: 3.0375969409942627 | CLS Loss: 0.012861755676567554
Epoch 41 / 200 | iteration 170 / 171 | Total Loss: 3.037323236465454 | KNN Loss: 3.0189735889434814 | CLS Loss: 0.018349716439843178
Epoch: 041, Loss: 3.0644, Train: 0.9935, Valid: 0.9861, Best: 0.9861
Epoch 42 / 200 | iteration 0 / 171 | Total Loss: 3.068254232406616 | KNN Loss: 3.0339410305023193 | CLS Loss: 0.03431329131126404
Epoch 42 / 200 | iteration 10 / 171 | Total Loss: 3.017998218536377 | KNN Loss: 3.000027894973755 | CLS Loss: 0.017970234155654907
Epoch 42 / 200 | iteration 20 / 171 | Total Loss: 3.08732008934021 | KNN Loss: 3.0662381649017334 | CLS Loss: 0.021081941202282906
Epoch 

Epoch 45 / 200 | iteration 20 / 171 | Total Loss: 3.074091911315918 | KNN Loss: 3.0496726036071777 | CLS Loss: 0.024419333785772324
Epoch 45 / 200 | iteration 30 / 171 | Total Loss: 3.0453028678894043 | KNN Loss: 3.0247385501861572 | CLS Loss: 0.020564377307891846
Epoch 45 / 200 | iteration 40 / 171 | Total Loss: 3.0892131328582764 | KNN Loss: 3.0515530109405518 | CLS Loss: 0.037660181522369385
Epoch 45 / 200 | iteration 50 / 171 | Total Loss: 3.1034021377563477 | KNN Loss: 3.0607082843780518 | CLS Loss: 0.042693767696619034
Epoch 45 / 200 | iteration 60 / 171 | Total Loss: 3.0539324283599854 | KNN Loss: 3.0289993286132812 | CLS Loss: 0.024933211505413055
Epoch 45 / 200 | iteration 70 / 171 | Total Loss: 3.053934335708618 | KNN Loss: 3.0283241271972656 | CLS Loss: 0.025610176846385002
Epoch 45 / 200 | iteration 80 / 171 | Total Loss: 3.0325820446014404 | KNN Loss: 3.0124928951263428 | CLS Loss: 0.02008924074470997
Epoch 45 / 200 | iteration 90 / 171 | Total Loss: 3.1002657413482666 | K

Epoch 48 / 200 | iteration 90 / 171 | Total Loss: 3.049514055252075 | KNN Loss: 3.0313374996185303 | CLS Loss: 0.01817660965025425
Epoch 48 / 200 | iteration 100 / 171 | Total Loss: 3.0612144470214844 | KNN Loss: 3.016990900039673 | CLS Loss: 0.04422348365187645
Epoch 48 / 200 | iteration 110 / 171 | Total Loss: 3.06685209274292 | KNN Loss: 3.0268375873565674 | CLS Loss: 0.04001457989215851
Epoch 48 / 200 | iteration 120 / 171 | Total Loss: 3.0710856914520264 | KNN Loss: 3.0409178733825684 | CLS Loss: 0.030167710036039352
Epoch 48 / 200 | iteration 130 / 171 | Total Loss: 3.051100015640259 | KNN Loss: 3.0374083518981934 | CLS Loss: 0.01369160134345293
Epoch 48 / 200 | iteration 140 / 171 | Total Loss: 3.0074405670166016 | KNN Loss: 2.9995031356811523 | CLS Loss: 0.007937494665384293
Epoch 48 / 200 | iteration 150 / 171 | Total Loss: 3.0776443481445312 | KNN Loss: 3.0624711513519287 | CLS Loss: 0.015173304826021194
Epoch 48 / 200 | iteration 160 / 171 | Total Loss: 3.0729422569274902 | 

Epoch 51 / 200 | iteration 160 / 171 | Total Loss: 3.0424156188964844 | KNN Loss: 3.018876314163208 | CLS Loss: 0.023539185523986816
Epoch 51 / 200 | iteration 170 / 171 | Total Loss: 3.046872138977051 | KNN Loss: 3.0347771644592285 | CLS Loss: 0.01209485623985529
Epoch: 051, Loss: 3.0575, Train: 0.9944, Valid: 0.9868, Best: 0.9868
Epoch 52 / 200 | iteration 0 / 171 | Total Loss: 3.0674281120300293 | KNN Loss: 3.050729274749756 | CLS Loss: 0.01669878326356411
Epoch 52 / 200 | iteration 10 / 171 | Total Loss: 3.0257301330566406 | KNN Loss: 3.005274534225464 | CLS Loss: 0.020455611869692802
Epoch 52 / 200 | iteration 20 / 171 | Total Loss: 3.0424771308898926 | KNN Loss: 3.0224692821502686 | CLS Loss: 0.020007802173495293
Epoch 52 / 200 | iteration 30 / 171 | Total Loss: 3.0345957279205322 | KNN Loss: 3.022479295730591 | CLS Loss: 0.012116420082747936
Epoch 52 / 200 | iteration 40 / 171 | Total Loss: 3.0512828826904297 | KNN Loss: 3.0302734375 | CLS Loss: 0.021009420976042747
Epoch 52 / 2

Epoch 55 / 200 | iteration 40 / 171 | Total Loss: 3.0876059532165527 | KNN Loss: 3.0606045722961426 | CLS Loss: 0.027001449838280678
Epoch 55 / 200 | iteration 50 / 171 | Total Loss: 3.046112060546875 | KNN Loss: 3.0245144367218018 | CLS Loss: 0.021597709506750107
Epoch 55 / 200 | iteration 60 / 171 | Total Loss: 3.0568127632141113 | KNN Loss: 3.0344200134277344 | CLS Loss: 0.022392842918634415
Epoch 55 / 200 | iteration 70 / 171 | Total Loss: 3.018359899520874 | KNN Loss: 3.0099265575408936 | CLS Loss: 0.008433266542851925
Epoch 55 / 200 | iteration 80 / 171 | Total Loss: 3.064112901687622 | KNN Loss: 3.0317254066467285 | CLS Loss: 0.032387517392635345
Epoch 55 / 200 | iteration 90 / 171 | Total Loss: 3.0333433151245117 | KNN Loss: 3.0269806385040283 | CLS Loss: 0.006362715270370245
Epoch 55 / 200 | iteration 100 / 171 | Total Loss: 3.0289909839630127 | KNN Loss: 3.0198559761047363 | CLS Loss: 0.009134982712566853
Epoch 55 / 200 | iteration 110 / 171 | Total Loss: 3.068129777908325 | 

In [None]:
test(model, test_data_iter, device)

In [None]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

In [None]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

In [None]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

In [None]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

In [None]:
clusters = DBSCAN(eps=2, min_samples=10).fit_predict(projections)

In [None]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

In [None]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [None]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [None]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [None]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [None]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 100
use_cuda = device != 'cpu'

In [None]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [None]:
losses = []
accs = []
sparsity = []

In [None]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

In [None]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

In [None]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

# Tree Visualization

In [None]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

# Extract Rules

# Accumulate samples in the leaves

In [None]:
print(f"Number of patterns: {len(root.get_leaves())}")

In [None]:
method = 'greedy'

In [None]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)

# Tighten boundaries

In [None]:
attr_names = [f"T_{i}" for i in range(test_samples.shape[2])]
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")