In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 8
tree_depth = 6
batch_size = 512
device = 'cpu'
train_data_path = r'F:\Downloads\archive\mitbih_train.csv'
test_data_path = r'F:\Downloads\archive\mitbih_test.csv'

In [3]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [4]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [5]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0)
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [6]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [None]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.22037935256958 | KNN Loss: 5.639108180999756 | CLS Loss: 1.5812710523605347
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 4.512310028076172 | KNN Loss: 3.8615217208862305 | CLS Loss: 0.6507881879806519
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 3.386953353881836 | KNN Loss: 2.649182081222534 | CLS Loss: 0.737771213054657
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 3.1240243911743164 | KNN Loss: 2.560368776321411 | CLS Loss: 0.5636554956436157
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 3.0672225952148438 | KNN Loss: 2.531372547149658 | CLS Loss: 0.5358500480651855
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 3.016354560852051 | KNN Loss: 2.437026023864746 | CLS Loss: 0.5793285965919495
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 2.989715099334717 | KNN Loss: 2.495742082595825 | CLS Loss: 0.493973046541214
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 2.927281379699707 | KNN Loss: 2.4779787063598633 | CLS L

Epoch 4 / 200 | iteration 80 / 171 | Total Loss: 2.629610061645508 | KNN Loss: 2.456813335418701 | CLS Loss: 0.17279668152332306
Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 2.5800623893737793 | KNN Loss: 2.4619479179382324 | CLS Loss: 0.11811457574367523
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 2.6214852333068848 | KNN Loss: 2.4368479251861572 | CLS Loss: 0.1846373826265335
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 2.652789831161499 | KNN Loss: 2.5058364868164062 | CLS Loss: 0.1469532549381256
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 2.631235361099243 | KNN Loss: 2.4475133419036865 | CLS Loss: 0.183722123503685
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 2.621601104736328 | KNN Loss: 2.4818685054779053 | CLS Loss: 0.13973253965377808
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 2.620903253555298 | KNN Loss: 2.517056465148926 | CLS Loss: 0.10384675860404968
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 2.6423487663269043 | KNN Loss: 2.483466

Epoch 7 / 200 | iteration 160 / 171 | Total Loss: 2.5354063510894775 | KNN Loss: 2.4283711910247803 | CLS Loss: 0.10703523457050323
Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 2.580207347869873 | KNN Loss: 2.476850986480713 | CLS Loss: 0.10335638374090195
Epoch: 007, Loss: 2.5394, Train: 0.9768, Valid: 0.9744, Best: 0.9744
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 2.5281362533569336 | KNN Loss: 2.450173854827881 | CLS Loss: 0.0779624953866005
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 2.50212025642395 | KNN Loss: 2.4447734355926514 | CLS Loss: 0.05734681338071823
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 2.564859628677368 | KNN Loss: 2.468865394592285 | CLS Loss: 0.09599415212869644
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 2.560581684112549 | KNN Loss: 2.4350554943084717 | CLS Loss: 0.12552624940872192
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 2.508180618286133 | KNN Loss: 2.434363842010498 | CLS Loss: 0.07381674647331238
Epoch 8 / 200 | iteratio

Epoch 11 / 200 | iteration 50 / 171 | Total Loss: 2.516437292098999 | KNN Loss: 2.4461705684661865 | CLS Loss: 0.07026666402816772
Epoch 11 / 200 | iteration 60 / 171 | Total Loss: 2.5064377784729004 | KNN Loss: 2.4180960655212402 | CLS Loss: 0.08834180235862732
Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 2.5585741996765137 | KNN Loss: 2.4586586952209473 | CLS Loss: 0.0999155044555664
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 2.5331742763519287 | KNN Loss: 2.415907859802246 | CLS Loss: 0.11726631969213486
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 2.494556188583374 | KNN Loss: 2.416490077972412 | CLS Loss: 0.07806612551212311
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 2.5526835918426514 | KNN Loss: 2.4261972904205322 | CLS Loss: 0.12648634612560272
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 2.509809732437134 | KNN Loss: 2.426271677017212 | CLS Loss: 0.08353807032108307
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 2.495009660720825 | KNN Loss: 

Epoch 14 / 200 | iteration 120 / 171 | Total Loss: 2.502730131149292 | KNN Loss: 2.3985047340393066 | CLS Loss: 0.10422539710998535
Epoch 14 / 200 | iteration 130 / 171 | Total Loss: 2.501847505569458 | KNN Loss: 2.415691375732422 | CLS Loss: 0.08615612238645554
Epoch 14 / 200 | iteration 140 / 171 | Total Loss: 2.4611289501190186 | KNN Loss: 2.4113149642944336 | CLS Loss: 0.04981403425335884
Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 2.4937100410461426 | KNN Loss: 2.424264669418335 | CLS Loss: 0.06944538652896881
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 2.4458136558532715 | KNN Loss: 2.397916078567505 | CLS Loss: 0.04789765551686287
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 2.480311870574951 | KNN Loss: 2.4273149967193604 | CLS Loss: 0.05299697443842888
Epoch: 014, Loss: 2.4819, Train: 0.9803, Valid: 0.9776, Best: 0.9787
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 2.5016870498657227 | KNN Loss: 2.434678316116333 | CLS Loss: 0.06700881570577621
Epoch 15 

Epoch 18 / 200 | iteration 10 / 171 | Total Loss: 2.5146872997283936 | KNN Loss: 2.407222032546997 | CLS Loss: 0.10746519267559052
Epoch 18 / 200 | iteration 20 / 171 | Total Loss: 2.4523472785949707 | KNN Loss: 2.3862624168395996 | CLS Loss: 0.06608489900827408
Epoch 18 / 200 | iteration 30 / 171 | Total Loss: 2.4582390785217285 | KNN Loss: 2.399864912033081 | CLS Loss: 0.05837418511509895
Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 2.457495927810669 | KNN Loss: 2.4064888954162598 | CLS Loss: 0.051007144153118134
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 2.5120275020599365 | KNN Loss: 2.434089422225952 | CLS Loss: 0.07793819159269333
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 2.496138095855713 | KNN Loss: 2.4454262256622314 | CLS Loss: 0.050711825489997864
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 2.4730207920074463 | KNN Loss: 2.4284307956695557 | CLS Loss: 0.04458993300795555
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 2.4442408084869385 | KNN Loss

Epoch 21 / 200 | iteration 80 / 171 | Total Loss: 2.434865713119507 | KNN Loss: 2.3630244731903076 | CLS Loss: 0.07184115052223206
Epoch 21 / 200 | iteration 90 / 171 | Total Loss: 2.446728467941284 | KNN Loss: 2.3893675804138184 | CLS Loss: 0.05736100301146507
Epoch 21 / 200 | iteration 100 / 171 | Total Loss: 2.5034406185150146 | KNN Loss: 2.438753366470337 | CLS Loss: 0.06468720734119415
Epoch 21 / 200 | iteration 110 / 171 | Total Loss: 2.453659772872925 | KNN Loss: 2.418555974960327 | CLS Loss: 0.035103872418403625
Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 2.442380905151367 | KNN Loss: 2.373149871826172 | CLS Loss: 0.06923096626996994
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 2.4614107608795166 | KNN Loss: 2.4326438903808594 | CLS Loss: 0.028766868636012077
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 2.460458278656006 | KNN Loss: 2.3848159313201904 | CLS Loss: 0.07564227283000946
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 2.4344379901885986 | KNN L

In [None]:
test(model, test_data_iter, device)

In [None]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

In [None]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

In [None]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

In [None]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

In [None]:
res = []
for i in np.arange(0.5, 4, 0.1):
    clusters = DBSCAN(eps=i, min_samples=10).fit_predict(projections)
    print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")
    res.append(sum(clusters != -1) / len(clusters))

In [162]:
plt.figure()
plt.plot(np.arange(0.5, 4, 0.1), res)
plt.show()

  """Entry point for launching an IPython kernel.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [163]:
clusters = DBSCAN(eps=2, min_samples=10).fit_predict(projections)

In [165]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

Number of inliers: 0.8540952903019505


In [166]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

  fig = plt.figure()


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [167]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [168]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [169]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [177]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 100
use_cuda = device != 'cpu'

In [178]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [179]:
losses = []
accs = []
sparsity = []

In [180]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
Epoch: 00 | Batch: 000 / 037 | Total loss: 3.030 | Reg loss: 0.007 | Tree loss: 3.030 | Accuracy: 0.017578 | 0.125 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 01 | Batch: 000 / 037 | Total loss: 2.942 | Reg loss: 0.004 | Tree loss: 2.942 | Accuracy: 0.210938 | 0.113 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 02 | Batch: 000 / 037 | Total loss: 2.902 | Reg loss: 0.007 | Tree loss: 2.902 | Accuracy: 0.140625 | 0.113 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 28 | Batch: 000 / 037 | Total loss: 2.187 | Reg loss: 0.035 | Tree loss: 2.187 | Accuracy: 0.255859 | 0.121 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 29 | Batch: 000 / 037 | Total loss: 2.166 | Reg loss: 0.036 | Tree loss: 2.166 | Accuracy: 0.251953 | 0.12 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 30 | Batch: 000 / 037 | Total loss: 2.127 | Reg loss: 0.036 | Tree loss: 2.127 | Accuracy: 0.271484 | 0.12 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
la

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 56 | Batch: 000 / 037 | Total loss: 1.827 | Reg loss: 0.044 | Tree loss: 1.827 | Accuracy: 0.421875 | 0.116 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 57 | Batch: 000 / 037 | Total loss: 1.815 | Reg loss: 0.044 | Tree loss: 1.815 | Accuracy: 0.451172 | 0.116 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 58 | Batch: 000 / 037 | Total loss: 1.825 | Reg loss: 0.044 | Tree loss: 1.825 | Accuracy: 0.457031 | 0.116 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894


Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 84 | Batch: 000 / 037 | Total loss: 1.864 | Reg loss: 0.043 | Tree loss: 1.864 | Accuracy: 0.437500 | 0.114 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 85 | Batch: 000 / 037 | Total loss: 1.757 | Reg loss: 0.043 | Tree loss: 1.757 | Accuracy: 0.488281 | 0.114 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 86 | Batch: 000 / 037 | Total loss: 1.855 | Reg loss: 0.043 | Tree loss: 1.855 | Accuracy: 0.449219 | 0.114 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894


Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 112 | Batch: 000 / 037 | Total loss: 1.746 | Reg loss: 0.042 | Tree loss: 1.746 | Accuracy: 0.476562 | 0.114 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 113 | Batch: 000 / 037 | Total loss: 1.733 | Reg loss: 0.042 | Tree loss: 1.733 | Accuracy: 0.451172 | 0.114 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 114 | Batch: 000 / 037 | Total loss: 1.753 | Reg loss: 0.042 | Tree loss: 1.753 | Accuracy: 0.486328 | 0.114 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 140 | Batch: 000 / 037 | Total loss: 1.821 | Reg loss: 0.041 | Tree loss: 1.821 | Accuracy: 0.431641 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 141 | Batch: 000 / 037 | Total loss: 1.743 | Reg loss: 0.041 | Tree loss: 1.743 | Accuracy: 0.484375 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 142 | Batch: 000 / 037 | Total loss: 1.732 | Reg loss: 0.041 | Tree loss: 1.732 | Accuracy: 0.474609 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 168 | Batch: 000 / 037 | Total loss: 1.766 | Reg loss: 0.040 | Tree loss: 1.766 | Accuracy: 0.500000 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 169 | Batch: 000 / 037 | Total loss: 1.762 | Reg loss: 0.040 | Tree loss: 1.762 | Accuracy: 0.466797 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 170 | Batch: 000 / 037 | Total loss: 1.771 | Reg loss: 0.040 | Tree loss: 1.771 | Accuracy: 0.453125 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 196 | Batch: 000 / 037 | Total loss: 1.751 | Reg loss: 0.039 | Tree loss: 1.751 | Accuracy: 0.496094 | 0.116 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 197 | Batch: 000 / 037 | Total loss: 1.748 | Reg loss: 0.039 | Tree loss: 1.748 | Accuracy: 0.449219 | 0.116 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 198 | Batch: 000 / 037 | Total loss: 1.766 | Reg loss: 0.038 | Tree loss: 1.766 | Accuracy: 0.455078 | 0.116 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 224 | Batch: 000 / 037 | Total loss: 1.688 | Reg loss: 0.038 | Tree loss: 1.688 | Accuracy: 0.515625 | 0.116 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 225 | Batch: 000 / 037 | Total loss: 1.742 | Reg loss: 0.038 | Tree loss: 1.742 | Accuracy: 0.466797 | 0.116 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 226 | Batch: 000 / 037 | Total loss: 1.777 | Reg loss: 0.037 | Tree loss: 1.777 | Accuracy: 0.447266 | 0.116 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 252 | Batch: 000 / 037 | Total loss: 1.759 | Reg loss: 0.036 | Tree loss: 1.759 | Accuracy: 0.482422 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 253 | Batch: 000 / 037 | Total loss: 1.834 | Reg loss: 0.036 | Tree loss: 1.834 | Accuracy: 0.445312 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 254 | Batch: 000 / 037 | Total loss: 1.752 | Reg loss: 0.036 | Tree loss: 1.752 | Accuracy: 0.462891 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 280 | Batch: 000 / 037 | Total loss: 1.754 | Reg loss: 0.036 | Tree loss: 1.754 | Accuracy: 0.470703 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 281 | Batch: 000 / 037 | Total loss: 1.772 | Reg loss: 0.036 | Tree loss: 1.772 | Accuracy: 0.480469 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 282 | Batch: 000 / 037 | Total loss: 1.729 | Reg loss: 0.036 | Tree loss: 1.729 | Accuracy: 0.470703 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 308 | Batch: 000 / 037 | Total loss: 1.761 | Reg loss: 0.036 | Tree loss: 1.761 | Accuracy: 0.460938 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 309 | Batch: 000 / 037 | Total loss: 1.805 | Reg loss: 0.036 | Tree loss: 1.805 | Accuracy: 0.457031 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 310 | Batch: 000 / 037 | Total loss: 1.837 | Reg loss: 0.036 | Tree loss: 1.837 | Accuracy: 0.412109 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 336 | Batch: 000 / 037 | Total loss: 1.783 | Reg loss: 0.036 | Tree loss: 1.783 | Accuracy: 0.466797 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 337 | Batch: 000 / 037 | Total loss: 1.768 | Reg loss: 0.036 | Tree loss: 1.768 | Accuracy: 0.457031 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 338 | Batch: 000 / 037 | Total loss: 1.751 | Reg loss: 0.036 | Tree loss: 1.751 | Accuracy: 0.470703 | 0.115 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 364 | Batch: 000 / 037 | Total loss: 1.762 | Reg loss: 0.036 | Tree loss: 1.762 | Accuracy: 0.466797 | 0.116 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 365 | Batch: 000 / 037 | Total loss: 1.751 | Reg loss: 0.036 | Tree loss: 1.751 | Accuracy: 0.462891 | 0.116 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 366 | Batch: 000 / 037 | Total loss: 1.766 | Reg loss: 0.036 | Tree loss: 1.766 | Accuracy: 0.443359 | 0.116 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 392 | Batch: 000 / 037 | Total loss: 1.769 | Reg loss: 0.036 | Tree loss: 1.769 | Accuracy: 0.447266 | 0.116 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 393 | Batch: 000 / 037 | Total loss: 1.743 | Reg loss: 0.036 | Tree loss: 1.743 | Accuracy: 0.464844 | 0.116 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 394 | Batch: 000 / 037 | Total loss: 1.727 | Reg loss: 0.036 | Tree loss: 1.727 | Accuracy: 0.507812 | 0.116 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

In [181]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

  """Entry point for launching an IPython kernel.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [182]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

  """Entry point for launching an IPython kernel.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [None]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

  """Entry point for launching an IPython kernel.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 5.117647058823529


# Extract Rules

# Accumulate samples in the leaves

In [None]:
print(f"Number of patterns: {len(root.get_leaves())}")

In [94]:
method = 'greedy'

In [95]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [96]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

NameError: name 'dataset' is not defined