## Setup
Installs and imports necessary modules from Github.

Trouble running this? Try restarting the Colab session ("Runtime" --> "Restart Session").

If you're working in your own branch in the Github repo, change the wget commands to point towards that branch.

In [2]:
# Get all scripts from the modules folder in Github
!mkdir -p modules
base_url = "https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/main/modules/"
files = ["configs.py", "plots.py", "network.py", "tasks.py", "bci.py"]
for file in files:
    !wget -q "{base_url}{file}" -O "modules/{file}"

# Imports for this script
import numpy as np
import modules.configs as configs
import modules.plots as plots

##Train and Evaluate the Model

In [3]:
# ====================
# == Initialization ==
# ====================

cfg = configs.BasicExperimentConfig()
task = cfg.task # The task the RNN will learn
rnn = cfg.rnn # Recurrent Neural Network
bci = cfg.bci # Brain computer interface
feedback = cfg.feedback # Mathematically optimal feedback


# =======================
# == Simulate Learning ==
# =======================

# -- Train the RNN
print(f"Training the RNN for {cfg.ntrials} trials...")
rnn.relearn(
    ntrials=cfg.ntrials,
    ext=task.stimuli,
    ntstart=task.stim_length,
    decoder=bci.decoder,
    feedback=feedback,
    target=task.targets,
)

# -- Compute Manifold
print(f"Computing manifold over {cfg.ntrials_manifold} trials...")
np.random.seed(2) # Set seed for manifold calculation...
manifold_data = rnn.get_manifold(task.stimuli, task.stim_length, ntrials=cfg.ntrials_manifold)
manifold_activity = manifold_data['proj_reshaped']
targets_by_trial = manifold_data['order']

# -- Train the BCI
bci.train(
    manifold_activity,
    manifold_data['eigenvecs'],
    task.targets[:,task.stim_length,:],
    targets_by_trial
)

# -- Calculate loss retrospectively using trained BCI
pred_coords = manifold_activity @ bci.decoder.T
final_loss = task.get_loss(pred_coords, targets_by_trial)
print(final_loss)


# =======================
# == Visualize Results ==
# =======================
# -- Plot Manifold and Scree Plots
#plots.plot_traj(manifold_activity, manifold_data['eigenvals']).show()


# ==================
# == Save Results ==
# ==================
cfg.save("result.pkl", manifold_data)


Training the RNN for 80 trials...


100%|██████████| 80/80 [07:13<00:00,  5.42s/it]


Computing manifold over 50 trials...


100%|██████████| 50/50 [00:06<00:00,  7.87it/s]


AttributeError: 'BasicReachingTask' object has no attribute 'target'