In [1]:
import numpy as np
import pandas as pd

In [17]:
class DataLoader:
    def __init__(self, animal_ids, data_type="new_all"):
        self.animal_ids = animal_ids
        self.data_type = data_type
        self.determine_load_function()

    def determine_load_function(self):
        """
        Given the data type requested by the user, determine the
        appropriate function to load the data.

        Options:
            - new_all: All data from the new dataset
            - new_viols_truncated: New dataset truncated to session 200 to
                closely resemble the old dataset where violations stopped
                being tracked at session 200
            - old_viols: the old, public dataset until sessions stopped being
                tracked at session 200
        Returns:
            function: The function to load the data

        """
        if self.data_type == "new_all":
            self.load_function = self.load_new_all
        elif self.data_type == "new_match_old_viols":
            self.load_function = self.load_new_match_old_viols
        elif self.data_type == "old_viols":
            self.load_function = self.load_old_viols
        else:
            raise ValueError("Invalid data type requested!")

        return self.load_function

    def load_data(self):
        print("Loading data for animal ids: ", self.animal_ids)
        return self.load_function()

    def load_new_all(self):
        data = pd.read_csv("../data/cleaned/all_animals_cleaned.csv")
        data = data[data["animal_id"].isin(self.animal_ids)]
        return data

    def load_new_match_old_viols(self):
        data = self.load_new_all()
        data = data.query("session_relative_to_old < 200")
        return data

    def load_old_viols(self):
        data = pd.read_csv("../data/cleaned/old_dataset/old_violation_data.csv")
        data = data[data["animal_id"].isin(self.animal_ids)]
        return data

In [18]:
animal_ids = [
    "W078",
]

In [20]:
dl = DataLoader(animal_ids, data_type="old_viols")
dl.load_data()

Loading data for animal ids:  ['W078']


Unnamed: 0,animal_id,session,trial,s_a,s_b,choice,correct_side,hit,delay,training_stage,violation,n_trial,training_stage_cat
511554,W078,1,1,,,1.0,0,0.0,1.0,1,False,1511189,1
511555,W078,1,2,,,0.0,0,1.0,1.0,1,False,1511190,1
511556,W078,1,3,,,0.0,0,1.0,1.0,1,False,1511191,1
511557,W078,1,4,,,0.0,0,1.0,1.0,1,False,1511192,1
511558,W078,1,5,,,1.0,1,1.0,1.0,1,False,1511193,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...
564848,W078,200,269,92.0,84.0,0.0,1,0.0,3.1,4,False,1564483,4
564849,W078,200,270,68.0,76.0,0.0,0,1.0,2.1,4,False,1564484,4
564850,W078,200,271,68.0,60.0,,1,,3.1,4,True,1564485,4
564851,W078,200,272,60.0,68.0,,0,,2.1,4,True,1564486,4
