## Setup
Installs and imports necessary modules from Github.

Trouble running this? Try restarting the Colab session ("Runtime" --> "Restart Session").

If you're working in your own branch in the Github repo, change the wget commands to point towards that branch.

In [None]:
!mkdir -p modules
!wget -q https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/main/modules/network.py -O modules/network.py
!wget -q https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/main/modules/tasks.py -O modules/tasks.py
!wget -q https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/main/modules/bci.py -O modules/bci.py
!wget -q https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/main/modules/plots.py -O modules/plots.py

import numpy as np
import modules.tasks as tasks
from modules.network import RNN
from modules.bci import BCI
from modules.plots import plot_pca_summary, save_figs

##Train and Evaluate the Model

In [None]:
# ====================
# == Initialization ==
# ====================

np.random.seed(2) # Random seed for this simulation

experiments = [1,2,3,4,5,6,8,9,10,11,12]
pca_viz_filepath = "pca.html"

ntrials = 80 # Number of trials to train the RNN for
ntrials_manifold = 80 # Number of trials for manifold computation
all_pca_figs = []


# =====================
# == Run Experiments ==
# =====================

for ntargets in experiments:

  print(f"Running experiment with ntargets={ntargets}")
  task = tasks.BasicReachingTask(ntargets=ntargets) # The task the RNN will learn
  rnn = RNN(N_in=task.ntargets, verbosity=1) # Recurrent Neural Network
  bci = BCI(rnn, task.target_max) # Brain computer interface
  feedback = np.linalg.pinv(bci.decoder) # Mathematically optimal feedback


  # =======================
  # == Simulate Learning ==
  # =======================

  # -- Train the RNN
  print(f"Training the RNN for {ntrials} trials...")
  rnn.relearn(
      ntrials=ntrials,
      ext=task.stimuli,
      ntstart=task.stim_length,
      decoder=bci.decoder,
      feedback=feedback,
      target=task.targets,
      delta=20.,
  )

  # -- Compute Manifold
  print(f"Computing manifold over {ntrials_manifold} trials...")
  manifold_data = rnn.get_manifold(task.stimuli, task.stim_length)
  manifold_activity = manifold_data['proj_reshaped']
  eigenvalues = manifold_data['eigenvals']

  # -- Train the BCI
  # ... TODO (How should this work? Check Feulner's code.)


  # =======================
  # == Evaluate Learning ==
  # =======================
  # -- Calculate MSE
  # (MSE must be computed retrospectively using the trained BCI)
  # ... TODO
  # -- Calculate Manifold Surface Area / Perturbation Sensitity
  # ... TODO


  # =======================
  # == Visualize Results ==
  # =======================
  # -- Plot Manifold and Scree Plots
  pca_fig = plot_pca_summary(manifold_activity, eigenvalues)
  pca_fig.update_layout(title_text=f"# of Targets = {ntargets}")
  all_pca_figs.append(pca_fig)
  pca_fig.show()
  print("Finished experiment!\n")

# ==================
# == Save Results ==
# ==================
print(f"Saving visualizations to .html")
save_figs(all_pca_figs, "pca.html", overwrite=True)
print("Finished!")

Running experiment with ntargets=1
Training the RNN for 80 trials...


100%|██████████| 80/80 [06:40<00:00,  5.01s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:03<00:00, 13.66it/s]


Finished experiment!

Running experiment with ntargets=2
Training the RNN for 80 trials...


100%|██████████| 80/80 [07:09<00:00,  5.37s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:06<00:00,  7.75it/s]


Finished experiment!

Running experiment with ntargets=3
Training the RNN for 80 trials...


100%|██████████| 80/80 [06:47<00:00,  5.09s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:05<00:00,  8.67it/s]


Finished experiment!

Running experiment with ntargets=4
Training the RNN for 80 trials...


100%|██████████| 80/80 [07:20<00:00,  5.50s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:07<00:00,  6.92it/s]


Finished experiment!

Running experiment with ntargets=5
Training the RNN for 80 trials...


100%|██████████| 80/80 [06:54<00:00,  5.19s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:05<00:00,  8.35it/s]


Finished experiment!

Running experiment with ntargets=6
Training the RNN for 80 trials...


100%|██████████| 80/80 [07:26<00:00,  5.59s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:08<00:00,  5.87it/s]


Finished experiment!

Running experiment with ntargets=8
Training the RNN for 80 trials...


100%|██████████| 80/80 [07:00<00:00,  5.26s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:06<00:00,  7.49it/s]


Finished experiment!

Running experiment with ntargets=9
Training the RNN for 80 trials...


100%|██████████| 80/80 [07:30<00:00,  5.63s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:09<00:00,  5.26it/s]


Finished experiment!

Running experiment with ntargets=10
Training the RNN for 80 trials...


100%|██████████| 80/80 [07:09<00:00,  5.37s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:09<00:00,  5.05it/s]


Finished experiment!

Running experiment with ntargets=11
Training the RNN for 80 trials...


100%|██████████| 80/80 [06:49<00:00,  5.12s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:08<00:00,  5.86it/s]


Finished experiment!

Running experiment with ntargets=12
Training the RNN for 80 trials...


100%|██████████| 80/80 [07:30<00:00,  5.64s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:09<00:00,  5.03it/s]


Finished experiment!

Saving visualizations to .html
Finished!
