## Setup
Installs and imports necessary modules from Github.

Trouble running this? Try restarting the Colab session ("Runtime" --> "Restart Session").

If you're working in your own branch in the Github repo, change the wget commands to point towards that branch.

In [None]:
# Get all scripts from the modules folder in Github
!mkdir -p modules
base_url = "https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/main/modules/"
files = ["configs.py", "plots.py", "network.py", "tasks.py", "bci.py"]
for file in files:
    !wget -q "{base_url}{file}" -O "modules/{file}"

# Imports for this script
import numpy as np
from os.path import join
import modules.configs as configs
import modules.plots as plots

##Train and Evaluate the Model

In [None]:
# ====================
# == Initialization ==
# ====================

experiments = [1,2,3,4,5,6,7,8,9,10,11,12]
eigenval_list = []
root_out_dir = "results" # Root output directory

# =====================
# == Run Experiments ==
# =====================

for ntargets in experiments:

  # ======================
  # == Setup Experiment ==
  # ======================

  exp_name = f"{ntargets}Targets" # Name of this experiment
  exp_out_dir = join(root_out_dir, exp_name)

  cfg = configs.BasicExperimentConfig(ntargets=ntargets)
  task = cfg.task # The task the RNN will learn
  rnn = cfg.rnn # Recurrent Neural Network
  bci = cfg.bci # Brain computer interface
  feedback = cfg.feedback # Mathematically optimal feedback

  print(f"Starting experiment: {exp_name}")
  print(f"-- Stimulus Shape: {task.stimuli.shape}")
  print(f"-- Targets Shape: {task.targets.shape}")

  # =======================
  # == Simulate Learning ==
  # =======================

  # -- Train the RNN
  print(f"Training the RNN for {cfg.ntrials} trials...")
  rnn.relearn(
      ntrials=cfg.ntrials,
      ext=task.stimuli,
      ntstart=task.stim_length,
      decoder=bci.decoder,
      feedback=feedback,
      target=task.targets,
  )

  # -- Compute Manifold
  print(f"Computing manifold over {cfg.ntrials_manifold} trials...")
  np.random.seed(2) # Set seed for manifold calculation...
  manifold_data = rnn.get_manifold(task.stimuli, task.stim_length, ntrials=cfg.ntrials_manifold)
  manifold_activity = manifold_data['proj_reshaped']
  eigenvalues = manifold_data['eigenvals']
  eigenval_list.append(eigenvalues)

  # ===========================================
  # == Visualize Results For This Experiment ==
  # ===========================================
  pca_fig = plots.plot_pca_summary(manifold_activity, eigenvalues)

  # ======================================
  # == Save Results For This Experiment ==
  # ======================================
  print("Saving results for experiment...")
  cfg.save(join(exp_out_dir, "result.pkl"), manifold_data, overwrite=True)
  plots.save_figs([pca_fig], join(exp_out_dir, "pca.html"), overwrite=True)
  print("Experiment complete!\n")

# ==============================================
# == Save Summary Results For All Experiments ==
# ==============================================
print("Saving summary results for all experiments...")
dim_fig = plots.plot_num_pcs_vs_targets(experiments, eigenval_list)
plots.save_figs([dim_fig], join(root_out_dir, "num_pcs_vs_targets.html"))

print("Zipping results folder...")
!zip -r "./results.zip" "results"
print("Finished!")