In [25]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the ASR model
class ArtifactSubspaceReconstruction(nn.Module):
    def __init__(self, input_size, subspace_size):
        super(ArtifactSubspaceReconstruction, self).__init__()
        # The projection matrix is the parameter we want to update to optimize our projection
        # We initialize the projection matrix with random values
        self.projection_matrix = nn.Parameter(torch.randn(input_size, subspace_size))
        self.loss_fn = nn.MSELoss()
        self.optimizer = optim.Adam(self.parameters(), lr=0.01)

    def forward(self, x):
        # Project the input onto the artifact subspace
        #print(x.shape)
        x_reconstructed = x @ self.projection_matrix @ self.projection_matrix.T
        return x_reconstructed
    
    def step(self, x):
        self.optimizer.zero_grad()
        # Forward pass through the model
        x_reconstructed = self.forward(x)
        # Compute loss
        loss = self.loss_fn(x_reconstructed, x)
        # Backward pass
        l = loss.backward()
        # Optimization step
        self.optimizer.step()
        # Return loss
        return loss.item()

# Force seed (for consistency, remove later to make sure it works)
torch.manual_seed(42)

# Hyperparameters
input_size = 10  # Dimensionality of the input signal
subspace_size = 5  # Dimensionality of the artifact subspace

# Instantiate the model
model = ArtifactSubspaceReconstruction(input_size, subspace_size)

# Loss function (Mean Squared Error)
LossFunction = nn.MSELoss()

# Optimizer (ADAM)
optimizer = optim.Adam(model.parameters(), lr=0.01)



# Randomly generated data
input_signal = torch.randn(100, input_size)  # 100 samples of input signals

# Training loop (can adjust as desired)
num_epochs = 1000

for epoch in range(num_epochs):
    # Forward pass through the model
    output_signal = model(input_signal)

    # Compute the loss
    loss = LossFunction(output_signal, input_signal)

    # Backward pass and optimization step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print the loss every 100 epochs
    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# Get the projection matrix
projection_matrix = model.projection_matrix.detach().numpy()

# Project the new signal onto the artifact subspace
new_signal = torch.randn(1, input_size)
reconstructed_signal = new_signal @ torch.tensor(projection_matrix) @ torch.tensor(projection_matrix).T

print("Original Signal:")
print(new_signal)
print("Reconstructed Signal:")
print(reconstructed_signal)

Epoch [100/1000], Loss: 4.7130
Epoch [200/1000], Loss: 1.3006
Epoch [300/1000], Loss: 0.7029
Epoch [400/1000], Loss: 0.5279
Epoch [500/1000], Loss: 0.4587
Epoch [600/1000], Loss: 0.4240
Epoch [700/1000], Loss: 0.4028
Epoch [800/1000], Loss: 0.3874
Epoch [900/1000], Loss: 0.3753
Epoch [1000/1000], Loss: 0.3654
Original Signal:
tensor([[-0.7320,  0.2601, -1.4236, -1.7884, -1.2844, -0.1398, -0.9846,  0.9026,
         -0.2177, -0.8070]])
Reconstructed Signal:
tensor([[-9.9205e-01, -8.3550e-02, -1.1132e-03, -6.8121e-01, -1.3139e+00,
         -5.9872e-01, -8.5047e-01,  1.0972e+00,  1.9519e-01, -1.0746e-01]])


In [37]:
def as_tensor(samples, labels):
    """
    Returns the data as t.tensor, with the correct data type

    Parameters:
     - samples: np.ndarray, size = [s, C, T]
     - labels:  np.ndarray, size = [s]

    Returns:
     - samples: t.tensor, size = [s, C, T], dtype = float
     - labels:  t.tensor, size = [s],       dtype = long
    """

    x = torch.tensor(samples).to(dtype=torch.float)
    y = torch.tensor(labels).to(dtype=torch.long) - 1  # labels are from 1 to 4, but torch expects 0 to 3

    return x, y

In [10]:
import numpy as np
from data_importers import get_BCIcomp_data,as_data_loader,as_tensor

In [27]:
DATA_PATH = '3class_train_f0.npz'
data = np.load('3class_train_f0.npz')

X, y = as_tensor(data['X'], data['y'])

# train_samples, train_labels = get_BCIcomp_data(1, training=True,data_path=DATA_PATH)
# test_samples, test_labels = get_BCIcomp_data(1, training=False,data_path=DATA_PATH)
#print(train_samples.size())
#X, y = as_tensor(data['X'], data['y'])
asr = ArtifactSubspaceReconstruction(X.size()[2], 30)

# Training loop (can adjust as desired)
num_epochs = 1000


for epoch in range(num_epochs):
    # Perform a step of the ASR model
    loss = asr.step(X)

    # Print the loss every 100 epochs
    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss:.4f}')

# Get the projection matrix
projection_matrix = asr.projection_matrix.detach().numpy()

print("Original Signal:")
print(new_signal)
print("Reconstructed Signal:")
print(reconstructed_signal)

Epoch [100/1000], Loss: 4609126.5000
Epoch [200/1000], Loss: 2116447.5000
Epoch [300/1000], Loss: 1228864.5000
Epoch [400/1000], Loss: 797817.9375
Epoch [500/1000], Loss: 554866.6250
Epoch [600/1000], Loss: 404610.9688
Epoch [700/1000], Loss: 305518.8750
Epoch [800/1000], Loss: 236997.9688
Epoch [900/1000], Loss: 187856.2812
Epoch [1000/1000], Loss: 151576.8125


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x10 and 480x30)

In [29]:
# Project the new signal onto the artifact subspace
reconstructed_signal = X @ torch.tensor(projection_matrix) @ torch.tensor(projection_matrix).T

print("Original Signal:")
print(X)
print("Reconstructed Signal:")
print(reconstructed_signal)

Original Signal:
tensor([[[ -15.,  -28.,  -12.,  ...,  -35.,    4.,   -5.],
         [ -70.,  -61.,  -35.,  ...,  -39.,   -1.,    7.],
         [ -89.,  -77.,  -56.,  ...,  -24.,    8.,   12.],
         ...,
         [ -99.,  -94.,  -84.,  ...,   -3.,   10.,   10.],
         [-115.,  -96.,  -81.,  ...,   33.,   47.,   51.],
         [ -64.,  -52.,  -44.,  ...,  -17.,  -13.,  -15.]],

        [[ -13.,   18.,   58.,  ...,  -91.,  -92.,  -62.],
         [   5.,   32.,   55.,  ...,  -69.,  -62.,  -29.],
         [  20.,   37.,   54.,  ..., -101.,  -89.,  -59.],
         ...,
         [ -11.,  -12.,   -2.,  ...,  -31.,   -1.,   14.],
         [  39.,   39.,   47.,  ...,  -27.,   -2.,   12.],
         [ -37.,  -38.,  -28.,  ...,  -18.,    2.,   10.]],

        [[ -70.,  -73.,  -91.,  ...,  -57.,  -51.,  -40.],
         [ -25.,  -29.,  -47.,  ..., -101., -117., -117.],
         [ -55.,  -59.,  -74.,  ...,  -54.,  -72.,  -72.],
         ...,
         [  21.,   33.,   31.,  ...,  -46.,  -58.,  