# **sEMG-force-based Gesture Recognition with Spiking Neural Networks**

## Dependencies and functions

### Import Dependencies

In [None]:
try:
  # Import google.colab to check if running in Colab
  import google.colab
  from google.colab import runtime
  colab_env = True  # Set flag to True for Colab environment
except:
  # Set flag to False if not in Colab
  colab_env = False

# Standard library imports
import os
import sys
import shutil
import time
import re
import json
import hashlib
import natsort
from datetime import datetime
from natsort import natsorted
from tqdm import tqdm
import pandas as pd
# Set Pandas display options
pd.set_option('display.max_columns', None)        # Show all columns
pd.set_option('display.max_rows', None)           # Show all rows
pd.set_option('display.width', None)              # Show full width of the DataFrame
pd.set_option('display.max_colwidth', None)       # Show full content of each cell

# Basic scientific computing libraries
import numpy as np
import matplotlib.pyplot as plt

# Advanced scientific computing and statistics
from scipy.signal import butter, filtfilt, resample

# Machine learning and data processing
from sklearn.utils import shuffle

# PyTorch for deep learning
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
if colab_env:
  !pip install snntorch &> /dev/null
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils

try:
  from google.colab import drive
  drive.mount('/content/gdrive/')
  gdrive_path = '/content/gdrive/MyDrive/'  # Define Google Drive path in Colab
except:
  pass

if colab_env:
  !pip install wfdb &> /dev/null
import wfdb

### Function definitions

In [None]:
def get_adam_optimizer(learning_rate):
  """
  Creates and returns an Adam optimizer for a globally defined model with a specified learning rate.

  Args:
  learning_rate (float): The learning rate to be used for the optimizer.

  Globals:
  model: The globally defined model for which the optimizer is to be created.

  Returns:
  An instance of torch.optim.Adam configured with the specified learning rate.
  """
  return torch.optim.Adam(model.parameters(), lr=learning_rate)



def hash_dictionary(dictionary):
  """
  Generates a truncated MD5 hash for a given dictionary.

  This function serializes the dictionary into a JSON string, ensuring the keys are sorted,
  and then computes an MD5 hash of this string. The hash is then truncated to 8 characters.

  Parameters:
  dictionary: The dictionary to hash. It should be capable of being serialized into JSON.

  Returns:
  str: An 8-character string representing the truncated MD5 hash of the dictionary.

  Example:
  >>> hash_dictionary({'key1': 'value1', 'key2': 'value2'})
  'e201065d'
  """

  dict_string = json.dumps(dictionary, sort_keys=True)
  return hashlib.md5(dict_string.encode()).hexdigest()



def find_session_byhash(path, hash):
  """
  Finds the last folder of a specific session based on a provided hash.

  This function searches through subdirectories in the specified path, looking for a file named 'session_hash.txt' in each directory. It compares the content of this file, which should be the session's hash, with the provided hash parameter. If a match is found, it returns only the name of the last folder in the path of the corresponding directory.

  Parameters:
  path (str): The base path to start searching for session directories.
  hash (str): The hash of the session to be found.

  Returns:
  str or None: The name of the last folder in the path of the directory containing the matching session hash, if found; otherwise, None.
  """

  # Iterate through all 'session_parameters.json' files in the specified path
  for root, dirs, files in os.walk(path):
    for file in files:
      if file == 'session_hash.txt':
        file_path = os.path.join(root, file)

        with open(file_path, 'r') as file:
          session_hash = file.read().strip()

        if session_hash == hash:
          return os.path.basename(root)

  return None



def save_json(data, file_path, savedatetime=True):
  """
  Saves a given dictionary as a JSON file at the specified file path, adding a timestamp.

  Args:
  - data (dict): The dictionary to be saved as JSON. If the file_path is 'session_results.json',
                  the current datetime is added to this dictionary under the key 'datetime'.
  - file_path (str): The file path where the JSON file will be saved.

  This function takes a dictionary and a file path as inputs, adds the current datetime to the
  dictionary if the file_path is 'session_results.json', and writes the dictionary as a JSON
  file at the given path. It handles any exceptions related to file operations and prints
  a relevant message in case of an error.
  """

  if savedatetime and 'session_results.json' in file_path:
    data['datetime'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  with open(file_path, 'w') as file:
    json.dump(data, file, indent=4)



def load_json(file_path):
  """
  Loads data from a JSON file into a Python dictionary.

  Args:
  - file_path (str): The file path of the JSON file to be loaded.

  Returns:
  - dict: The data loaded from the JSON file.

  This function takes a file path as input and reads the JSON file from the given path into a Python dictionary.
  It handles any exceptions related to file operations and prints a relevant message in case of an error.
  """

  with open(file_path, 'r') as file:
    data = json.load(file)
  return data



def normalize_json_files(path, dict_name):
  """
  Normalizza i file JSON in una directory specificata e compila i loro contenuti in un DataFrame Pandas.

  Questa funzione cerca i file JSON con un nome specificato nel percorso di directory fornito. Normalizza ciascun file JSON in un DataFrame Pandas, estrae il nome dell'hash della sessione dalla cartella in cui ogni file è stato trovato, e compila tutti i dati in un unico DataFrame.

  Parametri:
  path (str): Il percorso del file dove si trovano i file JSON specificati.
  dict_name (str): Il nome dei file JSON da cercare (es., 'session_parameters.json').

  Restituisce:
  pd.DataFrame: Un DataFrame contenente i dati normalizzati da ciascun file JSON, con la colonna 'session_hash' come prima colonna, indicante l'hash della sessione della cartella dove ogni file JSON è stato trovato.
  """

  # Create an empty DataFrame
  aggregated_df = pd.DataFrame()

  # Iterate through all 'session_parameters.json' files in the specified path
  for root, dirs, files in os.walk(path):
    for file in files:
      if file == dict_name:
        file_path = os.path.join(root, file)

        # Read and normalize the JSON file
        with open(file_path, 'r') as f:
          data = json.load(f)
        normalized_data = pd.json_normalize(data)

        with open(os.path.join(root, 'session_hash.txt'), 'r') as file:
          session_hash = file.read().strip()

        # Insert 'FolderName' as the first column
        normalized_data.insert(0, 'session_hash', session_hash)

        # Add the normalized data to the aggregated DataFrame
        aggregated_df = pd.concat([aggregated_df, normalized_data], ignore_index=True, sort=False)

  return aggregated_df



def get_sorted_dbfiles(directory):
  """
  Generates a naturally sorted list of unique file names (without extensions)
  in a specified directory, including its subdirectories.

  Args:
  directory (str): The path to the directory to be scanned.

  Returns:
  list: A list of sorted unique file names without extensions.
  """
  # Using a set to avoid duplicates
  file_names = set()

  # Recursively scanning the directory and subdirectories
  for root, dirs, files in os.walk(directory):
    for file in files:
      # Removing the file extension
      file_name_without_extension = os.path.splitext(file)[0]
      full_path_without_extension = os.path.join(root, file_name_without_extension)
      file_names.add(full_path_without_extension)

  # Naturally sorting the file names
  return natsort.natsorted(file_names)



def extract_file_info(string):
  """
  Extracts information about subject, session, dataset_type, data_type, object number, and sample
  from a given string using regular expressions, ignoring the file extension.
  Also returns the original string and a modified version of the string with 'object' in place of the object field.

  :param string: The input string to extract data from.
  :return: A dictionary containing extracted information, the original string, and the modified string.
  """

  # Defining the regular expression pattern
  pattern = r"subject(\d+)_session(\d+)/(\w+)_([a-z]+)_\w*(\d+)_sample(\d+)"

  # Searching for the pattern in the string
  match = re.search(pattern, string)

  if match:
    # Create a modified string with 'object' in place of the object field
    modified_string = re.sub(r'_([a-z]+)_', f'_datatype_', string)

    return {
      "subject": int(match.group(1)),
      "session": int(match.group(2)),
      "dataset_type": match.group(3),
      "data_type": match.group(4),
      "object": int(match.group(5)),
      "sample": int(match.group(6)),
      "modified_string": modified_string
      }

  else:
    return "No match found"



def signal_downsample(signal, orig_fs, target_fs):
  """
  Downsamples a signal from an original sampling frequency to a target frequency using fractional resampling with anti-aliasing filtering.
  This method avoids temporal distortion by accurately calculating the number of samples in the downsampled signal.

  :param signal: The input signal to be downsampled, assumed to be a 2D numpy array where rows represent time and columns represent channels.
  :param orig_fs: Original sampling frequency of the signal.
  :param target_fs: Target sampling frequency.
  :return: Downsampled signal.
  """

  # Number of channels
  num_channels = signal.shape[1]

  # Duration of the signal in seconds
  duration = signal.shape[0] / orig_fs

  # Accurate calculation of the number of samples in the downsampled signal
  num_samples = int(duration * target_fs)

  # Design an anti-aliasing Butterworth filter
  nyquist_target = target_fs / 2
  b, a = butter(N=5, Wn=nyquist_target, btype='low', fs=orig_fs)

  # Initialize the downsampled signal array
  downsampled_signal = np.zeros((num_samples, num_channels))

  # Process each channel
  for i in range(num_channels):
    # Apply the filter to the channel
    filtered_channel = filtfilt(b, a, signal[:, i])

    # Resample the filtered channel
    downsampled_signal[:, i] = resample(filtered_channel, num_samples)

  return downsampled_signal



def signal_upsample(signal, orig_fs, target_fs):
  """
  Upsamples a signal from an original sampling frequency to a target frequency using fractional resampling with a low-pass filter.
  This method avoids temporal distortion by accurately calculating the number of samples in the upsampled signal.

  :param signal: The input signal to be upsampled, assumed to be a 2D numpy array where rows represent time and columns represent channels.
  :param orig_fs: Original sampling frequency of the signal.
  :param target_fs: Target sampling frequency.
  :return: Upsampled signal.
  """

  # Number of channels
  num_channels = signal.shape[1]

  # Duration of the signal in seconds
  duration = signal.shape[0] / orig_fs

  # Accurate calculation of the number of samples in the upsampled signal
  num_samples = int(duration * target_fs)

  # Design a low-pass Butterworth filter
  nyquist_orig = orig_fs / 2
  b, a = butter(N=5, Wn=nyquist_orig, btype='low', fs=target_fs)

  # Initialize the upsampled signal array
  upsampled_signal = np.zeros((num_samples, num_channels))

  # Process each channel
  for i in range(num_channels):
    # Resample the channel
    channel_upsampled = resample(signal[:, i], num_samples)

    # Apply the filter to the upsampled channel
    upsampled_signal[:, i] = filtfilt(b, a, channel_upsampled)

  return upsampled_signal



def lowpass_filter(signal, cutoff_frequency, sampling_rate, order=5):
  """
  Applies a low-pass Butterworth filter to each channel of a multi-channel signal.

  Parameters:
  - signal (numpy.ndarray): The input signal array, where each column represents a channel.
  - cutoff_frequency (int or float, optional): The cutoff frequency of the low-pass filter in Hz.
  - sampling_rate (int or float, optional): The sampling rate of the signal in Hz.
  - order (int, optional): The order of the Butterworth filter. Default is 5.

  Returns:
  - numpy.ndarray: The filtered signal, with the same shape as the input signal.
  """

  num_channels = signal.shape[1]
  b, a = butter(N=order, Wn=cutoff_frequency, btype='low', fs=sampling_rate)

  filtered_signal = np.zeros_like(signal)  # Initialize a signal array of zeros with the same shape as the input signal
  for i in range(num_channels):
    filtered_signal[:, i] = filtfilt(b, a, signal[:, i])

  return filtered_signal



def process_data(array, max_derivative_order, delta_value, zc_enable=False, zc_max_derivative_order = 0, zc_value = None): # , ch_start=None, ch_end=None
  """
  Processes an array of data by calculating derivatives up to a specified maximum order and generating a spike array based on delta values. Optionally, it can also detect zero crossings with a specified tolerance and mark these crossings in the spike array.

  Args:
    array (np.ndarray): Numpy array containing the data to be processed. Expected to be a 2D array where rows represent observations and columns represent different variables.
    max_derivative_order (int): Maximum order of derivative to calculate for each data column, not considering zero crossing detection.
    delta_value (list or np.ndarray): List or array of delta values used for determining spikes. Length should match `max_derivative_order + 1`.
    zc_enable (bool, optional): Enables zero crossing detection if set to True. Defaults to False.
    zc_max_derivative_order (int, optional): Maximum derivative order to consider for zero crossing detection. Used only if `zc_enable` is True.
    zc_value (list or np.ndarray, optional): List or array of tolerance values for zero crossing detection for each derivative order. Length should match `zc_max_derivative_order + 1`. Used only if `zc_enable` is True.

  Returns:
    tuple: A tuple containing two np.ndarray elements:
      - The first array is the expanded array that includes calculated derivatives for each data column up to the specified `global_max_derivative_order`, which is the maximum of `max_derivative_order` and `zc_max_derivative_order` if zero crossing detection is enabled, otherwise it's just `max_derivative_order`.
      - The second array is the spike array, where spikes (represented by 1s) indicate significant changes in the data or zero crossings, based on `delta_value` and `zc_value`.

  Note:
    The function assumes that the input `array` is already preprocessed and ready for derivative calculation and spike detection. The lengths of `delta_value` and `zc_value` must be appropriate for the specified derivative orders. The derivative and spike arrays are sized based on the `global_max_derivative_order` to accommodate the computation of derivatives and spike detection across all orders.
  """

  old_dim_size = array.shape[1]
  global_max_derivative_order = max(max_derivative_order, zc_max_derivative_order) if zc_enable else max_derivative_order

  spike_array = np.zeros((array.shape[0] - global_max_derivative_order * 4, (old_dim_size * (max_derivative_order + 1) * 2) + (old_dim_size * (zc_max_derivative_order + 1) if zc_enable else 0)), dtype=np.int8)
  derivative_array = np.zeros((array.shape[0], old_dim_size * (global_max_derivative_order + 1)))
  derivative_array[:, :old_dim_size] = array

  if global_max_derivative_order:
    for n in range(1, global_max_derivative_order + 1):
      for i in range(old_dim_size):
        for j in range(array[:, i].shape[0] - n * 4):
          derivative_array[j + n * 2, old_dim_size * n + i] = - derivative_array[j, old_dim_size * (n - 1) + i] - 2 * derivative_array[j + 1, old_dim_size * (n - 1) + i] + 2 * derivative_array[j + 2, old_dim_size * (n - 1) + i] + derivative_array[j + 3, old_dim_size * (n - 1) + i]

    # Trim the arrays to the minimum shape
    derivative_array = derivative_array[global_max_derivative_order * 2 : - global_max_derivative_order * 2, :]

  for n in range(max_derivative_order + 1):
    for i in range(old_dim_size):
      dc_val = derivative_array[0, old_dim_size * n + i] # or = 0
      for k, j in enumerate(derivative_array[:, old_dim_size * n + i]):
        if j > dc_val + delta_value[n]:
          dc_val = j
          spike_array[k, (old_dim_size * n + i) * 2] = 1
          spike_array[k, (old_dim_size * n + i) * 2 + 1] = 0
        elif j < dc_val - delta_value[n]:
          dc_val = j
          spike_array[k, (old_dim_size * n + i) * 2] = 0
          spike_array[k, (old_dim_size * n + i) * 2 + 1] = 1
        else:
          spike_array[k, (old_dim_size * n + i) * 2] = 0
          spike_array[k, (old_dim_size * n + i) * 2 + 1] = 0

  if zc_enable:
    for n in range(zc_max_derivative_order + 1):
      for i in range(old_dim_size):
        zc_state = 'above' if derivative_array[0, old_dim_size * n + i] > zc_value[n] else 'below' if derivative_array[0, old_dim_size * n + i] < -zc_value[n] else 'inside'
        for k, j in enumerate(derivative_array[:, old_dim_size * n + i]):
          current_state = 'above' if j > zc_value[n] else 'below' if j < -zc_value[n] else zc_state
          if current_state != zc_state and current_state != 'inside':
            spike_array[k, (old_dim_size * (max_derivative_order + 1) * 2) + (old_dim_size * n + i)] = 1
            zc_state = current_state

  return derivative_array, spike_array

def safe_format(value, fmt=".5f"):
  """
  Safely format a given value, handling None cases.

  Args:
  value: The value to be formatted. Can be None or any value that supports formatting.
  fmt: The format string (default is ".5f").

  Returns:
  A formatted string according to 'fmt' if 'value' is not None, otherwise "N/D" (Not Available).
  """

  return f"{value:{fmt}}" if value is not None else "N/D"

def progress_bar(iteration, total, start_time, initial_progress=0, length=25, unit='', prefix='', suffix='', init=False):
  """
  Displays a progress bar to monitor the progress of an iterative process.

  Args:
  iteration (int): The current iteration number.
  total (int): The total number of iterations.
  start_time (float): The start time of the process (typically obtained using time.time()).
  initial_progress (int, optional): Initial progress for ETA calculation. Default is 0.
  length (int, optional): The length of the progress bar. Default is 34.
  unit (str, optional): The measurement unit for the iteration counter. Default is an empty string.
  prefix (str, optional): A prefix to add before the progress bar. Default is an empty string.
  suffix (str, optional): A suffix to add after the progress bar. Default is an empty string.
  init (bool, optional): If True, displays only the progress bar without time information. Default is False.

  Returns:
  None: The function does not return anything but prints the progress bar with additional information.

  The function updates the progress bar based on the current iteration and calculates the ETA (Estimated Time of Arrival)
  based on elapsed time and progress. If 'init' is True, it only prints the progress bar without calculating the ETA.
  """

  if iteration is None:
    iteration = 0

  unit = unit if not unit else ' ' + unit
  prefix = prefix if not prefix else prefix + ' '
  suffix = suffix if not suffix else ' ' + suffix

  # Complete the bar on the last iteration
  if iteration == total:
    bar = '█' * length
    percent = 100
  else:
    percent = 100 * (iteration / float(total))
    filled_length = int(length * iteration // total)
    full_char = '█'
    partial_chars = [' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉']
    partial_index = int((length * iteration % total) / (total / len(partial_chars)))
    bar = full_char * filled_length + partial_chars[partial_index] + ' ' * (length - filled_length - 1)

  if init:
    sys.stdout.write(f'\r{prefix}|{bar}| {iteration}/{total}{unit}{suffix}')
  else:
    elapsed_time = time.time() - start_time
    eta = (elapsed_time / max(1, iteration - initial_progress)) * (total - iteration)
    sec_per_iter = elapsed_time / max(1, iteration - initial_progress)

    elapsed_time_str = time.strftime('%H:%M:%S', time.gmtime(elapsed_time))
    eta_str = time.strftime('%H:%M:%S', time.gmtime(eta))

    sys.stdout.write(f'\r{prefix}|{bar}| {iteration}/{total}{unit} [{percent:.1f}%, {elapsed_time_str}<{eta_str}, {sec_per_iter:.2f}s/it]{suffix}{" "*20}')

  sys.stdout.flush()

In [None]:
def load_mvc_wfdb(path, subject, session, finger, direction):
  record_name = f"{path}/mvc_dataset/subject{subject:02d}_session{session}/mvc_force_finger{finger}_{direction}"
  try:
    record = wfdb.rdrecord(record_name)
    data = record.p_signal
  except FileNotFoundError:
    raise FileNotFoundError(f"Failed to open: {record_name}")

  return data

def get_mvc_finger(path, subject, session, finger):
  mvc = []
  for i in ['flexion', 'extension']:
    force_data = load_mvc_wfdb(path, subject, session, finger, i)[:, finger - 1]
    force_data = np.abs(force_data)
    force_data = np.sort(force_data)
    mvc.append(np.mean(force_data[-200:]))
  return mvc

def get_mvc(path, subject, session):
  mvc = []
  for i in range(1, 6):
    mvc.append(get_mvc_finger(path, subject, session, i))
  return mvc

def mvc_norm(signal, mvc_values):
  result = np.empty_like(signal)
  
  for i in range(5):
    NEG = mvc_values[i][0] # 'flexion'
    POS = mvc_values[i][1] # 'extension'
    # result[:, i][signal[:, i] > 0] = signal[:, i][signal[:, i] > 0] / POS
    # result[:, i][signal[:, i] < 0] = signal[:, i][signal[:, i] < 0] / NEG
    result[:, i] = np.where(signal[:, i] > 0, signal[:, i] / POS, signal[:, i])
    result[:, i] = np.where(signal[:, i] < 0, signal[:, i] / NEG, result[:, i])
  
  return result

def mvc_norm_set(data, mvc_path, mvc_session):
    new_data = np.empty_like(data)
    prev_subject = 0

    for i, entry in enumerate(data):
        subject = int(i / 5) + 1
        if prev_subject != subject:
            force_finger_mvc = get_mvc(mvc_path, subject, mvc_session)
            prev_subject = subject
        new_data[i] = mvc_norm(entry, force_finger_mvc)
    
    return new_data

### Class definitions

In [None]:
class HDsEMG_Dataset(torch.utils.data.Dataset):
  def __init__(self, target, spike_in):
    self.num_samples = target.shape[0]
    timesteps = target.shape[2]

    target   = target.astype(np.float32)
    spike_in = spike_in.astype(np.float32)

    labels_target = []
    lista = []
    for i in range(self.num_samples):
      target_i = target[i, :, :]
      target_i = np.transpose(target_i, (1, 0))
      labels_target.append(torch.from_numpy(target_i))

      spike_ing = spike_in[i, :, :]
      spike_ing = np.transpose(spike_ing, (1, 0))
      lista.append(torch.from_numpy(spike_ing))

    self.labels   = torch.stack(labels_target, dim=1)
    self.features = torch.stack(lista, dim=1)


  def __len__(self):
    """Number of samples."""
    return self.num_samples

  def __getitem__(self, idx):
    """General implementation, but we only have one sample."""
    return self.features[:, idx, :], self.labels[:, idx, :]

class Net(torch.nn.Module):
  def __init__(self, timesteps, input = 1):
    """
    Initializes the spiking neural network.

    Parameters:
    timesteps (int): Number of time steps for simulating the network.
    input (int, optional): Number of input features. Default is 1.

    The network consists of several layers, each followed by a Leaky integrate-and-fire (LIF) neuron model.
    The layers `fc_0`, `fc_1`, `fc_2`, and `fc_3` are fully connected.
    The neurons `lif_0`, `lif_1`, `lif_2`, and `li_3` are LIF models with parameters set according to `session_parameters`.

    The network supports either a single threshold and beta decay value per layer (`threshold_perlayer` and `beta_decay_perlayer` set to True), or unique values for each neuron within a layer (set to False).
    """

    super().__init__()

    self.timesteps = timesteps # number of time steps to simulate the network
    spike_grad = surrogate.fast_sigmoid() # surrogate gradient function

    if session_parameters['network']['threshold_perlayer']:
      thr_0 = session_parameters['network']['threshold'][0]
      thr_1 = session_parameters['network']['threshold'][1]
      thr_2 = session_parameters['network']['threshold'][2]
      thr_3 = session_parameters['network']['threshold'][3]
    else:
      thr_0 = torch.ones(session_results['network']['ofs'][0]) * session_parameters['network']['threshold'][0]
      thr_1 = torch.ones(session_results['network']['ofs'][1]) * session_parameters['network']['threshold'][1]
      thr_2 = torch.ones(session_results['network']['ofs'][2]) * session_parameters['network']['threshold'][2]
      thr_3 = torch.ones(session_results['network']['ofs'][3]) * session_parameters['network']['threshold'][3]

    if session_parameters['network']['beta_decay_perlayer']:
      beta_0 = session_parameters['network']['beta_decay'][0]
      beta_1 = session_parameters['network']['beta_decay'][1]
      beta_2 = session_parameters['network']['beta_decay'][2]
      beta_3 = session_parameters['network']['beta_decay'][3]
    else:
      beta_0 = torch.ones(session_results['network']['ofs'][0]) * session_parameters['network']['beta_decay'][0]
      beta_1 = torch.ones(session_results['network']['ofs'][1]) * session_parameters['network']['beta_decay'][1]
      beta_2 = torch.ones(session_results['network']['ofs'][2]) * session_parameters['network']['beta_decay'][2]
      if session_parameters['network']['beta_decay_outsingle']:
        beta_3 = session_parameters['network']['beta_decay'][3]
      else:
        beta_3 = torch.ones(session_results['network']['ofs'][3]) * session_parameters['network']['beta_decay'][3]

    self.fc_0 = torch.nn.Linear(in_features=input, out_features=session_results['network']['ofs'][0], bias = session_parameters['network']['bias'])
    self.lif_0 = snn.Leaky(beta=beta_0, threshold=thr_0, learn_beta=session_parameters['network']['beta_decay_train'], spike_grad=spike_grad, learn_threshold=session_parameters['network']['threshold_train'], reset_mechanism=session_parameters['network']['reset'])

    self.fc_1 = torch.nn.Linear(in_features=session_results['network']['ofs'][0], out_features=session_results['network']['ofs'][1], bias = session_parameters['network']['bias'])
    self.lif_1 = snn.Leaky(beta=beta_1, threshold=thr_1, learn_beta=session_parameters['network']['beta_decay_train'], spike_grad=spike_grad, learn_threshold=session_parameters['network']['threshold_train'], reset_mechanism=session_parameters['network']['reset'])

    self.fc_2 = torch.nn.Linear(in_features=session_results['network']['ofs'][1], out_features=session_results['network']['ofs'][2], bias = session_parameters['network']['bias'])
    self.lif_2 = snn.Leaky(beta=beta_2, threshold=thr_2, learn_beta=session_parameters['network']['beta_decay_train'], spike_grad=spike_grad, learn_threshold=session_parameters['network']['threshold_train'], reset_mechanism=session_parameters['network']['reset'])

    self.fc_3 = torch.nn.Linear(in_features=session_results['network']['ofs'][2], out_features=session_results['network']['ofs'][3], bias = session_parameters['network']['bias'])
    self.li_3 = snn.Leaky(beta=beta_3, threshold=1, learn_beta=session_parameters['network']['beta_decay_train'], spike_grad=spike_grad, reset_mechanism='none')

  def forward(self, x):
    """
    Forward pass for processing input data over multiple time steps.

    Parameters:
    x (Tensor): Input tensor with shape corresponding to [timesteps, batch_size, features].

    Returns:
    Tensor: Output tensor after processing through the network layers and time steps.

    The method processes the input `x` through the network layers across the specified number of time steps.
    For each time step, it computes the current (`cur_`) and spike (`spk_`) for each layer and updates the membrane potentials (`mem_`).
    The output for the final layer (`mem_3`) is recorded at each time step and stacked to form the return tensor.
    """

    # Initalize membrane potential
    mem_0 = self.lif_0.init_leaky()
    mem_1 = self.lif_1.init_leaky()
    mem_2 = self.lif_2.init_leaky()
    mem_3 = self.li_3.init_leaky()

    # Empty lists to record outputs
    mem_3_rec = []

    # Loop over
    for step in range(self.timesteps):
      x_timestep = x[step, :, :]

      cur_0 = self.fc_0(x_timestep)
      spk_0, mem_0 = self.lif_0(cur_0, mem_0)

      cur_1 = self.fc_1(spk_0)
      spk_1, mem_1 = self.lif_1(cur_1, mem_1)

      cur_2 = self.fc_2(spk_1)
      spk_2, mem_2 = self.lif_2(cur_2, mem_2)

      cur_3 = self.fc_3(spk_2)
      _, mem_3 = self.li_3(cur_3, mem_3)

      mem_3_rec.append(mem_3)

    return torch.stack(mem_3_rec)

class Net_(torch.nn.Module):
  """Simple spiking neural network in snntorch."""

  def __init__(self, timesteps, input = 1, target_variables = 1):
    super().__init__()

    l0 = train_setting['hyperparams'][0]
    l1 = train_setting['hyperparams'][1]
    l2 = train_setting['hyperparams'][2]

    self.timesteps = timesteps # number of time steps to simulate the network
    spike_grad = surrogate.fast_sigmoid() # surrogate gradient function

    if train_setting['1thr4layer']:
      thr_in       = train_setting['thr_init']
      thr_hidden   = train_setting['thr_init']
      thr_hidden2  = train_setting['thr_init']
    else:
      thr_in      = torch.ones(l0)*train_setting['thr_init']
      thr_hidden  = torch.ones(l1)*train_setting['thr_init']
      thr_hidden2 = torch.ones(l2)*train_setting['thr_init']

    if train_setting['1beta4layer']:
      beta_in      = train_setting['bet_init']
      beta_hidden  = train_setting['bet_init']
      beta_hidden2 = train_setting['bet_init']
      beta_out     = train_setting['bet_init']
    else:
      beta_in      = torch.ones(l0)*train_setting['bet_init']
      beta_hidden  = torch.ones(l1)*train_setting['bet_init']
      beta_hidden2 = torch.ones(l2)*train_setting['bet_init']
      beta_out     = torch.ones(target_variables)*train_setting['bet_init']
      if train_setting['beta_out_single']:
        beta_out   = train_setting['bet_init']

    print("in_features, out_features", input, l0)

    # layer 1
    self.fc_in = torch.nn.Linear(in_features=input, out_features=l0, bias = train_setting['bias'])
    self.lif_in = snn.Leaky(beta=beta_in, threshold=thr_in, learn_beta=train_setting['decay_train'], spike_grad=spike_grad, learn_threshold=train_setting['thr_train'], reset_mechanism=train_setting['reset'])

    # layer 2
    self.fc_hidden = torch.nn.Linear(in_features=l0, out_features=l1, bias = train_setting['bias'])
    self.lif_hidden = snn.Leaky(beta=beta_hidden, threshold=thr_hidden, learn_beta=train_setting['decay_train'], spike_grad=spike_grad, learn_threshold=train_setting['thr_train'], reset_mechanism=train_setting['reset'])

    # layer 3
    self.fc_hidden2 = torch.nn.Linear(in_features=l1, out_features=l2, bias = train_setting['bias'])
    self.lif_hidden2 = snn.Leaky(beta=beta_hidden2, threshold=thr_hidden2, learn_beta=train_setting['decay_train'], spike_grad=spike_grad, learn_threshold=train_setting['thr_train'], reset_mechanism=train_setting['reset'])

    # layer 4: leaky integrator neuron. Note the reset mechanism is disabled and we will disregard output spikes.
    self.fc_out = torch.nn.Linear(in_features=l2, out_features = target_variables, bias = train_setting['bias'])
    self.li_out = snn.Leaky(beta=beta_out, threshold=1.0, learn_beta=train_setting['decay_train'], spike_grad=spike_grad, reset_mechanism="none")


  def forward(self, x):
    """Forward pass for several time steps."""

    # Initalize membrane potential
    mem_1 = self.lif_in.init_leaky()
    mem_2 = self.lif_hidden.init_leaky()
    mem_22 = self.lif_hidden2.init_leaky()
    mem_3 = self.li_out.init_leaky()

    # Empty lists to record outputs
    mem_3_rec = []

    # Loop over
    for step in range(self.timesteps):
        x_timestep = x[step, :, :]

        cur_in = self.fc_in(x_timestep)
        spk_in, mem_1 = self.lif_in(cur_in, mem_1)

        cur_hidden = self.fc_hidden(spk_in)
        spk_hidden, mem_2 = self.lif_hidden(cur_hidden, mem_2)

        cur_hidden2 = self.fc_hidden2(spk_hidden)
        spk_hidden2, mem_22 = self.lif_hidden2(cur_hidden2, mem_22)

        cur_out = self.fc_out(spk_hidden2)
        _, mem_3 = self.li_out(cur_out, mem_3)

        mem_3_rec.append(mem_3)

    return torch.stack(mem_3_rec)

## Load or create session

In [None]:
session_info = {
  'project_name' : 'semg-force-gesturerecognition-snn', # Project name.
  'session_loadbyhash' : '',                      # Complete this field to load a session using its hash.
  'session_note' : '',                            # Session note.  lrmod3_20
  'session_output_path' : 'output/',              # Set the path to the project's output directory

  'overwrite' : False,                            # Flag to indicate if existing files should be overwritten.
  'github_support' : True,                        # Indicates if GitHub support is enabled.
  'repository_branch' : None,                     # Branch of the repository to use, default is 'main'.
  'lazygit_support' : False,                      # Indicates if LazyGit support is enabled.

  'dataset' : {
    'dataset_path' : '../datasets/emg/',          # Path to the dataset directory.
    'plot_exampledata' : True                     # Flag to indicate if example data from the dataset should be plotted.
  },

  'training' : {
    'epoch_max' : 800,                            # Maximum number of training epochs. If set to None, there will be no precise number of maximum epochs.
    'bestmodel_lrchange' : True,                  # If set to True, the model will revert to the best model state when a learning rate change occurs.
    'auto_unassign' : False,                      # Flag to indicate automatic unassignment of resources post-training.
  }
}

### Existing session

#### Existing session info

In [None]:
# Normalize the JSON files and create a DataFrame
df = normalize_json_files(session_info['session_output_path'], 'session_info.json')

# List of columns to ignore
columns_to_ignore = []

# Select all columns except those to ignore
columns_to_keep = [col for col in df.columns if col not in columns_to_ignore]
df_filtered = df[columns_to_keep]

# Display the filtered DataFrame
pd.DataFrame(df_filtered)

#### Existing session paramters

In [None]:
# Normalize the JSON files and create a DataFrame
df = normalize_json_files(session_info['session_output_path'], 'session_parameters.json')

# List of columns to ignore
columns_to_ignore = []

# Select all columns except those to ignore
columns_to_keep = [col for col in df.columns if col not in columns_to_ignore]
df_filtered = df[columns_to_keep]

# Display the filtered DataFrame
pd.DataFrame(df_filtered)

#### Existing session results

In [None]:
# Normalize the JSON files and create a DataFrame
df = normalize_json_files(session_info['session_output_path'], 'session_results.json')

# List of columns to ignore
columns_to_ignore = ['hash_name', 'colab_env', 'device', 'training.optimizer', 'training.train_loss', 'training.valid_loss', 'training.learning_rate']

# Select all columns except those to ignore
columns_to_keep = [col for col in df.columns if col not in columns_to_ignore]
df_filtered = df[columns_to_keep]

# Display the filtered DataFrame
pd.DataFrame(df_filtered)

### Session parameters

In [None]:
session_parameters = {
  'hash_seed' : None,                             # Hash seed. Default None.
  'dataset' : {
    'dataset_subpath' : 'hd-semg/2.0.0/1dof_dataset/', # Subpath to the dataset directory. Currently tested {'ninapro/db1/', 'ninapro/db2/', 'ninapro/db5/'}.
    'dataset_format' : 'wfdb',                    # Dataset format. Currently supported {'wfdb', 'numpy'}.
    'in_data_freq_hz' : 2048,                     # Frequency of input data sampling in Hertz. None if the sampling frequency is not changed.
    'out_data_freq_hz' : 100,                     # Frequency of output data sampling in Hertz. None if is equal to input data sampling in Hertz.
    'out_data_filterfreq_hz' : 5,                 # ... None if is equal to output data sampling in Hertz.
    'new_in_data_freq_hz' : 1000,                 # Frequency of input data sampling in Hertz. None if the sampling frequency is not changed.
    'new_out_data_freq_hz' : 1000,                # Frequency of output data sampling in Hertz. None if is equal to output data sampling in Hertz.
    'subjects' : [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19],                              # List of subject IDs (enumarate) to include in the dataset. Set an empty list to include all subjects.
    'sessions' : [2],
    'objects_file' : [],
    'objects_class' : [],
    'samples' : [],
    'channel_start' : None,                       # Starting index of channels to consider (None for the first possible channel).
    'channel_end' : None,                         # Ending index of channels to consider (None for the last possible channel).
    'channels_downsampling' : 2,
    'max_derivative_order' : 0,                   # Maximum order of derivative to calculate for the data.
    'delta_value' : 0.0165*2,                     # Value of delta for delta encoding. Set a list of values if you want different values per layer. Tested: 0.0165 at 100 Hz
    'zc_enable' : False,
    'zc_max_derivative_order' : 0,
    'zc_value' : None,                            # [0.0175, 0.06]
    'force_gain' : 1000,
    'mvc_norm' : True, 

    'train_samples' : [2, 3],                     # Object to be used for training. 3... [2, 3]
    'valid_samples' : [1],                        # Object to be used for validation.
    'test_samples' : [1],                         # Object to be used for testing.

    'random_seed' : 0,                            # Seed for random number generation to ensure reproducibility.
  },

  'network' : {
    'ofs' : [64, 128, 64, None],                  # Output feature values for network layers, the last feature output value (number of output classes) is calculated automatically if set to None.
    'bias' : False,                               # True if a bias is inserted in each neuron
    'threshold' : 0.1,                            # Threshold value. Set a list of values if you want different values per layer.
    'threshold_train' : False,                    # True if the threshold has been trained
    'threshold_perlayer' : True,                  # True: a single threshold value per layer; False: a unique threshold value for each neuron
    'beta_decay' : 0.9,                           # Beta decay value. Set a list of values if you want different values per layer.
    'beta_decay_train' : True,                    # True if the beta decay has been trained
    'beta_decay_perlayer' : False,                # True: a single beta decay value per layer; False: a unique beta decay value for each neuron
    'beta_decay_outsingle' : False,               # True: a single decay value for neurons in the last layer
    'reset' : 'subtract',                         # None; subtract; zero; mechanism for neuron reset
  },

  'training' : {
    'pretrained_model' : None,                    # Set a pretrained-folder path if needed.
    # 'optimizer' : 'adam',                       # Optimizer to use for training, currently supported {'adam'}.
    'best_loss' : True,                           # If set to True, the best model will be selected during training based on the lowest loss on the training set. If set to False, the selection will be based on the lowest loss on the validation set.
    'batch_size' : 32,                            # Batch size for training.
    'learning_rate_first' : 0.001,                # Initial learning rate for the optimizer.
    'learning_rate_decay' : 1/3,                  # Factor by which the learning rate decays.
    'learning_rate_decay_steps' : 4,              # Number of times the learning_rate_decay is applied.
    'patience' : 30                               # Number of epochs to wait for improvement before stopping training or changing the learning rate.
  }
}

session_results = {
  'hash_name' : None,                # UNEDITABLE # Hash of the session_parameters dictionary.
  'colab_env' : None,                # UNEDITABLE # Google Colab support.
  'device' : None,                   # UNEDITABLE # Hardware support.
  'datetime' : None,                 # UNEDITABLE # Last modification.

  'network': {
    'ofs': None                      # UNEDITABLE # Final offset values determined for the network layers.
  }
}

# Dictionary serving as a repository for various optimizer functions.
optimizer_functions = {
  'adam': get_adam_optimizer
}

if not session_parameters['dataset']['in_data_freq_hz'] or not session_parameters['dataset']['new_in_data_freq_hz'] or session_parameters['dataset']['in_data_freq_hz'] == session_parameters['dataset']['new_in_data_freq_hz']:
  session_parameters['dataset']['in_data_freq_hz'] = None
  session_parameters['dataset']['new_in_data_freq_hz'] = None

if not session_parameters['dataset']['out_data_freq_hz'] or (not session_parameters['dataset']['new_out_data_freq_hz'] and not session_parameters['dataset']['out_data_filterfreq_hz']) or (not session_parameters['dataset']['out_data_filterfreq_hz'] and session_parameters['dataset']['out_data_freq_hz'] == session_parameters['dataset']['new_out_data_freq_hz']):
  session_parameters['dataset']['out_data_freq_hz'] = None
  session_parameters['dataset']['out_data_filterfreq_hz'] = None
  session_parameters['dataset']['new_out_data_freq_hz'] = None

if not session_parameters['dataset']['out_data_freq_hz'] and (session_parameters['dataset']['out_data_filterfreq_hz'] or session_parameters['dataset']['new_out_data_freq_hz']):
  print('"out_data_freq_hz" value is missing')
  raise

if not isinstance(session_parameters['dataset']['delta_value'], list):
  session_parameters['dataset']['delta_value'] = [session_parameters['dataset']['delta_value']] * (session_parameters['dataset']['max_derivative_order'] + 1) # len(session_parameters['network']['ofs'])

if not isinstance(session_parameters['dataset']['zc_value'], list):
  session_parameters['dataset']['zc_value'] = [session_parameters['dataset']['zc_value']] * (session_parameters['dataset']['zc_max_derivative_order'] + 1) # len(session_parameters['network']['ofs'])

if not isinstance(session_parameters['network']['threshold'], list):
  session_parameters['network']['threshold'] = [session_parameters['network']['threshold']] * len(session_parameters['network']['ofs'])

if not isinstance(session_parameters['network']['beta_decay'], list):
  session_parameters['network']['beta_decay'] = [session_parameters['network']['beta_decay']] * len(session_parameters['network']['ofs'])

if not session_parameters['network']['threshold_train']:
  session_parameters['network']['threshold_perlayer'] = True

if not session_parameters['network']['beta_decay_train']:
  session_parameters['network']['beta_decay_perlayer'] = True

if session_parameters['network']['beta_decay_perlayer']:
  session_parameters['network']['beta_decay_outsingle'] = True

if not session_parameters['dataset']['channels_downsampling'] or session_parameters['dataset']['channels_downsampling'] == 1:
  session_parameters['dataset']['channels_downsampling'] = None
else:
  session_parameters['dataset']['channels_downsampling'] = int(session_parameters['dataset']['channels_downsampling'])

if not session_parameters['dataset']['force_gain']:
  session_parameters['dataset']['force_gain'] = 1

session_results['hash_name'] = hash_dictionary(session_parameters)[:8]
session_results['colab_env'] = colab_env
session_results['device'] =  'cpu' # str(torch.device('cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')))

session_folder_name = session_results['hash_name'] # if session_info['hash_rename'] else session_info['session_name']
session_loaded = False

if session_info['session_loadbyhash']:
  session_info['overwrite'] = False

if session_info['overwrite']:
  try:
    shutil.rmtree(os.path.join(session_info['session_output_path'], session_folder_name))
  except:
    pass
  os.makedirs(os.path.join(session_info['session_output_path'], session_folder_name), exist_ok=True)

  print('Session overwriting')

else:
  for session_hash in [session_info['session_loadbyhash']] if session_info['session_loadbyhash'] else [session_results['hash_name']]:
    existing_session_folder_name = find_session_byhash(session_info['session_output_path'], session_hash)
    if existing_session_folder_name:
      try:
        session_info_loaded = load_json(os.path.join(session_info['session_output_path'], existing_session_folder_name, 'session_info.json'))
        session_parameters_loaded = load_json(os.path.join(session_info['session_output_path'], existing_session_folder_name, 'session_parameters.json'))
        session_results_loaded = load_json(os.path.join(session_info['session_output_path'], existing_session_folder_name, 'session_results.json'))

        session_info_loaded['overwrite'] = session_info['overwrite']
        session_info_loaded['lazygit_support'] = session_info['lazygit_support']
        session_info_loaded['dataset']['plot_exampledata'] = session_info['dataset']['plot_exampledata']
        session_info_loaded['training']['epoch_max'] = session_info['training']['epoch_max']
        session_info_loaded['training']['auto_unassign'] = session_info['training']['auto_unassign']

        session_results_loaded['hash_name'] = session_results['hash_name']
        session_results_loaded['colab_env'] = session_results['colab_env']
        session_results_loaded['device'] = session_results['device']

        session_info = session_info_loaded
        session_parameters = session_parameters_loaded
        session_results = session_results_loaded

        session_folder_name = existing_session_folder_name

        session_loaded = True
        print('Existing session loaded:', os.path.join(session_info['session_output_path'], existing_session_folder_name))
        break
      except:
        print('Error loading session with hash:', session_info['session_loadbyhash'])
        raise

if not session_loaded:
  if session_info['session_loadbyhash']:
    print('No session found with hash:', session_info['session_loadbyhash'])
    raise

  else:
    os.makedirs(os.path.join(session_info['session_output_path'], session_folder_name), exist_ok=True)
    print('New session created:', os.path.join(session_info['session_output_path'], session_folder_name))

    save_json(session_info, os.path.join(session_info['session_output_path'], session_folder_name, 'session_info.json'))
    save_json(session_parameters, os.path.join(session_info['session_output_path'], session_folder_name, 'session_parameters.json'))
    save_json(session_results, os.path.join(session_info['session_output_path'], session_folder_name, 'session_results.json'), savedatetime=False)

    with open(os.path.join(session_info['session_output_path'], session_folder_name, 'session_hash.txt'), 'w') as file:
      file.write(session_results['hash_name'])

try:
  if session_results['training']['epoch_last']:
    session_trained = True
  else:
    session_trained = False
except:
  session_trained = False

## Session environment

In [None]:
if session_info['github_support'] and session_results['colab_env']:
  # Import userdata module from google.colab
  from google.colab import userdata

  # Remove the default sample_data directory in Colab
  %rm -rf /content/sample_data/

  # Change current working directory to /content
  %cd /content/

  # Repository name on GitHub
  repository = session_info['project_name']

  # Retrieve Git credentials stored in userdata
  github_json = userdata.get('github_json')
  github_json = json.loads(github_json)

  # Configure global Git settings with the retrieved credentials
  !git config --global user.name {github_json['name']}
  !git config --global user.email {github_json['email']}
  !git config --global user.password {github_json['password']}

  # Clone the GitHub repository using Git token for authentication
  !git clone -b {session_info['repository_branch'] if session_info['repository_branch'] else 'main'} https://{github_json['token']}@github.com/{github_json['name']}/{repository}

  # Change directory to the cloned repository's directory
  %cd /content/{repository}/

  if 'hd-semg-1dof-e100-f100/2.0.0/1dof_dataset/' in session_parameters['dataset']['dataset_subpath']:
    os.makedirs(f'/content/{repository}/dataset/', exist_ok=True)
    !git clone -b {session_info['repository_branch'] if session_info['repository_branch'] else 'main'} https://{github_json['token']}@github.com/{github_json['name']}/hd-semg-1dof-e100-f100 dataset/hd-semg-1dof-e100-f100

  if 'hd-semg-1dof-e400-f400/2.0.0/1dof_dataset/' in session_parameters['dataset']['dataset_subpath']:
    os.makedirs(f'/content/{repository}/dataset/', exist_ok=True)
    !git clone -b {session_info['repository_branch'] if session_info['repository_branch'] else 'main'} https://{github_json['token']}@github.com/{github_json['name']}/hd-semg-1dof-e400-f400 dataset/hd-semg-1dof-e400-f400

  if session_info['lazygit_support']:
    LAZYGIT_VERSION = !echo $(curl -s "https://api.github.com/repos/jesseduffield/lazygit/releases/latest" | grep -Po '"tag_name": "v\K[^"]*')
    LAZYGIT_SOURCE = f"https://github.com/jesseduffield/lazygit/releases/latest/download/lazygit_{LAZYGIT_VERSION[0]}_Linux_x86_64.tar.gz"
    !curl -Lo lazygit.tar.gz {LAZYGIT_SOURCE}
    !tar xf lazygit.tar.gz lazygit
    !install lazygit /usr/local/bin
    %rm lazygit.tar.gz
    %rm lazygit

else:
  # Change directory to main directory
  #%cd ...
  pass

## Dataset Processing and Loader


In [None]:
all_files = get_sorted_dbfiles(os.path.join(session_info['dataset']['dataset_path'], session_parameters['dataset']['dataset_subpath']))
selected_files = set()
#npsave_path = os.path.join(session_info['dataset']['dataset_path'], f'hd-semg-1dof-e{session_parameters["dataset"]["new_in_data_freq_hz"]}-f{session_parameters["dataset"]["new_out_data_freq_hz"]}/2.0.0/1dof_dataset/')

for file in all_files:
  extracted_info = extract_file_info(file.replace(os.path.join(session_info['dataset']['dataset_path'], session_parameters['dataset']['dataset_subpath']), ''))
  if extracted_info != 'No match found':
    try:
      if (
        (not session_parameters['dataset']['subjects'] or extracted_info['subject'] in session_parameters['dataset']['subjects']) and
        (not session_parameters['dataset']['sessions'] or extracted_info['session'] in session_parameters['dataset']['sessions']) and
        (not session_parameters['dataset']['objects_file'] or extracted_info['object'] in session_parameters['dataset']['objects_file']) and
        # (not session_parameters['dataset']['objects'] or extracted_info['object'] in session_parameters['dataset']['objects']) and
        (not session_parameters['dataset']['samples'] or extracted_info['sample'] in session_parameters['dataset']['samples'])
      ):
        selected_files.add(extracted_info['modified_string'])
    except:
      pass

# Initialize an empty dictionary to store classes found in the dataset
classes = set()
selected_files = natsort.natsorted(selected_files)

# Initialize empty lists to collect data for training, validation, and testing sets
x_train, x_valid, x_test = [], [], []
y_train, y_valid, y_test = [], [], []
x_orig_train, x_orig_valid, x_orig_test = [], [], []
y_orig_train, y_orig_valid, y_orig_test = [], [], []
patient_train, patient_valid, patient_test = [], [], []
all_data_maxsample = []
all_data = []
s_all_data_maxsample = []
s_all_data = []

tmpcnt = 0

for f in tqdm(selected_files, desc='Processing and loading Files', ncols=80):
  extracted_info = extract_file_info(f.replace(session_parameters['dataset']['dataset_subpath'], ''))

  if tmpcnt == 20:
    break
  tmpcnt += 1

  # sEMG raw process
  f_raw = f.replace('datatype', 'raw')
  if session_parameters['dataset']['dataset_format'] == 'wfdb':
    signal = np.array(wfdb.rdrecord(os.path.join(session_info['dataset']['dataset_path'], session_parameters['dataset']['dataset_subpath'], f_raw)).p_signal)
  elif session_parameters['dataset']['dataset_format'] == 'numpy':
    signal = np.load(os.path.join(session_info['dataset']['dataset_path'], session_parameters['dataset']['dataset_subpath'], f_raw) + '.npy')
  else:
    break

  if session_parameters['dataset']['channel_start'] or session_parameters['dataset']['channel_end']:
    signal = signal[:, session_parameters['dataset']['channel_start'] : session_parameters['dataset']['channel_end']]

  if session_parameters['dataset']['channels_downsampling']:
    signal = signal[:, ::session_parameters['dataset']['channels_downsampling']]

  signal_maxsample = signal

  if session_parameters['dataset']['new_in_data_freq_hz']:
    signal = signal_downsample(signal, session_parameters['dataset']['in_data_freq_hz'], session_parameters['dataset']['new_in_data_freq_hz'])
  #os.makedirs(os.path.dirname(os.path.join(npsave_path, f_raw) + '.npy'), exist_ok=True)
  #np.save(os.path.join(npsave_path, f_raw) + '.npy', signal)
  _, spike_signal = process_data(signal, session_parameters['dataset']['max_derivative_order'], session_parameters['dataset']['delta_value'], session_parameters['dataset']['zc_enable'], session_parameters['dataset']['zc_max_derivative_order'], session_parameters['dataset']['zc_value'])
  
  if extracted_info['sample'] in session_parameters['dataset']['train_samples']:
    x_train.append(spike_signal)
    patient_train.append(extracted_info['subject'])
  if extracted_info['sample'] in session_parameters['dataset']['valid_samples']:
    x_valid.append(spike_signal)
    patient_valid.append(extracted_info['subject'])
  if extracted_info['sample'] in session_parameters['dataset']['test_samples']:
    x_test.append(spike_signal)
    patient_test.append(extracted_info['subject'])

  if extracted_info['sample'] in session_parameters['dataset']['train_samples']:
    x_orig_train.append(signal)
  if extracted_info['sample'] in session_parameters['dataset']['valid_samples']:
    x_orig_valid.append(signal)
  if extracted_info['sample'] in session_parameters['dataset']['test_samples']:
    x_orig_test.append(signal)

  # Force process
  f_force = f.replace('datatype', 'force')
  if session_parameters['dataset']['dataset_format'] == 'wfdb':
    signal = np.array(wfdb.rdrecord(os.path.join(session_info['dataset']['dataset_path'], session_parameters['dataset']['dataset_subpath'], f_force)).p_signal)
  elif session_parameters['dataset']['dataset_format'] == 'numpy':
    signal = np.load(os.path.join(session_info['dataset']['dataset_path'], session_parameters['dataset']['dataset_subpath'], f_force) + '.npy')

  if session_parameters['dataset']['mvc_norm']:
    force_finger_mvc = get_mvc( '../datasets/emg/hd-semg/2.0.0', extracted_info['subject'], 1)
    signal = mvc_norm(signal, force_finger_mvc)

  if session_parameters['dataset']['objects_class']:
    if extracted_info['object'] in session_parameters['dataset']['objects_class']:
      classes.add(extracted_info['object'])
    signal = signal[:, [x - 1 for x in session_parameters['dataset']['objects_class']]]
  else:
    classes.add(extracted_info['object'])

  if session_parameters['dataset']['new_out_data_freq_hz']:
    signal = signal_upsample(signal, session_parameters['dataset']['out_data_freq_hz'], session_parameters['dataset']['new_out_data_freq_hz'])
    signal_orig = signal
  if session_parameters['dataset']['out_data_filterfreq_hz']:
    if session_parameters['dataset']['new_out_data_freq_hz']:
      signal = lowpass_filter(signal, session_parameters['dataset']['out_data_filterfreq_hz'], session_parameters['dataset']['new_out_data_freq_hz'])
    else:
      signal = lowpass_filter(signal, session_parameters['dataset']['out_data_filterfreq_hz'], session_parameters['dataset']['out_data_freq_hz'])
  #os.makedirs(os.path.dirname(os.path.join(npsave_path, f_force) + '.npy'), exist_ok=True)
  #np.save(os.path.join(npsave_path, f_force) + '.npy', signal)

  if session_parameters['dataset']['zc_enable']:
    if session_parameters['dataset']['max_derivative_order'] or session_parameters['dataset']['zc_max_derivative_order']:
      signal = signal[max(session_parameters['dataset']['max_derivative_order'], session_parameters['dataset']['zc_max_derivative_order']) * 2 : - max(session_parameters['dataset']['max_derivative_order'], session_parameters['dataset']['zc_max_derivative_order']) * 2, :]
  elif session_parameters['dataset']['max_derivative_order']:
    signal = signal[session_parameters['dataset']['max_derivative_order'] * 2 : - session_parameters['dataset']['max_derivative_order'] * 2, :]

  if session_parameters['dataset']['force_gain'] != 1:
    signal = signal * session_parameters['dataset']['force_gain']

  if extracted_info['sample'] in session_parameters['dataset']['train_samples']:
    y_train.append(signal)
  if extracted_info['sample'] in session_parameters['dataset']['valid_samples']:
    y_valid.append(signal)
  if extracted_info['sample'] in session_parameters['dataset']['test_samples']:
    y_test.append(signal)

  if extracted_info['sample'] in session_parameters['dataset']['train_samples']:
    y_orig_train.append(signal_orig)
  if extracted_info['sample'] in session_parameters['dataset']['valid_samples']:
    y_orig_valid.append(signal_orig)
  if extracted_info['sample'] in session_parameters['dataset']['test_samples']:
    y_orig_test.append(signal_orig)

# Convert lists to numpy arrays
x_train, x_valid, x_test = map(np.array, [x_train, x_valid, x_test])
y_train, y_valid, y_test = map(np.array, [y_train, y_valid, y_test])
x_orig_train, x_orig_valid, x_orig_test = map(np.array, [x_orig_train, x_orig_valid, x_orig_test])
y_orig_train, y_orig_valid, y_orig_test = map(np.array, [y_orig_train, y_orig_valid, y_orig_test])
patient_train, patient_valid, patient_test = map(np.array, [patient_train, patient_valid, patient_test])

# Transposing x and y datasets: swapping axis 0 with axis 2 for each dataset
x_train, x_valid, x_test = [np.transpose(x, (0, 2, 1)) for x in [x_train, x_valid, x_test]]
y_train, y_valid, y_test = [np.transpose(y, (0, 2, 1)) for y in [y_train, y_valid, y_test]]
x_orig_train, x_orig_valid, x_orig_test = [np.transpose(x, (0, 2, 1)) for x in [x_orig_train, x_orig_valid, x_orig_test]]
y_orig_train, y_orig_valid, y_orig_test = [np.transpose(y, (0, 2, 1)) for y in [y_orig_train, y_orig_valid, y_orig_test]]

# Check shapes of resulting datasets
x_train.shape, x_valid.shape, x_test.shape, y_train.shape, y_valid.shape, y_test.shape

In [None]:
if session_info['dataset']['plot_exampledata']:
  fig, ax1 = plt.subplots(figsize=(12, 6))

  # Corresponding labels for each index for reference
  labels = ["thumb", "index", "middle", "ring", "little"]

  if not session_parameters['dataset']['objects_class']:
    i = 0
    while True:
      try:
        ax1.plot(y_test[0][i], label=labels[i])
        i += 1
      except:
        break
  else:
    # Plot only the signals specified in the `indices_to_plot` list
    for i, index in enumerate([x - 1 for x in session_parameters['dataset']['objects_class']]):
        ax1.plot(y_test[0][i], label=labels[index])

  ax1.legend(loc='upper left')
  ax1.set_ylabel('forcesignal')

  plt.show()

In [None]:
# Set a random seed for reproducibility of shuffling
random_seed = session_parameters['dataset']['random_seed']

# Shuffle the training, validation, and testset dataset, both features and labels, using the specified random seed
x_train, y_train, patient_train = shuffle(x_train, y_train, patient_train, random_state=random_seed)
#x_valid, y_valid, patient_valid = shuffle(x_valid, y_valid, patient_valid,random_state=random_seed)
#x_test, y_test = shuffle(x_test, y_test, random_state=random_seed)

x_orig_train, y_orig_train, patient_train = shuffle(x_orig_train, y_orig_train, patient_train, random_state=random_seed)

if False:
  # Convert the shuffled numpy arrays to PyTorch tensors with dtype=torch.int8
  x_train_tensor = torch.tensor(x_train, dtype=torch.int8)
  y_train_tensor = torch.tensor(y_train, dtype=torch.int8)
  x_valid_tensor = torch.tensor(x_valid, dtype=torch.int8)
  y_valid_tensor = torch.tensor(y_valid, dtype=torch.int8)
  x_test_tensor = torch.tensor(x_test, dtype=torch.int8)
  y_test_tensor = torch.tensor(y_test, dtype=torch.int8)

# Create TensorDatasets
train_dataset = HDsEMG_Dataset(y_train, x_train)
valid_dataset = HDsEMG_Dataset(y_valid, x_valid)
test_dataset = HDsEMG_Dataset(y_test, x_test)

# Create TensorDatasets
train_orig_dataset = HDsEMG_Dataset(y_orig_train, x_orig_train)
valid_orig_dataset = HDsEMG_Dataset(y_orig_valid, x_orig_valid)
test_orig_dataset = HDsEMG_Dataset(y_orig_test, x_orig_test)

# Create DataLoaders
batch_size = session_parameters['training']['batch_size']
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
test_dataloader_1 = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Create DataLoaders
batch_size = session_parameters['training']['batch_size']
train_orig_dataloader = DataLoader(train_orig_dataset, batch_size=batch_size, shuffle=False)
valid_orig_dataloader = DataLoader(valid_orig_dataset, batch_size=batch_size, shuffle=False)
test_orig_dataloader = DataLoader(test_orig_dataset, batch_size=batch_size, shuffle=False)

# Print out some DataLoader details to confirm
len(train_dataloader), len(valid_dataloader), len(test_dataloader)

## Network


In [None]:
if session_loaded and session_trained:
  if session_results['training']['early_stopped'] or (session_info['training']['epoch_max'] and (session_results['training']['epoch_last'] >= (session_info['training']['epoch_max'] - 1))):
    try:
      model = Net(timesteps = x_train.shape[2], input = x_train.shape[1]).to(session_results['device'])
      model.load_state_dict(torch.load(os.path.join(session_info['session_output_path'], session_folder_name, 'model_best.pt'), map_location=torch.device(session_results['device'])))
      print(f'Model correctly loaded: {os.path.join(session_info["session_output_path"], session_folder_name, "model_best.pt")}\nThe training will not be performed as it has already been completed in the previous training or maximum epoch have already been reached')
    except FileNotFoundError as e:
      raise FileNotFoundError(f"Model file not found in {session_info['session_output_path']}: {e}")
    except KeyError as e:
      raise KeyError(f"Missing key in session_results: {e}")
    except Exception as e:
      raise Exception(f"Error loading the model: {e}")

    session_results['training']['done'] = True
    training_enable = False
  else:
    try:
      model = Net(timesteps = x_train.shape[2], input = x_train.shape[1]).to(session_results['device'])
      model.load_state_dict(torch.load(os.path.join(session_info['session_output_path'], session_folder_name, 'model_last.pt'), map_location=torch.device(session_results['device'])))
      print(f'Model correctly loaded: {os.path.join(session_info["session_output_path"], session_folder_name, "model_last.pt")}\nThe training will be concluded')
    except FileNotFoundError as e:
      raise FileNotFoundError(f"Model file not found in {session_info['session_output_path']}: {e}")
    except KeyError as e:
      raise KeyError(f"Missing key in session_results: {e}")
    except Exception as e:
      raise Exception(f"Error loading the model: {e}")

    session_results['training']['done'] = False
    training_enable = True

else:
  session_results['network']['ofs'] = session_parameters['network']['ofs']
  if session_results['network']['ofs'][-1] == None:
    session_results['network']['ofs'][-1] = len(classes)
  save_json(session_results, os.path.join(session_info['session_output_path'], session_folder_name, 'session_results.json'))
  if session_parameters['training']['pretrained_model']:
    model = Net(timesteps = x_train.shape[2], input = x_train.shape[1]).to(session_results['device'])
    model.load_state_dict(torch.load(session_parameters['training']['pretrained_model'], map_location=torch.device(session_results['device'])))
    print(f'Pretrained model correctly loaded: {session_parameters["training"]["pretrained_model"]}\nA new training will be performed')
  else:
    model = Net(timesteps = x_train.shape[2], input = x_train.shape[1]).to(session_results['device'])
    print(f'Model correctly created\nA new training will be performed')

  training_enable = True

## Training

### Training settings

In [None]:
error = torch.nn.MSELoss()

if training_enable:
  if session_loaded and session_trained:
    optimizer = torch.optim.Adam(params=model.parameters(), lr=session_results['training']['learning_rate'][-1])
    learning_rate = session_results['training']['learning_rate'][-1]

  else:
    optimizer = torch.optim.Adam(params=model.parameters(), lr=session_parameters['training']['learning_rate_first'])

    session_results['training'] = {
      'optimizer' : None,              # UNEDITABLE # Training optimizer.
      'done': False,                   # UNEDITABLE # Flag to indicate if the training process is complete.
      'early_stopped': False,          # UNEDITABLE # Indicates if training was stopped early due to lack of improvement.
      'epoch_best': None,              # UNEDITABLE # Epoch number where the best performance was achieved.
      'epoch_last': None,              # UNEDITABLE # Last epoch number before training concluded.
      'epoch_no_imporve': None,        # UNEDITABLE # For how many epochs has the model not improved on the validation set.
      'train_loss_min': None,          # UNEDITABLE # Best training loss achieved.
      'valid_loss_min': None,          # UNEDITABLE # Best validation loss achieved.
      'learning_rate_best': None,      # UNEDITABLE # Best learning rate achieved during training.
      'train_loss_best': None,         # UNEDITABLE # Best training loss achieved in the model_best.pt.
      'valid_loss_best': None,         # UNEDITABLE # Best validation loss achieved in the model_best.pt.
      'learning_rate': [],             # UNEDITABLE # History of learning rates used throughout training.
      'train_loss': [],                # UNEDITABLE # History of training losses per epoch.
      'valid_loss': [],                # UNEDITABLE # History of validation losses per epoch.
      'train_acc': [],                 # UNEDITABLE # History of training accuracies per epoch.
      'valid_acc': []                  # UNEDITABLE # History of validation accuracies per epoch.
    }

    session_results['post_training'] = {
      'test_acc': None,                # UNEDITABLE # Test accuracy achieved after training is complete.
      'test_acc_fakequant': None       # UNEDITABLE # Test accuracy achieved with fake quantization after training.
    }

    # Configuring and saving session information
    session_results['training']['optimizer'] = str(optimizer).replace("\n","").replace("    ",", ")
    session_results['training']['epoch_no_imporve'] = 0
    session_results['training']['train_loss_min'] = float('inf')
    session_results['training']['valid_loss_min'] = float('inf')

    save_json(session_results, os.path.join(session_info['session_output_path'], session_folder_name, 'session_results.json'))

    # Setting learning rate parameter
    learning_rate = session_parameters['training']['learning_rate_first']

### Training loop

In [None]:
if training_enable:
  if session_info['training']['epoch_max']:
    start_time = time.time()
    if session_results['training']['epoch_last'] == None:
      start_iter = 0
    else:
      start_iter = session_results['training']['epoch_last'] + 1
    progress_bar(start_iter, session_info['training']['epoch_max'], start_time, start_iter, prefix='Epoch:', init=True)

  model.train()

  while True:
    if session_info['training']['epoch_max'] and session_results['training']['epoch_last'] and (session_results['training']['epoch_last'] + 1) >= session_info['training']['epoch_max']:
      break

    # Training
    loss_values = []
    for i, (input, label) in enumerate(train_dataloader):
      input = torch.swapaxes(input=input, axis0=0, axis1=1).to(session_results['device'])
      label = torch.swapaxes(input=label, axis0=0, axis1=1).to(session_results['device'])

      output = model(input)

      # loss_value = error(output.float(), label.float())
      loss_value = error(output, label)
      optimizer.zero_grad()
      loss_value.backward()
      optimizer.step()
      loss_values.append(loss_value.item())

    train_loss_value = sum(loss_values) / len(loss_values)

    # Validation
    loss_values = []
    for i, (input, label) in enumerate(valid_dataloader):
      input = torch.swapaxes(input=input, axis0=0, axis1=1).to(session_results['device'])
      label = torch.swapaxes(input=label, axis0=0, axis1=1).to(session_results['device'])

      output = model(input)

      # loss_value = error(output.float(), label.float())
      loss_value = error(output, label)
      loss_values.append(loss_value.item())

    valid_loss_value = sum(loss_values) / len(loss_values)

    if session_results['training']['epoch_last'] == None:
      session_results['training']['epoch_last'] = 0
    else:
      session_results['training']['epoch_last'] += 1

    # Session checkpoint
    torch.save(model.state_dict(), os.path.join(session_info['session_output_path'], session_folder_name, 'model_last.pt'))

    train_best_loss, valid_best_loss = False, False
    session_results['training']['learning_rate'].append(learning_rate)
    session_results['training']['train_loss'].append(train_loss_value)
    session_results['training']['valid_loss'].append(valid_loss_value)
    if session_results['training']['train_loss_min'] > train_loss_value:
      session_results['training']['train_loss_min'] = train_loss_value
      train_best_loss = True
    if session_results['training']['valid_loss_min'] > valid_loss_value:
      session_results['training']['valid_loss_min'] = valid_loss_value
      valid_best_loss = True

    save_json(session_results, os.path.join(session_info['session_output_path'], session_folder_name, 'session_results.json'))

    if session_info['training']['epoch_max']:
      progress_bar(session_results['training']['epoch_last'] + 1, session_info['training']['epoch_max'], start_time, start_iter, prefix='Epoch:', suffix=f'TL: {safe_format(train_loss_value)} ({safe_format(session_results["training"]["train_loss_min"])}), VL: {safe_format(valid_loss_value)} ({safe_format(session_results["training"]["valid_loss_min"])}).')
    else:
      print(
        f'\rEpoch {session_results["training"]["epoch_last"]}.   ' +                                                      # Epoch
        f'Train loss: {safe_format(train_loss_value)} ({safe_format(session_results["training"]["train_loss_min"])}), ' + # Training loss
        f'Valid loss: {safe_format(valid_loss_value)} ({safe_format(session_results["training"]["valid_loss_min"])}). ',  # Validation loss
        end=''
      )

    # Check best model
    if (session_parameters['training']['best_loss'] and train_best_loss) or (not session_parameters['training']['best_loss'] and valid_best_loss):
      torch.save(model.state_dict(), os.path.join(session_info['session_output_path'], session_folder_name, 'model_best.pt'))

      session_results['training']['learning_rate_best'] = learning_rate
      session_results['training']['train_loss_best'] = train_loss_value
      session_results['training']['valid_loss_best'] = valid_loss_value
      session_results['training']['epoch_best'] = session_results['training']['epoch_last']

      save_json(session_results, os.path.join(session_info['session_output_path'], session_folder_name, 'session_results.json'))

      session_results['training']['epoch_no_imporve'] = 0

    else:
      session_results['training']['epoch_no_imporve'] += 1

      if session_results['training']['epoch_no_imporve'] >= session_parameters['training']['patience']:
        if learning_rate >= round((session_parameters['training']['learning_rate_decay'] ** session_parameters['training']['learning_rate_decay_steps']) * session_parameters['training']['learning_rate_first'], 16):
          learning_rate = round(learning_rate * session_parameters['training']['learning_rate_decay'], 16)

          if session_info['training']['bestmodel_lrchange']:
            try:
              model.load_state_dict(torch.load(os.path.join(session_info['session_output_path'], session_folder_name, 'model_best.pt'), map_location=torch.device(session_results['device'])))
            except:
              pass

          save_json(session_results, os.path.join(session_info['session_output_path'], session_folder_name, 'session_results.json'))
          for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate
          session_results['training']['epoch_no_imporve'] = 0
          if session_info['training']['epoch_max']:
            progress_bar(session_results['training']['epoch_last'] + 1, session_info['training']['epoch_max'], start_time, start_iter, prefix='Epoch:', suffix=f'TL: {safe_format(train_loss_value)} ({safe_format(session_results["training"]["train_loss_min"])}), VL: {safe_format(valid_loss_value)} ({safe_format(session_results["training"]["valid_loss_min"])}). Reducing learning rate: {learning_rate}\n')
          else:
            print(f'Reducing learning rate:', learning_rate)
        else:
          session_results['training']['early_stopped'] = True
          save_json(session_results, os.path.join(session_info['session_output_path'], session_folder_name, 'session_results.json'))
          if session_info['training']['epoch_max']:
            progress_bar(session_results['training']['epoch_last'] + 1, session_info['training']['epoch_max'], start_time, start_iter, prefix='Epoch:', suffix=f'TL: {safe_format(train_loss_value)} ({safe_format(session_results["training"]["train_loss_min"])}), VL: {safe_format(valid_loss_value)} ({safe_format(session_results["training"]["valid_loss_min"])}). Early stopping triggered after {session_results["training"]["epoch_last"]+1} epochs!')
          else:
            print(f'Early stopping triggered after {session_results["training"]["epoch_last"]+1} epochs!')
          break
      else:
        if session_info['training']['epoch_max']:
          progress_bar(session_results['training']['epoch_last'] + 1, session_info['training']['epoch_max'], start_time, start_iter, prefix='Epoch:', suffix=f'TL: {safe_format(train_loss_value)} ({safe_format(session_results["training"]["train_loss_min"])}), VL: {safe_format(valid_loss_value)} ({safe_format(session_results["training"]["valid_loss_min"])}). Patience: {session_results["training"]["epoch_no_imporve"]}.')
        else:
          print(f'Patience: {session_results["training"]["epoch_no_imporve"]}.', end='')
        pass

  session_results['training']['done'] = True
  save_json(session_results, os.path.join(session_info['session_output_path'], session_folder_name, 'session_results.json'))
  model.load_state_dict(torch.load(os.path.join(session_info['session_output_path'], session_folder_name, 'model_best.pt'), map_location=torch.device(session_results['device'])))

  if colab_env and session_info['training']['auto_unassign']:
    time.sleep(30)
    runtime.unassign()

## Test

In [None]:
model.eval()

### Loss

In [None]:
# Create a figure and axis
plt.figure(figsize=(16, 4))

# Plot training and validation loss
plt.plot(session_results['training']['train_loss'], label='Training Loss')
plt.plot(session_results['training']['valid_loss'], label='Validation Loss')

# Adding title and labels
plt.title('Training and Validation Loss Over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')

# Adding legend
plt.legend()

# Show plot
plt.show()