# Hypothesis Test - does joint sparsity bring any benefit?
In this notebook we examine weather joint sparsity has any benefit over standard sparse coding techniques for classification. The rough idea is that a joint sparse forward pass will learn a code or representation for the a data point that takes into account other examples of the same class. In effect then we are trying to map all members of a given class onto a particular subspace. To test the benefits we consider two tests:
- classification: we perform classification by taking the set of sparse representations of a class and then perform SVD, so as to identify say the top 5 singular vectors that span a linear space that in some sense represents the given class label in the encoder space. We benchmark against standard IHT and PCA.
- reconstruction / decoding: we observe the reconstruction rate for JIHT vs IHT and PCA. It is expected that IHT should have a better reconstruction error.

For the IHT/ JIHT:
- we train a model on MNIST

In [None]:
import numpy as np
from numpy import linalg as LA
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA 

import random
import os
import yaml

import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler

from skimage import data, color
from skimage.transform import rescale, resize, downscale_local_mean

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Parameters
rep_batch_size = 60000
test_batch_size = 5000

# Dimension of class subspace
l = 10

# Sparsity value for pca
numb_atoms = 500
K=50

# Load MNIST
root = './data'
download = True  # download MNIST dataset or not

# Access MNIST dataset and define processing transforms to proces
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# trans = transforms.Compose([transforms.ToTensor()])
train_set = dsets.MNIST(root=root, train=True, transform=trans, download=download)
test_set = dsets.MNIST(root=root, train=False, transform=trans)

train_loader = torch.utils.data.DataLoader(
                 dataset=train_set,
                 batch_size=rep_batch_size,
                 sampler = None,
                 shuffle=True)


test_loader = torch.utils.data.DataLoader(
                dataset=test_set,
                batch_size=test_batch_size,
                shuffle=True)