In [452]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
pd.options.mode.chained_assignment = None  # default='warn'
from util import MatfileIO, Bunch
from preprocessing import GaussianSmoothing

from sklearn.model_selection import StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report

In [342]:
# Preprocessing
def preprocess_unitTimeBin(unitTimeBin, bin_id, g_sigma, g_length, g_decimal_out): 
    utb = unitTimeBin[:, bin_id]
    sm_obj = GaussianSmoothing(utb, sigma=g_sigma, length=g_length, axis='row', decimal_out=g_decimal_out)
    conv = sm_obj.conv()
    return conv # conv.flatten()

In [296]:
def get_params():
    # Preprocessing - params
    p = Bunch()
    p.bin_width = 0.05
    p.time_win = [-1, 2]
    p.time_bins = np.arange(p.time_win[0], p.time_win[1], p.bin_width)
    p.bin_I = p.time_bins <= 1  # to include pre-reach epoch only
    p.g_sigma = 5
    p.g_length = 20
    p.g_decimal_out = 3
    p.bl_I = Bunch()
    p.bl_I['1'], p.bl_I['5'] = 'lelo', 'lelo'  # left/low
    p.bl_I['2'], p.bl_I['6'] = 'lehi', 'lehi'  # left/high
    p.bl_I['3'], p.bl_I['7'] = 'rilo', 'rilo'  # right/low
    p.bl_I['4'], p.bl_I['8'] = 'rihi', 'rihi'  # right/high
    return p

### Load and extract data

In [271]:
filepath = "/Volumes/dudmanlab/junchol/js2p0/WR40_082019/Matfiles"
io_cls = MatfileIO(filepath)
df = io_cls.extract_dataframe()
df['blockNum'] = df.blockNum.apply(lambda x: str(x)) 

In [266]:
df.head()

Unnamed: 0,blockNum,trialType,ctx,str,cg
0,1,to,[[[[[0. 1. 1. 0. 0. 0. 0. 2. 2. 0. 0. 0. 0. 0....,[[[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0....,[[[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0....
1,1,to,[[[[[2. 2. 0. 2. 2. 5. 3. 3. 3. 3. 0. 1. 0. 0....,[[[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0....,[[[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0....
2,1,sp,[[[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0....,[[[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0....,[[[[[0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0....
3,1,sp,[[[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0....,[[[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0....,[[[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0....
4,1,sp,[[[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0....,[[[[[0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0....,[[[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0....


In [4]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 140 entries, 0 to 139
Data columns (total 5 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   blockNum   140 non-null    int8  
 1   trialType  140 non-null    object
 2   ctx        140 non-null    object
 3   str        140 non-null    object
 4   cg         140 non-null    object
dtypes: int8(1), object(4)
memory usage: 4.6+ KB


In [344]:
# Use successful trials only
success_I = df.trialType.apply(lambda x: x == 'sp')
df_s = df.loc[success_I, :]

  success_I = df.trialType.apply(lambda x: x == 'sp')


### Neural Data
- df rows correspond to trials. 
- df columns correspond to different variables. 
    - blockNum: block IDs. 
    - trialType: trial types.  
        - 'to': timeout. 
        - 'sp': successful pull. 
        - 'ps': push.  
        - 'pmpp': premature pull & push. 
- df.ctx, df.str, df.cg contain binned spike counts per trial.
    - Each np.matrix is organized as # neurons by # time bins. 
    - By default, each np.matrix spans 3 s epoch (1 s pre and 2 s post event to which time bins are aligned) with the bin width of 50 ms (60 bins).

### Preprocessing

In [254]:
# Temporal information of neural data
p = get_params()

In [345]:
df_s.loc[:, 'ctx'] = df_s.ctx.apply(lambda x: preprocess_unitTimeBin(x, p.bin_I, p.g_sigma, p.g_length, p.g_decimal_out))  
df_s.loc[:, 'str'] = df_s.str.apply(lambda x: preprocess_unitTimeBin(x, p.bin_I, p.g_sigma, p.g_length, p.g_decimal_out))
df_s.loc[:, 'cg'] = df_s.cg.apply(lambda x: preprocess_unitTimeBin(x, p.bin_I, p.g_sigma, p.g_length, p.g_decimal_out)) 

df_s['bl_type'] = df_s.blockNum.apply(lambda x: p.bl_I[x])

In [258]:
clf = RandomForestClassifier(n_estimators=1000, max_depth=100, random_state=42)

X = np.vstack(df_s.cg.to_numpy())
y = df_s.bl_type.to_numpy()

skf = StratifiedKFold(n_splits=5)

for train_idx, test_idx in skf.split(X, y):
    clf.fit(X[train_idx, :], y[train_idx])
    
    # prediction on test set
    y_pred = clf.predict(X[test_idx, :])
    print(np.around(clf.score(X[test_idx, :], y[test_idx]), 4))
    
    report = classification_report(y[test_idx], y_pred, output_dict=True)
    df_report = pd.DataFrame(report).transpose()

0.5238
              precision    recall  f1-score   support

        lehi       1.00      0.33      0.50         6
        leli       0.50      1.00      0.67         8
        rihi       1.00      0.25      0.40         4
        rili       0.00      0.00      0.00         3

    accuracy                           0.52        21
   macro avg       0.62      0.40      0.39        21
weighted avg       0.67      0.52      0.47        21

0.5714
              precision    recall  f1-score   support

        lehi       0.75      0.60      0.67         5
        leli       0.54      0.88      0.67         8
        rihi       0.00      0.00      0.00         4
        rili       0.50      0.50      0.50         4

    accuracy                           0.57        21
   macro avg       0.45      0.49      0.46        21
weighted avg       0.48      0.57      0.51        21



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


0.5238
              precision    recall  f1-score   support

        lehi       0.40      0.80      0.53         5
        leli       0.67      0.75      0.71         8
        rihi       1.00      0.25      0.40         4
        rili       0.00      0.00      0.00         4

    accuracy                           0.52        21
   macro avg       0.52      0.45      0.41        21
weighted avg       0.54      0.52      0.47        21

0.4762
              precision    recall  f1-score   support

        lehi       0.75      0.60      0.67         5
        leli       0.40      0.25      0.31         8
        rihi       0.33      0.75      0.46         4
        rili       0.67      0.50      0.57         4

    accuracy                           0.48        21
   macro avg       0.54      0.53      0.50        21
weighted avg       0.52      0.48      0.47        21

0.5238
              precision    recall  f1-score   support

        lehi       1.00      0.40      0.57         5


In [91]:
train_bl_type = df_s.iloc[train_index]['bl_type']

In [95]:
train_bl_type.describe()

count       70
unique       4
top       leli
freq        27
Name: bl_type, dtype: object

In [98]:
train_bl_type == 'leli'

2     True
3     True
4     True
5     True
6     True
      ... 
91    True
92    True
93    True
94    True
95    True
Name: bl_type, Length: 70, dtype: bool