# Develop Design Matrix Clasess

Written by Jess Breda

**Goal**: The goal of this notebook is to write a parent design matrix class that takes into account the methods that are re-used across design matrix generation. 

For example, sa, sb columns are always present, prev_violation column, train/test split functions etc. These functions can then be inherited (rather than copy and pasted over and over) into experiment specific design matrices.

General breakdown of type of DMs I have made:
* stable columns, sweep over sigma (ss)
* one column changes (e.g.prev_violation filter for different taus) + ss
* multiple columns change (e.g. model comparison with binary and multi)

This notebook will allow for testing of this. I am also working on interaction terms so I may do some testing of that here and will write in more detail if so.

In [1]:
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import pathlib
import sys
from sklearn.model_selection import train_test_split


[
    sys.path.append(str(folder))
    for folder in pathlib.Path("../src/").iterdir()
    if folder.is_dir()
]

from exp_filter import ExpFilter
from null_model import NullModel
from get_rat_data import get_rat_viol_data
from fitting_utils import get_train_test_sessions, get_taus

sns.set_context("talk")
%load_ext autoreload
%autoreload 2

In [52]:
class DesignMatrixGenerator:
    def __init__(self, verbose=False):
        # Todo maybe add data? or animal id?
        pass

    @staticmethod
    def normalize_column(col):
        return (col - col.mean()) / col.std()

    @staticmethod
    def one_hot_encode_labels(df):
        """
        Function to one-hot encode choice labels for each trial. In
        the case of the rat data, this is a 3-dimensional vector
        left, right or violation (C = 3). Note this function is
        flexible to the number of choice options (C).

        params
        ------
        df : pd.DataFrame
            dataframe with columns `choice` likely generated by
            get_rat_viol_data()

        returns
        -------
        Y : np.ndarray, shape (N, C), where typically C = 3
            one-hot encoded choice labels for each trial as left,
            right or violation: [[1 0 0] , [0 1 0], [0 0 1]]
        """

        Y = pd.get_dummies(df["choice"], "choice", dummy_na=True).to_numpy(copy=True)
        return Y

    @staticmethod
    def encode_binary_lr_labels(df):
        """
        Function to encode choice labels for each trial as binary
        left or right (C = 2) and drop data for violation trials

        params
        ------
        df : pd.DataFrame
            dataframe with columns `choice` likely generated by
            get_rat_viol_data() or get_rat_data()

        returns
        -------
        y : np.ndarray, shape (N, 1)
            binary encoded labels with 0 for left and 1 for right
        """

        y = df["choice"].dropna().astype(int).to_numpy()
        return y

    @staticmethod
    def exp_filter_column(X, tau, column, verbose=False):
        """
        Function to apply exponential filter to a column in a dataframe
        and drop the original column

        params
        ------
        X : pd.DataFrame
            dataframe with column to be filtered
        tau : float
            time constant for exponential filter.
        column : str
            column to apply filter to
        verbose : bool (default=False)
            whether to print out progress

        returns
        -------
        X_filtered : pd.DataFrame
            dataframe with filtered column and original column dropped

        """
        X_filtered = ExpFilter(
            tau, column=column, verbose=verbose
        ).apply_filter_to_dataframe(X)

        X_filtered.drop(columns=[column], inplace=True)

        return X_filtered

    def generate_base_matrix(self, df, model_type="multi", return_labels=True):
        """
        Function to generate "base" design matrix given a dataframe
        with violations tracked. In this case base means:
            - normalized s_a, s_b columns
            - prev_violation column (multi only)
            - prev_sound_avg column
            - prev_correct column
            - prev_choice column
            - bias column
            - session number column (for merging)

        params
        ------
        df : pd.DataFrame
            dataframe with columns `s_a` `s_b` `session`, `violation`
            `correct_side` and `choice`, likely generated by
            get_rat_viol_data() or get_rat_date()
        model_type : str (default="multi")
            model design matrix will be used for. If multi, returns
            one-hot encoded labels and has a prev_violation column.
            If binary, returns binary encoded labels w/o prev_violation
        return_labels : bool (default=True)
            whether or not to return labels with design matrix

        returns
        -------
        X : pd.DataFrame, shape (N, 8) if multi, (N, 7) if binary
            design matrix with regressors for s_a, s_b, prev_violation,
            prev sound avg, correct side, choice info, bias and session id
            (for merging). If model_type is binary then prev_violation
            column is removed
        Y : np.ndarray, shape (N, 3) if multi-class (N, ) if binary
            when return_labels=True.
        """
        # todo add check to make sure only 1 animal in df

        # Initialize
        X = pd.DataFrame()
        stim_cols = ["s_a", "s_b"]
        X["session"] = df.session

        # Masks- if first trial in a session and/or previous trial
        # was a violation, "prev" variables get set to 0
        self.session_boundaries_mask = df["session"].diff() == 0
        X["prev_violation"] = (
            df["violation"].shift() * self.session_boundaries_mask
        ).fillna(0)
        self.prev_violation_mask = X["prev_violation"] == 0

        # Stimuli (s_a, s_b) get normalized
        for col in stim_cols:
            X[stim_cols] = self.normalize_column(df[stim_cols])

        # Average previous stimulus (s_a, s_b) loudness
        X["prev_sound_avg"] = df[stim_cols].shift().mean(axis=1)
        X["prev_sound_avg"] = self.normalize_column(X["prev_sound_avg"])
        X["prev_sound_avg"] *= self.session_boundaries_mask * self.prev_violation_mask

        # Prev correct side (L, R) (0, 1) -> (-1, 1),
        X["prev_correct"] = (
            df.correct_side.replace({0: -1}).astype(int).shift()
            * self.session_boundaries_mask
            * self.prev_violation_mask
        )

        # prev choice regressors (L, R, V) (0, 1, Nan) -> (-1, 1, 0),
        X["prev_choice"] = (
            df.choice.replace({0: -1}).fillna(0).astype(int).shift()
            * self.session_boundaries_mask
        )

        # if binary, drop the violation trials and the prev_violation column
        if model_type == "binary":
            X = X[df["violation"] != 1].reset_index(drop=True)
            X.drop(columns=["prev_violation"], inplace=True)

        X.fillna(0, inplace=True)  # fill nans that come from shift()
        X.insert(0, "bias", 1)  # add bias column

        if return_labels:
            if model_type == "binary":
                # make choice vector, drop nans (violations) to match X
                Y = self.encode_binary_lr_labels(df)
            elif model_type == "multi":
                Y = self.one_hot_encode_labels(df)
            return X, Y
        else:
            return X

In [50]:
class DesignMatrixGeneratorInteractions(DesignMatrixGenerator):
    def __init__(self, model_type):
        super().__init__()
        self.model_type = model_type

    def generate_base_matrix(self, df, return_labels=True):
        return super().generate_base_matrix(df, self.model_type, return_labels)

    def exp_filter_column(self, X, tau, column):
        return super().exp_filter_column(X, tau, column)

    @staticmethod
    def add_interaction_terms(X, interaction_pairs):
        """
        Add interaction terms to the design matrix X.

        params
        ------
        X : pd.DataFrame
            design matrix to add interaction terms to
        interaction_pairs : list of tuples
            each tuple contains the names of two columns to interact

        returns
        -------
        X_copy : pd.DataFrame
            design matrix with interaction terms added
        """

        X_copy = X.copy()

        for pair in interaction_pairs:
            col1, col2 = pair
            interaction_term = f"{col1}_x_{col2}"
            X_copy[interaction_term] = X_copy[col1] * X_copy[col2]

        return X_copy

    def generate_design_matrix(self, df, tau, column, interaction_pairs):
        """
        Function to generate design matrix with interaction terms
        and exponential filter applied to a column.

        params
        ------
        df : pd.DataFrame
            dataframe with columns `s_a` `s_b` `session`, `violation`
            `correct_side` and `choice`, likely generated by
            get_rat_viol_data() or get_rat_data()
        tau : float
            time constant for exponential filter. if tau = 0 or None,
            no filtering is applied and column is not dropped.
        column : str
            column to apply filter to
        interaction_pairs : list of tuples
            each tuple contains the names of two columns to interact
        """

        X, y = self.generate_base_matrix(df, return_labels=True)

        if tau:
            X = self.exp_filter_column(X, tau=tau, column=column)

        X = self.add_interaction_terms(X, interaction_pairs)

        return X, y

In [58]:
def create_violation_interaction_pairs(tau, cols):
    interaction_pairs = [(f"prev_violation_exp_{tau}", col) for col in cols]
    return interaction_pairs

In [61]:
tau = 7
X, Y = DesignMatrixGeneratorInteractions(model_type="multi").generate_design_matrix(
    get_rat_viol_data(["W075"]),
    tau=tau,
    column="prev_violation",
    interaction_pairs=create_violation_interaction_pairs(tau, cols=["s_a", "s_b"]),
)

returning viol data for ['W075']


In [62]:
X

Unnamed: 0,bias,session,s_a,s_b,prev_sound_avg,prev_correct,prev_choice,prev_violation_exp_7,prev_violation_exp_7_x_s_a,prev_violation_exp_7_x_s_b
0,1,1,0.000000,0.000000,0.000000,0.0,0.0,0.000000,0.000000,0.000000
1,1,1,0.000000,0.000000,0.000000,-1.0,-1.0,0.000000,0.000000,0.000000
2,1,1,0.000000,0.000000,0.000000,-1.0,1.0,0.000000,0.000000,0.000000
3,1,1,0.000000,0.000000,0.000000,-1.0,1.0,0.000000,0.000000,0.000000
4,1,2,0.000000,0.000000,0.000000,-0.0,0.0,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...
69932,1,199,-0.819730,0.006145,1.342463,1.0,-1.0,0.165893,-0.135988,0.001019
69933,1,199,1.628252,0.822674,-0.000000,-0.0,0.0,0.277834,0.452385,0.228567
69934,1,199,0.812258,0.006145,0.000000,0.0,0.0,0.374874,0.304494,0.002304
69935,1,199,1.628252,0.822674,0.448350,1.0,-1.0,0.324970,0.529133,0.267344


In [31]:
class TrainTestSplitter:
    def __init__(self, test_size=0.2, random_state=None):
        """
        Initialize the TrainTestSplitter class.

        params
        ------

        test_size : float (default=0.2)
            proportion of the data to include in the test set.
        random_state : int (default=None)
            random seed for reproducibility.
        """
        self.test_size = test_size
        self.random_state = random_state

    def get_sessions_for_split(self, df):
        """
        This function will compute a list of sessions to use for training
        and testing respectively and store them as attributes of the class.

        params:
        -------
        df : pd.DataFrame
            dataframe with `sessions` column
        test_size : float
            Proportion of data to use for test set

        computes:
        --------
        train_sessions : list
            list of sessions to use for training
        test_sessions : list
            list of sessions to use for testing
        """
        unique_sessions = df["session"].unique()
        self.train_sessions, self.test_sessions = train_test_split(
            unique_sessions, test_size=self.test_size, random_state=self.random_state
        )

        return None

    def apply_session_split(self, X, Y, filter_violations=False):
        """
        Function to apply session train/test split computed by
        get_sessions_for_split() to design matrix and labels.

        params
        ------
        X : pd.DataFrame, shape (N, D + 2)
            design matrix with bias column and session column
        Y : np.ndarray, shape (N, C) or (N, )
            one-hot encoded choice labels for mutli class (l, r, v) or
            binary class (l, r) respectively
        filter_violations : bool (default=False)
            whether to filter out violation trials from the test set
            for the multi-class case. this is used when running
            model comparision between binary and multi.

        returns
        -------
        X_train : pd.DataFrame, shape (N_train, D + 1)
            design matrix for training set
        X_test : pd.DataFrame, shape (N_test, D + 1)
            design matrix for test set
        Y_train : np.ndarray, shape (N_train, C) or (N_train, )
            one-hot encoded  or binary encoded choice labels
            for training set
        Y_test : np.ndarray, shape (N_test, K) or (N_test, )
            on-hot encoded or binary encoded choice labels for
            test set. K = 2 if drop_violations=True, K = 3 otherwise
        """
        ## Checks
        if not "session" in X.columns:
            raise ValueError("session column not found in X, can't split!")

        if not hasattr(self, "train_sessions"):
            raise ValueError("train_sessions and test_sessions not defined!")

        # Filter rows based on session values for X
        X_train = X[X["session"].isin(self.train_sessions)].copy()
        X_test = X[X["session"].isin(self.test_sessions)].copy()

        # Filter rows based on session values for Y
        # Assuming the index of Y corresponds to that of X
        Y_train = Y[X["session"].isin(self.train_sessions).values]
        Y_test = Y[X["session"].isin(self.test_sessions).values]

        X_train.drop(columns=["session"], inplace=True)
        X_test.drop(columns=["session"], inplace=True)

        self.X_train = X_train
        self.X_test = X_test
        self.Y_train = Y_train
        self.Y_test = Y_test

        # Additional code to filter out violations if flag is set
        if filter_violations:
            self.filter_violations_from_test_set()
            return (
                self.X_train,
                self.filtered_X_test,
                self.Y_train,
                self.filtered_Y_test,
            )

        return self.X_train, self.X_test, self.Y_train, self.Y_test

    def filter_violations_from_test_set(self):
        """
        Filters out the violation trials from Y_test and X_test. For
        the multi-class case to allow for comparison with the binary
        case on only L & R trials.

        Assumes that the violation is encoded as [0, 0, 1] in Y_test.
        """
        violation_filter = np.all(self.Y_test == np.array([0, 0, 1]), axis=1)
        non_violation_idx = np.where(~violation_filter)[0]

        self.filtered_Y_test = self.Y_test[non_violation_idx]
        self.filtered_X_test = self.X_test.iloc[non_violation_idx]

        return None

In [65]:
# df has been filtered for correct animal & stage!
dmg = DesignMatrixGenerator()
tts = TrainTestSplitter()
tts.get_sessions_for_split(df)
X_base, y = dmg.generate_base_matrix(df, model_type="multi")
# make custom enhancements to XX here with child class!
xtr, xt, ytr, yt = tts.apply_session_split(X_base, y, filter_violations=False)