In [31]:
import numpy as np
import scipy.io, scipy.interpolate
import pathlib
import matplotlib.pyplot as plt
import torch
import pytorch_lightning as pl
from pytorch_model_summary import summary
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import scipy.ndimage
import plotly.tools as tls
from pytorch_lightning.callbacks import Callback
from pytorch_lightning import Trainer
import wandb
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

L_FREQ, H_FREQ = 40, 300 # Lower and upper filtration bounds
CHANNELS_NUM = 62        # Number of channels in ECoG data
WAVELET_NUM = 40         # Number of wavelets in the indicated frequency range, with which the convolution is performed
DOWNSAMPLE_FS = 100      # Desired sampling rate
time_delay_secs = 0.2    # Time delay hyperparameter


current_fs = DOWNSAMPLE_FS

TYPE = "train"  # Script modes: "train" and "test"
model_to_test = f"{pathlib.Path().resolve()}/checkpoints/subj3_best-corr_mean_val=0.7361.ckpt"


In [32]:
class EcogFingerflexDataset(Dataset):
    """
    The class that defines the sampling unit
    """
    def __init__(self, path_to_ecog_data: str,
                 path_to_fingerflex_data: str, sample_len: int, train = False):
        """
        paths should point to .npy files
        """
        self.ecog_data, self.fingerflex_data = np.load(path_to_ecog_data).astype('float32'),\
                                            np.load(path_to_fingerflex_data).astype('float32')
        
        self.duration = self.ecog_data.shape[2]
        self.sample_len = sample_len                                 # sample size
        self.stride = 1                                              # stride between samples
        self.ds_len = (self.duration-self.sample_len) // self.stride
        self.train = train
        
        print("Duration: ", self.duration, "Ds_len:", self.ds_len)
    def __len__(self):
        return self.ds_len
    
    def __getitem__(self, index):

        sample_start = index*self.stride
        sample_end = sample_start+self.sample_len

        ecog_sample = self.ecog_data[...,sample_start:sample_end] # x
        
        fingerflex_sample = self.fingerflex_data[...,sample_start:sample_end] # y
        
        return ecog_sample, fingerflex_sample


class EcogFingerflexDatamodule(pl.LightningDataModule):
    """
    A class that encapsulates different datasets (for training and validation) and their dataloaders
    """
    def __init__(self, sample_len: int, data_dir = "./data",
                    batch_size=128, add_name=""):
        super().__init__()
        self.data_dir = data_dir     # Path to data folder
        self.sample_len = sample_len # Sample size
        self.batch_size = batch_size # Dataloader batch size
        self.add_name = add_name     #  dataset name
        
    def setup(self, stage = None):
        if stage is None or stage == "fit":
            self.train = EcogFingerflexDataset(f"{self.data_dir}/train/ecog_data{self.add_name}.npy",
                                              f"{self.data_dir}/train/fingerflex_data{self.add_name}.npy",
                                              self.sample_len, train = True)
            
            self.val = EcogFingerflexDataset(f"{self.data_dir}/val/ecog_data{self.add_name}.npy",
                                              f"{self.data_dir}/val/fingerflex_data{self.add_name}.npy",
                                              self.sample_len)
        
        if stage is None or stage == "test":
            self.test = EcogFingerflexDataset(f"{self.data_dir}/test/ecog_data{self.add_name}.npy",
                                              f"{self.data_dir}/test/fingerflex_data{self.add_name}.npy",
                                              self.sample_len)
    
    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, num_workers=4, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)
        

In [33]:
def correlation_metric(x, y):
    """
     Cosine distance calculation metric
    """
    cos_metric = nn.CosineSimilarity(dim=-1, eps=1e-08)

    cos_sim = torch.mean(cos_metric(x, y))

    return cos_sim

def corr_metric(x, y):
    """
    Pearson correlation calculation metric between univariate vectors
    """
    assert x.shape == y.shape  
    r = np.corrcoef(x, y)[0, 1]
    return r

class BaseEcogFingerflexModel(pl.LightningModule):
    """
        The class which encapsulates the model, its optimizer and the training process at different stages, including logging
    """
    def __init__(self, model):
        super().__init__()
        self.model = model # Pytorch model
        self.lr = 8.42e-5
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        
        y_hat = self.model(x)
        
        loss = F.mse_loss(y_hat, y)
        corr = correlation_metric(y_hat, y)

        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log(f"cosine_dst_train", corr, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        return 0.5*loss + 0.5*(1. - corr) # возврат значения функции потерь
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.mse_loss(y_hat, y)
        
        
        corr = correlation_metric(y_hat, y)

        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("cosine_dst_val", corr, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
        return y_hat # Return the result for the validation callback
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        
        y_hat = self.model(x)
        
        loss = F.mse_loss(y_hat, y)
        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-6) # set optimizer, lr and L2 regularization coeff
        return optimizer


In [34]:
"""
Here are the blocks that make up the final model + the model itself
"""

class ConvBlock(nn.Module):
    """
    Convolution block:
        - 1d conv
        - layer norm by embedding axis
        - activation
        - dropout
        - Max pooling
    """
    def __init__(self, in_channels, out_channels, kernel_size, 
                 stride=1, dilation=1, p_conv_drop=0.1):
        super(ConvBlock, self).__init__()
        
        # use it instead stride. 
        
        self.conv1d = nn.Conv1d(in_channels, out_channels, 
                                kernel_size=kernel_size, 
                                bias=False, 
                                padding='same')
        
        
        self.norm = nn.LayerNorm(out_channels)
        self.activation = nn.GELU()
        self.drop = nn.Dropout(p=p_conv_drop)

        self.downsample = nn.MaxPool1d(kernel_size=stride, stride=stride)

        self.stride = stride
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        
    def forward(self, x):
        
        x = self.conv1d(x)
        
        # norm by last axis.
        x = torch.transpose(x, -2, -1) 
        x = self.norm(x)
        x = torch.transpose(x, -2, -1)
        
        x = self.activation(x)
        
        x = self.drop(x)
        
        x = self.downsample(x)

        return x
    
    
class UpConvBlock(nn.Module):
    """
    Decoder convolution block
    """
    def __init__(self, scale, **args):
        super(UpConvBlock, self).__init__()
        self.conv_block = ConvBlock(**args)
        self.upsample = nn.Upsample(scale_factor=scale, mode='linear', align_corners=False)

            
    def forward(self, x):
        
        x = self.conv_block(x)
        x = self.upsample(x)
        return x    
    
    
class AutoEncoder1D(nn.Module):
    """
    This is the final Encoder-Decoder model with skip connections
    """
    def __init__(self,
                 n_electrodes=30,   # Number of channels
                 n_freqs = 16,      # Number of wavelets
                 n_channels_out=21, # Number of fingers
                 channels = [8, 16, 32, 32],  # Number of features on each encoder layer
                 kernel_sizes=[3, 3, 3],
                 strides=[4, 4, 4],
                 dilation=[1, 1, 1]
                 ):
        
        super(AutoEncoder1D, self).__init__()
        

        self.n_electrodes = n_electrodes
        self.n_freqs = n_freqs
        self.n_inp_features = n_freqs*n_electrodes
        self.n_channels_out = n_channels_out
        
        self.model_depth = len(channels)-1
        self.spatial_reduce = ConvBlock(self.n_inp_features, channels[0], kernel_size=3) # Dimensionality reduction
        
        # Encoder part
        self.downsample_blocks = nn.ModuleList([ConvBlock(channels[i], 
                                                        channels[i+1], 
                                                        kernel_sizes[i],
                                                        stride=strides[i], 
                                                        dilation=dilation[i]) for i in range(self.model_depth)])
        

        channels = [ch for ch in channels[:-1]] + channels[-1:] # channels

        # Decoder part
        self.upsample_blocks = nn.ModuleList([UpConvBlock(scale=strides[i],
                                                          in_channels=channels[i+1] if i == self.model_depth-1 else channels[i+1]*2 ,
                                                          out_channels=channels[i],
                                                          kernel_size=kernel_sizes[i]) for i in range(self.model_depth-1, -1, -1)])
        
        
        self.conv1x1_one = nn.Conv1d(channels[0]*2, self.n_channels_out, kernel_size=1, padding='same') # final 1x1 conv
      
    def forward(self, x):

        batch, elec, n_freq, time = x.shape
        x = x.reshape(batch, -1, time)  # flatten the input
        x = self.spatial_reduce(x)
        
        skip_connection = []
        
        for i in range(self.model_depth):
            skip_connection.append(x)
            x = self.downsample_blocks[i](x)

        
        for i in range(self.model_depth):
            x = self.upsample_blocks[i](x)
            x = torch.cat((x, skip_connection[-1 - i]), # skip connections
                         dim=1)
        
        x = self.conv1x1_one(x)

        return x


In [35]:
def corr_metric(x, y):
    """ Pearson’s ρ between two 1D arrays """
    xm = x - x.mean()
    ym = y - y.mean()
    return np.sum(xm*ym) / (np.sqrt(np.sum(xm**2)) * np.sqrt(np.sum(ym**2)) + 1e-8)

class ValidationCallback(Callback):
    """
    At end of each val epoch:
     - run the entire val set through model
     - gaussian‐smooth the predictions
     - compute per‐finger corr & mean
     - log mean corr & wandb image grid
    """
    def __init__(self, val_x: np.ndarray, val_y: np.ndarray, fg_num: int):
        super().__init__()
        # expect val_x, val_y shape = (T, fg_num)
        self.val_x = val_x
        self.val_y = val_y
        self.fg_num = fg_num

    def on_validation_epoch_end(self, trainer, pl_module):
        device = pl_module.device  # cpu, mps, or cuda
 
        # build a single‐batch input: shape (1, fg_num, T)
        x = torch.from_numpy(self.val_x.T).float().to(device).unsqueeze(0)
        with torch.no_grad():
            y_hat = pl_module.model(x)      # (1, fg_num, T)
        # back to numpy: (fg_num, T) → transpose to (T, fg_num)
        y_hat = y_hat.squeeze(0).cpu().numpy().T
        
        # downsample & smooth
        step = int((DOWNSAMPLE_FS / 100))
        y_hat = y_hat[::step]
        y_true = self.val_y[::step]
        y_hat = scipy.ndimage.gaussian_filter1d(y_hat, sigma=6, axis=0)

        # plot & compute corrs
        corrs = []
        rows = int(np.ceil(self.fg_num/2))
        cols = 2
        fig, axes = plt.subplots(rows, cols, figsize=(cols*4, rows*2.5), sharex=True, sharey=True)
        axes = axes.flatten()
        for i in range(self.fg_num):
            c = corr_metric(y_hat[:,i], y_true[:,i])
            corrs.append(c)
            ax = axes[i]
            ax.plot(y_hat[:,i], label='pred')
            ax.plot(y_true[:,i], label='true')
            ax.set_title(f"Finger {i} ρ={c:.2f}")
            ax.legend(fontsize='small')
        plt.tight_layout()

        mean_corr = float(np.mean(corrs))
        pl_module.log("corr_mean_val", mean_corr, prog_bar=True)
        wandb.log({
            "val_corr_mean": mean_corr,
            "val_plots": wandb.Image(fig)
        })
        plt.close(fig)


class TestCallback:
    """
    After training, call `TestCallback(...).test(pl_module)` to:
     - run val set,
     - save prediction & truth npy,
     - produce an HTML with plotly.
    """
    def __init__(self, val_x: np.ndarray, val_y: np.ndarray, fg_num: int):
        self.val_x = val_x  # shape (T, fg_num)
        self.val_y = val_y
        self.fg_num = fg_num

    def test(self, pl_module):
        device = pl_module.device

        x = torch.from_numpy(self.val_x.T).float().to(device).unsqueeze(0)
        with torch.no_grad():
            y_hat = pl_module.model(x).squeeze(0).cpu().numpy().T  # (T, fg_num)

        step = int((DOWNSAMPLE_FS / 100))
        y_hat = y_hat[::step]
        y_true = self.val_y[::step]
        y_hat = scipy.ndimage.gaussian_filter1d(y_hat, sigma=1, axis=0)

        # save arrays
        out = pathlib.Path("res_npy")
        out.mkdir(exist_ok=True)
        np.save(out/"prediction.npy", y_hat)
        np.save(out/"true.npy", y_true)

        # static MPL + interactive HTML
        rows = int(np.ceil(self.fg_num/2))
        fig, axes = plt.subplots(rows, 2, figsize=(8, rows*2.5), sharex=True, sharey=True)
        axes = axes.flatten()
        corrs = []
        for i in range(self.fg_num):
            c = corr_metric(y_hat[:,i], y_true[:,i])
            corrs.append(c)
            ax = axes[i]
            ax.plot(y_hat[:,i], label='pred')
            ax.plot(y_true[:,i], label='true')
            ax.set_title(f"Finger {i} ρ={c:.2f}")
            ax.legend(fontsize='small')
        plt.tight_layout()

        # to HTML
        plotly_fig = tls.mpl_to_plotly(fig)
        html_out = pathlib.Path("res_html")
        html_out.mkdir(exist_ok=True)
        plotly_fig.write_html(html_out/"validation.html")

        print("Test mean corr:", np.mean(corrs))
        plt.close(fig)

In [36]:


#256 -> 256/4 * (32) -> 64/4 * (64) -> 16 * (64)

SAMPLE_LEN = 256 # Window size
finger_num = 5   # Number of fingers

hp_autoencoder = dict(channels = [32, 32, 64, 64, 128, 128], 
                        kernel_sizes=[7, 7, 5, 5, 5],
                        strides=[2, 2, 2, 2, 2],
                        dilation=[1, 1, 1, 1, 1],
                        n_electrodes = CHANNELS_NUM,
                        n_freqs = WAVELET_NUM,
                        n_channels_out = finger_num) # A set of features for the model

model = AutoEncoder1D(**hp_autoencoder)


lighning_wrapper = BaseEcogFingerflexModel(model) # Wrapping in pytorch-lightning class



dm = EcogFingerflexDatamodule(sample_len=SAMPLE_LEN, add_name="")
summary(model, torch.zeros(4, CHANNELS_NUM,WAVELET_NUM, SAMPLE_LEN),
       show_input=False) # Model structure output





In [37]:
SAVE_PATH = f"./data"

def load_data(ecog_data_path, fingerflex_data_path):
    ecog_data = np.load(ecog_data_path)
    fingerflex_data = np.load(fingerflex_data_path)
    return ecog_data, fingerflex_data

ecog_data_val, fingerflex_data_val = load_data(f"{SAVE_PATH}/train/ecog_data.npy", f"{SAVE_PATH}/train/fingerflex_data.npy")

In [38]:
### TO TRAIN ###
if TYPE == "train":
    wandb.init(project="BCI_comp") # Logger initialization
    wandb_logger = WandbLogger()

    checkpoint_callback = ModelCheckpoint( # Initializing a callback to save model checkpoints
        save_top_k=2,
        monitor="corr_mean_val",
        mode="max",
        dirpath="checkpoints",
        filename="model-{epoch:02d}-{corr_mean_val}",
    )

    # The Trainer class encapsulates the interaction of model, data and logger
    trainer = Trainer(max_epochs=20, logger=wandb_logger, callbacks=[ValidationCallback(ecog_data_val,
                                                                                                  fingerflex_data_val,
                                                                                                  finger_num),
                                                                               checkpoint_callback])
    trainer.fit(lighning_wrapper, dm) # Model training process
    wandb.finish()                    # Signal to end the logging

elif TYPE == "test":
    trained_model = BaseEcogFingerflexModel.load_from_checkpoint(
        checkpoint_path=model_to_test,
        model=AutoEncoder1D(**hp_autoencoder))
    test_callback = TestCallback(ecog_data_val, fingerflex_data_val, finger_num)
    test_callback.test(trained_model) # Testing

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.


  | Name  | Type          | Params | Mode 
------------------------------------------------
0 | model | AutoEncoder1D | 652 K  | train
------------------------------------------------
652 K     Trainable params
0         Non-trainable params
652 K     Total params
2.610     Total estimated model params size (MB)
80        Modules in train mode
0         Modules in eval mode


Duration:  23980 Ds_len: 23724
Duration:  5980 Ds_len: 5724


Sanity Checking: |          | 0/? [00:00<?, ?it/s]


The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.



RuntimeError: Given groups=1, weight of size [32, 2480, 3], expected input[1, 959200, 62] to have 2480 channels, but got 959200 channels instead