# Bonus-Track Assignment 3: Sequential MNIST classification task wih ESN

Solve the sequential MNIST classification problem with an ESN (see details on this task from previous lab assignment files).

In [1]:
import os, torch
from torch import Tensor, cuda
from torch.linalg import pinv

from LAB3_2.Assignment3.TorchEchoStateNetworks import LatentESN_torch

In [2]:
gpu = 'cuda' if cuda.is_available() else 'cpu'
if not os.path.exists('caches'):
    os.makedirs('caches')

In [3]:
from Utils.utils import compute_acc


class sMNISTEsnClassifier(LatentESN_torch):
    def __init__(self, input_size: int, hidden_dim: int, omega: float, spectral_radius: float,
                 leakage_rate:float, tikhonov:float, device:str="cpu"):
        """
        Model based on Echo State network used into "Classification" scenario.

        :param input_size: Input dimension
        :param hidden_dim: hidden dimension
        :param omega: Scaling factor of input matrix and bias
        :param spectral_radius: Desiderata spectral radius
        :param tikhonov: Tikhonov regularization parameter
        """
        # Latent Echo state network (untrained)
        super().__init__(input_size, hidden_dim, omega, spectral_radius, leakage_rate, device)
        # Readout (trained)
        self.Wo = None

        self.tikhonov = tikhonov # Tikhonov regularization
        self.device = device

    def fit(self, x: Tensor, y: Tensor = None) -> Tensor:
        """
        Fit the model using the input, the target.
        First, it calculates the hidden states, which are used to fit the readout;
        (In particular, the last one) finally, we return the loss between the output
        and target with a trained model.
        """

        # Perform the LAST hidden states
        h_last = self.reservoir_last(seq=x) # [batch. hidden_dim]

        # Fit directly the readout
        I = torch.eye(h_last.shape[1], device=self.device)
        self.Wo = pinv(
            h_last.T @ h_last + self.tikhonov * I) @ h_last.T @ y

        y_pred = h_last @ self.Wo
        return self.MSE(y, y_pred)

    @staticmethod
    def MSE(y: Tensor, y_pred: Tensor) -> Tensor:
        """
        Mean square error
        :param y: Target
        :param y_pred: Predicted target
        """
        return torch.pow((y - y_pred), 2).mean()

    def predict(self, x: Tensor, y: Tensor = None):
        """
        Perform the forward pass. If it provided the target,
        it performs also the loss.
        :param x: Input signal
        :param y: Target signal
        """

        # Perform the LAST hidden states
        h_last = self.reservoir_last(seq=x)  # [batch. hidden_dim]

        # Output signal
        y_pred = (h_last @ self.Wo).argmax(-1) # [batch, 10]

        acc = None
        if y is not None:
            acc = compute_acc(y, y_pred)

        return (acc, y_pred) if acc is not None else y_pred

    def validate(self, x: Tensor, y: Tensor) -> tuple[float, Tensor]:
        loss, last_h, _ = self.predict(x, y)
        return loss, last_h


In [4]:
from Utils.utils import Sequential_mnist

tr_dataset = Sequential_mnist("train", root=".\..\..\Sources\MNIST", one_hot_encoding=True)
dev_dataset = Sequential_mnist("dev", root=".\..\..\Sources\MNIST", one_hot_encoding=True)
ts_dataset = Sequential_mnist("test", root=".\..\..\Sources\MNIST", one_hot_encoding=True)

In [5]:
tr_x = tr_dataset.data.transpose_(0,1).to("cuda")
tr_y = tr_dataset.target.float().to("cuda")

In [12]:
trainer = sMNISTEsnClassifier(1, hidden_dim=30,
                              omega=0.5,
                              spectral_radius=1.3,
                              leakage_rate=0.01,
                              tikhonov=0.000001,
                              device="cuda")
trainer.fit(tr_x, tr_y)

tensor(0.0760, device='cuda:0')

In [13]:
acc, _ = trainer.predict(tr_x, tr_y)
print(acc)

0.45408
