#### RNN based on Steinmetz `spks` dataset and code from [nma](https://deeplearning.neuromatch.io/projects/Neuroscience/neuro_seq_to_seq.html)

##### Imports

In [2]:
# Imports
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt

##### Device

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

##### Set random seed

In [4]:
# @markdown Executing `set_seed(seed=seed)` you are setting the seed

# For DL its critical to set the random seed so that students can have a
# baseline to compare their results to expected results.
# Read more here: https://pytorch.org/docs/stable/notes/randomness.html

# Call `set_seed` function in the exercises to ensure reproducibility.

def set_seed(seed=None, seed_torch=True):
  """
  Function that controls randomness. NumPy and random modules must be imported.

  Args:
    seed : Integer
      A non-negative integer that defines the random state. Default is `None`.
    seed_torch : Boolean
      If `True` sets the random seed for pytorch tensors, so pytorch module
      must be imported. Default is `True`.

  Returns:
    Nothing.
  """
  if seed is None:
    seed = np.random.choice(2 ** 32)
  random.seed(seed)
  np.random.seed(seed)
  if seed_torch:
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

  print(f'Random seed {seed} has been set.')


# In case that `DataLoader` is used
def seed_worker(worker_id):
  """
  DataLoader will reseed workers following randomness in
  multi-process data loading algorithm.

  Args:
    worker_id: integer
      ID of subprocess to seed. 0 means that
      the data will be loaded in the main process
      Refer: https://pytorch.org/docs/stable/data.html#data-loading-randomness for more details

  Returns:
    Nothing
  """
  worker_seed = torch.initial_seed() % 2**32
  np.random.seed(worker_seed)
  random.seed(worker_seed)

In [5]:
SEED = 2021
set_seed(seed=SEED)

Random seed 2021 has been set.


##### Data retrieval

In [7]:
# store the dataset files in the datasets folder
import os, requests

fname = []
ds = "../datasets"

for j in range(3):
  fname.append('steinmetz_part%d.npz'%j)
url = ["https://osf.io/agvxh/download"]
url.append("https://osf.io/uv3mw/download")
url.append("https://osf.io/ehmw2/download")

for j in range(len(url)):
  if not os.path.isfile("/".join([ds, fname[j]])):
    try:
      r = requests.get(url[j])
    except requests.ConnectionError:
      print("!!! Failed to download data !!!")
    else:
      if r.status_code != requests.codes.ok:
        print("!!! Failed to download data !!!")
      else:
        with open("/".join([ds, fname[j]]), "wb") as fid:
          fid.write(r.content)

In [8]:
# data loading
alldat = np.array([])
for j in range(len(fname)):
  alldat = np.hstack((alldat,
                      np.load("/".join([ds, 'steinmetz_part%d.npz'%j]),
                              allow_pickle=True)['dat']))

##### Dataset `spks` from recording 11

In [21]:
# select dataset spks from recording 11
dat = alldat[11]

# data are neurons X trials x time samples, but transpose to time samples X trials X neurons
X = dat['spks'].T

# response is -1, 0, 1 so expand to one hot encoding for logistic regression
rsp = dat['response']

##### Format dataset
An RNN has specific data requirements fed into the network for training, testing, and prediction. The depth of the network in a time or data series is given by `hidden_size`. In the case of the `spks` dataset this is 250 samples. These are the 250 sampling datapoints related to the same mouse event trial from the same probe sensor size of 698 neurons. Further mouse trials are added in sequence in the dataset, where the training and forward algorithms reset to the start of the RNN. This dataset is then 698 probe sensors wide and 250 samples x 340 trials in length, before being split into train and test sets. The matrix shape is `(85000, 698)`.

Care must be taken to shuffle and split the dataset on a 250 entry boundary, otherwise sequence data will be corrupted.

In [None]:
# response is -1, 0, 1 so expand to one hot encoding for logistic regression
y = np.vstack(([(rsp == -1) * 1], [rsp == 0], [rsp == 1])).T

# reformat spks data for (samples x trials, neurons)
Xshape = X.shape
X = X.reshape([Xshape[0] * Xshape[1], Xshape[2]])

# number of hidden
n_hidden = Xshape[0]

##### Data augmentation

Noise plays a significant role in the recording of neural signals. By adding noise to the train sets We can convince the network to be more resilient to noise when it's encountered. By adding noise to the test set we can measure how resilient the network is.

__Hilbert transform__: In the one dimensional signal realm the Hilbert transform performs a derivative function. It increases higher frequency amplitudes and also emphasis high frequency noise in the signal. It's used here to offset the integration performed by the resistance of the cell membrane and the cell capacitance, restoring the neural signal somewhat at the expense of added noise.

__White noise before Hilbert__: Simulates the effect of increasing the membrane resistance of the cell.

__White noise no Hilbert__: Simulates the effect of resistance exterior to the cell.

##### Dataloader

##### Define RNN

In [None]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.h2o = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.h2o(hidden)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

rnn = RNN(n_letters, n_hidden, n_categories)

##### Train the RNN: setup

#### Train the RNN