# Dev New Design Matrix Generator (DMG)

**Goal**: Use this notebook to develop a new, more flexible DMG class.

**Motivation**: The previous version has a "generate_base_matrix" function and then other classes operated on that base matrix (e.g. filtering the history). The structure of the base matrix was not flexible at all and required adding lots of if else logic. Also, adding violation features to a binary model was not possible Now that I am entering a phase of wanting to change features readily, this needs to be updated.

In [95]:
import pandas as pd
import numpy as np
from multiglm.features.exp_filter import ExpFilter
from multiglm.features.design_matrix_generator import *
from multiglm.features.design_matrix_generator_PWM import *

from multiglm.data.dataset_loader import *

%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
df = DatasetLoader(animal_ids=["W051"], data_type="new_trained").load_data()

Loading data for animal ids:  ['W051']


In [114]:
config = {
    "bias": lambda df: (add_bias_column(df)),
    "s_a": lambda df: (copy(df.s_a)),
    "s_b": lambda df: (copy(df.s_b)),
    "prev_avg_stim": lambda df: prev_avg_stim(df, mask_prev_violation=True),
    "prev_violation": lambda df: (
        shift_n_trials_up(
            df.violation,
            df.session,
            shift_size=1,
        )
    ),
    "prev_avg_stim": lambda df: prev_avg_stim(df, mask_prev_violation=True),
    "prev_correct" : lambda df: prev_correct_side(df),
    "prev_choice": lambda df: (
        mask_prev_event(
            shift_n_trials_up(
                remap_values(df.choice, {0: -1, 1: 1}),
                df.session,
                shift_size=1,
            ),
            df.violation,
            df.session,
        )
    ),
    "labels": {"column_name": "choice"},
}

dmg = DesignMatrixGeneratorPWM(df, config, verbose=True)
X, y = dmg.create()

DMG: Creating data matrix with columns: dict_keys(['bias', 's_a', 's_b', 'prev_avg_stim', 'prev_violation', 'prev_choice'])
DMG: Creating labels with column: choice.
DMG: One hot encoding labels.


correct_side
0    50621
1    48302
Name: count, dtype: int64

In [113]:
X.head(10)

Unnamed: 0,choice,bias,s_a,s_b,prev_avg_stim,prev_violation,prev_choice
0,2,1,60.0,68.0,0.0,0.0,0.0
1,2,1,60.0,68.0,64.0,1.0,0.0
2,2,1,76.0,68.0,64.0,1.0,0.0
3,0,1,68.0,76.0,72.0,1.0,0.0
4,1,1,84.0,76.0,72.0,0.0,-1.0
5,2,1,60.0,68.0,80.0,0.0,1.0
6,1,1,76.0,68.0,64.0,1.0,0.0
7,1,1,84.0,76.0,72.0,0.0,1.0
8,1,1,60.0,68.0,80.0,0.0,1.0
9,0,1,76.0,84.0,64.0,0.0,1.0


In [85]:
dd = DesignMatrixGeneratorPWM(df, config, verbose=True)

dd.verbose

False

In [57]:
len(y)

85353

### Labels

In [38]:


config_labels = {"column_name": "choice", "mapping": {0: 0, 1: 1, 2: 7}}

In [39]:
labels_col = df[config_labels["column_name"]]

if "mapping" in config_labels.keys():
    labels_col = remap_values(labels_col, config_labels["mapping"])

In [41]:
labels_col = labels_col.dropna().astype(int)

In [44]:
pd.get_dummies(labels_col).to_numpy(copy=True)

array([[False, False,  True],
       [False, False,  True],
       [False, False,  True],
       ...,
       [False, False,  True],
       [False,  True, False],
       [ True, False, False]])

### Exp Filter

In [4]:
from multiglm.features.exp_filter import ExpFilter

""" 
Testing filter is the exact same for both the class directly and the function
"""

filt = ExpFilter(tau=4, column="violation", verbose=False).apply_filter_to_dataframe(df)
(filt["violation_exp"] == exp_filter_column(df["violation"], df.session, tau=4)).all()

True

In [65]:
(filt["violation_exp"] == exp_filter_column(df["violation"], df.session, tau=4)).all()

True

### Assorted Funcs

In [20]:
remap_values(df.choice, {0: -1, 2: 0}), 

0        0
1        0
2        0
3       -1
4        1
        ..
98918    1
98919    1
98920    0
98921    1
98922   -1
Name: choice, Length: 98923, dtype: int64

In [27]:
shift_n_trials_up(df.violation, df.session, shift_size=2)

0        0.0
1        0.0
2        1.0
3        1.0
4        1.0
        ... 
98918    0.0
98919    0.0
98920    0.0
98921    0.0
98922    1.0
Length: 98923, dtype: float64

In [74]:
df.violation.unique()

array([1, 0])

In [79]:
"s_a" in config.keys()

False

### Shift size & Session Masking

In [25]:
"""Testing a shift size > n trials in a session """

session_ids = [1] * 10 + [2] * 3 + [3] * 8

# Create the Series
sessions = pd.Series(session_ids)

mask = get_session_start_mask(sessions, shift_size=7)

for i in range(len(mask)):
    print(mask[i], sessions[i])

False 1
False 1
False 1
False 1
False 1
False 1
False 1
True 1
True 1
True 1
False 2
False 2
False 2
False 3
False 3
False 3
False 3
False 3
False 3
False 3
True 3


In [69]:
get_prev_event_mask(df.violation, df.session)

0         True
1        False
2        False
3        False
4         True
         ...  
98918     True
98919     True
98920     True
98921    False
98922     True
Length: 98923, dtype: bool

In [19]:
config = {
    "s_a_shifted": lambda bb: (
        mask_prev_event(
            shift_n_trials_up(
                bb["s_a"],
                bb.session,
                shift_size=1,
            ),
            bb.violation,
            bb.session,
        )
    )
}

dmg = DesignMatrixGeneratorPWM(df, config)
dmg.create()

No labels found in config, only creating data matrix.


Unnamed: 0,choice,s_a_shifted
0,2,0.0
1,2,0.0
2,2,0.0
3,0,0.0
4,1,68.0
...,...,...
98918,1,76.0
98919,1,91.0
98920,2,60.0
98921,1,0.0


Next steps:

1. base label class
2. PWM specific data fxs
3. PWM specific label fxs