# CS 8770 Project 1

## Part 1

### Load libs

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.autograd import Function
import torch.optim as optim

### Model definitions

In [None]:
class MLP(nn.Module):
    
    # H: list of hidden layer dims
    # phi: non-linearity to use
    # n_classes: num of classes to pred
    def __init__(self, H, phi=nn.ReLU(), n_classes=10):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()
        self.layers = nn.Sequential()
        # create hidden layers based off input list H
        H.insert(0,28*28) # input layer
        [self.layers.append(nn.Linear(h,l)).append(phi) for h, l in zip(H,H[1:])] # hidden layers
        self.layers.append(nn.Linear(H[-1],n_classes)) # output layer

    def forward(self, x):
        x = self.flatten(x)
        return self.layers(x)

In [None]:
# Single RBF Neuron
class RBFNeuron(nn.Module):

    # mu: RBF mu vector
    # sig: RBF sigma
    def __init__(self, mu, sig):
        super(RBFNeuron, self).__init__()
        self.mu = nn.Parameter(mu)
        self.sig = nn.Parameter(sig)

    def __call__(self, x):
        top = torch.linalg.norm(x-self.mu, dim=1)
        return torch.exp(-0.5 * (top / self.sig) ** 2).float().clone().detach()

# Layer of RBF Neurons
class RBFLayer(nn.Module):

    # nin: input dim
    # nout: output dim
    # mus: list of mean vectors for RBF neurons
    # sigs: list of sigmas for RBF neurons
    def __init__(self, nin, nout, mus, sigs):
        super(RBFLayer, self).__init__()
        self.mus = nn.Parameter(mus)
        self.sigs = nn.Parameter(sigs)
        self.neurons = nn.ModuleList([RBFNeuron(mus[i],sigs[i]) for i in range(nout)])

    def __call__(self, x):
        return torch.stack([f(x) for f in self.neurons], dim=1)

# Full RBF Network
class RBFNet(nn.Module):

    # mus: list of means to use in basis functions
    # sigs: list of sigmas to use in basis functions
    # n_classes: num of classes to pred
    def __init__(self, mus, sigs, n_classes=10):
        super(RBFNet, self).__init__()
        self.K = len(mus) # number of RBFs
        mus = torch.div(mus, torch.linalg.vector_norm(mus, dim=1).view(-1,1)) # unit norm means
        self.mus = nn.Parameter(mus)
        self.sigs = nn.Parameter(sigs)
        self.flatten = nn.Flatten()
        self.layers = nn.Sequential(
            RBFLayer(28*28, self.K, self.mus, self.sigs),
            nn.Linear(self.K, n_classes)
        )

    def forward(self, x):
        x = self.flatten(x)
        x = torch.div(x, torch.linalg.vector_norm(x, dim=1).view(-1,1)) # unit norm x
        return self.layers(x)

In [None]:
class CNN(nn.Module):
    
    def __init__(self, kernel_size=3, pool=2, dropout=0.2, n_classes=10):
        super(CNN, self).__init__()
        self.flatten = nn.Flatten()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=10, kernel_size=28, padding=14),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(pool),
            nn.Dropout(dropout),
        )
        self.mlp = nn.Sequential(
            nn.Linear(14*14*10, 64),
            nn.ReLU(inplace = True),
            nn.Dropout(dropout),
            nn.Linear(64, n_classes)
        )
        
    def forward(self, x):
        x = self.conv(x)
        x = self.flatten(x)
        return self.mlp(x)

### Train (& validation / test?) loops

In [None]:
from tqdm.notebook import tqdm # status bar

In [None]:
def train(model, data, loss_fn, optimizer, epochs=3):

    for epoch in range(epochs):

        epoch_loss = []

        for batch, (samples, labels) in enumerate(tqdm(data)):

            # forward pass
            prediction = model(samples)
            loss = loss_fn(prediction, labels)

            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # record loss
            epoch_loss.append(loss.log10().item())

        # keep track of loss over our batches
        print(loss)
        
    return epoch_loss

In [None]:
from torchmetrics import Precision, Recall

precision = Precision(task='multiclass', num_classes=10)
recall = Recall(task='multiclass', num_classes=10)

def test(model, data, loss_fn):

    for batch, (samples, labels) in enumerate(tqdm(data)):

        # forward pass
        prediction = model(samples)
        loss = loss_fn(prediction, labels)
        
        # metrics
        precision(prediction, labels)
        recall(prediction, labels)

    # test loss
    print(loss)
    precision_value = precision.compute()
    recall_value = recall.compute()
    f1_value = 2*precision_value*recall_value/(precision_value+recall_value)
    print(f"precision: {precision_value}, recall: {recall_value}, f1: {f1_value}")

In [None]:
def print_params(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.data)

def count_params(model):
    return sum(param.numel() for param in model.parameters() if param.requires_grad)

### Load data

In [None]:
from torchvision import datasets
from torchvision.transforms import ToTensor

In [None]:
train_data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor())

test_data = datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor())

dset = "MNIST"

In [None]:
train_data = datasets.FashionMNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor())

test_data = datasets.FashionMNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor())

dset = "FashionMNIST"

In [None]:
from torch.utils.data import DataLoader

batch_size = 4

train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_data, batch_size=batch_size, shuffle=True)

### Fit models

In [None]:
loss_fn = nn.CrossEntropyLoss()

In [None]:
# hyperparams
learning_rate = 1e-3
epochs = 1

### MLP

In [None]:
mlp_model = MLP([512,512], nn.ReLU(), n_classes=10)
mlp_epoch_loss = []

optimizer = optim.Adam(mlp_model.parameters(), lr=learning_rate)

count_params(mlp_model)

In [None]:
mlp_epoch_loss += train(mlp_model, train_dl, loss_fn, optimizer, epochs=epochs)

In [None]:
test(mlp_model, test_dl, loss_fn)

### RBF

In [None]:
from sklearn.cluster import KMeans

kmeans = KMeans(64, init='k-means++', n_init='auto', random_state=0)
kmeans.fit(train_data.data.flatten(1))
clusters = kmeans.cluster_centers_.astype(float)

In [None]:
# k mean mus
# mus = torch.from_numpy(clusters)
# sigs = torch.ones(len(mus))*5e-1

# train pt mus
mus = train_data.data.flatten(1)[:64].float()
sigs = torch.ones(len(mus))*5e-1

In [None]:
rbf_model = RBFNet(mus, sigs, n_classes=10)
rbf_epoch_loss = []

optimizer = optim.SGD(rbf_model.parameters(), lr=learning_rate)

count_params(rbf_model)

In [None]:
rbf_epoch_loss += train(rbf_model, train_dl, loss_fn, optimizer, epochs=epochs)

In [None]:
plt.plot(rbf_epoch_loss)

In [None]:
test(rbf_model, test_dl, loss_fn)

### CNN

In [None]:
cnn_model = CNN(n_classes=10)
cnn_epoch_loss = []

optimizer = optim.SGD(cnn_model.parameters(), lr=learning_rate)

count_params(cnn_model)

In [None]:
cnn_epoch_loss += train(cnn_model, train_dl, loss_fn, optimizer, epochs=epochs)

In [None]:
test(cnn_model, test_dl, loss_fn)

In [None]:
# show learned filters
nrow, ncol = 2,5

fig, axes = plt.subplots(nrows=nrow, ncols=ncol)

for i in range(nrow):
    for j in range(ncol):
        axes[i,j].set_xticks([])
        axes[i,j].set_yticks([])
        axes[i,j].imshow(np.squeeze( cnn_model.conv[0].weight[ncol*i+j,:,:,:].detach().numpy() ))

### Results

In [None]:
import seaborn as sn
import pandas as pd

In [None]:
digit_mnist_classes = np.arange(10)

fashion_mnist_classes = {
    0: 'T-shirt/top',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle boot'}

In [None]:
C = digit_mnist_classes if dset == "MNIST" else fashion_mnist_classes

# resub because we are loading our MNIST training data set
test_dl_2 = DataLoader(dataset=test_data, shuffle=True, batch_size=1)

model = mlp_model

confusion_mat = torch.zeros((10,10))
for sample, label in tqdm(test_dl_2):
    
    label = int(label.numpy())

    prediction = model(sample)
    # take the largest output and return integer of which it was (make a classification decision)
    prediction = int(torch.argmax(prediction).numpy())
    
    confusion_mat[label,prediction] += 1
    
df_cm = pd.DataFrame(np.asarray(confusion_mat),
                     index = [C[i] for i in np.arange(10)],
                     columns = [C[i] for i in np.arange(10)])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True)
plt.show()