In [1]:
import sys
sys.path.append("subspaces")

import torch
import torchvision
import matplotlib.pyplot as plt
import math
from tqdm.notebook import tqdm

%matplotlib inline

In [2]:
# Load Dataset
import scipy.io as sio

dataset_path = "/home/pc-bonito/Datasets/TsukubaHandSize24x24.mat"

mat_contents = sio.loadmat(dataset_path)
dataset = torch.from_numpy(mat_contents["data"])
# [H, W, images_per_view, cameras, classes, persons]

In [3]:
# Split participants in 80%/20% splits. 
train_split = int(dataset.shape[-1] * 0.8)

# Flatten a list of lists
def flatten(l):
    return [item for sublist in l for item in sublist]

In [4]:
# Reshape train data
dataset_train = dataset[:,:,:,:,:,:train_split]
dataset_train = dataset_train.permute(4, 3, 5, 2, 0, 1)
dataset_train = torch.flatten(dataset_train, start_dim=1, end_dim=3)
dataset_train = torch.flatten(dataset_train, start_dim=2, end_dim=3)

# Generate class labels and reshape dataset_train
labels_train = [i for i in range(dataset_train.shape[0])]
labels_train = flatten([[labels_train[i]]*dataset_train.shape[1] for i in range(dataset_train.shape[0])])

dataset_train = torch.reshape(dataset_train, (dataset_train.shape[0]*dataset_train.shape[1], dataset_train.shape[2]))

print(len(labels_train))
print(dataset_train.shape)

1008000
torch.Size([1008000, 576])


In [5]:
from subspaces.vector_space import VectorSpace

dataset_eval = dataset[:, :, :, :, :, train_split:]
dataset_eval = dataset_eval.permute(4, 3, 5, 2, 0, 1)
dataset_eval = torch.flatten(dataset_eval, start_dim=2, end_dim=3)
dataset_eval = torch.flatten(dataset_eval, start_dim=3, end_dim=4)
dataset_eval = dataset_eval.permute(0, 2, 1, 3)

# Generates a vector space for each label. Each vector space contains as base vectors, the images from each view.
eval_subspace_list = []
eval_correct_labels = []

label = 0
for i in dataset_eval:
    for person_image in i:
        vspace = VectorSpace(vector_size=24*24)
        vspace.append(person_image)
        eval_subspace_list.append(vspace)
        eval_correct_labels.append(label)
    label += 1

print(len(eval_subspace_list))
print(len(eval_correct_labels))

36000
36000


In [6]:
from subspaces.vector_msm import VectorMSM
# Train model
min_energy_list = torch.linspace(0.05, 1, 20)
ratio = []

for min_energy in tqdm(min_energy_list):
    model = VectorMSM(vector_size=24*24)
    model.train(dataset_train, labels_train)

    # Eval model
    _, accuracy = model.eval(eval_subspace_list, eval_correct_labels)

    print(accuracy)
    ratio.append(accuracy)

  0%|          | 0/20 [00:00<?, ?it/s]

0.05669444444444444
0.05669444444444444


In [None]:
import matplotlib.pyplot as plt

plt.plot(min_energy_list.numpy(), ratio, 'r--')
plt.title("Correct predictions over minimum energy preserved on SVD")
plt.xlabel("minimum energy preserved on SVD")
plt.ylabel("prediction ratio")
plt.show()