## Setup
Installs and imports necessary modules from Github.

Trouble running this? Try restarting the Colab session ("Runtime" --> "Restart Session").

If you're working in your own branch in the Github repo, change the wget commands to point towards that branch.

In [3]:
# Get all scripts from the modules folder in Github
!mkdir -p modules
base_url = "https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/main/modules/"
files = ["configs.py", "plots.py", "network.py", "tasks.py", "bci.py"]
for file in files:
    !wget -q "{base_url}{file}" -O "modules/{file}"

# Imports for this script
import numpy as np
from modules.configs import BasicExperimentConfig
from modules.plots import plot_pca_summary

##Train and Evaluate the Model

In [None]:
# ====================
# == Initialization ==
# ====================

cfg = BasicExperimentConfig()
task = cfg.task # The task the RNN will learn
rnn = cfg.rnn # Recurrent Neural Network
bci = cfg.bci # Brain computer interface
feedback = cfg.feedback # Mathematically optimal feedback


# =======================
# == Simulate Learning ==
# =======================

# -- Train the RNN
print(f"Training the RNN for {cfg.ntrials} trials...")
rnn.relearn(
    ntrials=cfg.ntrials,
    ext=task.stimuli,
    ntstart=task.stim_length,
    decoder=bci.decoder,
    feedback=feedback,
    target=task.targets,
)

# -- Compute Manifold
print(f"Computing manifold over {cfg.ntrials_manifold} trials...")
np.random.seed(2) # Set seed for manifold calculation...
manifold_data = rnn.get_manifold(task.stimuli, task.stim_length, ntrials=cfg.ntrials_manifold)
manifold_activity = manifold_data['proj_reshaped']
eigenvalues = manifold_data['eigenvals']

# -- Train the BCI
# ... TODO (How should this work? Check Feulner's code.)


# =======================
# == Visualize Results ==
# =======================
# -- Plot Manifold and Scree Plots
plot_pca_summary(manifold_activity, eigenvalues).show()


# ==================
# == Save Results ==
# ==================
cfg.save("result.pkl", manifold_data)


Training the RNN for 80 trials...


100%|██████████| 80/80 [07:12<00:00,  5.40s/it]


Computing manifold over 50 trials...


100%|██████████| 50/50 [00:06<00:00,  8.19it/s]
