## Training RNN network on single trial data

In [1]:
# imports
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../")

import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import os
import mat73
from IPython.display import display, Markdown

import pyaldata as pyal
import pylab
from sklearn.decomposition import PCA
from sklearn.cross_decomposition import CCA
from tools.curbd import curbd
import pylab
import importlib

from tools.dsp.preprocessing import preprocess
from tools.rnn_and_curbd.RNN_functions import *

np.random.seed(44)

In [2]:
# data_dir = "/data/bnd-data/raw/M044/M044_2024_12_04_09_30"
# mat_file = "M044_2024_12_04_09_30_pyaldata.mat"

In [3]:
data_dir = "/Users/zosiasus/Documents/M044/"
mat_file = "M044_2024_12_04_09_30_pyaldata.mat"
fname = os.path.join(data_dir, mat_file)

df = pyal.mat2dataframe(fname, shift_idx_fields=True)
mouse = mat_file.split('_')[0]

field values_before_camera_trigger could not be converted to int.
field idx_before_camera_trigger could not be converted to int.
array field all_KSLabel could not be converted to int.


## PREPROCESS DATAFRAME

In [4]:
df_ = preprocess(df, only_trials=True)
df_ = pyal.select_trials(df_, "idx_trial_end > 30365")  # Remove first 5 minutes because the switch was off

# Separate columns of 'all_spikes' to M1 and Dls (dorsolateral striatum)
brain_areas = ["Dls_rates", "M1_rates"]
df_["M1_rates"] = [df_["all_rates"][i][:,300:] for i in range(len(df_))]
df_["Dls_rates"] = [df_["all_rates"][i][:,0:300] for i in range(len(df_))]



Combined every 3 bins
Resulting all_spikes ephys data shape is (NxT): (474, 133)


### useful variables

In [5]:
# perturbation time
perturbation_time = df_.idx_sol_on[0]
perturbation_time_seconds = perturbation_time * df_.bin_size[0]

# solenoid angles
sol_angles: list = df_.values_Sol_direction.unique()
sol_angles.sort()


## RUN RNN AND CURBD FOR EVERY SOLENOID ANGLE

In [6]:
for angle in sol_angles:
    print(f"RUNNING RNN TRAINING AND CURBD FOR TRIAL OF TYPE: {angle}")
    # CHOOSE ONE TYPE OF TRIAL
    angle_df = pyal.select_trials(df_, f"values_Sol_direction == {angle}")  

    num_trials = len(angle_df)

    # CONCATINATE IN TIME all trials 
    concat_trials = pyal.concat_trials(angle_df, signal = "all_rates")

    # PREPARE FOR RNN
    reset_points, trial_len = get_reset_points(angle_df, concat_trials, brain_areas)
    activity = np.transpose(concat_trials)
    regions = get_regions(angle_df, brain_areas)

    print(f"Building {len(regions)} region RNN network")
    print(f"Regions: {[region[0] for region in regions]}\n")

    # RNN TRAINING
    rnn_model, rnn_accurancy_figure = RNN(activity, reset_points, regions, angle_df, mouse, 
                                                        dtFactor = 1, ampInWN= 0.01, tauRNN = 0.2, nRunTrain=100 )

    # RNN ASSESMENT with PCA and CCA
    scores, variance_figure, PCA_figure, CCA_figure = PCA_and_CCA(concat_trials, rnn_model, 
                                                                  num_components = 10, trial_num = num_trials, mouse_num = mouse, 
                                                                  printing=True)
    
    # CURBD
    curbd_arr, curbd_labels = curbd.computeCURBD(rnn_model)
    n_regions = curbd_arr.shape[0]

    # PLOTTING CURBD
    all_currents, all_currents_labels = format_for_plotting(curbd_arr, curbd_labels, n_regions, trial_len, 
                                                                          num_trials)
    
    region_currents_fig = plot_region_currents(all_currents, all_currents_labels, perturbation_time_seconds, 
                                                             angle_df.bin_size[0], num_trials, mouse, )

    all_currents_fig = plot_all_currents(all_currents, all_currents_labels, perturbation_time_seconds, num_trials, 
                                                       angle_df.bin_size[0], mouse)
    
    # a bit of memory management :)
    del angle_df, num_trials, concat_trials, reset_points, trial_len , activity, regions, rnn_model, curbd_arr, curbd_labels, n_regions, 
    all_currents, all_currents_labels, region_currents_fig, all_currents_fig
    
    print(f"finished run for trial type: {angle}\n\n")
    
    

RUNNING RNN TRAINING AND CURBD FOR TRIAL OF TYPE: 0


TypeError: get_reset_points() missing 1 required positional argument: 'dtFactor'

In [None]:
print("DONE!")