In [12]:
import datasets
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import TensorDataset, DataLoader
from prosenet import Encoder, EncoderClassifier, ProSeNet, PrototypeProjection
from sklearn.utils import class_weight

In [13]:
# Dataset parameters
num_classes = 5
num_features = 1
gpus = 1
device = torch.device("cuda")

In [14]:
# Set any non-default args
new_rnn_args = {
    "input_size": num_features,
    "hidden_size": 32,
    "num_layers": 3,
    "dropout": 0.1,
    "bidirectional": True,
    "batch_first": True,
}

encoder = Encoder(new_rnn_args)

new_proto_args = {
    "K": 30,
    "D": 32,
    "dmin": 2.0,
    "Ld": 0.01,  # 0.1,
    "Lc": 0.0,
    "Le": 1.0,
}

In [15]:
data = datasets.ArrhythmiaDataset("../dataset/data/", normalize=False)
print(data)

MIT-BIH Arrhythmia Dataset
Num classes: 5
Input shape: (187, 1)
Train, Test counts: 87554, 21892



In [16]:
train_gen = DataLoader(
    TensorDataset(torch.FloatTensor(data.X_train), torch.LongTensor(data.y_train)),
    batch_size=128,
    shuffle=True,
    num_workers=4,
)
test_gen = DataLoader(
    TensorDataset(torch.FloatTensor(data.X_test), torch.LongTensor(data.y_test)),
    batch_size=128,
    shuffle=False,
    num_workers=4,
)


class_weights = 1 - (np.bincount(data.y_train) / data.y_train.shape[0])
class_weights = torch.FloatTensor(class_weights).to(device)
# class_weights[0] = 0.01
class_weights

tensor([0.1723, 0.9746, 0.9339, 0.9927, 0.9265], device='cuda:0')

# First train just the `encoder`

In [17]:
encoder_classifier = EncoderClassifier(
    encoder, class_weights, new_rnn_args["hidden_size"], num_classes
)
print(encoder_classifier)

EncoderClassifier(
  (encoder): Encoder(
    (rnn): LSTM(1, 32, num_layers=3, batch_first=True, dropout=0.1, bidirectional=True)
  )
  (pred): Linear(in_features=32, out_features=5, bias=True)
  (named_metrics_train): ModuleDict(
    (acc): Accuracy()
    (avg_p): Precision()
    (avg_r): Recall()
  )
  (named_metrics_val): ModuleDict(
    (acc): Accuracy()
    (avg_p): Precision()
    (avg_r): Recall()
  )
)


In [18]:
tb_logger = pl.loggers.TensorBoardLogger(
    name="ecg_encoder", save_dir="lightning_logs/"
)
trainer = pl.Trainer(
    max_epochs=100,
    gpus=gpus,
    deterministic=True,
    logger=tb_logger,
)
trainer.fit(encoder_classifier, train_gen, test_gen)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name                | Type       | Params
---------------------------------------------------
0 | encoder             | Encoder    | 59.1 K
1 | pred                | Linear     | 165   
2 | named_metrics_train | ModuleDict | 0     
3 | named_metrics_val   | ModuleDict | 0     
---------------------------------------------------
59.3 K    Trainable params
0         Non-trainable params
59.3 K    Total params
0.237     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

# Now freeze the `encoder` and train the prototypes head

In [19]:
for para in encoder.parameters():
    para.requires_grad = False
# print(para)

In [20]:
pnet = ProSeNet(
    n_classes=num_classes,
    class_weights=class_weights,
    encoder=encoder,
    prototypes_args=new_proto_args,
)

pnet

ProSeNet(
  (encoder): Encoder(
    (rnn): LSTM(1, 32, num_layers=3, batch_first=True, dropout=0.1, bidirectional=True)
  )
  (prototypes_layer): Prototypes()
  (classifier): Sequential(
    (0): Linear(in_features=30, out_features=5, bias=True)
  )
  (named_metrics_train): ModuleDict(
    (acc): Accuracy()
    (avg_p): Precision()
    (avg_r): Recall()
  )
  (named_metrics_val): ModuleDict(
    (acc): Accuracy()
    (avg_p): Precision()
    (avg_r): Recall()
  )
)

In [21]:
trainer = pl.Trainer(
    max_epochs=100,
    gpus=gpus,
    deterministic=True,
    gradient_clip_val=5.0,
    callbacks=[PrototypeProjection(pnet, train_gen)],
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [22]:
trainer.fit(pnet, train_gen, test_gen)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name                | Type       | Params
---------------------------------------------------
0 | encoder             | Encoder    | 59.1 K
1 | prototypes_layer    | Prototypes | 960   
2 | classifier          | Sequential | 155   
3 | named_metrics_train | ModuleDict | 0     
4 | named_metrics_val   | ModuleDict | 0     
---------------------------------------------------
1.1 K     Trainable params
59.1 K    Non-trainable params
60.3 K    Total params
0.241     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[82865, 78051, 15729, 54602, 81318, 78051, 25169, 21758, 58402, 81318,
          1414, 81318, 81318, 81318, 81318, 81318, 81318, 54602, 81318, 76835,
         81318, 81318,  1414,  6840, 14178, 14178, 12975, 10148, 81318, 78051]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[65950, 42615, 77331, 71286, 23345, 42615, 62598, 24909, 50159, 23345,
         28716, 23345, 23345, 23345, 23345, 23345, 23345, 85326, 23345, 81865,
         23345, 54803, 79243, 12025, 49615, 49615, 74184,  1801, 23345, 42615]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[18741, 76914, 82485, 35909, 76801, 20493, 39762, 32407, 35988,   148,
          2181, 74650, 54230, 65006, 39128, 44379, 33779, 67544, 59587, 72348,
         72557, 46418, 33207, 47728, 64918, 28105, 51494, 55127, 56077, 74761]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[53112,  5080,  1676, 62148, 34037, 12228,  4321, 19058, 30326, 70061,
           484, 64009, 40630, 64521, 84156, 22708,  9022, 60128, 40694, 13289,
         15030, 29585, 17407,  5853, 61506, 55358, 29188, 47383, 10805, 28088]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[  468, 87306,  8897, 59727,  3267,  6597, 21405, 18379, 53744, 59867,
         85471, 79719, 11344, 63434, 75575, 85280, 59654,  4461, 39531, 19113,
         20742, 61444, 34913, 45360, 25255,  3179,    70, 13899, 20784,  5188]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[59854, 67240, 37751,  3079, 53485, 31853, 20201, 51409, 63328, 20194,
         23810, 13937, 42423, 63797, 66015, 67003, 70265, 59017, 19814, 41398,
         21051, 81415, 35728, 24368,  2554, 58333,  6904, 80249,  5251,   736]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[49879,  6715, 33277,   275, 86154, 79464, 60466, 22201, 74072, 28308,
         54047, 57584, 73105, 72576, 24444, 45885, 65221, 71148, 37607, 12003,
         70030, 85155, 77885, 61571, 33642, 19780, 26146, 30627, 31347, 50230]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[64538, 80075, 71899, 16038, 30644, 52456, 33509, 81485, 34343, 64392,
         24019, 43884, 78338, 72021, 22007, 33749, 30079, 40404, 27476, 17802,
         80635, 77889, 71393, 74890, 56956, 40360, 84097, 31521, 45216, 27037]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[47202, 22454, 59137, 33937, 55891, 57987,  1169, 28445, 72583, 73557,
          2214,  5621, 85116, 75744, 34218, 86647, 79195, 26753, 58347, 60408,
         71185, 24670, 29072, 12455, 71883, 32574, 31711, 54628,   974, 28337]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[55303, 58720, 56604, 78791, 20204, 21604, 31372, 10512,  8979, 69048,
          4105, 70099, 39915, 48351, 22662, 56557, 83919, 79577, 32607, 77193,
          3324, 78961, 54117, 70080, 12072, 15734, 21297, 65458, 87202, 30164]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[ 3276, 48386,    18, 77320,  2008, 68480, 26917, 41882, 52791, 15779,
         11833, 31343,   262, 35329,  8718, 48781, 75642, 72867,  2044, 37686,
         20694, 74814, 43458, 65232, 60486, 16856, 67597, 64629, 25611,  5982]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[44391,  9645, 50094, 82087, 49259, 24932, 49188,  7283, 10042, 60059,
         63317, 75219, 70948, 79283, 66769, 71144, 86970, 45197, 40813, 45126,
          2379, 87236, 53299, 50009,  2896, 40037, 78096, 86390,  2475,  5845]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[29014, 64792, 76458, 28797, 54854, 21905, 70776, 74357, 55783, 86668,
         62863, 42212, 64552,  2330, 79154, 37692, 24537, 62219,  3522, 48713,
         85793,  7893, 81102, 46964, 16915, 76355, 52183, 70385, 34675,   649]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[16605, 14005, 77083, 67463, 64400,  2967, 59897, 77894, 13499, 37170,
          4746, 52901,  4950, 50115,  1516, 24285, 22727, 57883,  9001, 28211,
         23667,  6607, 53988, 51008,   105,  3170, 70323,   543, 68210, 23875]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[40815, 82124, 26761, 50148,  4941, 65868, 62757, 61173, 31793, 50830,
         50815, 67408, 78854,  3940, 84296, 10452, 25267, 57218, 72942, 79101,
         71894, 53070,  5810, 63608, 78592, 63392, 53348, 37705, 27054, 75218]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[51363,  1369, 69186, 26601, 10627, 73199, 64656,  6551, 10998, 79164,
         67280, 82236, 40853, 77214, 86944, 19119, 39908, 42439,   117, 23560,
         75924,  3474, 16209, 55451, 52362, 36051, 17872, 70393, 11013, 61220]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[ 6442, 76899, 79647, 81723, 85184, 26852, 65790, 77583, 67107, 26977,
         72010, 23175, 11880, 47441, 80885, 86566, 27341, 34071,  2961, 26522,
           707, 61825, 27709, 16524, 30939, 25950, 36463, 15628, 33997, 67841]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[ 6930, 81145, 50795, 75480, 11730, 80664,  5799, 25224, 34064,  5073,
         10137, 64646, 69491, 48149, 61685, 84930,  4436, 56958, 31441, 65659,
          6971, 37102, 64626, 12937, 68500, 28764, 82705, 75944, 62644, 16139]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[23166, 43656, 27057, 59451, 26598, 12384, 32085, 65586, 86122, 77428,
          7803, 13085, 64994, 12165, 62759,  8304, 20119, 58598, 70969, 35018,
          6658,  8163, 41077, 80052, 35397, 36717, 25618, 21728, 61213, 76981]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[78407, 87532, 55748, 25718, 11145, 21578, 69877,   346,  6716, 84284,
         54843, 16560, 55319, 24817, 49749, 59613, 57510, 82900, 45089, 13033,
         87333, 40675,  2246,  8467,  3134, 57116, 23399,  7671, 19894, 54924]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[18001, 86787, 69782, 45085, 11796, 20280, 76102, 35316, 26563, 75997,
         61048, 63272,  2381, 70831, 24243, 31959,  5760, 45101, 21671, 78523,
          3577, 63949, 23473,  9801, 26492,  7797, 69356, 23546, 47861, 14107]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[50468, 79146, 70685, 71290, 56623, 75590, 76397,  5530, 43667,  4575,
         53470, 79030, 44063, 47632,  5619, 23228, 55634, 43838, 14103, 82254,
         39545, 10747, 32141,  6888, 64308, 46482, 40774, 66984,  1905, 48285]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[58404, 10610,  7188, 63676, 19378, 49465, 11708, 72488, 32526, 30125,
         15648, 59262, 78801, 46602, 14866, 58917, 23968, 35879, 85239, 10752,
         59345, 72638, 82561,  4511, 53000,  7185, 82593,  6433, 37264, 64467]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[27203, 21895, 81336,  5658, 65976, 15752, 36082, 61093, 49540, 23681,
          5151, 55042, 85659, 24727, 58492, 51576, 36812, 17036, 17389, 55629,
         77837, 13240, 16716,  5086, 27572, 56710, 81434, 82857, 83045, 50946]],
       device='cuda:0')
... assigned new prototypes from projections.


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

new protos tensor([[86532,  8370, 68886, 37615, 45996, 41031, 19912, 68124, 46040, 36394,
         54249, 67008, 34436, 12771, 26819, 11464, 61069, 57662, 54883, 14655,
         40396, 21156, 54925, 20992, 65537, 46327, 67828, 38997, 68195, 70891]],
       device='cuda:0')
... assigned new prototypes from projections.


# Diagnostics

In [None]:
# Encoding layer(s) output
pnet.encoder(torch.FloatTensor(data.X_train[np.newaxis, 0]))

In [None]:
# Prototype layer weights
protos = np.squeeze(pnet.prototypes_layer.prototypes.detach().cpu().numpy())
print(protos.min(), protos.max())
print(protos.shape)

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
cax = ax.matshow(protos, cmap=plt.cm.RdBu)
fig.colorbar(cax, ax=ax)

In [None]:
pred_weights = np.squeeze(pnet.classifier[0].weight.detach().cpu().numpy())
print(pred_weights.min(), pred_weights.max())

fig, ax = plt.subplots(figsize=(10, 10))
cax = ax.matshow(pred_weights, cmap=plt.cm.RdBu)
fig.colorbar(cax, ax=ax)

In [None]:
X_encoded = []
for batch in train_gen:
    X, y = batch
    X_encoded.append(encoder.to(device)(X.to(device)))
X_encoded = torch.cat(X_encoded, dim=0).unsqueeze(-2)
X_encoded = X_encoded.unsqueeze(-2).detach().cpu()
X_encoded

In [None]:
# distance matrix from protos
# protos = pnet.prototypes_layer.weights[0]
d2 = torch.norm(X_encoded - protos, p=2, dim=-1)

idxs = d2.argmin(0).numpy()
idxs

In [None]:
matched_protos = data.X_train[np.squeeze(idxs)]
matched_protos_y = data.y_train[np.squeeze(idxs)]
matched_protos.shape

In [None]:
for i in range(30):
    plt.plot(np.arange(187), matched_protos[i, :, 0])
    plt.title(matched_protos_y[i])
    plt.show()

In [None]:
pnet.predict(data.X_test[-3:, :], batch_size=1)

In [None]:
data.y_test[-3:]

In [None]:
plt.plot(np.arange(187), data.X_train[-2])

In [None]:
class_weights

In [None]:
pnet.classifier.weights[0].numpy()

In [None]:
encoder_classifier.layers[-1].weights

In [None]:
protos = np.squeeze(pnet.prototypes_layer.weights[0].numpy())
protos

In [None]:
plt.matshow(protos, cmap=plt.cm.RdBu)

In [None]:
pnet.encoder.predict(data.X_train[-2:, :])

In [None]:
pnet.encoder.predict(data.X_train[:2, :])

In [None]:
pnet.prototypes_layer._diversity_term()

In [None]:
pnet.losses