# **sEMG-based Gesture Recognition with Spiking Neural Networks**

This project leverages surface Electromyography (sEMG) signals for precise hand gesture recognition using Spiking Neural Networks (SNNs).

The training and validation are based on the NinaPro dataset, which encompasses sEMG and kinematic data from different subjects performing 52 distinct hand movements. The Delta modulation approach is applied to both the raw signal and its first and second derivative.

For the neural network management, SLAYER libraries are employed. The chosen neuron model is the Leaky Integrate and Fire (LIF) or the CUrrent-Based Alpha (CUBA).

## LAVA framework in Colab
LAVA (Loihi Advanced Virtualized Architecture) is a software framework developed by Intel, tailored for the development and simulation of Spiking Neural Networks (SNNs) and neuromorphic computing applications. It is intricately linked with Intel's Loihi, a neuromorphic computing chip designed to mimic biological brain functioning, offering energy-efficient processing and real-time learning capabilities.

In [None]:
try:
  # Import google.colab to check if running in Colab
  import google.colab
  from google.colab import drive
  drive.mount('/content/gdrive/')
  !pip install /content/gdrive/MyDrive/Library/lava-dl-main.zip -q
  import os
  os.kill(os.getpid(), 9)

except:
  pass

## Dependencies and functions
In this section of the notebook, we focus on setting up the foundational components necessary for our project. This includes importing libraries and defining essential functions that will be used throughout the notebook.

### Import Dependencies

In [None]:
# Standard library imports
import os
import shutil
import time
import re
import json
import hashlib
from tqdm import tqdm
import pandas as pd
# Set Pandas display options
pd.set_option('display.max_columns', None)        # Show all columns
pd.set_option('display.max_rows', None)           # Show all rows
pd.set_option('display.width', None)              # Show full width of the DataFrame
pd.set_option('display.max_colwidth', None)       # Show full content of each cell

# Basic scientific computing libraries
import numpy as np
import matplotlib.pyplot as plt

# Advanced scientific computing and statistics
from scipy.io import loadmat
from scipy import stats
from scipy.stats import mode, skew, kurtosis

# Machine learning and data processing
from sklearn.utils import shuffle

# PyTorch for deep learning
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision

# Other libraries
import h5py

try:
  # Import google.colab to check if running in Colab
  import google.colab
  from google.colab import runtime
  colab_env = True  # Set flag to True for Colab environment
except:
  # Set flag to False if not in Colab
  colab_env = False

try:
  from google.colab import drive
  drive.mount('/content/gdrive/')
  gdrive_path = '/content/gdrive/MyDrive/'  # Define Google Drive path in Colab
except:
  pass

# LAVA framework
try:
  import lava.lib.dl.slayer as slayer
except:
  print('LAVA not included')


### Function definitions

In [None]:
def get_adam_optimizer(learning_rate):
  """
  Creates and returns an Adam optimizer for a globally defined model with a specified learning rate.

  Args:
  learning_rate (float): The learning rate to be used for the optimizer.

  Globals:
  model: The globally defined model for which the optimizer is to be created.

  Returns:
  An instance of torch.optim.Adam configured with the specified learning rate.
  """
  return torch.optim.Adam(model.parameters(), lr=learning_rate)



def hash_dictionary(dictionary):
  """
  Generates a truncated MD5 hash for a given dictionary.

  This function serializes the dictionary into a JSON string, ensuring the keys are sorted,
  and then computes an MD5 hash of this string. The hash is then truncated to 8 characters.

  Parameters:
  dictionary: The dictionary to hash. It should be capable of being serialized into JSON.

  Returns:
  str: An 8-character string representing the truncated MD5 hash of the dictionary.

  Example:
  >>> hash_dictionary({'key1': 'value1', 'key2': 'value2'})
  'e201065d'
  """

  dict_string = json.dumps(dictionary, sort_keys=True)
  return hashlib.md5(dict_string.encode()).hexdigest()



def find_session_byhash(path, hash):
  """
  Finds the last folder of a specific session based on a provided hash.

  This function searches through subdirectories in the specified path, looking for a file named 'session_hash.txt' in each directory. It compares the content of this file, which should be the session's hash, with the provided hash parameter. If a match is found, it returns only the name of the last folder in the path of the corresponding directory.

  Parameters:
  path (str): The base path to start searching for session directories.
  hash (str): The hash of the session to be found.

  Returns:
  str or None: The name of the last folder in the path of the directory containing the matching session hash, if found; otherwise, None.
  """

  # Iterate through all 'session_parameters.json' files in the specified path
  for root, dirs, files in os.walk(path):
    for file in files:
      if file == 'session_hash.txt':
        file_path = os.path.join(root, file)

        with open(file_path, 'r') as file:
          session_hash = file.read().strip()

        if session_hash == hash:
          return os.path.basename(root)

  return None



def save_json(data, file_path):
  """
  Saves a given dictionary as a JSON file at the specified file path.

  Args:
  - data (dict): The dictionary to be saved as JSON.
  - file_path (str): The file path where the JSON file will be saved.

  This function takes a dictionary and a file path as inputs and writes the dictionary as a JSON
  file at the given path. It handles any exceptions related to file operations and prints
  a relevant message in case of an error.
  """

  with open(file_path, 'w') as file:
    json.dump(data, file, indent=4)



def load_json(file_path):
  """
  Loads data from a JSON file into a Python dictionary.

  Args:
  - file_path (str): The file path of the JSON file to be loaded.

  Returns:
  - dict: The data loaded from the JSON file.

  This function takes a file path as input and reads the JSON file from the given path into a Python dictionary.
  It handles any exceptions related to file operations and prints a relevant message in case of an error.
  """

  with open(file_path, 'r') as file:
    data = json.load(file)
  return data



def normalize_json_files(path, dict_name):
  """
  Normalizza i file JSON in una directory specificata e compila i loro contenuti in un DataFrame Pandas.

  Questa funzione cerca i file JSON con un nome specificato nel percorso di directory fornito. Normalizza ciascun file JSON in un DataFrame Pandas, estrae il nome dell'hash della sessione dalla cartella in cui ogni file è stato trovato, e compila tutti i dati in un unico DataFrame.

  Parametri:
  path (str): Il percorso del file dove si trovano i file JSON specificati.
  dict_name (str): Il nome dei file JSON da cercare (es., 'session_parameters.json').

  Restituisce:
  pd.DataFrame: Un DataFrame contenente i dati normalizzati da ciascun file JSON, con la colonna 'session_hash' come prima colonna, indicante l'hash della sessione della cartella dove ogni file JSON è stato trovato.
  """

  # Create an empty DataFrame
  aggregated_df = pd.DataFrame()

  # Iterate through all 'session_parameters.json' files in the specified path
  for root, dirs, files in os.walk(path):
    for file in files:
      if file == dict_name:
        file_path = os.path.join(root, file)

        # Read and normalize the JSON file
        with open(file_path, 'r') as f:
          data = json.load(f)
        normalized_data = pd.json_normalize(data)

        with open(os.path.join(root, 'session_hash.txt'), 'r') as file:
          session_hash = file.read().strip()

        # Insert 'FolderName' as the first column
        normalized_data.insert(0, 'session_hash', session_hash)

        # Add the normalized data to the aggregated DataFrame
        aggregated_df = pd.concat([aggregated_df, normalized_data], ignore_index=True, sort=False)

  return aggregated_df



def print_unique_successive_elements(data_array):
  """
  Prints unique successive elements from a given array.

  Args:
  - data_array: Array from which unique successive elements are extracted.

  This function iterates through the array, identifying and collecting elements that differ
  from their immediate predecessors. It then prints the list of these unique elements.
  """

  prev = None
  unique_successive_elements = []

  for elem in data_array:
    elem = int(np.squeeze(elem))  # Convert to standard Python int
    if elem != prev:
      unique_successive_elements.append(elem)
      prev = elem

  print("Unique successive elements:", unique_successive_elements)




def load_single_file(file_path):
  """
  Loads data from a single .mat file and returns EMG, stimulus, and repetition data.

  Args:
  - file_path: Path to the .mat file to be loaded.

  The function extracts 'emg', 'restimulus', and 'rerepetition' arrays from the .mat file.
  If a debug flag is set, it also prints additional details about the file contents.

  Returns:
  - Tuple of Numpy arrays: EMG data, stimulus data, repetition data.
  """

  mat_data = loadmat(file_path)

  if False:
    print(mat_data.keys())
    for key in mat_data.keys():
      if key.startswith('__'):
          continue
      data = mat_data[key]
      sample_data = data[:5] if len(data) >= 5 else data
      print(f"Key: {key}, Length: {len(data)}, Sample Data: {sample_data}")

    print_unique_successive_elements(mat_data['stimulus'])
    print_unique_successive_elements(mat_data['rerepetition'])

  emg = np.array(mat_data['emg'])
  stimulus = np.array(mat_data['restimulus'])
  repetition = np.array(mat_data['rerepetition'])

  return emg, stimulus, repetition



def natural_sort_key(s):
  """
  Generates a sorting key for natural (human-friendly) sorting of strings.

  Args:
  - s: String to be sorted.

  The function splits the input string into a list of integers and non-integer substrings,
  facilitating natural sorting where numerical parts are sorted numerically.

  Returns:
  - List of mixed integer and string parts of the input.
  """

  return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]



def load_all_data(base_folder, selected_subjects, selected_exercises):
  """
  Loads all data from a specified base folder, processing each subject's data
  and filtering based on selected exercises.

  Args:
  - base_folder: Directory containing subject folders.
  - selected_subjects: List of selected subjects for data loading.
  - selected_exercises: List of selected exercises for data loading.

  The function iterates through each subject folder, loading data for predefined exercises
  from .mat files. It handles sorting and directory traversal to compile a dictionary
  of all subjects' data, organized by subject and exercise.

  Returns:
  - Dictionary containing data for all subjects, organized by subject and exercise.
  """

  all_data = {}
  subjects = os.listdir(base_folder)
  subjects = sorted(subjects, key=natural_sort_key)
  if '.DS_Store' in subjects:
    subjects.remove('.DS_Store')

  for s, subject in enumerate(subjects):
    if not selected_subjects or s in (selected_subjects):
      subject_folder = os.path.join(base_folder, subject)
      subfolders = [d for d in os.listdir(subject_folder) if os.path.isdir(os.path.join(subject_folder, d))]
      if subfolders:
        subject_folder = os.path.join(subject_folder, subfolders[0])

      subject_data = {}

      exercises = os.listdir(subject_folder) ##
      exercises = sorted(exercises, key=natural_sort_key)
      if '.DS_Store' in exercises:
        exercises.remove('.DS_Store')
      for file_name in exercises:
        for exercise in selected_exercises:
          if exercise in file_name:
            file_path = os.path.join(subject_folder, file_name)
            emg_data, stimulus_data, repetition_data = load_single_file(file_path)
            subject_data[exercise] = {'emg': emg_data, 'stimulus': stimulus_data, 'repetition': repetition_data}

      all_data[f'{subject}'] = subject_data

  return all_data



def update_set_with_list(unique_set, new_list):
    """
    Update a given set with elements from a new list.

    This function adds each element from the provided list to the set.
    Since a set only holds unique elements, any duplicates in the list
    will not be added if they are already present in the set.

    Parameters:
    unique_set (set): The set to be updated with new elements.
    new_list (list): The list of new elements to add to the set.

    Returns:
    None: The function updates the set in-place and does not return a value.

    Example:
    >>> my_set = {1, 2, 3}
    >>> update_set_with_list(my_set, [3, 4, 5])
    >>> print(my_set)
    {1, 2, 3, 4, 5}
    """
    for element in new_list:
        unique_set.add(element)



def process_data(array, max_derivative_order, delta_value, ch_start=None, ch_end=None):
  """
  Processes an array of data by expanding it based on derivative order and generating a spike array.

  Args:
  - array: Numpy array containing the data to be processed.
  - max_derivative_order: Maximum order of derivative to be used in processing.
  - delta_value: Delta value used for determining spikes.
  - ch_start: (Optional) Starting index for channel slicing.
  - ch_end: (Optional) Ending index for channel slicing.

  The function first slices the array based on provided channel indices. It then expands the array
  to include derivatives up to the specified order. A spike array is generated based on the
  delta value, where spikes represent significant changes in the data.

  Returns:
  - Tuple containing the expanded array and the spike array.
  """

  array = array[:, ch_start:ch_end]
  old_dim_size = array.shape[1]
  new_dim_size = old_dim_size * (max_derivative_order + 1)

  spike_array = np.zeros((array.shape[0] - max_derivative_order * 4, new_dim_size * 2), dtype=np.int8)
  expanded_array = np.zeros((array.shape[0], new_dim_size)) #, dtype=np.int16
  expanded_array[:, :old_dim_size] = array

  for n in range(1, max_derivative_order + 1):
    for i in range(old_dim_size):
      for j in range(array[:, i].shape[0] - n * 4):
        expanded_array[j + n * 2, old_dim_size * n + i] = - expanded_array[j, old_dim_size * (n - 1) + i] - 2 * expanded_array[j + 1, old_dim_size * (n - 1) + i] + 2 * expanded_array[j + 2, old_dim_size * (n - 1) + i] + expanded_array[j + 3, old_dim_size * (n - 1) + i]

  # Trim the arrays to the minimum shape
  expanded_array = expanded_array[max_derivative_order * 2 : - max_derivative_order * 2, :]

  for n in range(max_derivative_order + 1):
    for i in range(old_dim_size):
      dc_val = expanded_array[0, old_dim_size * n + i] # or = 0
      for k, j in enumerate(expanded_array[:, old_dim_size * n + i]):
        if j > dc_val + delta_value[n]:
          dc_val = j
          spike_array[k, (old_dim_size * n + i) * 2] = 1
          spike_array[k, (old_dim_size * n + i) * 2 + 1] = 0
        elif j < dc_val - delta_value[n]:
          dc_val = j
          spike_array[k, (old_dim_size * n + i) * 2] = 0
          spike_array[k, (old_dim_size * n + i) * 2 + 1] = 1
        else:
          spike_array[k, (old_dim_size * n + i) * 2] = 0
          spike_array[k, (old_dim_size * n + i) * 2 + 1] = 0

  return expanded_array, spike_array



def plot_emg(data, subject, exercise, figsize=(10, 6), alpha=0.3):
  """
  Plots EMG data for a given subject and exercise with stimulus and repetition annotations.

  Args:
  - data: Dictionary containing EMG, stimulus, and repetition data.
  - subject: The subject identifier whose data is to be plotted.
  - exercise: The exercise identifier for the specific data to plot.
  - figsize: (Optional) Size of the figure for the plot.
  - alpha: (Optional) Transparency level of the plot.

  This function visualizes EMG data for the specified subject and exercise. It includes overlay plots
  for stimulus and repetition. Stimulus data is represented as colored regions, and repetition data
  as a dashed line. The plot includes customization options for size and transparency.

  The function handles multiple channels of EMG data, assigns different colors to each, and
  represents different stimulus conditions with unique colors. It also sets up the plot with
  appropriate labels, legends, and axes.

  No return value, as the function's purpose is to display a plot.
  """

  emg_data = data[subject][exercise]['emg'][:,:16]
  stimulus = data[subject][exercise]['stimulus'].flatten()
  repetition = data[subject][exercise]['repetition'].flatten()

  # Extended color list for channels
  colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k', 'orange', 'purple', 'pink', 'brown', 'grey', 'lime', 'navy', 'teal', 'maroon']

  # Real labels for stimulus
  stimulus_real_labels = {
      0: 'Rest',
      1: 'Index flexion',
      2: 'Index extension',
      3: 'Middle flexion',
      4: 'Middle extension',
      5: 'Ring flexion',
      6: 'Ring extension',
      7: 'Little finger flexion',
      8: 'Little finger extension',
      9: 'Thumb adduction',
      10: 'Thumb abduction',
      11: 'Thumb flexion',
      12: 'Thumb extension'
  }

  # Extended color list with real labels for stimulus
  stimulus_colors = {
      0: 'white',
      1: 'green',
      2: 'red',
      3: 'blue',
      4: 'yellow',
      5: 'purple',
      6: 'orange',
      7: 'pink',
      8: 'brown',
      9: 'grey',
      10: 'lime',
      11: 'navy',
      12: 'teal'
  }

  fig, ax = plt.subplots(figsize=figsize)

  # Plot each channel
  for i in range(emg_data.shape[1]):
      ax.plot(emg_data[:, i], alpha=alpha, color=colors[i])

  # Add colored regions for stimulus labels
  legend_handles = []
  for label, label_name in stimulus_real_labels.items():
      color = stimulus_colors[label]
      ax.fill_between(np.arange(len(stimulus)), np.min(emg_data), np.max(emg_data),
                      where=(stimulus==label), facecolor=color, alpha=alpha)
      legend_handles.append(plt.Rectangle((0,0),1,1, color=color, label=label_name))

  # Add labels and legend for stimulus labels
  ax.set_xlabel('Time (samples)')
  ax.set_ylabel('Amplitude')
  ax.legend(handles=legend_handles, title='Stimulus Labels')

  # Create a secondary axis for 'repetition'
  ax2 = ax.twinx()
  ax2.plot(repetition, 'k--', label='Repetition', alpha=alpha)
  ax2.set_ylabel('Repetition')
  ax2.legend(loc='upper right')


  # Set x-axis ticks based on the user-specified step
  x_step = 1000
  x_ticks = np.arange(0, len(stimulus), x_step)
  ax.set_xticks(x_ticks)

  plt.tight_layout()
  plt.show()



def get_possible_labels(tensor):
  """
  Extracts unique elements from each row of a tensor and returns them as a list of lists.

  Args:
  - tensor: A NumPy array or tensor from which unique elements are to be extracted.

  This function iterates through each row of the provided tensor. For each row, it finds unique
  elements and appends them as a list to `label_lists`. Each sublist in `label_lists` corresponds
  to the unique elements found in each row of the tensor.

  Returns:
  - label_lists: A list of lists, where each sublist contains unique elements from the respective row of the tensor.
  """

  label_lists = []
  for i in range(tensor.shape[0]):
    label_lists.append(np.unique(tensor[i, :]).tolist())
    # if (lenlabel_lists[-1]):
  return label_lists



def perform_voting(tensor):
  """
  Performs a voting mechanism across rows of a tensor to determine the most frequent element (mode).

  Args:
  - tensor: A NumPy array or tensor on which voting is to be performed.

  This function iterates over each row of the given tensor. For each row, it calculates the mode,
  which is the most frequently occurring element in that row. The mode is determined using the
  `scipy.stats.mode` function. If the mode result is a scalar, it is directly used; otherwise,
  the first element of the mode result is used. This is necessary as `scipy.stats.mode` can return
  an array of modes in case of ties.

  The function creates an array `winners` to store the mode of each row.

  Returns:
  - winners: A NumPy array of the same length as the number of rows in `tensor`,
    containing the mode of each row.
  """
  winners = np.zeros(tensor.shape[0], dtype=np.int8)
  for i in range(tensor.shape[0]):
    mode_result = mode(tensor[i, :])
    if np.isscalar(mode_result.mode):
      winners[i] = mode_result.mode
    else:
      winners[i] = mode_result.mode[0]
  return winners



def safe_format(value, fmt=".5f"):
  """
  Safely format a given value, handling None cases.

  Args:
  value: The value to be formatted. Can be None or any value that supports formatting.
  fmt: The format string (default is ".5f").

  Returns:
  A formatted string according to 'fmt' if 'value' is not None, otherwise "N/D" (Not Available).
  """

  return f"{value:{fmt}}" if value is not None else "N/D"


## Load or create session
This section of the notebook outlines the setup and initialization process for a machine learning session, focusing on configuring session parameters and managing session results.

The only editable variables in the notebook are:
- `session_info`: It contains general information about the session, including the project name, session name, and settings related to the handling of results.
- `session_parameters`: This details the specific parameters for the session, covering dataset details, neural network configurations, and training settings.
- `optimizer_functions`: This dictionary serving as a repository for various optimizer functions.

Finally, it is possible to view existing sessions, the analysis focuses on reviewing existing sessions by examining two key files: `session_parameters.json` and `session_results.json`. The process involves normalizing these JSON files into DataFrames, followed by selective filtering to extract relevant information. This approach facilitates a comprehensive understanding of the session's parameters and outcomes.


### Session setup

In [None]:
session_info = {
  'project_name' : 'semg-gesturerecognition-snn', # Project name.
  'session_loadbyhash' : 'c96a2b5b',                      # Complete this field to load a session using its hash.
  'session_note' : '',               # Session note.
                                                  # Set the path to the project's output directory
  'session_output_path' : os.path.join(gdrive_path, f'Projects/semg-gesturerecognition-snn/output/'),

  #'hash_rename' : True,                          # If active, the folder containing the results will be renamed with the hash of the session_parameters dictionary.
  'overwrite' : False,                            # Flag to indicate if existing files should be overwritten.
  'github_support' : True,                        # Indicates if GitHub support is enabled.
  'repository_branch' : None,                     # Branch of the repository to use, default is 'main'.
  'lazygit_support' : False,                      # Indicates if LazyGit support is enabled.

  'dataset' : {
    'dataset_path' : 'dataset/',                  # Path to the dataset directory.
    'ninapro_db2_download' : False,               # Downloading the NinaPro DB2 dataset.
    'plot_exampledata' : False                    # Flag to indicate if example data from the dataset should be plotted.
  },

  'training' : {
    'epoch_max' : None,                           # Maximum number of training epochs. If set to None, there will be no precise number of maximum epochs.
    'bestmodel_lrchange' : True,                  # If set to True, the model will revert to the best model state when a learning rate change occurs.
    'auto_unassign' : False,                      # Flag to indicate automatic unassignment of resources post-training.
  }
}

session_parameters = {
  'dataset' : {
    'dataset_subpath' : 'ninapro/db5/',           # Subpath to the dataset directory. Currently tested {'ninapro/db1/', 'ninapro/db2/', 'ninapro/db5/'}.
    'data_freq_hz' : 200,                         # Frequency of data sampling in Hertz.
    'subjects' : [],                              # List of subject IDs (enumarate) to include in the dataset. Set an empty list to include all subjects.
    'exercises' : ['E1'],                         # List of exercise IDs to include in the dataset. Specify specific exercises or set an empty list to include all exercises.
    'channel_start' : None,                       # Starting index of channels to consider (None for the first possible channel).
    'channel_end' : None,                         # Ending index of channels to consider (None for the last possible channel).
    'max_derivative_order' : 2,                   # Maximum order of derivative to calculate for the data.
    'delta_value' : 15,                           # Value of delta for delta encoding. Set a list of values if you want different values per layer.

    'train_repetition' : [1, 4, 6],               # Repetitions to be used for training.
    'valid_repetition' : [3],                     # Repetitions to be used for validation.
    'test_repetition' : [2, 5],                   # Repetitions to be used for testing.

    'exercise_win_ratio' : 0.5,                   # Min ratio of the window to be a specific repetition for consideration in exercise.
    'exercise_min_ratio' : 0.8,                   # Min ratio of 'exercise' labels in a window to be considered an exercise.
    'rest_min_ratio' : 1,                         # Min ratio of 'rest' labels in a window to be considered rest.

    'win_size_s' : 0.5,                           # Size of each window in seconds.
    'win_shift_s' : 0.1,                          # Shift of the sliding window in seconds.
    'start_delay_s' : 2,                          # Delay in seconds before starting data collection.
    'rest_delay_s' : 0,                           # Delay in seconds after each exercise session before starting data collection.

    'random_seed' : 0,                            # Seed for random number generation to ensure reproducibility.
  },

  'network' : {
    'current_decay' : 1,                          # Decay rate of the current; typical values are 1 for LIF and 0.25 for CUBA neuron models.
    'ofs' : [64, 128, 64, None],                  # Output feature values for network layers, the last feature output value (number of classes) is calculated automatically if set to None.
    'max_delay' : 62                              # Maximum delay allowed in the network's synaptic connections. Set a list of values if you want different values per layer.
  },

  'training' : {
    'pretrained_model' : None,                    # Set a pretrained-folder path if needed.
    'optimizer' : 'adam',                         # Optimizer to use for training, currently supported {'adam'}.
    'best_loss' : True,                           # If set to True, the best model will be selected during training based on the lowest loss on the validation set. If set to False, the selection will be based on the highest accuracy.
    'batch_size' : 32,                            # Batch size for training.
    'true_rate' : 0.2,                            # The firing rate set for positive class outputs during training. This rate is used to define the desired level of neuron activity for correct classifications.
    'false_rate' : 0.03,                          # The firing rate set for negative class outputs during training. This rate determines the neuron activity level for incorrect classifications.
    'learning_rate_first' : 0.001,                # Initial learning rate for the optimizer.
    'learning_rate_decay' : 0.1,                  # Factor by which the learning rate decays.
    'learning_rate_decay_steps' : 3,              # Number of times the learning_rate_decay is applied.
    'patience' : 30                               # Number of epochs to wait for improvement before stopping training or changing the learning rate.
  }
}

session_results = {
  'hash_name' : None,                # UNEDITABLE # Hash of the session_parameters dictionary.
  'colab_env' : None,                # UNEDITABLE # Google Colab support.
  'device' : None,                   # UNEDITABLE # Hardware support.

  'network': {
    'ofs': None                      # UNEDITABLE # Final offset values determined for the network layers.
  }
}

# Dictionary serving as a repository for various optimizer functions.
optimizer_functions = {
  'adam': get_adam_optimizer
}

if not isinstance(session_parameters['dataset']['delta_value'], list):
  session_parameters['dataset']['delta_value'] = [session_parameters['dataset']['delta_value']] * (session_parameters['dataset']['max_derivative_order'] + 1) # len(session_parameters['network']['ofs'])

if not isinstance(session_parameters['network']['max_delay'], list):
  session_parameters['network']['max_delay'] = [session_parameters['network']['max_delay']] * len(session_parameters['network']['ofs'])
  session_parameters['network']['max_delay'][-1] = None

session_results['hash_name'] = hash_dictionary(session_parameters)[:8]
session_results['colab_env'] = colab_env
session_results['device'] = str(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

session_folder_name = session_results['hash_name'] # if session_info['hash_rename'] else session_info['session_name']
session_loaded = False

if session_info['overwrite']:
  try:
    shutil.rmtree(os.path.join(session_info['session_output_path'], session_folder_name))
  except:
    pass
  os.makedirs(os.path.join(session_info['session_output_path'], session_folder_name), exist_ok=True)

  print('New session overwritten:', os.path.join(session_info['session_output_path'], session_folder_name))

else:
  for session_hash in [session_info['session_loadbyhash']] if session_info['session_loadbyhash'] else [session_results['hash_name']]:
    existing_session_folder_name = find_session_byhash(session_info['session_output_path'], session_hash)
    if existing_session_folder_name:
      try:
        session_info_loaded = load_json(os.path.join(session_info['session_output_path'], existing_session_folder_name, 'session_info.json'))
        session_parameters_loaded = load_json(os.path.join(session_info['session_output_path'], existing_session_folder_name, 'session_parameters.json'))
        session_results_loaded = load_json(os.path.join(session_info['session_output_path'], existing_session_folder_name, 'session_results.json'))

        session_info_loaded['overwrite'] = session_info['overwrite']
        session_info_loaded['lazygit_support'] = session_info['lazygit_support']
        session_info_loaded['dataset']['plot_exampledata'] = session_info['dataset']['plot_exampledata']
        session_info_loaded['training']['epoch_max'] = session_info['training']['epoch_max']
        session_info_loaded['training']['auto_unassign'] = session_info['training']['auto_unassign']

        session_results_loaded['hash_name'] = session_results['hash_name']
        session_results_loaded['colab_env'] = session_results['colab_env']
        session_results_loaded['device'] = session_results['device']

        session_info = session_info_loaded
        session_parameters = session_parameters_loaded
        session_results = session_results_loaded

        session_folder_name = existing_session_folder_name

        session_loaded = True
        print('Existing session loaded:', os.path.join(session_info['session_output_path'], existing_session_folder_name))
        break
      except:
        print('Error loading session with hash:', session_info['session_loadbyhash'])
        raise

if not session_loaded:
  if session_info['session_loadbyhash']:
    print('No session found with hash:', session_info['session_loadbyhash'])
    raise

  else:
    os.makedirs(os.path.join(session_info['session_output_path'], session_folder_name), exist_ok=True)
    print('New session created:', os.path.join(session_info['session_output_path'], session_folder_name))

    save_json(session_info, os.path.join(session_info['session_output_path'], session_folder_name, 'session_info.json'))
    save_json(session_parameters, os.path.join(session_info['session_output_path'], session_folder_name, 'session_parameters.json'))
    save_json(session_results, os.path.join(session_info['session_output_path'], session_folder_name, 'session_results.json'))

    with open(os.path.join(session_info['session_output_path'], session_folder_name, 'session_hash.txt'), 'w') as file:
      file.write(session_results['hash_name'])

try:
  if session_results['training']['epoch_last']:
    session_trained = True
  else:
    session_trained = False
except:
  session_trained = False

### Existing sessions

In [None]:
# Normalize the JSON files and create a DataFrame
df = normalize_json_files(session_info['session_output_path'], 'session_parameters.json')

# List of columns to ignore
columns_to_ignore = []

# Select all columns except those to ignore
columns_to_keep = [col for col in df.columns if col not in columns_to_ignore]
df_filtered = df[columns_to_keep]

# Display the filtered DataFrame
pd.DataFrame(df_filtered)

In [None]:
# Normalize the JSON files and create a DataFrame
df = normalize_json_files(session_info['session_output_path'], 'session_results.json')

# List of columns to ignore
columns_to_ignore = ['hash_name', 'colab_env', 'device', 'training.optimizer', 'training.train_loss', 'training.valid_loss', 'training.train_acc', 'training.valid_acc']

# Select all columns except those to ignore
columns_to_keep = [col for col in df.columns if col not in columns_to_ignore]
df_filtered = df[columns_to_keep]

# Display the filtered DataFrame
pd.DataFrame(df_filtered)

## Session environment
This cell is designed for setting up the development environment and repository access, particularly for sessions using GitHub and Colab.

- **Google Colab User Data:** Imports user data from Google Colab for authentication purposes.
- **Repository Cloning:** If `session_info['github_support']` is true and the session is run in a Colab environment (`session_results['colab_env']`).

- **Generic Setup:** For sessions not using Colab or GitHub, an `else` block is included to organize any alternative setup required.

- **LazyGit Installation:** If `session_info['lazygit_support']` is true, the cell automates the installation of LazyGit, a simple terminal UI for Git commands. It fetches the latest version, installs it, and cleans up the installation files.

In [None]:
if session_info['github_support'] and session_results['colab_env']:
  # Import userdata module from google.colab
  from google.colab import userdata

  # Remove the default sample_data directory in Colab
  %rm -rf /content/sample_data

  # Change current working directory to /content
  %cd /content

  # Repository name on GitHub
  repository = session_info['project_name']

  # Retrieve Git credentials stored in userdata
  github_json = userdata.get('github_json')
  github_json = json.loads(github_json)

  # Configure global Git settings with the retrieved credentials
  !git config --global user.name {github_json['name']}
  !git config --global user.email {github_json['email']}
  !git config --global user.password {github_json['password']}

  # Clone the GitHub repository using Git token for authentication
  !git clone -b {session_info['repository_branch'] if session_info['repository_branch'] else 'main'} https://{github_json['token']}@github.com/{github_json['name']}/{repository}

  # Change directory to the cloned repository's directory
  %cd /content/{repository}

else:
  # Change directory to main directory
  #%cd ...
  pass

if session_info['lazygit_support']:
  LAZYGIT_VERSION = !echo $(curl -s "https://api.github.com/repos/jesseduffield/lazygit/releases/latest" | grep -Po '"tag_name": "v\K[^"]*')
  LAZYGIT_SOURCE = f"https://github.com/jesseduffield/lazygit/releases/latest/download/lazygit_{LAZYGIT_VERSION[0]}_Linux_x86_64.tar.gz"
  !curl -Lo lazygit.tar.gz {LAZYGIT_SOURCE}
  !tar xf lazygit.tar.gz lazygit
  !install lazygit /usr/local/bin
  %rm lazygit.tar.gz
  %rm lazygit

### Downloading NinaPro DB2
If `ninapro_db2_download`, his cell performs the download of the NinaPro DB2 dataset. The dataset is downloaded as a compressed file and then extracted to the specified location for further processing and analysis.

In [None]:
if session_results['colab_env'] and session_info['dataset']['ninapro_db2_download']:
  # Retrieve Kaggle token from userdata and parse the JSON
  kaggle_json = userdata.get('kaggle_json')
  kaggle_json = json.loads(kaggle_json)

  # Create a kaggle.json file with the token
  with open('kaggle.json', 'w') as file:
      json.dump(kaggle_json, file, indent=4)

  # Print out the details of the kaggle.json file
  for fn in kaggle_token.keys():
      print('User uploaded file "{name}" with length {length} bytes'.format(
          name=fn, length=len(kaggle_token[fn])))

  # Set up Kaggle API credentials file and permissions
  !mkdir -p ~/.kaggle/ && mv kaggle.json ~/.kaggle/ && chmod 600 ~/.kaggle/kaggle.json

  # Download the specified Kaggle dataset
  !kaggle datasets download "mengliumei/ninaprodb2"

  # Clean up by removing the kaggle.json file
  %rm kaggle.json

  # Create a directory for the dataset and unzip the downloaded dataset into it
  !mkdir dataset/ninapro/db2
  !unzip ninaprodb2.zip -d dataset/ninapro/db2

  # Remove the downloaded zip file after extraction
  %rm ninaprodb2.zip

## Dataset Processing

The code is designed to adaptively process input data from the NinaPro dataset according to various parameters:

- `dataset_subpath`: Specifies the particular subpath within the dataset directory.
- `data_freq_hz`: Sets the data sampling frequency.
- `subjects`: Determines the subject IDs to include in the dataset.
- `exercises`: Identifies specific exercises for focused analysis.
- `channel_start`, `channel_end`: Defines the range of channels to consider in the dataset.
- `max_derivative_order`: Establishes the highest order of derivatives to be calculated.
- `delta_value`: Determines the threshold for generating a spike array, highlighting significant data changes.
- `plot_exampledata`: Optionally activates plotting of EMG data for selected sessions and exercises.

In [None]:
all_data = load_all_data(os.path.join(session_info['dataset']['dataset_path'], session_parameters['dataset']['dataset_subpath']), session_parameters['dataset']['subjects'], session_parameters['dataset']['exercises'])

# Define the max_derivative_order variable
max_derivative_order = session_parameters['dataset']['max_derivative_order']
delta_value = session_parameters['dataset']['delta_value']

# Original shapes
print('EMG signal data samples:')
print({subject: {exercise: data['emg'].shape for exercise, data in subject_data.items()} for subject, subject_data in all_data.items()}, '\n')

# Expand 'emg' arrays in 'all_data' based on max_derivative_order
for i, (subject, subject_data) in tqdm(enumerate(all_data.items()), ncols=80, desc='Data subject loading', total=len(all_data)):
  for exercise, exercise_data in subject_data.items():
    all_data[subject][exercise]['emg'], all_data[subject][exercise]['emg_spike'] = process_data(exercise_data['emg'], session_parameters['dataset']['max_derivative_order'], session_parameters['dataset']['delta_value'], ch_start=session_parameters['dataset']['channel_start'], ch_end=session_parameters['dataset']['channel_end'])

# Verify the shapes after expansion
print('\n\nEMG signal and derivative data samples:')
print({subject: {exercise: data['emg'].shape for exercise, data in subject_data.items()} for subject, subject_data in all_data.items()})

# Verify the shapes after delta modulation process
print('\nEMG spiking data samples:')
print({subject: {exercise: data['emg_spike'].shape for exercise, data in subject_data.items()} for subject, subject_data in all_data.items()})

In [None]:
if session_info['dataset']['plot_exampledata'] or True:
  plot_emg(all_data, 's1', 'E1', figsize=(180, 8))

## Data Loader

The code is configured to effectively handle data loading and preparation across various stages of machine learning model development. The primary settings include:

- `train_repetition`, `valid_repetition`, `test_repetition`: Designates separate repetitions for training, validation, and testing phases to ensure distinct data segmentation.
- `exercise_win_ratio`, `exercise_min_ratio`, `rest_min_ratio`: Establishes parameters for window analysis, such as the minimum ratio of a window for a specific repetition, and minimum proportions of exercise and rest labels within a window.
- `win_size_s`, `win_shift_s`: Defines the size and shift in seconds of the window for detailed data sampling.
- `start_delay_s`, `rest_delay_s`: Implements start and post-exercise session delays in seconds for data collection stabilization.
- `random_seed`: Sets a random seed to ensure reproducibility of results.



In [None]:
# Define which repetition numbers are used for training, validation, and testing
train_repetition = session_parameters['dataset']['train_repetition']
valid_repetition = session_parameters['dataset']['valid_repetition']
test_repetition = session_parameters['dataset']['test_repetition']

# Define the minimum ratios for considering a window as exercise or rest
exercise_win_ratio = session_parameters['dataset']['exercise_win_ratio']  # Minimum ratio of the window that must be a specific repetition for it to be considered
exercise_min_ratio = session_parameters['dataset']['exercise_min_ratio']  # Minimum ratio of 'exercise' labels in the window for it to be considered exercise
rest_min_ratio = session_parameters['dataset']['rest_min_ratio']          # Minimum ratio of 'rest' labels in the window for it to be considered rest

# Define the frequency of the data
data_freq_hz = session_parameters['dataset']['data_freq_hz']

# Define the window size, window shift, and delays in seconds
win_size_s = session_parameters['dataset']['win_size_s']        # Size of each window in seconds
win_shift_s = session_parameters['dataset']['win_shift_s']      # Shift for the sliding window in seconds
start_delay_s = session_parameters['dataset']['start_delay_s']  # Initial delay in seconds before starting the windowing
rest_delay_s = session_parameters['dataset']['rest_delay_s']    # Delay in seconds after each exercise before considering rest

# Convert the window size, window shift, and delays into samples
win_size_sample = int(win_size_s * data_freq_hz)        # Window size in samples
win_shift_sample = int(win_shift_s * data_freq_hz)      # Window shift in samples
start_delay_sample = int(start_delay_s * data_freq_hz)  # Initial delay in samples
rest_delay_sample = int(rest_delay_s * data_freq_hz)    # Delay after exercise in samples

In [None]:
# Initialize an empty dictionary to store classes found in the dataset
classes = set()

# Initialize empty lists to collect data for training, validation, and testing sets
x_train_orig = []
x_valid_orig = []
x_test_orig = []

x_train = []
x_valid = []
x_test = []

y_train = []
y_valid = []
y_test = []

# Loop over all subjects and exercises in the all_data dictionary
#for subject, subject_data in all_data.items():
for i, (subject, subject_data) in tqdm(enumerate(all_data.items()), ncols=80, desc='Data subject loading', total=len(all_data)):
  for exercise, exercise_data in subject_data.items():
    emg_data_orig = exercise_data['emg']
    emg_data = exercise_data['emg_spike']
    stimulus = exercise_data['stimulus'].flatten()
    repetition = exercise_data['repetition'].flatten()

    update_set_with_list(classes, stimulus)

    # Initialize starting index for windowing
    start_idx = start_delay_sample
    exercise_idx = - rest_delay_sample

    while start_idx + win_size_sample <= len(emg_data):
      # Extract the window from the EMG data, stimulus, and repetition
      win_emg_orig = emg_data_orig[start_idx:start_idx + win_size_sample, :]
      win_emg = emg_data[start_idx:start_idx + win_size_sample, :]
      win_stimulus = stimulus[start_idx:start_idx + win_size_sample]
      win_repetition = repetition[start_idx:start_idx + win_size_sample]

      # Calculate ratios for the current window
      exercise_ratio = np.sum(win_stimulus != 0) / win_size_sample
      rest_ratio = np.sum(win_stimulus == 0) / win_size_sample

      # Check if the window is mostly one of the training repetitions
      is_train_repetition = np.sum(np.isin(win_repetition, train_repetition)) >= win_size_sample * exercise_win_ratio
      is_valid_repetition = np.sum(np.isin(win_repetition, valid_repetition)) >= win_size_sample * exercise_win_ratio
      is_test_repetition = np.sum(np.isin(win_repetition, test_repetition)) >= win_size_sample * exercise_win_ratio

      # Categorize the window based on its characteristics
      if is_train_repetition and exercise_ratio >= exercise_min_ratio and start_idx % 20 == 0:
        x_train_orig.append(win_emg_orig)
        x_train.append(win_emg)
        y_train.append(win_stimulus)
        exercise_idx = start_idx

      elif is_train_repetition and rest_ratio >= rest_min_ratio and start_idx % 20 == 0:
        if exercise_idx + win_size_sample + rest_delay_sample <= start_idx:
          x_train_orig.append(win_emg_orig)
          x_train.append(win_emg)
          y_train.append(win_stimulus)

      elif is_valid_repetition and exercise_ratio >= exercise_min_ratio and start_idx % 20 == 0:
        x_valid_orig.append(win_emg_orig)
        x_valid.append(win_emg)
        y_valid.append(win_stimulus)
        exercise_idx = start_idx

      elif is_valid_repetition and rest_ratio >= rest_min_ratio and start_idx % 20 == 0:
        if exercise_idx + win_size_sample + rest_delay_sample <= start_idx:
          x_valid_orig.append(win_emg_orig)
          x_valid.append(win_emg)
          y_valid.append(win_stimulus)

      elif is_test_repetition and exercise_ratio >= exercise_min_ratio:
        x_test_orig.append(win_emg_orig)
        x_test.append(win_emg)
        y_test.append(win_stimulus)
        exercise_idx = start_idx

      elif is_test_repetition and rest_ratio >= rest_min_ratio:
        if exercise_idx + win_size_sample + rest_delay_sample <= start_idx:
          x_test_orig.append(win_emg_orig)
          x_test.append(win_emg)
          y_test.append(win_stimulus)

      # Shift the starting index for the next window
      start_idx += win_shift_sample

# Convert lists to numpy arrays
x_train_orig = np.array(x_train_orig)
x_valid_orig = np.array(x_valid_orig)
x_test_orig = np.array(x_test_orig)

# Convert lists to numpy arrays
x_train = np.array(x_train)
x_valid = np.array(x_valid)
x_test = np.array(x_test)

y_train = np.array(y_train)
y_valid = np.array(y_valid)
y_test = np.array(y_test)

# Transpose the arrays
x_train_orig = np.transpose(x_train_orig, (0, 2, 1))
x_valid_orig = np.transpose(x_valid_orig, (0, 2, 1))
x_test_orig = np.transpose(x_test_orig, (0, 2, 1))

# Transpose the arrays
x_train = np.transpose(x_train, (0, 2, 1))
x_valid = np.transpose(x_valid, (0, 2, 1))
x_test = np.transpose(x_test, (0, 2, 1))

# Check shapes of resulting datasets
x_train.shape, x_valid.shape, x_test.shape, y_train.shape, y_valid.shape, y_test.shape

In [None]:
x_test_orig

In [None]:
# Apply the function to y_train, y_valid, and y_test
y_train_multi = get_possible_labels(y_train)
y_valid_multi = get_possible_labels(y_valid)
y_test_multi = get_possible_labels(y_test)

# Apply the function to y_train, y_valid, and y_test
y_train = perform_voting(y_train)
y_valid = perform_voting(y_valid)
y_test = perform_voting(y_test)

# Check shapes of resulting datasets
x_train.shape, x_valid.shape, x_test.shape, y_train.shape, y_valid.shape, y_test.shape

In [None]:
class EMG_Dataset(Dataset):
  """
  A custom dataset class for Electromyography (EMG) data compatible with PyTorch's Dataset interface.

  This class is designed to handle datasets for EMG analysis tasks, providing necessary methods
  to integrate with PyTorch's data loading utilities.

  Args:
  - x: Input data (features), typically EMG signals in a NumPy array or similar format.
  - y: Target data (labels), usually in a NumPy array format.

  Methods:
  - __init__: Initializes the dataset with input and target data.
  - __getitem__: Retrieves a single sample (input-target pair) from the dataset at the specified index.
  - __len__: Returns the total number of samples in the dataset.
  """

  def __init__(self, x, y):
    """
    Initializes the EMG_Dataset instance with input and target data.
    """

    super(EMG_Dataset, self).__init__()
    self.input = x
    self.target = y

  def __getitem__(self, idx):
    """
    Retrieves the input-target pair at the specified index in the dataset.

    Args:
    - idx: Index of the sample to retrieve.

    Returns:
    - A tuple containing the input data (as a float32 tensor) and the target label (as a long tensor) for the specified index.
    """

    y_out = torch.tensor(self.target[idx]).long()  # Convert to long tensor
    return (
      torch.tensor(self.input[idx].astype(np.float32)),  # Convert input to float32 tensor
      y_out
    )

  def __len__(self):
    """
    Returns the total number of samples in the dataset.
    """

    return len(self.target)  # Return the length of the target data


In [None]:
# Set a random seed for reproducibility of shuffling
random_seed = session_parameters['dataset']['random_seed']

# Shuffle the training, validation, and testset dataset, both features and labels, using the specified random seed
x_train_orig, _ = shuffle(x_train_orig, y_train, random_state=random_seed)
x_train, y_train = shuffle(x_train, y_train, random_state=random_seed)
#x_valid, y_valid = shuffle(x_valid, y_valid, random_state=random_seed)
#x_test, y_test = shuffle(x_test, y_test, random_state=random_seed)

if False:
  # Convert the shuffled numpy arrays to PyTorch tensors with dtype=torch.int8
  x_train_tensor = torch.tensor(x_train, dtype=torch.int8)
  y_train_tensor = torch.tensor(y_train, dtype=torch.int8)
  x_valid_tensor = torch.tensor(x_valid, dtype=torch.int8)
  y_valid_tensor = torch.tensor(y_valid, dtype=torch.int8)
  x_test_tensor = torch.tensor(x_test, dtype=torch.int8)
  y_test_tensor = torch.tensor(y_test, dtype=torch.int8)

# Create TensorDatasets
train_dataset = EMG_Dataset(x_train, y_train)
valid_dataset = EMG_Dataset(x_valid, y_valid)
test_dataset = EMG_Dataset(x_test, y_test)

# Create DataLoaders
batch_size = session_parameters['training']['batch_size']
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
debugtest_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Print out some DataLoader details to confirm
len(train_dataloader), len(valid_dataloader), len(test_dataloader)

## Network

The implementation of the custom neural network module in PyTorch is designed for spiking neural network applications, focusing on key features and parameters:

- `current_decay`: Configurable decay rate of the current, pivotal in defining the dynamics of spiking neurons. The value is adaptable to different neuron models, like LIF and CUBA.
- `ofs`: The network layers are dynamically built based on the output feature values. These values determine the architecture of each layer in the network.
- `max_delay`: The maximum delay allowed in the network's synaptic connections, crucial for managing the timing of spike propagation.

After initial setup, the code includes logic to manage model loading and training control based on the session's state:

- Model Loading: The code determines whether to load an existing model or create a new one based on the session's training status. It handles scenarios of both early stopping and ongoing training.
- Training Enablement: A check is performed to decide if further training is necessary, setting the `training_enable` flag accordingly. This decision is based on the availability and status of the pre-trained or last saved model.


In [None]:
class Network(torch.nn.Module):
  """
  A custom neural network module implemented using PyTorch, designed for processing spike-based data.

  This network consists of multiple dense layers with specific neuron parameters, including dropout,
  threshold, current decay, and voltage decay. It is designed for spiking neural network (SNN) applications.

  Attributes:
  - blocks (torch.nn.ModuleList): A list of dense layers constituting the network.

  Methods:
  - forward: Defines the forward pass of the network.
  - grad_flow: Monitors and visualizes the gradient flow in the network.
  - export_hdf5: Exports the network configuration and parameters to an HDF5 file.
  """

  def __init__(self):
    """
    Initializes the Network instance, setting up the layers and neuron parameters.
    """

    super(Network, self).__init__()

    neuron_params = {
        'threshold'     : 1.25,
        'current_decay' : session_parameters['network']['current_decay'],
        'voltage_decay' : 0.03,
        'tau_grad'      : 0.03,
        'scale_grad'    : 3,
        'requires_grad' : True,
      }
    neuron_params_drop = {**neuron_params, 'dropout' : slayer.neuron.Dropout(p=0.05),}

    neuron_params_1 = {
        'threshold'     : 1.25,
        'current_decay' : 0.0,
        'voltage_decay' : 0.0,
        'tau_grad'      : 0.03,
        'scale_grad'    : 3,
        'requires_grad' : True,
      }
    neuron_params_drop_1 = {**neuron_params_1, 'dropout' : slayer.neuron.Dropout(p=0.05),}

    layers = []
    for i in range(len(session_results['network']['ofs'])):
      if i == len(session_results['network']['ofs']) - 1:
        layer = slayer.block.cuba.Dense(neuron_params, session_results['network']['ofs'][i-1], session_results['network']['ofs'][i], weight_norm=True)
      else:
        input_size = session_results['network']['ofs'][i-1] if i > 0 else network_if
        layer = slayer.block.cuba.Dense(neuron_params_drop, input_size, session_results['network']['ofs'][i], weight_norm=True, delay=isinstance(session_parameters['network']['max_delay'][i], int) and session_parameters['network']['max_delay'][i] > 0)
      layers.append(layer)
    self.blocks = torch.nn.ModuleList(layers)

  def forward(self, spike):
    """
    Defines the forward pass of the neural network.

    Args:
    - spike: Input spike data for the network.

    Returns:
    - Output spike after passing through the network layers.
    """

    for i in range(len(session_results['network']['ofs'])):
      spike = self.blocks[i](spike)
    return spike

  def grad_flow(self, path):
    """
    Monitors and plots the gradient flow through the network layers.

    Args:
    - path: File path where the gradient flow plot will be saved.

    Returns:
    - A list containing gradient norms for each synapse in the network layers.
    """

    grad = [b.synapse.grad_norm for b in self.blocks if hasattr(b, 'synapse')]

    plt.figure()
    plt.semilogy(grad)
    plt.savefig(path + 'gradFlow.png')
    plt.close()

    return grad

  def export_hdf5(self, filename):
    """
    Exports the network configuration and parameters to an HDF5 file.

    Args:
    - filename: Path of the HDF5 file where the network configuration will be saved.
    """

    h = h5py.File(filename, 'w')
    layer = h.create_group('layer')
    for i, b in enumerate(self.blocks):
      b.export_hdf5(layer.create_group(f'{i}'))



network_if = x_train.shape[1]

if session_loaded and session_trained:
  if session_results['training']['early_stopped']:
    try:
      model = torch.load(os.path.join(session_info['session_output_path'], session_folder_name, 'model_best.pt'), map_location=torch.device(session_results['device']))
      print(f'Model correctly loaded: {os.path.join(session_info["session_output_path"], session_folder_name, "model_best.pt")}\nThe training will not be performed as it has already been completed in the previous training')
    except FileNotFoundError as e:
      raise FileNotFoundError(f"Model file not found in {session_info['session_output_path']}: {e}")
    except KeyError as e:
      raise KeyError(f"Missing key in session_results: {e}")
    except Exception as e:
      raise Exception(f"Error loading the model: {e}")

    training_enable = False
  else:
    try:
      model = torch.load(os.path.join(session_info['session_output_path'], session_folder_name, 'model_last.pt'), map_location=torch.device(session_results['device']))
      print(f'Model correctly loaded: {os.path.join(session_info["session_output_path"], session_folder_name, "model_last.pt")}\nThe training will be concluded')
    except FileNotFoundError as e:
      raise FileNotFoundError(f"Model file not found in {session_info['session_output_path']}: {e}")
    except KeyError as e:
      raise KeyError(f"Missing key in session_results: {e}")
    except Exception as e:
      raise Exception(f"Error loading the model: {e}")

    training_enable = True

else:
  session_results['network']['ofs'] = session_parameters['network']['ofs']
  if session_results['network']['ofs'][-1] == None:
    session_results['network']['ofs'][-1] = len(classes)
  save_json(session_results, os.path.join(session_info['session_output_path'], session_folder_name, 'session_results.json'))

  if session_parameters['training']['pretrained_model']:
    model = torch.load(session_parameters['training']['pretrained_model'], map_location=torch.device(session_results['device']))
    print(f'Pretrained model correctly loaded: {session_parameters["training"]["pretrained_model"]}\nA new training will be performed')
  else:
    model = Network().to(session_results['device'])
    print(f'Model correctly created\nA new training will be performed')

  for i, max_delay in enumerate(session_parameters['network']['max_delay']):
    if max_delay:
      model.blocks[i].delay.max_delay = max_delay

  training_enable = True

## Training

Configurations for the training process are based on the following parameters:

- `epoch_max`: Maximum number of epochs for training; 'None' for no set limit.
- `auto_unassign`: Indicates if resources should be automatically unassigned after training.
- `pretrained_model`: Path to a pre-trained model, if used.
- `optimizer`: Type of optimizer, like 'adam', for training.
- `best_loss`: Criteria for selecting the best model, based on loss or accuracy.
- `batch_size`: Size of data batches during training.
- `true_rate` and `false_rate`: Firing rates for positive and negative class outputs in training.
- `learning_rate_first`: Initial learning rate for the optimizer.
- `learning_rate_decay` and `learning_rate_decay_steps`: Parameters for adjusting the learning rate over time.
- `patience`: Epochs to wait for improvement before altering the training approach.

### Training settings

In [None]:
if training_enable:
  error = slayer.loss.SpikeRate(true_rate=session_parameters['training']['true_rate'], false_rate=session_parameters['training']['false_rate'], reduction='sum').to(session_results['device'])
  stats = slayer.utils.LearningStats()

  if session_loaded and session_trained:
    if session_parameters['training']['optimizer'] in optimizer_functions:
      optimizer = optimizer_functions[session_parameters['training']['optimizer']](session_results['training']['learning_rate'][-1])
    else:
      raise ValueError(f"Optimizer '{session_parameters['training']['optimizer']}' not supported")

    assistant = slayer.utils.Assistant(model, error, optimizer, stats, classifier=slayer.classifier.Rate.predict)
    stats.training.min_loss = session_results['training']['train_loss_min']
    stats.testing.min_loss = session_results['training']['valid_loss_min']
    stats.training.max_accuracy = session_results['training']['train_accuracy_max']
    stats.testing.max_accuracy = session_results['training']['valid_accuracy_max']

    learning_rate = session_results['training']['learning_rate'][-1]

  else:
    if session_parameters['training']['optimizer'] in optimizer_functions:
      optimizer = optimizer_functions[session_parameters['training']['optimizer']](session_parameters['training']['learning_rate_first'])
    else:
      raise ValueError(f"Optimizer '{session_parameters['training']['optimizer']}' not supported")

    assistant = slayer.utils.Assistant(model, error, optimizer, stats, classifier=slayer.classifier.Rate.predict)

    session_results['training'] = {
      'optimizer' : None,              # UNEDITABLE # Training optimizer.
      'done': False,                   # UNEDITABLE # Flag to indicate if the training process is complete.
      'early_stopped': False,          # UNEDITABLE # Indicates if training was stopped early due to lack of improvement.
      'epoch_best': None,              # UNEDITABLE # Epoch number where the best performance was achieved.
      'epoch_last': None,              # UNEDITABLE # Last epoch number before training concluded.
      'epoch_no_imporve': None,        # UNEDITABLE # For how many epochs has the model not improved on the validation set.
      'train_loss_min': None,          # UNEDITABLE # Best training loss achieved.
      'valid_loss_min': None,          # UNEDITABLE # Best validation loss achieved.
      'train_accuracy_max': None,      # UNEDITABLE # Best training accuracy achieved.
      'valid_accuracy_max': None,      # UNEDITABLE # Best validation accuracy achieved.
      'learning_rate_best': None,      # UNEDITABLE # Best learning rate achieved during training.
      'train_loss_best': None,         # UNEDITABLE # Best training loss achieved in the model_best.pt.
      'valid_loss_best': None,         # UNEDITABLE # Best validation loss achieved in the model_best.pt.
      'train_accuracy_best': None,     # UNEDITABLE # Best training accuracy achieved in the model_best.pt.
      'valid_accuracy_best': None,     # UNEDITABLE # Best validation accuracy achieved in the model_best.pt.
      'learning_rate': [],             # UNEDITABLE # History of learning rates used throughout training.
      'train_loss': [],                # UNEDITABLE # History of training losses per epoch.
      'valid_loss': [],                # UNEDITABLE # History of validation losses per epoch.
      'train_acc': [],                 # UNEDITABLE # History of training accuracies per epoch.
      'valid_acc': []                  # UNEDITABLE # History of validation accuracies per epoch.
    }

    session_results['post_training'] = {
      'test_acc': None,                # UNEDITABLE # Test accuracy achieved after training is complete.
      'test_acc_fakequant': None       # UNEDITABLE # Test accuracy achieved with fake quantization after training.
    }

    # Configuring and saving session information
    session_results['training']['optimizer'] = str(optimizer).replace("\n","").replace("    ",", ")
    session_results['training']['epoch_no_imporve'] = 0

    save_json(session_results, os.path.join(session_info['session_output_path'], session_folder_name, 'session_results.json'))

    # Setting learning rate parameter
    learning_rate = session_parameters['training']['learning_rate_first']

### Training loop

In [None]:
if training_enable:
  while True:
    if session_info['training']['epoch_max'] and session_results['training']['epoch_last'] and (session_results['training']['epoch_last'] + 1) >= session_info['training']['epoch_max']:
      break

    # training loop
    for i, (input, label) in enumerate(train_dataloader):
      output = assistant.train(input, label)

    # validation loop
    for i, (input, label) in enumerate(valid_dataloader):
      output = assistant.test(input, label)

    if session_results['training']['epoch_last'] == None:
      session_results['training']['epoch_last'] = 0
    else:
      session_results['training']['epoch_last'] += 1

    # session checkpoint
    torch.save(model, os.path.join(session_info['session_output_path'], session_folder_name, 'model_last.pt'))

    session_results['training']['learning_rate'].append(learning_rate)
    session_results['training']['train_loss'].append(stats.training.loss)
    session_results['training']['valid_loss'].append(stats.testing.loss)
    session_results['training']['train_acc'].append(stats.training.accuracy)
    session_results['training']['valid_acc'].append(stats.testing.accuracy)
    session_results['training']['train_loss_min'] = stats.training.min_loss
    session_results['training']['valid_loss_min'] = stats.testing.min_loss
    session_results['training']['train_accuracy_max'] = stats.training.max_accuracy
    session_results['training']['valid_accuracy_max'] = stats.testing.max_accuracy

    save_json(session_results, os.path.join(session_info['session_output_path'], session_folder_name, 'session_results.json'))

    print(
      f'\rE {session_results["training"]["epoch_last"]}.   '                                        # Epoch
      f'TL: {safe_format(stats.training.loss)} ({safe_format(stats.training.min_loss)}), '          # Training loss
      f'VL: {safe_format(stats.testing.loss)} ({safe_format(stats.testing.min_loss)}).   '          # Validation loss
      f'TA: {safe_format(stats.training.accuracy)} ({safe_format(stats.training.max_accuracy)}), '  # Training accuracy
      f'VA: {safe_format(stats.testing.accuracy)} ({safe_format(stats.testing.max_accuracy)}).   ', # Validation accuracy
      end=''
    )

    # check best model
    if (session_parameters['training']['best_loss'] and stats.testing.best_loss) or (not session_parameters['training']['best_loss'] and stats.testing.best_accuracy):
      torch.save(model, os.path.join(session_info['session_output_path'], session_folder_name, 'model_best.pt'))

      session_results['training']['learning_rate_best'] = learning_rate
      session_results['training']['train_loss_best'] = stats.training.loss
      session_results['training']['valid_loss_best'] = stats.testing.loss
      session_results['training']['train_accuracy_best'] = stats.training.accuracy
      session_results['training']['valid_accuracy_best'] = stats.testing.accuracy
      session_results['training']['epoch_best'] = session_results['training']['epoch_last']

      save_json(session_results, os.path.join(session_info['session_output_path'], session_folder_name, 'session_results.json'))

      session_results['training']['epoch_no_imporve'] = 0

    else:
      session_results['training']['epoch_no_imporve'] += 1

      if session_results['training']['epoch_no_imporve'] >= session_parameters['training']['patience']:
        # if learning_rate >= (1 / round(1 / ((session_parameters['training']['learning_rate_decay'] ** session_parameters['training']['learning_rate_decay_steps']) * session_parameters['training']['learning_rate_first']))):
        if learning_rate >= round((session_parameters['training']['learning_rate_decay'] ** session_parameters['training']['learning_rate_decay_steps']) * session_parameters['training']['learning_rate_first'], 16):
          learning_rate *= session_parameters['training']['learning_rate_decay']

          if session_info['training']['bestmodel_lrchange']:
            try:
              model = torch.load(os.path.join(session_info['session_output_path'], session_folder_name, 'model_best.pt'), map_location=torch.device(session_results['device']))
            except:
              pass

          save_json(session_results, os.path.join(session_info['session_output_path'], session_folder_name, 'session_results.json'))
          for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate
          session_results['training']['epoch_no_imporve'] = 0
          print(f'Reducing learning rate:', learning_rate)
        else:
          session_results['training']['early_stopped'] = True
          save_json(session_results, os.path.join(session_info['session_output_path'], session_folder_name, 'session_results.json'))
          print(f'Early stopping triggered after {session_results["training"]["epoch_last"]+1} epochs!')
          break
      else:
        print(f'Patience: {session_results["training"]["epoch_no_imporve"]}.', end='')

    stats.update()

  session_results['training']['done'] = True
  save_json(session_results, os.path.join(session_info['session_output_path'], session_folder_name, 'session_results.json'))

  if session_info['training']['auto_unassign']:
    time.sleep(30)
    runtime.unassign()