In [6]:
import sys
sys.path.append("subspaces")

import torch
import torchvision
import matplotlib.pyplot as plt
import math
from tqdm.notebook import tqdm

%matplotlib inline

# Download MNIST dataset
emnist_train = torchvision.datasets.EMNIST('./emnist', download=True, train=True, split='digits')
emnist_eval = torchvision.datasets.EMNIST('./emnist', download=True, train=False, split='digits')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
# Instead of transforming data into [batch, image_data]
# Tread it as [batch, H, W]
# Create list of correct_labels for train and eval sets

vector_size = 28

train_data = torch.empty(len(emnist_train), vector_size, vector_size)
train_correct_labels = []

eval_data = torch.empty(len(emnist_eval), vector_size, vector_size)
eval_correct_labels = []

for i in range(len(emnist_train)):
    train_data[i] = torch.squeeze(torchvision.transforms.functional.to_tensor(emnist_train[i][0]))
    train_correct_labels.append(emnist_train[i][1])

for i in range(len(emnist_eval)):
    eval_data[i] = torch.squeeze(torchvision.transforms.functional.to_tensor(emnist_eval[i][0]))
    eval_correct_labels.append(emnist_eval[i][1])

In [8]:
# Now, apply SVD on each individual image.
componens_per_image = 15 # Completely arbitrary, XX bu-buh XX

_, _, Vh_train = torch.linalg.svd(train_data)
Vh_train = Vh_train[:, 0:componens_per_image, :]
Vh_train = torch.reshape(Vh_train, [Vh_train.shape[0]*Vh_train.shape[1], Vh_train.shape[2]]).to(device)

train_correct_labels = [[train_correct_labels[i]]*componens_per_image for i in range(len(train_correct_labels))]

def flatten(l):
    return [item for sublist in l for item in sublist]

train_correct_labels = flatten(train_correct_labels)

In [9]:
from subspaces.vector_space import VectorSpace
_, _, Vh_eval = torch.linalg.svd(eval_data)
Vh_eval = Vh_eval.to(device)

vspace_list = []

for i, image in tqdm(enumerate(Vh_eval)):
    vspace_list.append(VectorSpace(vector_size=vector_size))
    vspace_list[i].append(image)
    vspace_list[i].label=eval_correct_labels[i]

0it [00:00, ?it/s]

In [10]:
from subspaces.vector_msm import VectorMSM
from subspaces.vector_space import VectorSpace
def flatten(l):
    return [item for sublist in l for item in sublist]

min_energy_list = torch.linspace(0.05, 1, 20)
ratio = []

for min_energy in tqdm(min_energy_list):
    # Train with min_energy
    model = VectorMSM(vector_size=vector_size)
    model.train(Vh_train, train_correct_labels, min_energy=min_energy)

    # Predict
    predition_list, prediction_ratio = model.eval(vspace_list, eval_correct_labels)
    print(float(min_energy), prediction_ratio)
    ratio.append(prediction_ratio)

  0%|          | 0/20 [00:00<?, ?it/s]

0.05000000074505806 0.102225
0.10000000149011612 0.10505
0.15000000596046448 0.1018
0.20000000298023224 0.1016
0.25 0.1011
0.30000001192092896 0.1001
0.3499999940395355 0.099975
0.4000000059604645 0.101075


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

plt.plot(min_energy_list.numpy(), ratio, 'r--')
plt.title("Correct predictions over minimum energy preserved on SVD")
plt.xlabel("minimum energy preserved on SVD")
plt.ylabel("prediction ratio")
plt.show()

In [None]:
# This code doesn't make sense!