## Setup
Installs and imports necessary modules from Github.

Trouble running this? Try restarting the Colab session ("Runtime" --> "Restart Session").

If you're working in your own branch in the Github repo, change the wget commands to point towards that branch.

In [1]:
!mkdir -p modules
!wget -q https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/calebs_branch/modules/network.py -O modules/network.py
!wget -q https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/calebs_branch/modules/tasks.py -O modules/tasks.py
!wget -q https://raw.githubusercontent.com/cathat00/NMA_B-T_Project/calebs_branch/modules/bci.py -O modules/bci.py

import numpy as np
import matplotlib.pyplot as plt # Matplotlib for visualization
import plotly.express as px # Plotly library for visualization

import modules.tasks as tasks
import pandas as pd
from modules.network import RNN
from modules.bci import BCI

##Plotting Utils
TODO: We can probably move these to a new Python module in Github.

In [7]:
def get_manifold_plot(manifold):

  # Manifold is of shape (time, principal component)

  t = np.arange(manifold.shape[0])  # time or index

  # Create a DataFrame for plotting
  df = pd.DataFrame({
      "PC1": manifold[:, 0],
      "PC2": manifold[:, 1],
      "PC3": manifold[:, 2],
      "time": t
  })

  # Plot with Plotly
  fig = px.scatter_3d(
      df,
      x="PC1", y="PC2", z="PC3",
      color="time",
      color_continuous_scale="Viridis",
      title="Neural Manifold Trajectory (3D PCA)"
  )
  fig.update_traces(marker=dict(size=1))
  return fig


def get_scree_plot(eigenvalues):
  df_eigenvals = pd.DataFrame({
      'Principal Component Number': np.arange(1, len(eigenvalues) + 1),
      'Eigenvalue': eigenvalues
  })

  fig = px.line(
      df_eigenvals,
      x='Principal Component Number',
      y='Eigenvalue',
      markers=True,
      title='Scree Plot'
  )

  fig.update_layout(xaxis_title='Principal Component Number', yaxis_title='Eigenvalue')
  return fig

##Train and Evaluate the Model

In [3]:
# ====================
# == Initialization ==
# ====================

np.random.seed(2) # Random seed for this simulation

task = tasks.BasicReachingTask() # The task the RNN will learn
rnn = RNN(N_in=task.ntargets, verbosity=1) # Recurrent Neural Network
bci = BCI(rnn, task.target_max) # Brain computer interface
feedback = np.linalg.pinv(bci.decoder) # Mathematically optimal feedback


# =======================
# == Simulate Learning ==
# =======================
ntrials = 80 # Number of trials to train the RNN for
ntrials_manifold = 50 # Number of trials for manifold computation

# -- Train the RNN
print(f"Training the RNN for {ntrials} trials...")
rnn.relearn(
    ntrials=ntrials,
    ext=task.stimuli,
    ntstart=task.stim_length,
    decoder=bci.decoder,
    feedback=feedback,
    target=task.targets,
    delta=20.,
)

# -- Compute Manifold
print(f"Computing manifold over {ntrials_manifold} trials...")
manifold_data = rnn.get_manifold(task.stimuli, task.stim_length)

# -- Train the BCI
#bci.train()


# =======================
# == Evaluate Learning ==
# =======================
# -- Calculate MSE
# MSE must be computed retrospectively using the trained BCI.
# ... TODO
# -- Calculate Manifold Surface Area
# ... TODO


# =======================
# == Visualize Results ==
# =======================
# -- Plot MSE
# ...TODO
# -- Plot Manifold
get_manifold_plot(manifold_data['activity_proj']).show()
# -- Scree Plot
get_scree_plot(manifold_data['eigenvals']).show()
# -- TODO (Other Plots?)

Training the RNN for 80 trials...


100%|██████████| 80/80 [06:33<00:00,  4.91s/it]


Computing manifold over 80 trials...


100%|██████████| 50/50 [00:03<00:00, 13.58it/s]
