In [None]:
import matplotlib.pyplot as plt
import numpy as np

import torch.nn as nn
import torch
import types

from Models.NoiseCNN import CNN

In [None]:
model = CNN()
model.load_state_dict(torch.load("TrainedModels/CNN_AlexNet.pt"))
model.eval()

In [None]:
def get_all_layers(model_children, layer_type):
  for layer in model_children:
    for t in layer_type:
        if type(layer) == t:
            yield layer
    if type(layer) == nn.Sequential:
        yield get_all_layers(layer, layer_type)

def fill_list_from_generators(gen, list):
    for item in gen:
        if isinstance(item, types.GeneratorType):
            fill_list_from_generators(item, list)
        else:
            list.append(item)

In [None]:
# get all the model children as list
model_children = list(model.children())
conv_layers = []
fill_list_from_generators(get_all_layers(model_children, (nn.Conv1d,)), conv_layers)
model_weights = [layer.weight for layer in conv_layers]

In [None]:
# take a look at the conv layers and the respective weights
for i, (weight, conv) in enumerate(zip(model_weights, conv_layers)):
    # print(f"WEIGHT: {weight} \nSHAPE: {weight.shape}")
    print(f"{i} CONV: {conv} ====> SHAPE: {weight.shape}")

## Convolutional Filter Visualisation

We can plot the convolution filters of the trained network. The shapes should resemble the snhapes which the filter detects in the input (and passes to the next layer).

LAYER_NUM selects the layer of the network to view.

IN_CHANNEL_NUM selects the input channel for which to view the filters for

In [None]:
# visualize conv layer filters
LAYER_NUM = 2
IN_CHANNEL_NUM = 0
num_filters = model_weights[LAYER_NUM].shape[0]

plt.figure(figsize=(20, 20))
for i, filter in enumerate(model_weights[LAYER_NUM]):
    plt.subplot(num_filters//8 + 1, 8 + 1, i+1)
    plt.plot(filter[IN_CHANNEL_NUM, :].detach().cpu())
    # plt.axis('off')

plt.show()

### Visualise intermediate signals in the CNN

In [None]:
# Onehot encoding
from torch.utils.data import Dataset, DataLoader

def generate_onehot(c):
    if c == "N":
        return np.array([1, 0, 0, 0])
    if c == "O":
        return np.array([0, 1, 0, 0])
    if c == "A":
        return np.array([0, 0, 1, 0])
    if c == "~":
        return np.array([0, 0, 0, 1])

def generate_index(c):
    if c == "N":
        return 0
    if c == "O":
        return 1
    if c == "A":
        return 2
    if c == "~":
        return 3

# dataset["onehot"] = dataset["class"].map(generate_onehot)
dataset["class_index"] = dataset["class"].map(generate_index)

class Dataset(torch.utils.data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, dataset):
        'Initialization'
        self.dataset = dataset

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.dataset.index)

  def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        row = self.dataset.iloc[index]

        X = row["data"][0]
        y = row["class_index"]

        return X, y

train_dataset, test_dataset = train_test_split(dataset, test_size=0.15, stratify=dataset["class"])

torch_dataset_train = Dataset(train_dataset)
torch_dataset_test = Dataset(test_dataset)

train_dataloader = DataLoader(torch_dataset_train, batch_size=32, shuffle=True, pin_memory=True)
test_dataloader = DataLoader(torch_dataset_test, batch_size=32, shuffle=True, pin_memory=True)

In [None]:
model.eval()

true_labels = []
predictions = []

with torch.no_grad():
    for i, (signals, labels) in enumerate(test_dataloader):
        signals = torch.unsqueeze(signals.to(device), 1).float()
        # fft = torch.abs(torch.fft.fft(signals))
        # signals = torch.cat([signals, fft], dim=1)
        labels = labels.detach().numpy()
        true_labels.append(labels)

        output = model(signals).detach().to("cpu").numpy()
        predictions.append(np.argmax(output, axis=-1))

predictions = np.concatenate(predictions)
true_labels = np.concatenate(true_labels)