In [None]:
## Training the optimized wavpool via the param, with a different split of the dataset, and run that sucker. 

from WavPool.training.train_model import TrainingLoop
from WavPool.models.wavpool import WavPool
from WavPool.models.vanillaCNN import VanillaCNN
from WavPool.models.vanillaMLP import VanillaMLP


from WavPool.data_generators.cifar_generator import CIFARGenerator
from WavPool.data_generators.mnist_generator import MNISTGenerator
from WavPool.data_generators.fashion_mnist_generator import FashionMNISTGenerator

import json 
import os 
import matplotlib.pyplot as plt 
import numpy as np

import pandas as pd
import torch 
from sklearn.metrics import confusion_matrix as confusion 

In [None]:
def get_params(model_name, dataset_name): 
    optimize_path = f"../results/optimize_params"
    dir_name = [directory for directory in os.listdir(optimize_path) if (model_name in directory) and (dataset_name in directory)][0]
    file_name = f"{optimize_path}/{dir_name.rstrip('/')}/experiment_config.json"
    with open(file_name, 'r') as f: 
        exp_params = json.load(f)
    return exp_params

def train_model(model, model_name, dataset, dataset_name): 
    exp_params = get_params(model_name, dataset_name)

    loop = TrainingLoop(
        model_class=model, 
        model_params=exp_params['model_kwargs'], 
        data_class=dataset, 
        data_params=exp_params['data_kwargs'], 
        optimizer_class=torch.optim.SGD, 
        optimizer_config=exp_params['optimizer_kwargs'], 
        loss=torch.nn.CrossEntropyLoss, epochs=120)
    loop()

    save_path = f"../results/wavpool_inference_{dataset_name}/"
    loop.save(save_path=save_path)
    test_data = loop.data_loader['test']
    
    return loop.model, test_data

def predict(model, data_loader): 
    predictions = torch.tensor([])
    labels = torch.tensor([])
    for i, batch in enumerate(data_loader):
        data_input, label = batch
        model_prediction = model(data_input)

        labels = torch.concat([labels, label])
        predictions = torch.concat([predictions, model_prediction])
    return predictions, labels

def confusion_matrix(predictions, labels, log=True):   
    num_classes = [i for i in range(predictions.shape[1])]
    _, predicted_class = torch.max(predictions, 1)
    c = confusion(
        labels.ravel(), predicted_class.ravel(), labels=num_classes
        )
    if log: 
        c = np.log(c)
        c[c == -np.inf] = 0
    return c

In [None]:
mnist_model, test_data = train_model(WavPool, "WavPool", MNISTGenerator, "_MNIST")
mnist_pred, mnist_label = predict(mnist_model, test_data)
fmnist_model, test_data = train_model( WavPool, "WavPool", FashionMNISTGenerator, "FashionMNIST")
fmnist_pred, fmnist_label = predict(fmnist_model, test_data)
cifarmodel, test_data = train_model(WavPool, "WavPool", CIFARGenerator, "CIFAR")
cifar_pred, cifar_label = predict(cifarmodel, test_data)

In [None]:
c_mnist = confusion_matrix(mnist_pred, mnist_label)
c_fmnist = confusion_matrix(fmnist_pred, fmnist_label)
c_cifar = confusion_matrix(cifar_pred, cifar_label)

fig, subplots = plt.subplots(1, 3, figsize=(16, 6))

subplots[0].imshow(c_mnist)
subplots[0].set_yticks(ticks=[i for i in range(10)], labels=[i for i in range(10)])
subplots[0].set_xticks(ticks=[i for i in range(10)], labels=[i for i in range(10)])
subplots[0].set_title("MNIST")


subplots[1].imshow(c_fmnist)
labels =['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

subplots[1].set_yticks(ticks=[i for i in range(10)], labels=labels, rotation=45)
subplots[1].set_xticks(ticks=[i for i in range(10)], labels=labels, rotation=45)
subplots[1].set_title("Fashion MNIST")

scale = subplots[2].imshow(c_cifar)
labels =['Plane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

subplots[2].set_yticks(ticks=[i for i in range(10)], labels=labels, rotation=45)
subplots[2].set_xticks(ticks=[i for i in range(10)], labels=labels, rotation=45)
subplots[2].set_title("CIFAR-10")


fig.supxlabel('Predicted')
fig.supylabel('True')
cbar_ax = fig.add_axes([1.0, 0.1, 0.05, 0.9])
fig.colorbar(scale, cax=cbar_ax)

plt.tight_layout()

plt.show()

In [None]:
# Doing the same for the other models

mnist_model, test_data = train_model(VanillaCNN, "CNN", MNISTGenerator, "_MNIST")
mnist_pred_cnn, mnist_label_cnn = predict(mnist_model, test_data)

fmnist_model, test_data = train_model(VanillaCNN, "CNN", FashionMNISTGenerator, "FashionMNIST")
fmnist_pred_cnn, fmnist_label_cnn = predict(fmnist_model, test_data)

cifarmodel, test_data = train_model(VanillaCNN, "CNN", CIFARGenerator, "CIFAR")
cifar_pred_cnn, cifar_label_cnn = predict(cifarmodel, test_data)


In [None]:
c_mnist = confusion_matrix(mnist_pred_cnn, mnist_label_cnn)
c_fmnist = confusion_matrix(fmnist_pred_cnn, fmnist_label_cnn)
c_cifar = confusion_matrix(cifar_pred_cnn, cifar_label_cnn)

fig, subplots = plt.subplots(1, 3, figsize=(16, 6))

subplots[0].imshow(c_mnist)
subplots[0].set_yticks(ticks=[i for i in range(10)], labels=[i for i in range(10)])
subplots[0].set_xticks(ticks=[i for i in range(10)], labels=[i for i in range(10)])
subplots[0].set_title("MNIST")


subplots[1].imshow(c_fmnist)
labels =['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

subplots[1].set_yticks(ticks=[i for i in range(10)], labels=labels, rotation=45)
subplots[1].set_xticks(ticks=[i for i in range(10)], labels=labels, rotation=45)
subplots[1].set_title("Fashion MNIST")

scale = subplots[2].imshow(c_cifar)
labels =['Plane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

subplots[2].set_yticks(ticks=[i for i in range(10)], labels=labels, rotation=45)
subplots[2].set_xticks(ticks=[i for i in range(10)], labels=labels, rotation=45)
subplots[2].set_title("CIFAR-10")


cbar_ax = fig.add_axes([1.0, 0.1, 0.05, 0.9])
fig.colorbar(scale, cax=cbar_ax)

fig.supxlabel('Predicted')
fig.supylabel('True')


plt.tight_layout()

plt.show()

In [None]:
# Doing the same for the other models

mnist_model, test_data = train_model(VanillaMLP, "VanillaMLP", MNISTGenerator, "_MNIST")
mnist_pred_mlp, mnist_label_mlp = predict(mnist_model, test_data)

fmnist_model, test_data = train_model(VanillaMLP, "VanillaMLP",  FashionMNISTGenerator, "FashionMNIST")
fmnist_pred_mlp, fmnist_label_mlp = predict(fmnist_model, test_data)

cifarmodel, test_data = train_model(VanillaMLP, "VanillaMLP",  CIFARGenerator, "CIFAR")
cifar_pred_mlp, cifar_label_mlp = predict(cifarmodel, test_data)


In [None]:
c_mnist = confusion_matrix(mnist_pred_mlp, mnist_label_mlp)
c_fmnist = confusion_matrix(fmnist_pred_mlp, fmnist_label_mlp)
c_cifar = confusion_matrix(cifar_pred_mlp, cifar_label_mlp)

fig, subplots = plt.subplots(1, 3, figsize=(16, 6))

subplots[0].imshow(c_mnist)
subplots[0].set_yticks(ticks=[i for i in range(10)], labels=[i for i in range(10)])
subplots[0].set_xticks(ticks=[i for i in range(10)], labels=[i for i in range(10)])
subplots[0].set_title("MNIST")


subplots[1].imshow(c_fmnist)
labels =['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

subplots[1].set_yticks(ticks=[i for i in range(10)], labels=labels, rotation=45)
subplots[1].set_xticks(ticks=[i for i in range(10)], labels=labels, rotation=45)
subplots[1].set_title("Fashion MNIST")

scale = subplots[2].imshow(c_cifar)
labels =['Plane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

subplots[2].set_yticks(ticks=[i for i in range(10)], labels=labels, rotation=45)
subplots[2].set_xticks(ticks=[i for i in range(10)], labels=labels, rotation=45)
subplots[2].set_title("CIFAR-10")

cbar_ax = fig.add_axes([1.0, 0.1, 0.05, 0.9])
fig.colorbar(scale, cax=cbar_ax)

fig.supxlabel('Predicted')
fig.supylabel('True')
plt.tight_layout()

plt.show()
