## Setup
Installs and imports necessary modules from Github.

Trouble running this? Try restarting the Colab session ("Runtime" --> "Restart Session").

If you're working in your own branch in the Github repo, change the wget commands to point towards that branch.

In [None]:
!mkdir -p modules
!wget -q https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/main/modules/network.py -O modules/network.py
!wget -q https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/main/modules/tasks.py -O modules/tasks.py
!wget -q https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/main/modules/bci.py -O modules/bci.py
!wget -q https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/main/modules/plots.py -O modules/plots.py

import numpy as np
import modules.tasks as tasks
from modules.network import RNN
from modules.bci import BCI
from modules.plots import plot_pca_summary, save_figs, plot_num_pcs_vs_targets

##Train and Evaluate the Model

In [None]:
# ====================
# == Initialization ==
# ====================

np.random.seed(2) # Random seed for this simulation

experiments = [1,2,3,4,5,6,8,9,10,11,12]
pca_viz_filepath = "pca.html"

ntrials = 80 # Number of trials to train the RNN for
ntrials_manifold = 80 # Number of trials for manifold computation
all_pca_figs = []
num_targets_list = []
eigenval_list = []


# =====================
# == Run Experiments ==
# =====================

for ntargets in experiments:

  print(f"Running experiment with ntargets={ntargets}")
  task = tasks.BasicReachingTask(ntargets=ntargets) # The task the RNN will learn
  rnn = RNN(N_in=task.ntargets, verbosity=1) # Recurrent Neural Network
  bci = BCI(rnn, task.target_max) # Brain computer interface
  feedback = np.linalg.pinv(bci.decoder) # Mathematically optimal feedback


  # =======================
  # == Simulate Learning ==
  # =======================

  # -- Train the RNN
  print(f"Training the RNN for {ntrials} trials...")
  rnn.relearn(
      ntrials=ntrials,
      ext=task.stimuli,
      ntstart=task.stim_length,
      decoder=bci.decoder,
      feedback=feedback,
      target=task.targets,
      delta=20.,
  )

  # -- Compute Manifold
  print(f"Computing manifold over {ntrials_manifold} trials...")
  manifold_data = rnn.get_manifold(task.stimuli, task.stim_length)
  manifold_activity = manifold_data['proj_reshaped']
  eigenvalues = manifold_data['eigenvals']
  num_targets_list.append(ntargets)
  eigenval_list.append(eigenvalues)


  # -- Train the BCI
  # ... TODO (How should this work? Check Feulner's code.)


  # =======================
  # == Evaluate Learning ==
  # =======================
  # -- Calculate MSE
  # (MSE must be computed retrospectively using the trained BCI)
  # ... TODO
  # -- Calculate Manifold Surface Area / Perturbation Sensitity
  # ... TODO


  # =======================
  # == Visualize Results ==
  # =======================
  # -- Plot Manifold and Scree Plots
  pca_fig = plot_pca_summary(manifold_activity, eigenvalues)
  pca_fig.update_layout(title_text=f"# of Targets = {ntargets}")
  all_pca_figs.append(pca_fig)
  pca_fig.show()
  print("Finished experiment!\n")

# ================================
# == Plot Dimensionality Summary ==
# ================================
dimen_fig = plot_num_pcs_vs_targets(num_targets_list, eigenval_list, eigenval_thresh=0.9)
dimen_fig.show()

# ==================
# == Save Results ==
# ==================
print(f"Saving visualizations to .html")
save_figs(all_pca_figs, "pca.html", overwrite=True)
print("Finished!")