In [None]:
!!pip install torch torchvision
!!pip install tensorly
!!pip install tqdm

In [None]:
import sys
print(sys.version)

# Dataset

Obtención y preprocesamiento del dataset MNIST de números escritos a mano.

In [None]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
import tensorly as tl
from tqdm import tqdm
tl.set_backend('pytorch')

# Select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transform dataset to tensor
transform = transforms.Compose([
                                lambda img: transforms.functional.rotate(img, 270),
                                lambda img: transforms.functional.hflip(img),
                                transforms.GaussianBlur(5, sigma=0.9), # Applying gaussian blur to reduce noise
                                transforms.ToTensor()
                              ])


# Download and load the training dataset
train_set = datasets.EMNIST(root='./data', split='digits', train=True, transform=transform, download=True)
print(train_set.classes)

Cargando el Dataset en su **forma matricial**

In [None]:
number_of_classes = len(train_set.classes)
char_matrices = []
for char in range(number_of_classes):
  char_mask = train_set.targets == char
  char_images = train_set.data[char_mask]
  if len(char_images) == 0:
    continue  # Skip empty masks

  # Shuffle the indices of the images
  shuffled_indices = torch.randperm(len(char_images))
  char_images = char_images[shuffled_indices]

  char_matrix = torch.stack([image.flatten() for image in char_images])
  char_matrices.append(char_matrix.T)
  print(char_matrix.T.shape)

Cargando el dataset en su **forma tensorial**

In [None]:
# Initialize char_tensors as a list of tensors filled with zeros
char_tensors = []
# Loop over each character and stack its images into a tensor
for char in range(number_of_classes):
  char_mask = train_set.targets == char
  char_images = train_set.data[char_mask]
  if len(char_images) == 0:
    continue  # Skip empty masks

  # Shuffle the indices of the images
  shuffled_indices = torch.randperm(len(char_images))
  char_images = char_images[shuffled_indices]

  char_tensor = torch.stack([image for image in char_images])
  char_tensor = char_tensor.permute(1, 2, 0)  # Change the order of tensor dimensions ijk -> ikj
  char_tensors.append(char_tensor)

# Método 1

La data del dataset MNIST se interpreta en forma de vectores que se apilan para generar una matriz.

Cálculo de $SVD(E_k)$ para cada $E_k$

In [None]:
Ek_svd = []

for A in char_matrices:
  A = A.float().to(device)
  U_, S_, Vh_ = torch.linalg.svd(A, full_matrices=False)
  Ek_svd.append((A, U_, S_, Vh_))

Ejecutando una prueba con el modelo alcanzado para un dígito aleatorio extraído del conjunto de pruebas

In [None]:
def predict_single_svd1(z, Ek_svd):
  residuals = []
  for (A, U, S, Vh) in Ek_svd:
    V = Vh.T
    res = torch.zeros(Vh.shape[1]).to(device)
    for i in range (0, S.shape[0]):
      if(S[i] == 0): # In this case, the original matrix A was rank-deficient. Which means we can only approximate x in ||Ax - b||
        break
      res += torch.dot(U[:, i], z).item() / S[i].item() * V[:,i]
    res = A @ res
    res = torch.linalg.vector_norm(res - z, ord=2).item()
    residuals.append(res)
  return np.argmin(residuals)

import random
import matplotlib.pyplot as plt

# Test on random digit from dataset
random_index = random.randint(0, len(train_set) - 1)
random_image = train_set.data[random_index]
random_label = train_set.targets[random_index]

y_pred = predict_single_svd1(random_image.T.to(device).flatten().float(), Ek_svd)

plt.imshow(random_image.T, cmap='gray')
plt.show()
print('pred: ', y_pred, 'gt: ', random_label.item())

In [None]:
import time

def compute_accuracy_and_time_svd1(k, subset_size=0.05):
  total_accuracy = 0
  total_time_taken = 0
  for i in range(k):
      start_time = time.time()
      # Test model accuracy on validation set
      validation_set = datasets.EMNIST(root='./data', split='digits', train=False, transform=transform, download=True)

      from torch.utils.data import Subset
      # Create a subset of the validation dataset with subset_size % of the data
      subset_indices = torch.randperm(len(validation_set))[:int(len(validation_set)*subset_size)]
      validation_subset = Subset(validation_set, subset_indices)

      validationLoader = torch.utils.data.DataLoader(validation_subset, batch_size=1, shuffle=True)

      # Test each image and compute accuracy
      total = 0
      correct = 0
      progress_bar = tqdm(enumerate(validationLoader), total=len(validationLoader))
      progress_bar.set_description(f'Processing subset for {i}')
      for idx, (image, label) in progress_bar:
          image = image[0,0,:,:]
          image, label = image.cuda(), label.cuda()
          pred = predict_single_svd1(image.T.flatten().float(), Ek_svd)
          if pred == label:
              correct += 1
          total += 1

      accuracy = correct / total
      total_accuracy += accuracy

      end_time = time.time()
      time_taken = end_time - start_time
      total_time_taken += time_taken

      print(f'Iteration {i+1} - Accuracy: {accuracy:.4f}, Time taken: {time_taken:.2f}s')

  average_accuracy = total_accuracy / k
  average_time_taken = total_time_taken / k

  print('\n\n')
  print('--------------------------')
  print('|   Average Performance   |')
  print('--------------------------')
  print(f'| Accuracy | Time (in s) |')
  print('--------------------------')
  print(f'|  {average_accuracy:.4f}  |   {average_time_taken:.2f}    |')
  print('--------------------------')

  return (average_accuracy, average_time_taken)

# Compute accuracy and execution time
k=10
acc, time_taken = compute_accuracy_and_time_svd1(k, 0.005)
print(f'Average accuracy over {k} iterations: {acc:.4f}')
print(f'Average time taken over {k} iterations: {time_taken:.2f} seconds')

# Método 2

Cálculo de $SVD(E_k)$ para cada $E_k$

In [None]:
Uk_svd = []

for A in char_matrices:
  A = A.float().to(device)
  U_, _, _ = torch.linalg.svd(A, full_matrices=False)
  Uk_svd.append(U_)

Cálculo de $(I - U_kU_k^T)$



In [None]:
def uk_prepare(Uk_svd, k):
  Uk_pre = []
  for i, Uk in enumerate(Uk_svd):
    Uk = Uk[:,:k].to(device)
    trunc = torch.eye(Uk.shape[0], Uk.shape[0]).to(device) - torch.matmul(Uk, Uk.T)
    Uk_pre.append(trunc)
  return Uk_pre

In [None]:
Uk_pre = uk_prepare(Uk_svd, k=8)

Ejecutando una prueba con el modelo alcanzado para un dígito aleatorio extraído del conjunto de pruebas

In [None]:
def predict_single_svd2(z, Uk_pre_):
  residuals = torch.zeros(len(Uk_pre_), device=z.device)
  for i, A in enumerate(Uk_pre_):
      res = torch.matmul(A, z)
      residuals[i] = torch.linalg.vector_norm(res, ord=2)

  return torch.argmin(residuals).item()

import random
import matplotlib.pyplot as plt

# Test on random digit from dataset
random_index = random.randint(0, len(train_set) - 1)
random_image = train_set.data[random_index]
random_label = train_set.targets[random_index]

y_pred = predict_single_svd2(random_image.flatten().float().to(device), Uk_pre)

plt.imshow(random_image.T, cmap='gray')
plt.show()
print('pred: ', y_pred, 'gt: ', random_label.item())

In [None]:
import time
from tabulate import tabulate

validation_set = datasets.EMNIST(root='./data', split='digits', train=False, transform=transform, download=True)

acc_values = []
time_values = []
counts_values = []

min_lambda = 10
max_lambda = 11

for k in range(min_lambda, max_lambda):
  acc, time_taken, counts, correct_counts = compute_accuracy_and_time_for_method(Uk_svd, preproc_method=uk_prepare, method=predict_single_svd2, kappa=k, validation_set=validation_set, compute_confusion_matrix=False)
  acc_values.append(acc)
  time_values.append(time_taken)
  counts_values.append(correct_counts/counts)

# Display accuracy table
header = ["λ"] + [str(i) for i in range(number_of_classes)] + ["Precisión", "Time (in s)"]
table = []
for i in range(min_lambda, max_lambda):
  arr_index = i-min_lambda
  row = [i] + counts_values[arr_index].tolist() + [acc_values[arr_index], time_values[arr_index]]
  table.append(row)

print(tabulate(table, headers=header))

# Plot accuracy vs k for each class
fig, ax = plt.subplots(figsize=(14, 8))
for i in range(number_of_classes):
  ax.plot(range(min_lambda, max_lambda), np.array(counts_values)[0:, i], label=f"Clase {i}")
ax.set_yscale('log', base=2)
ax.set_xlabel("Valor de λ")
ax.set_ylabel("Precisión")
ax.set_title("Precisión por dígito vs Valor de λ")
ax.legend()
plt.show()

# Plot the accuracy vs k graph
plt.plot(range(min_lambda, max_lambda), acc_values)
plt.xlabel('Valor de λ')
plt.ylabel('Precisión')
plt.title('Precisión vs Valor de λ')
plt.show()

# Plot execution time vs. value of k
plt.plot(range(min_lambda, max_lambda), time_values)
plt.xlabel('Valor de λ')
plt.ylabel('Tiempo de ejecución (s)')
plt.title('Tiempo de ejecución vs Valor de λ')
plt.show()

In [None]:
import string

# Define a list of colors and line styles for the lines
colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray',
          'olive', 'cyan', 'magenta', 'lime', 'teal', 'navy', 'maroon', 'gold',
          'indigo', 'peru', 'darkslategray', 'coral', 'royalblue', 'mediumvioletred',
          'lightseagreen', 'darkkhaki', 'darkorchid']
line_styles = ['-', '--', '-.', ':']

# Define a list of class labels as letters from the alphabet (A-Z)
class_labels = list(string.ascii_uppercase)[:number_of_classes-1]

# Plot accuracy vs k for each class
fig, ax = plt.subplots(figsize=(14, 8))
for i in range(number_of_classes):
    color = colors[i % len(colors)]
    line_style = line_styles[i % len(line_styles)]
    ax.plot(range(min_lambda, max_lambda), np.array(counts_values)[:, i],
            color=color, linestyle=line_style, label=f"Clase {class_labels[i-1]}")
ax.set_yscale('log', base=2)
ax.set_xlabel("Valor de λ")
ax.set_ylabel("Precisión")
ax.set_title("Precisión por letra vs Valor de λ")
ax.legend()
plt.show()

In [None]:
compute_accuracy_and_time_for_method(Uk_svd, preproc_method=uk_prepare, method=predict_single_svd2, kappa=10, validation_set=validation_set, compute_confusion_matrix=True)

# Método 3


Se forman 10 tensores $E_k$, donde $k$ representa la etiqueta de los dígitos que se hallan en el tensor.

Cálculo de $HOSVD(\mathcal{A})$

Para $n=1,\ldots,N$


*   Hallar $A_{(n)}$
*   Hallar la SVD de cada $A_{(n)}$: $A_{(n)} = U^{(n)}\Sigma^{(n)}\left[V^{(n)}\right]^T$
* Luego, $\mathcal{S} = \mathcal{A} \times_1 (U^{(1)})^T \times_2 (U^{(2)})^T \times_3 \ldots \times_N (U^{(N)})^T$

In [None]:
def hosvd(T):
  U = []
  S = []
  Vh = []
  T_core = T.clone()
  for i, ni in enumerate(T.shape):
    # Compute the unfolding of the tensor
    T_unfolded = tl.unfold(T, i)
    # Compute the SVD of the unfolded tensor tensor
    U_, S_, Vh_ = torch.linalg.svd(T_unfolded.to(device), full_matrices=False)
    U.append(U_)
    S.append(S_)
    Vh.append(Vh_)
    T_core = tl.tenalg.mode_dot(T_core.to(device), U_.T, i)
  return (T_core, [U, S, Vh])

Calcular $HOSVD(E_k)$ para cada $E_k$

In [None]:
Ek_hosvd = []

for T in char_tensors:
  T = T.float()
  T_core, (U, _, _) = hosvd(T)
  Ek_hosvd.append((T_core, U[:2])) # Appending only U^(1) and U^(2) along with the core tensor

Cálculo de $A_v$ ($\mathcal{A} = \sum_{v=1}^{k}A_v \times_3 u_v$)

In [None]:
def hosvd_trunc(Ek_hosvd, k):
  Ek_trunc = []
  for i, svd in enumerate(Ek_hosvd):
    A = []
    #Obtaining truncated svd expression
    for i in range (k):
      trunc = svd[0][:,:,i].to(device)
      trunc = tl.tenalg.mode_dot(trunc, svd[1][0].to(device), 0)
      trunc = tl.tenalg.mode_dot(trunc, svd[1][1].to(device), 1)
      A.append((trunc, tl.tenalg.inner(trunc, trunc, n_modes=None).item()))
    Ek_trunc.append(A)
  return Ek_trunc

Obtener el conjunto de $A_v$ para cada tensor $E_k$ cuya HOSVD ha sido calculada

In [None]:
Ek_trunc = hosvd_trunc(Ek_hosvd, k=20) # k is a hyper-parameter whose impact can be analyzed

Ejecutando una prueba con el modelo alcanzado para un dígito aleatorio extraído del conjunto de pruebas

In [None]:
def r_hosvd(Z, A):
  z_dot_aj = sum((tl.tenalg.inner(Z, Aj, n_modes=None).item() / Aj_) * Aj for (Aj, Aj_) in A)
  return torch.linalg.matrix_norm(Z - z_dot_aj).item()

def predict_single_hosvd(Z, svd_):
  residuals = torch.zeros(len(svd_), device=Z.device)
  for i, A in enumerate(svd_):
    res = r_hosvd(Z, A)
    residuals[i] = res
  return torch.argmin(residuals).item()

import random
import matplotlib.pyplot as plt

# Test on random digit from dataset
random_index = random.randint(0, len(train_set) - 1)
random_image = train_set.data[random_index]
random_label = train_set.targets[random_index]

y_pred = predict_single_hosvd(random_image.to(device), Ek_trunc)

plt.imshow(random_image.T, cmap='gray')
plt.show()
print('pred: ', y_pred, 'gt: ', random_label.item())

In [None]:
import time
from tabulate import tabulate

validation_set = datasets.EMNIST(root='./data', split='digits', train=False, transform=transform, download=True)

acc_values = []
time_values = []
counts_values = []

min_lambda = 10
max_lambda = 41
step=2

for k in range(min_lambda, max_lambda, step):
  acc, time_taken, counts, correct_counts = compute_accuracy_and_time_for_method(
                                                                                Ek_hosvd,
                                                                                preproc_method=hosvd_trunc,
                                                                                method=predict_single_hosvd,
                                                                                kappa=k,
                                                                                validation_set=validation_set,
                                                                                compute_confusion_matrix=False,
                                                                                do_not_flatten=True)
  acc_values.append(acc)
  time_values.append(time_taken)
  counts_values.append(correct_counts/counts)

# Display accuracy table
header = ["λ"] + [str(i) for i in range(number_of_classes)] + ["Precisión", "Tiempo (s)"]
table = []
for i in range(min_lambda, max_lambda, step):
  arr_index = round((i-min_lambda)/2)
  row = [i] + counts_values[arr_index].tolist() + [acc_values[arr_index], time_values[arr_index]]
  table.append(row)

print(tabulate(table, headers=header))

# Plot the accuracy vs k graph
plt.plot(range(min_lambda, max_lambda, step), acc_values)
plt.xlabel('Valor de λ')
plt.ylabel('Precisión')
plt.title('Precisión vs Valor de λ')
plt.show()

# Plot execution time vs. value of k
plt.plot(range(min_lambda, max_lambda, step), time_values)
plt.xlabel('Valor de λ')
plt.ylabel('Tiempo de ejecución (s)')
plt.title('Tiempo de ejecución vs Valor de λ')
plt.show()

In [None]:
# Plot accuracy vs k for each class
fig, ax = plt.subplots(figsize=(14, 8))
for i in range(number_of_classes):
  ax.plot(range(min_lambda, max_lambda, 2), np.array(counts_values)[:, i], label=f"Clase {i}")
ax.set_yscale('log', base=2)
ax.set_xlabel("Valor de λ")
ax.set_ylabel("Precisión")
ax.set_title("Precisión por dígito vs Valor de λ")
ax.legend()
plt.show()

# RandNLDA


In [None]:
def cwt_matrix(n_rows, n_columns):
  S = torch.zeros(n_rows, n_columns).to(device)
  nz_positions = np.random.randint(0, n_rows, n_columns)
  values = np.random.choice([1, -1], n_columns)
  for i in range(n_columns):
      S[nz_positions[i]][i] = values[i]

  return S

def clarkson_woodruff_transform(input_matrix, sketch_size):
  S = cwt_matrix(sketch_size, input_matrix.shape[0])
  return S @ input_matrix

In [None]:
def uk_prime_prepare(Uk_svd, k_t):
  k, t = k_t
  Uk_pre = []
  for i, Uk in enumerate(Uk_svd):
    Uk = Uk[:,:k].to(device)
    trunc = torch.eye(Uk.shape[0], Uk.shape[0], device=device) - torch.matmul(Uk, Uk.T)
    trunc = clarkson_woodruff_transform(trunc, t)
    Uk_pre.append(trunc)
  return Uk_pre

In [None]:
Uk_svd = []

for A in char_matrices:
  A = A.float().to(device)
  U_, _, _ = torch.linalg.svd(A.to(device), full_matrices=False)
  Uk_svd.append(U_)

In [None]:
Uk_prime_pre = uk_prime_prepare(Uk_svd, (8, int(784 * 0.8)))

In [None]:
def predict_single_randnla(z, Uk_pre_):
    residuals = torch.zeros(len(Uk_pre_), device=z.device)

    for i, A in enumerate(Uk_pre_):
        res = torch.matmul(A, z)  # Matrix-vector multiplication using broadcasting
        residuals[i] = torch.linalg.vector_norm(res, ord=2)

    return torch.argmin(residuals).item() + 1

import random
import matplotlib.pyplot as plt

# Test on random digit from dataset
random_index = random.randint(0, len(train_set) - 1)
random_image = train_set.data[random_index]
random_label = train_set.targets[random_index]

y_pred = predict_single_randnla(random_image.flatten().float().to(device), Uk_prime_pre)

plt.imshow(random_image.T, cmap='gray')
plt.show()
print('pred: ', y_pred, 'gt: ', random_label.item())

In [None]:
import time
from tabulate import tabulate

validation_set = datasets.EMNIST(root='./data', split='letters', train=False, transform=transform, download=True)

min_lambda = 10
max_lambda = 41
step =  3
proportions = np.arange(1.0, 0.4, -0.2)
ts = [int(784 * proportion) for proportion in proportions]

acc_values = []
time_values = []
counts_values = []

for t in ts:
    acc_t_values = []
    time_t_values = []
    counts_t_values = []
    for k in range(min_lambda, max_lambda, step):
        acc, time_taken, counts, correct_counts = compute_accuracy_and_time_for_method(
                                                                                      Uk_svd,
                                                                                      preproc_method=uk_prime_prepare,
                                                                                      method=predict_single_randnla,
                                                                                      kappa=(k, t),
                                                                                      validation_set=validation_set,
                                                                                      compute_confusion_matrix=False)
        acc_t_values.append(acc)
        time_t_values.append(time_taken)
        counts_t_values.append(correct_counts / counts)
    acc_values.append(acc_t_values)
    time_values.append(time_t_values)
    counts_values.append(counts_t_values)

# Display accuracy table
header = ["λ", "t"] + [str(i) for i in range(number_of_classes)] + ["Precisión", "Tiempo (s)"]
table = []
for i in range(len(ts)):
    for j in range(0, max_lambda-min_lambda, step):
        row = [j, ts[i]]
        row += counts_values[i][j//step].tolist()
        row += [acc_values[i][j//step], time_values[i][j//step]]
        table.append(row)

print(tabulate(table, headers=header))

# Plot the accuracy vs k graph
for j in range(len(ts)):
    plt.plot(range(min_lambda, max_lambda, step), acc_values[j], label=f"t = {ts[j]}")
plt.xlabel('Valor de λ')
plt.ylabel('Precisión')
plt.title('Precisión vs Valor de λ')
plt.legend()
plt.show()

# Plot execution time vs. value of k
for j in range(len(ts)):
    plt.plot(range(min_lambda, max_lambda, step), time_values[j], label=f"t = {ts[j]}")
plt.xlabel('Valor de λ')
plt.ylabel('Tiempo de ejecución (s)')
plt.title('Tiempo de ejecución vs Valor de λ')
plt.legend()
plt.show()

# Otros métodos


In [None]:
def compute_accuracy_and_time_for_method(data, preproc_method=None, method=None, kappa=1, validation_set=None, compute_confusion_matrix=False, do_not_flatten=False):
  start_time = time.time()
  start_value = 0 if number_of_classes == 10 else 1

  preproc_data = None
  if(preproc_method is not None):
    preproc_data = preproc_method(data, kappa)

  # Test model accuracy on validation set
  validationLoader = torch.utils.data.DataLoader(validation_set, batch_size=1, shuffle=True)

  # Test each image and compute accuracy
  total = 0
  correct = 0
  counts = np.zeros(number_of_classes, dtype=int)  # To keep count of each type of digit
  correct_counts = np.zeros(number_of_classes , dtype=int)  # To keep count of correctly classified digits

  if(compute_confusion_matrix == True):
    confusion_matrix = np.zeros((number_of_classes, number_of_classes), dtype=int)  # To compute the confusion matrix

  for idx, (image, label) in enumerate(tqdm(validationLoader)):

    image = image[0,0,:,:]
    image = image.to(device).T.float()

    if(preproc_data is not None):
      if(do_not_flatten == False):
        image = image.flatten()
      pred = method(image, preproc_data)
    else:
      pred = method(image, data)

    counts[label] += 1
    if pred == label:
      correct_counts[label] += 1
      correct += 1
    total += 1

    if(compute_confusion_matrix == True):
      confusion_matrix[label, pred] += 1  # Update the confusion

  accuracy = correct / total
  end_time = time.time()
  execution_time = end_time - start_time

  if(compute_confusion_matrix == True):
    # Plot the confusion matrix
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.matshow(confusion_matrix, cmap=plt.cm.Blues)

    for i in range(start_value, confusion_matrix.shape[0]):
        for j in range(start_value, confusion_matrix.shape[1]):
            ax.text(j, i, str(confusion_matrix[i, j]), ha='center', va='center')
    ax.set_xlabel('Valor predicho')
    ax.set_ylabel('Valor real')
    plt.show()

  return (accuracy, execution_time, counts, correct_counts)