In [20]:
# widen jupyter notebook window
from IPython.display import display, HTML
display(HTML("<style>.container {width:95% !important; }</style>"))

# check environment
import os
print(f'Conda Environment: ' + os.environ['CONDA_DEFAULT_ENV'])

from platform import python_version
print(f'python version: {python_version()}')

Conda Environment: lickTask2
python version: 3.9.10


In [21]:
"""
model 2ABT data with Q-learning algorithms.
"""

# descriptor 'bin_date': describes which binned set of days to compare to 'before' days.
# modelID: 'standard' (contains both alpha/zeta) or 'reduced' (combines alpha/zeta into one parameter).
params = {
    'paths': {
        'dir_data': 'path/to/2ABT_dataframe',
        'filename': '2ABT_dataframe.csv',
    },
    'hyper_params': {
        'split': 4,
        'train_prop': 0.75,
        'lr': 0.1,
        'n_iter': 10000,
        'seed': 42,
    },
    'descriptors': {
        'geno': 'genotype',
        'bin_date': 'day4-7',
        'modelID': 'standard',
    },
     'params_init_dict': {
            'beta':   0.5, 
            'bias_l': 0.0,
            'zeta':  0.75,
            'alpha': 0.25,
     }
}

In [22]:

### batch_run stuff
from pathlib import Path

import sys
dir_save = '/path/to/save/results'
dir_save = Path(dir_save)


## standard libraries
import copy
import numpy as np
import matplotlib.pyplot as plt
import torch
import random
from tqdm import tqdm
import pandas as pd
from datetime import datetime
from sklearn.model_selection import train_test_split

## get date
dateToday = datetime.today().strftime('%Y%m%d')

from Q_learning_2ABT import qLearning_models as qLearning, helpers

# unpack descriptors and hyperparameters
geno = params['descriptors']['geno']
bin_date = params['descriptors']['bin_date']
modelID = params['descriptors']['modelID']
lr = params['hyper_params']['lr']
n_iter = params['hyper_params']['n_iter']
seed = params['hyper_params']['seed']
train_prop = params['hyper_params']['train_prop']
split = params['hyper_params']['split']

# set global seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)

# put all the info in saved file names
saveInfoString = ('_' + geno +'_Fit_' + modelID + '_lr' + str(lr) + '_iter' + str(n_iter) + '_seed' + str(seed) + '_train_' + str(train_prop) + '_split_' + str(split))

In [23]:
## load data
data = pd.read_csv(os.path.join(params['paths']['dir_data'], params['paths']['filename']),index_col=False)


################################################################################################################
## split each session into split = params['hyper_params']['split'] consecutive parts and label them uniquely
data1 = data.groupby('Session')
tmp = data1.apply(lambda x: np.array_split(x.index,split))
data['SessionSplit'] = np.nan
for i in tmp.index:
    for j in range(split):
        data.loc[tmp[i][j],'SessionSplit'] = i + '_' + str(j)

# save actual sessions as TrueSession, switch SessionSplit into Session:
data['TrueSession'] = data['Session']
data['Session'] = data['SessionSplit']

In [24]:
 
# initialize results / logger dictionaries.
results = {'mouse': [], 'epoch': [], 'beta': [], 'bias_l': [], 'zeta': [], 'alpha': [], 'nll': [], 'n_iteration': [], 'accuracy': [],'greedyAcc':[], 'n': [],
          'nll_test': [], 'accuracy_test': [], 'greedyAcc_test': [], 'n_test': []}

masterLogger = pd.DataFrame({'action_l': [], 'prob_l': [],'reward': [], 'Q_l': [], 'Q_r': [], 'session': [], 'mouse': [], 'epoch': []})
masterTestLogger = pd.DataFrame({'action_l': [], 'prob_l': [],'reward': [], 'Q_l': [], 'Q_r': [], 'session': [], 'mouse': [], 'epoch': []})

for mouse in data['Mouse'].unique():
    print(mouse)
    for epoch in ['before', params['descriptors']['bin_date']]:        
        print(epoch)
    
        data_mouse_epoch = data.loc[(data['Mouse'] == mouse) & (data['binDate']==epoch)]

        truSesh = data_mouse_epoch.TrueSession.unique()
        stratifyClass = np.repeat(truSesh,np.repeat(split,len(truSesh))) # stratify parameter of train_test_split assigns approx. same proportion of each class to train and test.

        # randomly select train and test sessions for each condition, and produce training dataframe
        if train_prop < 1:
            if split > 3: #enforce stratifying sessions into 4 or more
                train_session_ids, test_session_ids = train_test_split(data_mouse_epoch.Session.unique(), train_size=train_prop, shuffle = True, random_state=seed, stratify=stratifyClass)
            else:
                train_session_ids, test_session_ids = train_test_split(data_mouse_epoch.Session.unique(), train_size=train_prop, shuffle = True, random_state=seed)
            test = True
        else:
            train_session_ids = data_mouse_epoch['Session'].unique()
            test_session_ids = np.empty(0, dtype=object)
            test = False
        
        data_analyze = data_mouse_epoch.loc[data_mouse_epoch['Session'].isin(train_session_ids)]

        
        # create list of dictionaries containing Decision and Reward for each session
        data_sessions = [{'Decision': data_analyze[data_analyze['Session']==u]['Decision'].to_numpy(),
                          'Reward': data_analyze[data_analyze['Session']==u]['Reward'].to_numpy(),
                         'blockPosition': data_analyze[data_analyze['Session']==u]['blockTrial'].to_numpy(),
                         'Target': data_analyze[data_analyze['Session']==u]['Target'].to_numpy(),
                         'DAB_I_HighProbSel': data_analyze[data_analyze['Session']==u]['DAB_I_HighProbSel'].to_numpy(),
                         'Switch': data_analyze[data_analyze['Session']==u]['Switch'].to_numpy(),
                         'DAB_I_flipLR_event': data_analyze[data_analyze['Session']==u]['DAB_I_flipLR_event'].to_numpy(),
                         } for u in data_analyze['Session'].unique()]

        # separate out lists of decisions and rewards
        decisions_emp = list([torch.as_tensor(d['Decision'], dtype=torch.float32) for d in data_sessions])
        rewards_emp = list([torch.as_tensor(d['Reward'], dtype=torch.float32) for d in data_sessions])
        blockPositions = list([torch.as_tensor(d['blockPosition'], dtype=torch.float32) for d in data_sessions])
        target = list([torch.as_tensor(d['Target'], dtype=torch.float32) for d in data_sessions])
        DAB_I_HighProbSel = list([torch.as_tensor(d['DAB_I_HighProbSel'], dtype=torch.float32) for d in data_sessions])
        switch = list([torch.as_tensor(d['Switch'], dtype=torch.float32) for d in data_sessions])
        DAB_I_flipLR_event = list([torch.as_tensor(d['DAB_I_flipLR_event'], dtype=torch.float32) for d in data_sessions])                                            

        
        modelParams = torch.as_tensor([t for t in params['params_init_dict'].values()])
        modelParams.requires_grad_(True)

        optimizer = torch.optim.SGD(params=[modelParams], lr=params['hyper_params']['lr'])
        fn_loss = torch.nn.NLLLoss()

        # convergence_checker = qLearning.Convergence_checker(tol_convergence=thresh, window_convergence=100)

        loss_rolling = []

        #Note: modelID must take values: 'standard', 'reduced'
        for i_epoch in tqdm(range(params['hyper_params']['n_iter'])):
            logger, loss_rolling = qLearning.epoch_step_batch(optimizer, fn_loss, loss_rolling, modelParams, decisions_emp, rewards_emp, blockPositions,
                                                              target, DAB_I_HighProbSel, switch, DAB_I_flipLR_event, modelID)
            # diff_window_convergence, loss_smooth, converged = convergence_checker(loss_rolling)
#             if converged:
#                 print('converged')
#                 break   
#             if i_epoch == (n_epoch-1):
#                 print('did not converge')
#                 break
                
        
        # convert logger into a df to be saved at the end.
        mouseLogger = pd.concat([helpers.logger_to_df(helpers.append_dict(d, 'session', [ii]*len(d['prob_l']))) for ii,d in enumerate(logger)], axis=0)
        mouseLogger['mouse'] = mouse
        mouseLogger['epoch'] = epoch
        masterLogger = pd.concat([masterLogger, mouseLogger], axis=0)
        
        params_detached = np.array([p.detach().cpu() for p in modelParams])
        param_dict = dict(zip(params['params_init_dict'].keys(), np.array(params_detached)))

        
        #find policy accuracy
        probs = np.stack([1-mouseLogger['prob_l'].to_numpy(), mouseLogger['prob_l'].to_numpy()], axis = 1) #2 column matrix of prob_l and 1-prob_l
        decisions = np.concatenate([d.numpy().astype(int) for d in decisions_emp]) #convert decisions_emp list of tensors into one vector of ints.
        decisions_oneHot = helpers.idx_to_oneHot(decisions)  #convert decisions to one-hot

        confMatrix = helpers.confusion_matrix(probs, decisions_oneHot)
        numDecisions = np.sum(decisions_oneHot, axis=0)
        weightedAccuracy = (confMatrix.diagonal() * numDecisions).sum()/numDecisions.sum()
        
        #greedy policy:
        probs_greedy = (probs > 0.5).astype(int)
        confMatrix_greedy = helpers.confusion_matrix(probs_greedy, decisions_oneHot)        
        weightedAccuracy_greedy = (confMatrix_greedy.diagonal() * numDecisions).sum()/numDecisions.sum()


        results['beta'].append(param_dict['beta'])
        results['bias_l'].append(param_dict['bias_l'])
        results['zeta'].append(param_dict['zeta'])
        results['alpha'].append(param_dict['alpha'])
        results['mouse'].append(mouse)
        results['epoch'].append(epoch)
        results['nll'].append(loss_rolling[-1])
        results['n_iteration'].append(i_epoch)
        results['accuracy'].append(weightedAccuracy)
        results['greedyAcc'].append(weightedAccuracy_greedy)
        results['n'].append(np.sum([len(d) for d in decisions_emp]))
        
        
       ########################################################### 
        # Now fit to Test data
        if test:
            data_analyze = data_mouse_epoch.loc[data_mouse_epoch['Session'].isin(test_session_ids)]

            # create list of dictionaries containing Decision and Reward for each session
            data_sessions = [{'Decision': data_analyze[data_analyze['Session']==u]['Decision'].to_numpy(),
                          'Reward': data_analyze[data_analyze['Session']==u]['Reward'].to_numpy(),
                         'blockPosition': data_analyze[data_analyze['Session']==u]['blockTrial'].to_numpy(),
                         'Target': data_analyze[data_analyze['Session']==u]['Target'].to_numpy(),
                         'DAB_I_HighProbSel': data_analyze[data_analyze['Session']==u]['DAB_I_HighProbSel'].to_numpy(),
                         'Switch': data_analyze[data_analyze['Session']==u]['Switch'].to_numpy(),
                         'DAB_I_flipLR_event': data_analyze[data_analyze['Session']==u]['DAB_I_flipLR_event'].to_numpy(),
                         } for u in data_analyze['Session'].unique()]

            testLogger = [qLearning.run_session(params=torch.as_tensor(params_detached), 
                                           mode_generative=False, 
                                           decisions_emp = torch.as_tensor(d['Decision'], dtype=torch.float32),
                                           rewards_emp = torch.as_tensor(d['Reward'], dtype=torch.float32),
                                            blockPosition = torch.as_tensor(d['blockPosition'], dtype=torch.float32),
                                            target = torch.as_tensor(d['Target'], dtype=torch.float32),
                                            DAB_I_HighProbSel = torch.as_tensor(d['DAB_I_HighProbSel'], dtype=torch.float32),
                                            switch = torch.as_tensor(d['Switch'], dtype=torch.float32),
                                            DAB_I_flipLR_event = torch.as_tensor(d['DAB_I_flipLR_event'], dtype=torch.float32),                                            
                                            modelID = modelID,
                                           ) for d in data_sessions]

            mouseTestLogger = pd.concat([helpers.logger_to_df(helpers.append_dict(d, 'session', [ii]*len(d['prob_l']))) for ii,d in enumerate(testLogger)], axis=0)
            mouseTestLogger['mouse'] = mouse
            mouseTestLogger['epoch'] = epoch
            masterTestLogger = pd.concat([masterTestLogger, mouseTestLogger], axis=0)

            #find policy accuracy
            probs = np.stack([1-mouseTestLogger['prob_l'].to_numpy(), mouseTestLogger['prob_l'].to_numpy()], axis = 1) #2 column matrix of prob_l and 1-prob_l

            decisions_emp = list([torch.as_tensor(d['Decision'], dtype=torch.float32) for d in data_sessions])# separate out lists of decisions and rewards
            decisions = np.concatenate([d.numpy().astype(int) for d in decisions_emp]) #convert decisions_emp list of tensors into one vector of ints.
            decisions_oneHot = helpers.idx_to_oneHot(decisions)  #convert decisions to one-hot

            confMatrix = helpers.confusion_matrix(probs, decisions_oneHot)
            numDecisions = np.sum(decisions_oneHot, axis=0)
            weightedAccuracy = (confMatrix.diagonal() * numDecisions).sum()/numDecisions.sum()

            #greedy policy:
            probs_greedy = (probs > 0.5).astype(int)
            confMatrix_greedy = helpers.confusion_matrix(probs_greedy, decisions_oneHot)        
            weightedAccuracy_greedy = (confMatrix_greedy.diagonal() * numDecisions).sum()/numDecisions.sum()

            #nll
            nll = -sum(np.log(probs)[decisions_oneHot])/len(decisions)

            results['nll_test'].append(nll)
            results['accuracy_test'].append(weightedAccuracy)
            results['greedyAcc_test'].append(weightedAccuracy_greedy)
            results['n_test'].append(len(decisions))

if not test:        
    del results['nll_test']
    del results['accuracy_test']
    del results['greedyAcc_test']
    del results['n_test']
    
results = pd.DataFrame(results)

results.to_csv(str(Path(dir_save) / (dateToday + saveInfoString + '_batch.csv')), index = False)  
masterLogger.to_csv(str(Path(dir_save) / (dateToday + saveInfoString + '_batch_logger_train.csv')), index = False)
masterTestLogger.to_csv(str(Path(dir_save) / (dateToday + saveInfoString + '_batch_logger_test.csv')), index = False)

KMR255
before


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [22:10<00:00,  7.52it/s]


day4-7


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [22:11<00:00,  7.51it/s]


KMR268
before


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [21:00<00:00,  7.93it/s]


day4-7


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [41:12<00:00,  4.04it/s]


KMR269
before


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [18:20<00:00,  9.08it/s]


day4-7


100%|███████████████████████████████████████████████████████████████████████████| 10000/10000 [2:22:21<00:00,  1.17it/s]


KMR274
before


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [21:56<00:00,  7.59it/s]


day4-7


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [23:44<00:00,  7.02it/s]


KMR277
before


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [19:25<00:00,  8.58it/s]


day4-7


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [26:21<00:00,  6.32it/s]


KMR276
before


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [23:55<00:00,  6.97it/s]


day4-7


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [23:27<00:00,  7.10it/s]


KMR279
before


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [18:03<00:00,  9.23it/s]


day4-7


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [22:59<00:00,  7.25it/s]


KMR287
before


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [15:32<00:00, 10.72it/s]


day4-7


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [16:44<00:00,  9.95it/s]


KMR288
before


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [20:28<00:00,  8.14it/s]


day4-7


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [20:56<00:00,  7.96it/s]


KMR284
before


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [20:28<00:00,  8.14it/s]


day4-7


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [24:25<00:00,  6.82it/s]


KMR286
before


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [17:24<00:00,  9.57it/s]


day4-7


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [21:52<00:00,  7.62it/s]


KMR292
before


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [18:03<00:00,  9.23it/s]


day4-7


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [21:26<00:00,  7.77it/s]


KMR298
before


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [22:56<00:00,  7.27it/s]


day4-7


100%|█████████████████████████████████████████████████████████████████████████████| 10000/10000 [25:38<00:00,  6.50it/s]
