In [None]:
import numpy as np
import pandas as pd

from tqdm import tqdm 
from typing import List
from pathlib import Path

import matplotlib.pyplot as plt
from matplotlib import cm

import torch

from api_neurotask import *
from utils import *

import warnings
warnings.filterwarnings("ignore")

### Load and Filter a specific Dataset

Before diving into the analysis, it's essential to load the dataset and preprocess it. In this step, we'll filter out non-reward trials to focus our analysis on the relevant data.\
Remove trials with outcomes: Aborted (A), Incomplete (I), Failed (F)\
Also, return the bin size of the dataset in ms\

In [None]:
dataset, bin_size = load_and_filter_parquet('./data/6_1_Churchland3_Maze.parquet', ['A', 'I','F'])
groups = dataset.groupby(['session', 'animal'])
dataset

In [None]:
for column in dataset.columns:
    if "Neuron" not in column:
        print(column)

In [None]:
dataset.groupby(['session', 'animal'])['trial_id'].nunique().reset_index(name='unique_trials_per_group')

In [None]:
dataset.groupby(['session', 'animal', 'result'])['trial_id'].nunique().reset_index(name='n_trial')

In [None]:
dataset.groupby(['session', 'animal', 'correct_reach'])['trial_id'].nunique().reset_index(name='n_trial')

In [None]:
dataset.groupby(['session', 'animal', 'maze_conditions'])['trial_id'].nunique().reset_index(name='n_trial_per_condition')

In [None]:
n_trials = dataset.groupby(['session', 'animal', 'trial_id']).ngroups
n_trials

In [None]:
n_neurons = len([col for col in dataset.columns if col.startswith('Neuron')])
n_neurons

In [None]:
plot_rastor(dataset, session_id=3, animal_id=1, trial_id=66, behavior_to_plot='hand_vel')

### 2. Rebin Data for Analysis

The `rebin` function is designed to rebin a dataset by aggregating data points into larger bins based on a specified bin size. Here's a brief overview of how to use it:

The `rebin` function takes the following parameters:
- `dataset1`: The DataFrame containing the data to be rebinned.
- `prev_bin_size`: The original bin size of the data.
- `new_bin_size`: The desired bin size to aggregate data points into. This is the new bin size you want the data to be rebinned to.
- `reset` (optional): A boolean indicating whether to reset the index of the resulting DataFrame.

When called, the function aggregates data points within each bin based on the specified aggregation functions.\
The aggregation functions are determined based on the column names of the input DataFrame. For spiking data, the function aggregates by summing the values. For columns related to behavior, the function applies a custom decimation function to downsample the data. 

In [None]:
# Rebin the dataset with a new bin size (in ms).
binsize = 20
dataset = rebin(dataset, prev_bin_size=bin_size, new_bin_size=binsize)
dataset

In [None]:
plot_rastor(dataset, session_id=3, animal_id=1, trial_id=666, behavior_to_plot='hand_vel')

### 3. Align to specific event

Before using the `align_event` function, it's essential to understand its purpose and how it operates. This function is designed to align events within a DataFrame based on a specified start event marker. Here's a brief overview of how to use it:

The `align_event` function takes the following parameters:
- `df`: The DataFrame containing the data.
- `bin_size`: the bin size of the data in ms.
- `start_event`: The event marker indicating the start of a trial or session.
- `offset_min` (optional): The minimum offset (in ms) to consider before the start_event.
- `offset_max` (optional): The maximum offset (in ms) to consider after the start_event.


Please note that for Dataset 1, it's not possible to align events since it doesn't contain event information.\
Also note, after the alignment some trials (those are on the two ends of each (session, animal) group) will be trimmed, so we would want to get rid of those.

In [None]:
event_cols = [col for col in dataset.columns if col.startswith('Event')]
event_cols

In [None]:
event_bins = get_event_bins(dataset, session_id=3, animal_id=1)
event_bins

In [None]:
trials_len = [len(trial[1]) for trial in dataset.groupby(['session', 'animal', 'trial_id'])]

plt.hist(trials_len, bins=40, edgecolor='gray', alpha=0.7)
plt.xlabel("trial length (bins)")
plt.ylabel("num of trials")
plt.show()

In [None]:
plot_event_bins_dist(event_bins)

In [None]:
align_at = 'EventMovement_start' # The event to align the trials around (the column name)
offset_min = -400 #ms
offset_max = 580 #ms

events = event_bins[align_at].values

bins_before = -offset_min // binsize
bins_after = offset_max // binsize
event_bin = bins_before

trial_length = (-offset_min) + offset_max + 1 * binsize # Total trial length, ms, (including the bin we are aligning around)
n_bins = trial_length // binsize

dataset_aligned = align_event(dataset, align_at, bin_size=20, offset_min=offset_min, offset_max=offset_max)
print(f'length of aligned trials = {n_bins} bins x {binsize} ms = {trial_length} ms\n')

In [None]:
"""
Make sure all trials are the same length after aligment.
"""
len_counts = get_trials_len_count(dataset_aligned, session_id=3, animal_id=1)
len_counts

## Behavioral Analysis

In [None]:
trial_idx = dataset['trial_id'].drop_duplicates().values

cor_trials = dataset[dataset['correct_reach']]
inc_trials = dataset[~dataset['correct_reach']]

cor_trials_idx = cor_trials['trial_id'].drop_duplicates().values
inc_trials_idx = cor_trials['trial_id'].drop_duplicates().values

In [None]:
spikes = get_spikes(dataset_aligned, session_id=3, animal_id=1)
spikes = torch.tensor(spikes)
spikes.shape

In [None]:
hand_vel = get_reaches(dataset_aligned, session_id=3, animal_id=1, behavior='hand_vel')
hand_pos = get_reaches(dataset_aligned, session_id=3, animal_id=1, behavior='hand_pos')

hand_vel = torch.tensor(hand_vel)
hand_pos = torch.tensor(hand_pos)

print(hand_vel.shape)
print(hand_pos.shape)

In [None]:
reach_conds = get_maze_conditions(dataset, session_id=3, animal_id=1)
reach_conds = torch.tensor(reach_conds)

reach_conds.shape

In [None]:
avg_reaches, conds_std = get_conds_average_reach(hand_vel, reach_conds)
avg_reaches.shape

In [None]:
active_target_pos = torch.tensor(dataset.drop_duplicates(subset='trial_id')[['target_pos_x', 'target_pos_y']].values)#[succ_trials_idx]
active_target_pos.shape

In [None]:
conds_target_pos = get_conds_target_pos(reach_conds, active_target_pos)
conds_target_pos.shape

In [None]:
plot_unique_target_pos(dataset)

In [None]:
plot_cond_avg_reaches(hand_vel, reach_conds)

In [None]:
plot_single_reaches(hand_vel, active_target_pos, n_trials_to_plot=80)

In [None]:
cor_reaches, inc_reaches = get_correct_incorrect_reaches_in_cond(dataset_aligned, cond=51, bhv='hand_vel')

print(cor_reaches.keys())
print(len(cor_reaches['indcs']))
print(len(inc_reaches['indcs']))

In [None]:
plot_reaches_in_cond(dataset, avg_reaches, cond_to_plot=51, behavior_to_plot='hand_vel')

In [None]:
plot_reaches_in_conds(dataset_aligned, hand_vel, binsize, reach_conds, bins_before, align_at='move\nonset', n_conds_to_plot=5, behavior_to_plot='hand_vel')