# Hypothesis Test - does joint sparsity bring any benefit?
In this notebook we examine weather joint sparsity has any benefit over standard sparse coding techniques for classification. The rough idea is that a joint sparse forward pass will learn a code or representation for the a data point that takes into account other examples of the same class. In effect then we are trying to map all members of a given class onto a particular subspace. To test the benefits we consider two tests:
- classification: we perform classification by taking the set of sparse representations of a class and then perform SVD, so as to identify say the top 5 singular vectors that span a linear space that in some sense represents the given class label in the encoder space. We benchmark against standard IHT and PCA.
- reconstruction / decoding: we observe the reconstruction rate for JIHT vs IHT and PCA. It is expected that IHT should have a better reconstruction error.


## Methodology
### For classification:
For the IHT/ JIHT:
- we train a model on MNIST with an IHT/ JIHT forward pass
- we then run the entire MNIST traing set through the modelto find all the training data point encodings
- group the encodings by class, and carry out SVD to find the j singular vectors. These j linear vectors define a linear manifold or subspace which we 'associate' with the class
- then run the entire test set through the model to find all the test data point encodings
- classify each test data point by assigning into the class whose linear manifold or subspace is closest to the data points encoding by projecting

PCA benchmarking approach:
- Find the j principal components of the training data set for each class
- Project the test data onto each of the sets of j principal components of each class
- Assign a test data point to the class for which it has the largest projection (shortest distance) 

Then compare them all by looking at the percentage of data points that they correctly categorised.

### For reconstruction:
For the IHT/ JIHT:
- Simple: forward pass and then reverse pass, calculate l2 distance between decoded and original data point. Calculate the percentage error over the entire test set and training set. Also plot to inpect visually

PCA benchmarking approach:
- Calculate the m principal components of the training data set (these act as our atoms)
- For each data point encode it as the sum of the K principal components that the data point is closest (calculate the inner product between the data point and each principal component, select largest K
- reconstruct data point or image from just these K principal components

Compare the total reconstruction error betweenIHT, JIHT and PCA for both the test and training data sets.

## Import MNIST Data
First script simply imports the MNIST training and test data

In [None]:
import numpy as np
from numpy import linalg as LA
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA 

import random
import os
import yaml

import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler

from skimage import data, color
from skimage.transform import rescale, resize, downscale_local_mean

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Parameters
rep_batch_size = 60000
test_batch_size = 5000

# Dimension of class subspace
l = 10

# Sparsity value for pca
numb_atoms = 500
K=50

# Load MNIST
root = './data'
download = True  # download MNIST dataset or not

# Access MNIST dataset and define processing transforms to proces
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# trans = transforms.Compose([transforms.ToTensor()])
train_set = dsets.MNIST(root=root, train=True, transform=trans, download=download)
test_set = dsets.MNIST(root=root, train=False, transform=trans)

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=rep_batch_size,
                 sampler = None,
                 shuffle=True)


test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=test_batch_size,
                shuffle=True)

### Run data on through IHT Model:
Firstly just load the IHT model and check few examples to inspect weather it is working as it should

In [None]:
# Load mode
importlib.reload(aux)

model = aux.load_model(model_filename)

# model.mask = torch.ones(N_TEST_IMG, m)

# Check that reconstructions etc. are working as they
fig = plt.figure(figsize=(5, 2))
# original data (first row) for viewing
view_data = Variable(train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.)
view_data = view_data.to(device)
decoded, encoded, nnz = model2(view_data,int(25))

for i in range(N_TEST_IMG):
    plt.subplot(2,N_TEST_IMG,i+1)
    plt.imshow(np.reshape(view_data.cpu().data.numpy()[i], (28, 28)), cmap='gray')

for i in range(N_TEST_IMG):
    plt.subplot(2,N_TEST_IMG,i+6)
    plt.imshow(np.reshape(decoded.cpu().data.numpy()[i], (28, 28)), cmap='gray')
    
plt.show()

Secondly we just runthe entire training data set through the model