## Setup
Installs and imports necessary modules from Github.

Trouble running this? Try restarting the Colab session ("Runtime" --> "Restart Session").

If you're working in your own branch in the Github repo, change the wget commands to point towards that branch.

In [1]:
!mkdir -p modules
!wget -q https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/calebs_branch/modules/network.py -O modules/network.py
!wget -q https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/calebs_branch/modules/tasks.py -O modules/tasks.py
!wget -q https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/calebs_branch/modules/bci.py -O modules/bci.py
!wget -q https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/calebs_branch/modules/plots.py -O modules/plots.py

import numpy as np
import modules.tasks as tasks
from modules.network import RNN
from modules.bci import BCI
from modules.plots import plot_pca_summary

##Train and Evaluate the Model

In [5]:
# ====================
# == Initialization ==
# ====================

np.random.seed(2) # Random seed for this simulation

experiments = [1,2,3,4,5,6,8,9,10,11,12]
pca_viz_filepath = "pca.html"

ntrials = 80 # Number of trials to train the RNN for
ntrials_manifold = 80 # Number of trials for manifold computation
all_pca_figs = []


# =====================
# == Run Experiments ==
# =====================

for ntargets in experiments:

  print(f"Running experiment with ntargets={ntargets}")
  task = tasks.BasicReachingTask(ntargets=ntargets) # The task the RNN will learn
  rnn = RNN(N_in=task.ntargets, verbosity=1) # Recurrent Neural Network
  bci = BCI(rnn, task.target_max) # Brain computer interface
  feedback = np.linalg.pinv(bci.decoder) # Mathematically optimal feedback


  # =======================
  # == Simulate Learning ==
  # =======================

  # -- Train the RNN
  print(f"Training the RNN for {ntrials} trials...")
  rnn.relearn(
      ntrials=ntrials,
      ext=task.stimuli,
      ntstart=task.stim_length,
      decoder=bci.decoder,
      feedback=feedback,
      target=task.targets,
      delta=20.,
  )

  # -- Compute Manifold
  print(f"Computing manifold over {ntrials_manifold} trials...")
  manifold_data = rnn.get_manifold(task.stimuli, task.stim_length)
  manifold_activity = manifold_data['proj']
  eigenvalues = manifold_data['eigenvals']

  # -- Train the BCI
  # ... TODO (How should this work? Check Feulner's code.)


  # =======================
  # == Evaluate Learning ==
  # =======================
  # -- Calculate MSE
  # (MSE must be computed retrospectively using the trained BCI)
  # ... TODO
  # -- Calculate Manifold Surface Area / Perturbation Sensitity
  # ... TODO


  # =======================
  # == Visualize Results ==
  # =======================
  # -- Plot Manifold and Scree Plots
  pca_fig = plot_pca_summary(manifold_activity, eigenvalues)
  pca_fig.update_layout(title_text=f"# of Targets = {ntargets}")
  all_pca_figs.append(pca_fig)
  print("Finished experiment!\n")

# ==================
# == Save Results ==
# ==================
print(f"Saving visualizations to .html")
save_figs(all_pca_figs, "pca.html", overwrite=True)
print("Finished!")

Running experiment with ntargets=1
Training the RNN for 80 trials...


100%|██████████| 80/80 [06:46<00:00,  5.09s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:04<00:00, 11.05it/s]


Finished experiment!

Running experiment with ntargets=2
Training the RNN for 80 trials...


100%|██████████| 80/80 [06:55<00:00,  5.19s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:05<00:00,  8.44it/s]


Finished experiment!

Running experiment with ntargets=3
Training the RNN for 80 trials...


100%|██████████| 80/80 [06:48<00:00,  5.11s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:03<00:00, 13.69it/s]


Finished experiment!

Running experiment with ntargets=4
Training the RNN for 80 trials...


100%|██████████| 80/80 [06:58<00:00,  5.23s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:03<00:00, 13.41it/s]


Finished experiment!

Running experiment with ntargets=5
Training the RNN for 80 trials...


100%|██████████| 80/80 [06:55<00:00,  5.19s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:03<00:00, 13.28it/s]


Finished experiment!

Running experiment with ntargets=6
Training the RNN for 80 trials...


100%|██████████| 80/80 [07:08<00:00,  5.35s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:03<00:00, 13.11it/s]


Finished experiment!

Running experiment with ntargets=8
Training the RNN for 80 trials...


100%|██████████| 80/80 [06:57<00:00,  5.22s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:03<00:00, 13.12it/s]


Finished experiment!

Running experiment with ntargets=9
Training the RNN for 80 trials...


100%|██████████| 80/80 [07:06<00:00,  5.33s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:04<00:00, 10.12it/s]


Finished experiment!

Running experiment with ntargets=10
Training the RNN for 80 trials...


100%|██████████| 80/80 [07:02<00:00,  5.28s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:03<00:00, 12.92it/s]


Finished experiment!

Running experiment with ntargets=11
Training the RNN for 80 trials...


100%|██████████| 80/80 [07:23<00:00,  5.54s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:03<00:00, 13.03it/s]


Finished experiment!

Running experiment with ntargets=12
Training the RNN for 80 trials...


100%|██████████| 80/80 [07:09<00:00,  5.37s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:05<00:00,  8.47it/s]


Finished experiment!

Saving visualizations to ./results.html
