## New Dataset Cleaning

**Goal**: This notebook will create functions that will clean the new datasets coming in from extract_AthenaDelayComp_dat.m. I want to manipulate the new df such that it can easily be used with existing design matrix generators.

For context, this is a script written by me and Chuck to get violation data from all sessions of a trained PWM animal (not just up to session 200). It appears that Athena encoded timeout trials as whenever the "wait_for_cpoke" Tuped after 2 minutes. In the old dataset, these were counted as violations. I want to remove them from the current dataset and reset the trial counters. I will store information though on how many timeouts in a row. Finding sessions with high timeout rates would be good- they will have a large change in total trial counts.

In [4]:
from multiglm.data.get_old_rat_data import *
import pandas as pd
import numpy as np
from multiglm.data import ANIMAL_IDS, COLUMN_RENAME
from multiglm.data.dataset_cleaner import DatasetCleaner



%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Test class developed:

In [5]:
dc = DatasetCleaner(
    "W075",
    column_rename_map=COLUMN_RENAME,
    load_path="/Users/jessbreda/Desktop/github/animal-learning/data/raw/",
    save_out=True,
)
dc.run()

** RUNNING W075 **


Unnamed: 0,animal_id,session_date,session_file_counter,rig_id,training_stage,s_a,s_b,hit,violation,trial_not_started,...,fixation_time,trial_start_wait_time,l_water_vol,r_water_vol,antibias_beta,antibias_right_prob,using_psychometric_pairs,choice,session,n_prev_trial_not_started
0,W075,2015-07-16,1,18,1,,,0.0,0,False,...,0.01,200,18.000000,18.000000,0,0.000000,0,0,1,0.0
1,W075,2015-07-16,1,18,1,,,1.0,0,False,...,0.00,200,18.000001,18.000001,0,0.000000,0,1,1,0.0
2,W075,2015-07-16,1,18,1,,,1.0,0,False,...,0.00,200,18.000004,18.000004,0,0.000000,0,1,1,0.0
3,W075,2015-07-16,1,18,1,,,1.0,0,False,...,0.00,200,18.000014,18.000014,0,0.000000,0,1,1,0.0
4,W075,2015-07-16,1,18,1,,,1.0,0,False,...,0.00,200,18.000033,18.000033,0,0.000000,0,1,1,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
276174,W075,2017-08-06,678,18,4,84.0,92.0,1.0,0,False,...,5.40,200,33.650125,33.650125,3,0.411211,1,1,678,0.0
276175,W075,2017-08-06,678,18,4,76.0,68.0,1.0,0,False,...,3.40,200,33.663689,33.663689,3,0.415209,1,1,678,0.0
276176,W075,2017-08-06,678,18,4,76.0,68.0,1.0,0,False,...,3.40,200,33.677196,33.677196,3,0.414947,1,1,678,0.0
276177,W075,2017-08-06,678,18,4,84.0,92.0,1.0,0,False,...,5.40,200,33.690646,33.690646,3,0.414695,1,1,678,0.0


---
## DEVELOPMENT CODE 

below is the code that went into making the class DatasetCleaner. First, we start by making some simulated data for testing the "trial_not_started" accounting. The goal of this analysis is for trials that were started, if there were any previous trials not started, count how many and document them. Not this is only for consecutive trials not started.


In [10]:
import pandas as pd

# Your data
data = {
    "session_date": [
        160802,
        160802,
        160802,
        160802,
        160802,
        160802,
        160802,
        160802,
        160802,
        160802,
        160802,
    ],
    "trial_not_started": [
        False,
        True,
        False,
        True,
        False,
        False,
        True,
        True,
        False,
        False,
        False,
    ],
}

# Create the DataFrame
df = pd.DataFrame(data)

# Step 1: Calculate the cumulative sum of 'trial_not_started', resetting when a trial is started
df["trial_not_started_cumsum"] = df["trial_not_started"].cumsum() - df[
    "trial_not_started"
].cumsum().where(~df["trial_not_started"]).ffill().fillna(0)

# Step 2: Shift the cumulative values down to align them with the next trial
df["n_prev_trial_not_started"] = df["trial_not_started_cumsum"].shift(fill_value=0)


# Step 3: Remove the trials where 'trial_not_started' is True
df_filtered = df[df["trial_not_started"] == False].copy()

# # Adding 'ground_truth' for comparison (not part of the solution, just for verification)
df_filtered["ground_truth"] = [0, 1, 1, 0, 2, 0, 0]

# Drop unnecessary columns
df_filtered.drop(
    ["trial_not_started", "trial_not_started_cumsum"], axis=1, inplace=True
)

# Show the filtered dataframe to verify the results
df_filtered

Unnamed: 0,session_date,n_prev_trial_not_started,ground_truth
0,160802,0.0,0
2,160802,1.0,1
4,160802,1.0,1
5,160802,0.0,0
8,160802,2.0,2
9,160802,0.0,0
10,160802,0.0,0


Okay great- this seems to work with the ground truth data. Let's see if we can do it with a dataset that has multiple sessions

In [13]:
# Create a simulated dataset
np.random.seed(42)  # For reproducible results

# Generate random session dates
session_dates = np.random.choice(
    [160802, 160803, 160804, 160805], size=100, replace=True
)

# Generate random "trial_not_started" flags and "trial_start_wait_time" values
trial_not_started = np.random.choice(
    [0, 1], size=100, p=[0.8, 0.2]
)  # 80% trials started, 20% not
trial_start_wait_time = np.where(
    trial_not_started == 1, np.random.randint(1, 60, size=100), 0
)

# Assemble the DataFrame
s = pd.DataFrame(
    {
        "session_date": session_dates,
        "trial_not_started": trial_not_started,
        "trial_start_wait_time": trial_start_wait_time,
    }
)

# Sort by session_date to mimic the user's dataset structure
s.sort_values("session_date", inplace=True)

s.head()

Unnamed: 0,session_date,trial_not_started,trial_start_wait_time
99,160802,0,0
38,160802,1,26
31,160802,0,0
30,160802,1,6
57,160802,0,0


In [14]:
def calculate_n_prev_trial_not_started(df_group):
    df_group.trial_not_started = df_group.trial_not_started.astype(bool)
    # Step 1: Calculate the cumulative sum of 'trial_not_started', resetting when a trial is started
    df_group["trial_not_started_cumsum"] = df_group[
        "trial_not_started"
    ].cumsum() - df_group["trial_not_started"].cumsum().where(
        ~df_group["trial_not_started"]
    ).ffill().fillna(
        0
    )

    # Step 2: Shift the cumulative values down to align them with the next trial
    df_group["n_prev_trial_not_started"] = df_group["trial_not_started_cumsum"].shift(
        fill_value=0
    )

    # Step 3: Remove the trials where 'trial_not_started' is True
    # df_group_filtered = df_group[df_group["trial_not_started"] == False].copy() # This is implemented in the code, but want to visualize everything in dev
    df_group_filtered = df_group.copy()

    # Drop the temporary cumulative sum column
    df_group_filtered.drop(["trial_not_started_cumsum"], axis=1, inplace=True)

    return df_group_filtered


# Group by 'session_date' and apply the function
df_grouped = (
    s.groupby("session_date")
    .apply(calculate_n_prev_trial_not_started)
    .reset_index(drop=True)
)

# Show the result

df_grouped.head(25)

Unnamed: 0,session_date,trial_not_started,trial_start_wait_time,n_prev_trial_not_started
0,160802,False,0,0.0
1,160802,True,26,0.0
2,160802,False,0,1.0
3,160802,True,6,0.0
4,160802,False,0,1.0
5,160802,False,0,0.0
6,160802,True,7,0.0
7,160802,True,16,1.0
8,160802,False,0,2.0
9,160802,False,0,0.0


Looks good- will work this into the DatasetCleaner class.



###  tasks:


[X] confirm dB nan if stimuli off


[X] map correct side were correct side is Left=0 and Right=1

[X] map of choice where choice made **by** the rat, where Left=0 and Right=1 (and Violation=NaN)

[X] make a session counter by date

[X] compute delay len (subtract go or not?)- don't need to do this

[X] drop timeouts but maintain dur (see code above)

[X] make trial counter

In [7]:
ndf = pd.read_csv("/Volumes/brody/jbreda/PWM_data_scrape/W078_trials_data.csv")

ndf.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 108475 entries, 0 to 108474
Data columns (total 29 columns):
 #   Column               Non-Null Count   Dtype  
---  ------               --------------   -----  
 0   rat_name             108475 non-null  object 
 1   session_date         108475 non-null  int64  
 2   session_counter      108475 non-null  int64  
 3   rig_id               108475 non-null  int64  
 4   training_stage       108475 non-null  int64  
 5   A1_dB                102034 non-null  float64
 6   A2_dB                102034 non-null  float64
 7   hit_history          91273 non-null   float64
 8   violation_history    108475 non-null  int64  
 9   timeout_history      108475 non-null  int64  
 10  A1_sigma             108475 non-null  float64
 11  Rule                 108475 non-null  object 
 12  ThisTrial            108475 non-null  object 
 13  violation_iti        108475 non-null  int64  
 14  error_iti            108475 non-null  int64  
 15  secondhit_delay  

In [8]:
rdf = pd.read_csv(
    "/Users/jessbreda/Desktop/github/animal-learning/data/raw/rat_behavior.csv"
)

In [9]:
rdf.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2540006 entries, 0 to 2540005
Data columns (total 10 columns):
 #   Column          Dtype  
---  ------          -----  
 0   subject_id      object 
 1   session         int64  
 2   trial           int64  
 3   s_a             float64
 4   s_b             float64
 5   choice          float64
 6   correct_side    int64  
 7   hit             float64
 8   delay           float64
 9   training_stage  int64  
dtypes: float64(5), int64(4), object(1)
memory usage: 193.8+ MB


In [11]:
df = get_rat_viol_data(animal_ids="W078")

df.info()

returning truncated viol data for W078
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 53299 entries, 0 to 53298
Data columns (total 13 columns):
 #   Column              Non-Null Count  Dtype  
---  ------              --------------  -----  
 0   animal_id           53299 non-null  object 
 1   session             53299 non-null  int64  
 2   trial               53299 non-null  int64  
 3   s_a                 40115 non-null  float64
 4   s_b                 40115 non-null  float64
 5   choice              43481 non-null  float64
 6   correct_side        53299 non-null  int64  
 7   hit                 43481 non-null  float64
 8   delay               53299 non-null  float64
 9   training_stage      53299 non-null  int64  
 10  violation           53299 non-null  bool   
 11  n_trial             53299 non-null  int64  
 12  training_stage_cat  53299 non-null  int64  
dtypes: bool(1), float64(5), int64(6), object(1)
memory usage: 4.9+ MB


**What is expexted by design matrix generator**

`animal_id` : str <-- `rat_name`

`session`:  int  <-- `session_counter`

`trial` : int <-- MAKE

`s_a`: float 64 <-- `A1_dB`

`s_b`: float 64 <-- `A2_dB`

`choice`<- animals choice 0 (error), 1 (hit), Nan (violation) <- switching this to 0,1,2

`correct_side`<- 0 (left), 1 (right)

In [17]:
renaming_map = {
    "rat_name": "animal_id",
    "session_date": "session_date",
    "session_counter": "session_file_counter",  # can have multiple in the same date
    "A1_dB": "s_a",
    "A2_dB": "s_b",
    "hit_history": "hit",
    "violation_history": "violation",
    "timeout_history": "trial_not_started",
    "A1_sigma": "s_a_sigma",  # not sure why this is here, dB mapping done in matlab
    "Rule": "rule",
    "ThisTrial": "correct_side",  # eventually: 1 = right, 0 = left
    "violation_iti": "violation_penalty_time",
    "error_iti": "error_penalty_time",
    "secondhit_delay": "delayed_reward_time",  # related to stg 3, doesn't tell you if animal used it though
    "PreStim_time": "pre_stim_time",
    "A1_time": "s_a_time",
    "Del_time": "delay_time",
    "A2_time": "s_b_time",
    "time_bet_aud2_gocue": "post_s_b_to_go_cue_time",
    "time_go_cue": "go_cue_time",
    "CP_duration": "fixation_time",
    "CenterLed_duration": "trial_start_wait_time",  # how much time elapses w/o activity until available trial is a "timeout"
    "Left_volume": "l_water_vol",
    "Right_volume": "r_water_vol",
    "Beta": "antibias_beta",  # higher = stronger antibias
    "RtProb": "antibias_right_prob",  # higher = more likely for a right trial to occur
    "psych_pairs": "using_psychometric_pairs",
}

In [157]:
class DatasetCleaner:
    def __init__(
        self,
        animal_id,
        column_rename_map,
        load_path="/Volumes/brody/jbreda/PWM_data_scrape/",
        save_out=False,
        save_path="../data/cleaned/by_animal/",
    ):
        self.animal_id = animal_id
        self.column_remap = column_rename_map
        self.load_path = load_path
        self.save_out = save_out
        self.save_path = save_path

    def run(self):
        self.raw_df = self.load_animal_df()
        self.rename_columns()
        self.map_correct_side_and_choice()
        self.make_session_column()
        self.drop_and_account_for_trial_non_starts()

        self.cleaned_df = self.raw_df.copy()
        return self.cleaned_df

    def load_animal_df(self):
        return pd.read_csv(self.load_path + f"{self.animal_id}_trials_data.csv")

    def rename_columns(self):
        if not hasattr(self, "raw_df"):
            self.raw_df = self.load_animal_df()

        self.raw_df.rename(columns=renaming_map, inplace=True)

        return None

    def map_correct_side_and_choice(self):
        self.raw_df["correct_side"] = self.raw_df.correct_side.map(
            {"RIGHT": 1, "LEFT": 0}
        )

        self.raw_df["choice"] = self.raw_df.apply(self.determine_animal_choice, axis=1)

        return None

    def make_session_column(self):
        # convert to date object first
        self.raw_df["session_date"] = pd.to_datetime(
            self.raw_df.session_date, format="%y%m%d"
        )

        # defining a session as all trials from a single day
        self.raw_df["session"] = (
            self.raw_df["session_date"].rank(method="dense").astype(int)
        )

        return None

    def drop_and_account_for_trial_non_starts(self):
        self.raw_df = (
            self.raw_df.groupby("session")
            .apply(self.calc_n_prev_trial_not_started)
            .reset_index(drop=True)
        )

        return None

    def add_trial_column(self):
        self.raw_df = (
            self.raw_df.groupby("session")
            .apply(self.calc_trial_counts)
            .reset_index(drop=True)
        )

        return None

    @staticmethod
    def determine_animal_choice(row):
        if row.hit == 0:
            return 0
        elif row.hit == 1:
            return 1
        elif row.violation == 1:
            return 2
        elif row.trial_not_started == 1:
            return 3
        else:  # "timeout"
            return -1

    @staticmethod
    def calc_n_prev_trial_not_started(session_group):
        # Convert dtype
        session_group.trial_not_started = session_group.trial_not_started.astype(bool)

        # Calculate the cumulative sum of 'trial_not_started',
        # resetting when a trial is started to only count consecutive non-starts
        session_group["trial_not_started_cumsum"] = session_group[
            "trial_not_started"
        ].cumsum() - session_group["trial_not_started"].cumsum().where(
            ~session_group["trial_not_started"]
        ).ffill().fillna(
            0
        )

        # Shift the cumulative values down to align with the next trial
        # in order to create  "prev history" variable
        session_group["n_prev_trial_not_started"] = session_group[
            "trial_not_started_cumsum"
        ].shift(fill_value=0)

        # Remove the trials where 'trial_not_started', so only the
        # history of them remains on valid trials
        filtered_df = session_group.query("trial_not_started != True").copy()
        # filtered_df = session_group.copy()

        # Drop the temporary cumulative sum column
        filtered_df.drop(["trial_not_started_cumsum"], axis=1, inplace=True)

        return filtered_df

    @staticmethod
    def calc_trial_counts(session_group):
        session_group["trial"] = np.arange(1, len(session_group) + 1)

        return session_group

In [158]:
dc = DatasetCleaner(
    "W078",
    column_rename_map=renaming_map,
    load_path="/Users/jessbreda/Desktop/github/animal-learning/data/raw/",
)
dc.run()

Unnamed: 0,animal_id,session_date,session_file_counter,rig_id,training_stage,s_a,s_b,hit,violation,trial_not_started,...,fixation_time,trial_start_wait_time,l_water_vol,r_water_vol,antibias_beta,antibias_right_prob,psych_pairs,choice,session,n_prev_trial_not_started
0,W078,2015-07-16,1,3,1,,,1.0,0,False,...,0.01,200,18.000000,18.000000,0,0.000000,0,1,1,0.0
1,W078,2015-07-16,1,3,1,,,0.0,0,False,...,0.00,200,18.000001,18.000001,0,0.000000,0,0,1,0.0
2,W078,2015-07-16,1,3,1,,,1.0,0,False,...,0.00,200,18.000004,18.000004,0,0.000000,0,1,1,0.0
3,W078,2015-07-16,1,3,1,,,0.0,0,False,...,0.00,200,18.000014,18.000014,0,0.000000,0,0,1,0.0
4,W078,2015-07-16,1,3,1,,,1.0,0,False,...,0.00,200,18.000033,18.000033,0,0.000000,0,1,1,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
105735,W078,2016-08-02,366,9,4,60.0,68.0,0.0,0,False,...,4.40,200,29.134359,29.134359,3,0.557581,0,0,366,0.0
105736,W078,2016-08-02,366,9,4,76.0,84.0,,1,False,...,4.40,200,29.167101,29.167101,3,0.534723,0,2,366,0.0
105737,W078,2016-08-02,366,9,4,76.0,68.0,1.0,0,False,...,4.40,200,29.199718,29.199718,3,0.534723,0,1,366,0.0
105738,W078,2016-08-02,366,9,4,92.0,84.0,1.0,0,False,...,3.40,200,29.232207,29.232207,3,0.530272,0,1,366,0.0


In [161]:
dc.cleaned_df.hit.value_counts()

hit
1.0    75160
0.0    16113
Name: count, dtype: int64

In [162]:
dc = DatasetCleaner(
    "W075",
    column_rename_map=renaming_map, 
    load_path="/Users/jessbreda/Desktop/github/animal-learning/data/raw/",
)
dc.run()

Unnamed: 0,animal_id,session_date,session_file_counter,rig_id,training_stage,s_a,s_b,hit,violation,trial_not_started,...,fixation_time,trial_start_wait_time,l_water_vol,r_water_vol,antibias_beta,antibias_right_prob,psych_pairs,choice,session,n_prev_trial_not_started
0,W075,2015-07-16,1,18,1,,,0.0,0,False,...,0.01,200,18.000000,18.000000,0,0.000000,0,0,1,0.0
1,W075,2015-07-16,1,18,1,,,1.0,0,False,...,0.00,200,18.000001,18.000001,0,0.000000,0,1,1,0.0
2,W075,2015-07-16,1,18,1,,,1.0,0,False,...,0.00,200,18.000004,18.000004,0,0.000000,0,1,1,0.0
3,W075,2015-07-16,1,18,1,,,1.0,0,False,...,0.00,200,18.000014,18.000014,0,0.000000,0,1,1,0.0
4,W075,2015-07-16,1,18,1,,,1.0,0,False,...,0.00,200,18.000033,18.000033,0,0.000000,0,1,1,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
276174,W075,2017-08-06,678,18,4,84.0,92.0,1.0,0,False,...,5.40,200,33.650125,33.650125,3,0.411211,1,1,678,0.0
276175,W075,2017-08-06,678,18,4,76.0,68.0,1.0,0,False,...,3.40,200,33.663689,33.663689,3,0.415209,1,1,678,0.0
276176,W075,2017-08-06,678,18,4,76.0,68.0,1.0,0,False,...,3.40,200,33.677196,33.677196,3,0.414947,1,1,678,0.0
276177,W075,2017-08-06,678,18,4,84.0,92.0,1.0,0,False,...,5.40,200,33.690646,33.690646,3,0.414695,1,1,678,0.0


In [139]:
dc.combined_data.plot(x="trial", )

n_prev_trial_not_started
0.0     1178
1.0      599
2.0      348
3.0      223
4.0      135
5.0       82
6.0       50
7.0       35
8.0       26
9.0       18
10.0      12
11.0       7
12.0       4
14.0       3
13.0       3
15.0       1
16.0       1
17.0       1
18.0       1
19.0       1
20.0       1
21.0       1
22.0       1
23.0       1
24.0       1
25.0       1
26.0       1
Name: count, dtype: int64