In [1]:
%matplotlib widget
import torch
import torch.nn as nn
import pickle
import os
from queue import LifoQueue
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from sklearn.metrics import davies_bouldin_score
import matplotlib.pyplot as plt
import numpy as np
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from tqdm.notebook import tqdm
from sklearn.metrics import pairwise_distances
import network.cpc
from network.cpc import CDCK2
from utils.MatplotlibUtils import reduce_dims_and_plot
from utils.ClassificationUtiols import onehot_coding
from soft_decision_tree.sdt_model import SDT


# Load the model and the dataset

In [2]:
model_path = r'C:\Users\eitan\OneDrive - Technion\Desktop\tsne\negative from sample\with_data\knn_loss_batch_512_k_32\models\epoch_40.pt'
dataset_path = r'C:\Users\eitan\OneDrive - Technion\Desktop\tsne\negative from sample\with_data\knn_loss_batch_512_k_32\data\test_data.file'
batch_size = 32
print(f"Load the model from: {model_path}")
model = torch.load(model_path, map_location='cpu')

with open(dataset_path, 'rb') as fp:
    dataset = pickle.load(fp)
    
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

Load the model from: C:\Users\eitan\OneDrive - Technion\Desktop\tsne\negative from sample\with_data\knn_loss_batch_512_k_32\models\epoch_40.pt


# Extract representations

In [3]:
projections = torch.tensor([])
samples = torch.tensor([])
model = model.eval()
device = 'cpu'
with torch.no_grad():
    bar = tqdm(total=len(loader.dataset))
    for batch in loader:
        hidden = CDCK2.init_hidden(len(batch))
        batch = batch.to(device)
        hidden = hidden.to(device)

        y = model.predict(batch, hidden).detach().cpu()
        projections = torch.cat([projections, y])
        samples = torch.cat([samples, batch])
        bar.update(y.shape[0])

HBox(children=(FloatProgress(value=0.0, max=25367.0), HTML(value='')))

# Fit GMM and calculate indices

In [4]:
scores = []
best_score = float('inf')
clusters = None
range_ = list(range(5, 20))
for k in tqdm(range_):
    y = GaussianMixture(n_components=k).fit_predict(projections)
    cur_score = davies_bouldin_score(projections, y)
    scores.append(cur_score)
    
    if cur_score < best_score:
        best_score = cur_score
        clusters = y

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))




In [5]:
plt.figure()
plt.xlabel('Number of clusters')
plt.ylabel('DB Score')
plt.plot(range_, scores)
best_k = range_[np.argmin(scores)]
plt.axvline(best_k, color='r')
plt.show()

labels = set(clusters)
print(labels)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}


In [6]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Visualize with T-SNE

In [7]:
perplexity = 200

p = reduce_dims_and_plot(projections,
                         y=clusters,
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [57]:
# unnormalized_samples = samples.clone()

# for col, sensor in enumerate(tqdm(dataset.dataset.all_signals)):
#     denormalizer = dataset.dataset.get_denormalization_for_sensor(sensor)
#     unnormalized_samples[:, col, :] = denormalizer(unnormalized_samples[:, col, :])

sampled = samples[..., range(0, samples.shape[-1], 100)]

samples_f = sampled.flatten(1)
tree_dataset = list(zip(samples_f, clusters))
batch_size = 2048
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

## Training configurations

In [58]:
lr = 2e-4
weight_decay = 5e-4
sparsity_lamda = 8e-3
epochs = 1000
output_dim = len(set(clusters))
log_interval = 50
tree_depth = 7

tree = SDT(input_dim=samples_f.shape[1], output_dim=len(labels), depth=tree_depth, lamda=1e-3, use_cuda=False)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [59]:
losses = []
accs = []
tree = tree.train()
tree.train()
for epoch in range(epochs):
    # Training
    for batch_idx, (data, target) in enumerate(tree_loader):
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)
        
        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # L1
        fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
        l1_regularization = sparsity_lamda * torch.norm(fc_params, 1)
        loss = loss_tree + l1_regularization

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(tree_loader):03d} | Total loss: {loss.item():.3f} | L1 loss: {l1_regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f}")



Epoch: 00 | Batch: 000 / 013 | Total loss: 11.957 | L1 loss: 9.644 | Tree loss: 2.313 | Accuracy: 0.074707
Epoch: 01 | Batch: 000 / 013 | Total loss: 11.021 | L1 loss: 8.731 | Tree loss: 2.289 | Accuracy: 0.158203
Epoch: 02 | Batch: 000 / 013 | Total loss: 10.137 | L1 loss: 7.868 | Tree loss: 2.269 | Accuracy: 0.201172
Epoch: 03 | Batch: 000 / 013 | Total loss: 9.299 | L1 loss: 7.050 | Tree loss: 2.250 | Accuracy: 0.190918
Epoch: 04 | Batch: 000 / 013 | Total loss: 8.495 | L1 loss: 6.278 | Tree loss: 2.217 | Accuracy: 0.215820
Epoch: 05 | Batch: 000 / 013 | Total loss: 7.753 | L1 loss: 5.553 | Tree loss: 2.200 | Accuracy: 0.214844
Epoch: 06 | Batch: 000 / 013 | Total loss: 7.045 | L1 loss: 4.876 | Tree loss: 2.169 | Accuracy: 0.262695
Epoch: 07 | Batch: 000 / 013 | Total loss: 6.390 | L1 loss: 4.245 | Tree loss: 2.144 | Accuracy: 0.318848
Epoch: 08 | Batch: 000 / 013 | Total loss: 5.780 | L1 loss: 3.664 | Tree loss: 2.117 | Accuracy: 0.352051
Epoch: 09 | Batch: 000 / 013 | Total loss: 

Epoch: 78 | Batch: 000 / 013 | Total loss: 2.065 | L1 loss: 0.397 | Tree loss: 1.668 | Accuracy: 0.528320
Epoch: 79 | Batch: 000 / 013 | Total loss: 2.051 | L1 loss: 0.398 | Tree loss: 1.653 | Accuracy: 0.527344
Epoch: 80 | Batch: 000 / 013 | Total loss: 2.068 | L1 loss: 0.400 | Tree loss: 1.668 | Accuracy: 0.503418
Epoch: 81 | Batch: 000 / 013 | Total loss: 2.054 | L1 loss: 0.401 | Tree loss: 1.653 | Accuracy: 0.516602
Epoch: 82 | Batch: 000 / 013 | Total loss: 2.047 | L1 loss: 0.403 | Tree loss: 1.644 | Accuracy: 0.510742
Epoch: 83 | Batch: 000 / 013 | Total loss: 2.049 | L1 loss: 0.404 | Tree loss: 1.646 | Accuracy: 0.499512
Epoch: 84 | Batch: 000 / 013 | Total loss: 2.056 | L1 loss: 0.405 | Tree loss: 1.651 | Accuracy: 0.484863
Epoch: 85 | Batch: 000 / 013 | Total loss: 2.032 | L1 loss: 0.406 | Tree loss: 1.626 | Accuracy: 0.520508
Epoch: 86 | Batch: 000 / 013 | Total loss: 2.029 | L1 loss: 0.408 | Tree loss: 1.621 | Accuracy: 0.507812
Epoch: 87 | Batch: 000 / 013 | Total loss: 2.0

Epoch: 155 | Batch: 000 / 013 | Total loss: 1.899 | L1 loss: 0.425 | Tree loss: 1.474 | Accuracy: 0.550781
Epoch: 156 | Batch: 000 / 013 | Total loss: 1.931 | L1 loss: 0.426 | Tree loss: 1.505 | Accuracy: 0.509766
Epoch: 157 | Batch: 000 / 013 | Total loss: 1.903 | L1 loss: 0.425 | Tree loss: 1.477 | Accuracy: 0.560547
Epoch: 158 | Batch: 000 / 013 | Total loss: 1.906 | L1 loss: 0.426 | Tree loss: 1.480 | Accuracy: 0.558105
Epoch: 159 | Batch: 000 / 013 | Total loss: 1.916 | L1 loss: 0.427 | Tree loss: 1.489 | Accuracy: 0.536133
Epoch: 160 | Batch: 000 / 013 | Total loss: 1.911 | L1 loss: 0.427 | Tree loss: 1.484 | Accuracy: 0.535156
Epoch: 161 | Batch: 000 / 013 | Total loss: 1.897 | L1 loss: 0.426 | Tree loss: 1.471 | Accuracy: 0.555176
Epoch: 162 | Batch: 000 / 013 | Total loss: 1.907 | L1 loss: 0.427 | Tree loss: 1.480 | Accuracy: 0.549805
Epoch: 163 | Batch: 000 / 013 | Total loss: 1.917 | L1 loss: 0.427 | Tree loss: 1.490 | Accuracy: 0.533691
Epoch: 164 | Batch: 000 / 013 | Total

Epoch: 232 | Batch: 000 / 013 | Total loss: 1.810 | L1 loss: 0.454 | Tree loss: 1.355 | Accuracy: 0.627441
Epoch: 233 | Batch: 000 / 013 | Total loss: 1.795 | L1 loss: 0.454 | Tree loss: 1.341 | Accuracy: 0.623535
Epoch: 234 | Batch: 000 / 013 | Total loss: 1.782 | L1 loss: 0.456 | Tree loss: 1.325 | Accuracy: 0.622559
Epoch: 235 | Batch: 000 / 013 | Total loss: 1.807 | L1 loss: 0.458 | Tree loss: 1.349 | Accuracy: 0.625000
Epoch: 236 | Batch: 000 / 013 | Total loss: 1.783 | L1 loss: 0.459 | Tree loss: 1.324 | Accuracy: 0.629395
Epoch: 237 | Batch: 000 / 013 | Total loss: 1.796 | L1 loss: 0.460 | Tree loss: 1.336 | Accuracy: 0.631836
Epoch: 238 | Batch: 000 / 013 | Total loss: 1.781 | L1 loss: 0.460 | Tree loss: 1.321 | Accuracy: 0.636230
Epoch: 239 | Batch: 000 / 013 | Total loss: 1.799 | L1 loss: 0.461 | Tree loss: 1.338 | Accuracy: 0.626953
Epoch: 240 | Batch: 000 / 013 | Total loss: 1.793 | L1 loss: 0.463 | Tree loss: 1.330 | Accuracy: 0.639648
Epoch: 241 | Batch: 000 / 013 | Total

Epoch: 309 | Batch: 000 / 013 | Total loss: 1.709 | L1 loss: 0.493 | Tree loss: 1.216 | Accuracy: 0.682129
Epoch: 310 | Batch: 000 / 013 | Total loss: 1.699 | L1 loss: 0.494 | Tree loss: 1.205 | Accuracy: 0.682617
Epoch: 311 | Batch: 000 / 013 | Total loss: 1.706 | L1 loss: 0.494 | Tree loss: 1.212 | Accuracy: 0.693359
Epoch: 312 | Batch: 000 / 013 | Total loss: 1.683 | L1 loss: 0.494 | Tree loss: 1.189 | Accuracy: 0.701660
Epoch: 313 | Batch: 000 / 013 | Total loss: 1.679 | L1 loss: 0.494 | Tree loss: 1.185 | Accuracy: 0.680664
Epoch: 314 | Batch: 000 / 013 | Total loss: 1.693 | L1 loss: 0.494 | Tree loss: 1.199 | Accuracy: 0.689453
Epoch: 315 | Batch: 000 / 013 | Total loss: 1.698 | L1 loss: 0.495 | Tree loss: 1.204 | Accuracy: 0.675781
Epoch: 316 | Batch: 000 / 013 | Total loss: 1.677 | L1 loss: 0.494 | Tree loss: 1.182 | Accuracy: 0.698242
Epoch: 317 | Batch: 000 / 013 | Total loss: 1.687 | L1 loss: 0.495 | Tree loss: 1.192 | Accuracy: 0.685059
Epoch: 318 | Batch: 000 / 013 | Total

Epoch: 386 | Batch: 000 / 013 | Total loss: 1.609 | L1 loss: 0.503 | Tree loss: 1.106 | Accuracy: 0.760254
Epoch: 387 | Batch: 000 / 013 | Total loss: 1.640 | L1 loss: 0.503 | Tree loss: 1.137 | Accuracy: 0.735840
Epoch: 388 | Batch: 000 / 013 | Total loss: 1.610 | L1 loss: 0.503 | Tree loss: 1.108 | Accuracy: 0.752930
Epoch: 389 | Batch: 000 / 013 | Total loss: 1.614 | L1 loss: 0.502 | Tree loss: 1.112 | Accuracy: 0.745117
Epoch: 390 | Batch: 000 / 013 | Total loss: 1.625 | L1 loss: 0.502 | Tree loss: 1.123 | Accuracy: 0.749023
Epoch: 391 | Batch: 000 / 013 | Total loss: 1.621 | L1 loss: 0.503 | Tree loss: 1.118 | Accuracy: 0.752930
Epoch: 392 | Batch: 000 / 013 | Total loss: 1.601 | L1 loss: 0.503 | Tree loss: 1.098 | Accuracy: 0.750488
Epoch: 393 | Batch: 000 / 013 | Total loss: 1.608 | L1 loss: 0.503 | Tree loss: 1.106 | Accuracy: 0.752441
Epoch: 394 | Batch: 000 / 013 | Total loss: 1.609 | L1 loss: 0.503 | Tree loss: 1.107 | Accuracy: 0.753418
Epoch: 395 | Batch: 000 / 013 | Total

Epoch: 463 | Batch: 000 / 013 | Total loss: 1.551 | L1 loss: 0.508 | Tree loss: 1.043 | Accuracy: 0.795410
Epoch: 464 | Batch: 000 / 013 | Total loss: 1.556 | L1 loss: 0.508 | Tree loss: 1.049 | Accuracy: 0.775391
Epoch: 465 | Batch: 000 / 013 | Total loss: 1.578 | L1 loss: 0.508 | Tree loss: 1.070 | Accuracy: 0.783203
Epoch: 466 | Batch: 000 / 013 | Total loss: 1.555 | L1 loss: 0.508 | Tree loss: 1.047 | Accuracy: 0.796387
Epoch: 467 | Batch: 000 / 013 | Total loss: 1.571 | L1 loss: 0.508 | Tree loss: 1.063 | Accuracy: 0.790039
Epoch: 468 | Batch: 000 / 013 | Total loss: 1.566 | L1 loss: 0.508 | Tree loss: 1.058 | Accuracy: 0.794434
Epoch: 469 | Batch: 000 / 013 | Total loss: 1.573 | L1 loss: 0.508 | Tree loss: 1.066 | Accuracy: 0.772461
Epoch: 470 | Batch: 000 / 013 | Total loss: 1.538 | L1 loss: 0.508 | Tree loss: 1.030 | Accuracy: 0.807129
Epoch: 471 | Batch: 000 / 013 | Total loss: 1.559 | L1 loss: 0.508 | Tree loss: 1.051 | Accuracy: 0.775391
Epoch: 472 | Batch: 000 / 013 | Total

Epoch: 540 | Batch: 000 / 013 | Total loss: 1.506 | L1 loss: 0.511 | Tree loss: 0.995 | Accuracy: 0.810547
Epoch: 541 | Batch: 000 / 013 | Total loss: 1.512 | L1 loss: 0.511 | Tree loss: 1.001 | Accuracy: 0.811523
Epoch: 542 | Batch: 000 / 013 | Total loss: 1.500 | L1 loss: 0.510 | Tree loss: 0.990 | Accuracy: 0.809082
Epoch: 543 | Batch: 000 / 013 | Total loss: 1.516 | L1 loss: 0.511 | Tree loss: 1.005 | Accuracy: 0.810547
Epoch: 544 | Batch: 000 / 013 | Total loss: 1.515 | L1 loss: 0.510 | Tree loss: 1.004 | Accuracy: 0.813477
Epoch: 545 | Batch: 000 / 013 | Total loss: 1.512 | L1 loss: 0.510 | Tree loss: 1.002 | Accuracy: 0.803711
Epoch: 546 | Batch: 000 / 013 | Total loss: 1.507 | L1 loss: 0.510 | Tree loss: 0.997 | Accuracy: 0.801758
Epoch: 547 | Batch: 000 / 013 | Total loss: 1.520 | L1 loss: 0.510 | Tree loss: 1.010 | Accuracy: 0.805664
Epoch: 548 | Batch: 000 / 013 | Total loss: 1.524 | L1 loss: 0.510 | Tree loss: 1.014 | Accuracy: 0.790039
Epoch: 549 | Batch: 000 / 013 | Total

Epoch: 617 | Batch: 000 / 013 | Total loss: 1.446 | L1 loss: 0.512 | Tree loss: 0.934 | Accuracy: 0.822754
Epoch: 618 | Batch: 000 / 013 | Total loss: 1.455 | L1 loss: 0.512 | Tree loss: 0.943 | Accuracy: 0.810547
Epoch: 619 | Batch: 000 / 013 | Total loss: 1.455 | L1 loss: 0.512 | Tree loss: 0.942 | Accuracy: 0.822266
Epoch: 620 | Batch: 000 / 013 | Total loss: 1.462 | L1 loss: 0.512 | Tree loss: 0.950 | Accuracy: 0.809082
Epoch: 621 | Batch: 000 / 013 | Total loss: 1.449 | L1 loss: 0.512 | Tree loss: 0.937 | Accuracy: 0.815430
Epoch: 622 | Batch: 000 / 013 | Total loss: 1.450 | L1 loss: 0.512 | Tree loss: 0.938 | Accuracy: 0.805176
Epoch: 623 | Batch: 000 / 013 | Total loss: 1.449 | L1 loss: 0.512 | Tree loss: 0.937 | Accuracy: 0.821777
Epoch: 624 | Batch: 000 / 013 | Total loss: 1.465 | L1 loss: 0.512 | Tree loss: 0.953 | Accuracy: 0.812500
Epoch: 625 | Batch: 000 / 013 | Total loss: 1.447 | L1 loss: 0.512 | Tree loss: 0.935 | Accuracy: 0.812012
Epoch: 626 | Batch: 000 / 013 | Total

Epoch: 694 | Batch: 000 / 013 | Total loss: 1.393 | L1 loss: 0.513 | Tree loss: 0.880 | Accuracy: 0.824707
Epoch: 695 | Batch: 000 / 013 | Total loss: 1.424 | L1 loss: 0.513 | Tree loss: 0.911 | Accuracy: 0.813477
Epoch: 696 | Batch: 000 / 013 | Total loss: 1.414 | L1 loss: 0.513 | Tree loss: 0.901 | Accuracy: 0.807129
Epoch: 697 | Batch: 000 / 013 | Total loss: 1.425 | L1 loss: 0.513 | Tree loss: 0.912 | Accuracy: 0.809570
Epoch: 698 | Batch: 000 / 013 | Total loss: 1.424 | L1 loss: 0.513 | Tree loss: 0.911 | Accuracy: 0.813965
Epoch: 699 | Batch: 000 / 013 | Total loss: 1.428 | L1 loss: 0.513 | Tree loss: 0.915 | Accuracy: 0.818848
Epoch: 700 | Batch: 000 / 013 | Total loss: 1.421 | L1 loss: 0.513 | Tree loss: 0.908 | Accuracy: 0.816895
Epoch: 701 | Batch: 000 / 013 | Total loss: 1.425 | L1 loss: 0.513 | Tree loss: 0.911 | Accuracy: 0.819336
Epoch: 702 | Batch: 000 / 013 | Total loss: 1.422 | L1 loss: 0.514 | Tree loss: 0.909 | Accuracy: 0.815918
Epoch: 703 | Batch: 000 / 013 | Total

Epoch: 771 | Batch: 000 / 013 | Total loss: 1.387 | L1 loss: 0.515 | Tree loss: 0.872 | Accuracy: 0.820312
Epoch: 772 | Batch: 000 / 013 | Total loss: 1.373 | L1 loss: 0.514 | Tree loss: 0.859 | Accuracy: 0.829102
Epoch: 773 | Batch: 000 / 013 | Total loss: 1.375 | L1 loss: 0.514 | Tree loss: 0.861 | Accuracy: 0.840332
Epoch: 774 | Batch: 000 / 013 | Total loss: 1.384 | L1 loss: 0.514 | Tree loss: 0.870 | Accuracy: 0.832031
Epoch: 775 | Batch: 000 / 013 | Total loss: 1.380 | L1 loss: 0.514 | Tree loss: 0.866 | Accuracy: 0.824219
Epoch: 776 | Batch: 000 / 013 | Total loss: 1.385 | L1 loss: 0.515 | Tree loss: 0.870 | Accuracy: 0.826660
Epoch: 777 | Batch: 000 / 013 | Total loss: 1.364 | L1 loss: 0.515 | Tree loss: 0.850 | Accuracy: 0.833008
Epoch: 778 | Batch: 000 / 013 | Total loss: 1.376 | L1 loss: 0.514 | Tree loss: 0.862 | Accuracy: 0.835938
Epoch: 779 | Batch: 000 / 013 | Total loss: 1.399 | L1 loss: 0.514 | Tree loss: 0.885 | Accuracy: 0.808594
Epoch: 780 | Batch: 000 / 013 | Total

Epoch: 848 | Batch: 000 / 013 | Total loss: 1.357 | L1 loss: 0.512 | Tree loss: 0.845 | Accuracy: 0.820801
Epoch: 849 | Batch: 000 / 013 | Total loss: 1.352 | L1 loss: 0.512 | Tree loss: 0.840 | Accuracy: 0.815918
Epoch: 850 | Batch: 000 / 013 | Total loss: 1.354 | L1 loss: 0.512 | Tree loss: 0.842 | Accuracy: 0.812988
Epoch: 851 | Batch: 000 / 013 | Total loss: 1.316 | L1 loss: 0.512 | Tree loss: 0.804 | Accuracy: 0.836426
Epoch: 852 | Batch: 000 / 013 | Total loss: 1.337 | L1 loss: 0.512 | Tree loss: 0.825 | Accuracy: 0.820312
Epoch: 853 | Batch: 000 / 013 | Total loss: 1.339 | L1 loss: 0.511 | Tree loss: 0.827 | Accuracy: 0.821777
Epoch: 854 | Batch: 000 / 013 | Total loss: 1.356 | L1 loss: 0.511 | Tree loss: 0.845 | Accuracy: 0.805176
Epoch: 855 | Batch: 000 / 013 | Total loss: 1.338 | L1 loss: 0.512 | Tree loss: 0.827 | Accuracy: 0.827148
Epoch: 856 | Batch: 000 / 013 | Total loss: 1.356 | L1 loss: 0.512 | Tree loss: 0.844 | Accuracy: 0.812988
Epoch: 857 | Batch: 000 / 013 | Total

Epoch: 925 | Batch: 000 / 013 | Total loss: 1.314 | L1 loss: 0.507 | Tree loss: 0.807 | Accuracy: 0.835449
Epoch: 926 | Batch: 000 / 013 | Total loss: 1.309 | L1 loss: 0.507 | Tree loss: 0.802 | Accuracy: 0.812988
Epoch: 927 | Batch: 000 / 013 | Total loss: 1.290 | L1 loss: 0.506 | Tree loss: 0.783 | Accuracy: 0.833496
Epoch: 928 | Batch: 000 / 013 | Total loss: 1.313 | L1 loss: 0.506 | Tree loss: 0.807 | Accuracy: 0.810059
Epoch: 929 | Batch: 000 / 013 | Total loss: 1.311 | L1 loss: 0.506 | Tree loss: 0.805 | Accuracy: 0.824219
Epoch: 930 | Batch: 000 / 013 | Total loss: 1.309 | L1 loss: 0.506 | Tree loss: 0.803 | Accuracy: 0.833984
Epoch: 931 | Batch: 000 / 013 | Total loss: 1.315 | L1 loss: 0.506 | Tree loss: 0.810 | Accuracy: 0.819824
Epoch: 932 | Batch: 000 / 013 | Total loss: 1.321 | L1 loss: 0.506 | Tree loss: 0.815 | Accuracy: 0.827148
Epoch: 933 | Batch: 000 / 013 | Total loss: 1.300 | L1 loss: 0.506 | Tree loss: 0.794 | Accuracy: 0.833008
Epoch: 934 | Batch: 000 / 013 | Total

In [60]:
plt.figure()
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Prune the weights

In [61]:
new_weights = tree.inner_nodes.weight.clone()
new_weights[((weights_mean - weights_std) < new_weights) & (new_weights < (weights_mean + weights_std))] = 0

with torch.no_grad():
    tree.inner_nodes.weight.copy_(new_weights)

In [62]:
plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## Verify that the accuracy didn't change too much

In [63]:
correct = 0

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
            data, target = data.to(device), target.to(device)
            output, penalty = tree.forward(data)
            pred = output.data.max(1)[1]
            correct += pred.eq(target.view(-1).data).sum()

print(f"Accuracy: {correct / len(tree_loader.dataset)}")


Accuracy: 0.8072298765182495


# Tree Visualization

In [64]:
plt.figure(figsize=(10, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 6.226415094339623


# Extract Rules

In [47]:
signal_names = dataset.dataset.all_signals
attr_names = ['bias']
for signal_name in signal_names:
    attr_names += [f"T{i}.{signal_name}" for i in range(sampled.shape[-1])]

# print(attr_names)
stack = LifoQueue()
edge_stack = LifoQueue()
stack.put(root)
rule_counter = 0
root.reset()
while not stack.empty():
    node = stack.get()
    if node.is_leaf():
        print(f"============== Rule {rule_counter} ==============")
        for stack_node, cond in zip(stack.queue, edge_stack.queue[1:]):
            print(repr(stack_node.get_condition(attr_names)) + cond)
            print()
        
        rule_counter += 1
        edge_stack.get()
        continue
          
    if node.left is not None and not node.left.visited:
        stack.put(node)
        stack.put(node.left)
        node.left.visited = True
        edge_stack.put(' < 0')
        continue
        
    if node.right is not None and not node.right.visited:
        stack.put(node)
        stack.put(node.right)
        node.right.visited = True
        edge_stack.put(' > 0')
        continue
        
    if node is not root:
        edge_stack.get()

-0.6530124545097351 * bias + -0.08108553290367126 * T0.gyro_bias_0 + -0.08871910721063614 * T1.gyro_bias_0 + -0.08403283357620239 * T2.gyro_bias_0 + -0.09683618694543839 * T3.gyro_bias_0 + -0.07611164450645447 * T4.gyro_bias_0 + -0.09530968219041824 * T5.gyro_bias_0 + -0.07447057962417603 * T6.gyro_bias_0 + -0.07902420312166214 * T7.gyro_bias_0 + -0.09884173423051834 * T8.gyro_bias_0 + -0.08543957024812698 * T9.gyro_bias_0 + -0.08540912717580795 * T10.gyro_bias_0 + -0.12193066626787186 * T11.gyro_bias_0 + -0.11392191052436829 * T12.gyro_bias_0 + -0.12276637554168701 * T13.gyro_bias_0 + -0.13473999500274658 * T14.gyro_bias_0 + -0.14156514406204224 * T15.gyro_bias_0 + -0.15838274359703064 * T16.gyro_bias_0 + -0.15411284565925598 * T17.gyro_bias_0 + -0.1604440063238144 * T18.gyro_bias_0 + -0.14794662594795227 * T19.gyro_bias_0 + -0.10496390610933304 * T0.gyro_bias_1 + -0.1254509538412094 * T1.gyro_bias_1 + -0.11356838047504425 * T2.gyro_bias_1 + -0.13202457129955292 * T3.gyro_bias_1 + -0.