In [None]:
import time, re, datetime, os, glob
from datetime import timedelta
import seaborn as sns
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
import pandas as pd
from IPython import embed as shell

## INITIALIZE A FEW THINGS
sns.set_style("darkgrid", {'xtick.bottom': True, 'ytick.left': True, 'lines.markeredgewidth':0 } )
sns.set_style("ticks")
sns.set_context(context="poster")

## CONNECT TO DJ
import datajoint as dj
from ibl_pipeline import reference, subject, action, acquisition, data, behavior
from ibl_pipeline.analyses.load_mouse_data_datajoint import get_water, get_weights, get_water_weight, get_behavior

# get all the mice from CSHL (exclude those with undefined sex, i.e. example animals)
subjects = pd.DataFrame.from_dict((subject.Subject() & 'responsible_user = "valeria"' & 'sex != "U"').fetch(
    order_by='subject_nickname', as_dict=True))
print(subjects['subject_nickname'].unique())

In [None]:
for mousename in subjects['subject_nickname']:
    
    try:
        print(mousename)
        weiwa, bl   = get_water_weight(mousename)
        behav       = get_behavior(mousename)
        trialcounts = behav.groupby(['date'])['choice'].count().reset_index()

        # set the same format for both date axes
        weiwa['date'] = weiwa['date'].astype('datetime64[ns]')
        trialcounts['date'] = trialcounts['date'].astype('datetime64[ns]')

        # MERGE THE TABLES TOGETHER BY DATE
        data = weiwa.merge(trialcounts, on='date', how='outer')

        # RENAME AND CLEAN UP
        data['trialcount'] = data.choice
        data.drop(columns=['choice', 'days'], inplace=True)
        df = data.drop_duplicates(subset=['date'], keep='last')
        df.sort_values('date', inplace=True)
        
        # rename to abbreviation
        df['water_type'] = df['water_type'].str.replace('Citric Acid','CA')
            
        # COMPUTE, FOR EACH MONDAY AND FRIDAY, THE CORRESPONDING WEEKEND WATER REGIME
        df.index = df['date']
        df2 = df.reindex(pd.date_range(df['date'].min(), df['date'].max()))
        df2['date'] = df2.index
        df2.reset_index(inplace=True)
        df2['dayofweek'] = df2['date'].dt.day_name()
        df2.drop(columns='index', inplace=True)
        
        # BEFORE WE STARTED LOGGING WATER TYPES, EVERYTHING WAS WATER!
        df2['water_type'] = df2.water_type.fillna('Water')

        # NOW LIST WEEKEND WATER
        df2['previousWater'] = df2['water_type'].shift(1)
        df2['nextWater']     = df2['water_type'].shift(-1)
        df2['weekend_water'] = None
        df2.loc[df2.dayofweek == 'Friday', 'weekend_water'] = df2.loc[df2.dayofweek == 'Friday', 'nextWater']
        df2.loc[df2.dayofweek == 'Monday', 'weekend_water'] = df2.loc[df2.dayofweek == 'Monday', 'previousWater']

        ## RECODE INTO MONDAY VS FRIDAY
        Friday = df2.loc[df2['dayofweek'] == 'Friday', :].reset_index()
        Monday = df2.loc[df2['dayofweek'] == 'Monday', :].reset_index()

        # DIFFERENCE BETWEEN MONDAY AND FRIDAY
        summarydat = Friday[['weekend_water', 'water_type']]
        summarydat['weight'] = Monday['weight'] - Friday['weight']
        summarydat['water_administered'] = Monday['water_administered'] - Friday['water_administered']
        summarydat['trialcount'] = Monday['trialcount'] - Friday['trialcount']
        summarydat['subject'] = mousename

        # append all sessions into one dataFrame
        if not 'alldat' in locals():
            alldat = summarydat.copy()
        else:
            alldat = alldat.append(summarydat, sort=False, ignore_index=True)
    except:
        pass


In [None]:
## NOW PLOT WITH SEABORN
sns.set_style("ticks")
sns.set_context(context="talk")

f, ax = plt.subplots(1,2, sharex=True, figsize=(15, 5), constrained_layout=True)
sns.catplot(x="weekend_water", y="weight", kind="strip",  data=alldat, ax=ax[0], zorder=1)
sns.pointplot(x="weekend_water", y="weight", color="k", join=False, data=alldat, legend=False, ax=ax[0], zorder=2) # errorbar
ax[0].set(ylabel="$\Delta$ weight (g), Monday-Friday", xlabel="Weekend water regime", ylim=[-2,2])
ax[0].axhline(y=0, color=".15", zorder=0)

sns.catplot(x="weekend_water", y="trialcount", kind="strip", data=alldat, ax=ax[1], zorder=1)
sns.pointplot(x="weekend_water", y="trialcount", color="k", join=False, data=alldat, ax=ax[1], zorder=2) # errorbar
ax[1].set(ylabel="$\Delta$ trial numbers, Monday-Friday", xlabel="Weekend water regime", ylim=[-200, 200])
ax[1].axhline(y=0, color=".15", zorder=0)

# sns.catplot(x="weekend_water", y="water_administered", kind="strip", data=alldat, ax=ax[2], zorder=1)
# sns.pointplot(x="weekend_water", y="water_administered", color="k", join=False, data=alldat, ax=ax[2], zorder=2) # errorbar
# ax[2].set(ylabel="$\Delta$ water earned in task (mL), Monday-Friday", xlabel="Weekend water regime", ylim=[-1, 1])
# ax[2].axhline(y=0, color=".15", zorder=0)

sns.despine(fig=f, offset=2, trim=True)
f.savefig('citricAcid_trialCounts_DJ.pdf')
f.savefig('citricAcid_trialCounts_DJ.png')