# Collecting performances using reward differences only

In [1]:
import numpy as np
import os

import model_based_agent as mba 
import worm_env as we 
import ensemble_mod_env as eme

from improc import *
import utils as ut
import tab_agents as tab
from datetime import datetime 

In [2]:
def reward_diff_method(
    collection_eps = 3,
    frac_on = 1/2,
    eval_on_list = [1/2,1/3,1/4],
    collection_ep_time = 600, # in seconds. Must be a multiple of worm_ep_len
    eval_ep_time = 120, # in seconds. Also must be multiple of worm_ep_len
    worm_ep_len = 120, # in seconds
    init_df = None, # with folder 
):
    '''
    Function output:
    Saves all trajectories for collection and eval episodes. 
    collect{i}.pkl for former and mod{i}_{eval frac ind}.pkl is the model after ep i.
    eval{i}_{eval frac ind}.pkl for latter

    1. Collects data with light on frac_on of the time.
    2. Evaluates reward difference policy with various amounts of light penalty given
        by eval_on_list. 
    '''

    folder = './Data/Run'+datetime.now().strftime('%d-%m-%H-%M')+'/'
    if os.path.isdir(folder):
        os.rmdir(folder)
    os.mkdir(folder)

    # Initialize objects
    dh = mba.DataHandler()
    if init_df is not None:
        dh.load_df(init_df)
    worm = we.ProcessedWorm(0,ep_len=worm_ep_len) 

    ant = tab.Q_Alpha_Agent()
    runner = mba.WormRunner(ant,worm,act_spacing=1)
        # act_spacing here is only for eval episodes


    for ce in range(collection_eps):
        # Collecting random data
        #############################
        fname = folder+f'collect{ce}.pkl'
        if collection_ep_time%worm_ep_len != 0:
            raise ValueError('Collection_ep_time is not a multiple of worm_ep_len')
        print(f'Collecting randoms {ce}')
        mba.get_init_traj(fname, worm, int(collection_ep_time/worm_ep_len), rand_probs=[1-frac_on,frac_on])
        dh.add_dict_to_df([fname],reward_ahead=10,timestep_gap=1,prev_act_window=3,jump_limit=100)

        # Find RDiff matrix and collect eval episodes
        #############################
        cam,task = init_instruments()
        for i,ev in enumerate(eval_on_list):
            print(f'Finding policy')
            mset = eme.ModelSet(1,frac=1,lp_frac=ev)
            mset.make_models(dh,{'lambda':.1,'iters':10})
            # Save model
            mname = folder+f'mod{ce}_{i}.pkl'
            with open(mname,'wb') as f:
                pickle.dump(mset.models[0],f)

            rdiff = np.sign(mset.models[0]['reward_on'][:,:,0]-mset.models[0]['reward_off'][:,:,0])
            runner.agent.Qtab[:,0] = np.zeros(144)
            runner.agent.Qtab[:,1] = rdiff.flatten()
            ename = folder+f'eval{ce}_{i}.pkl'
            print(f'Running eval ep {i}')
            runner.eval_ep(cam,task,ename,eval_eps=int(eval_ep_time/worm_ep_len))
        cam.exit()
        task.write(0)
        task.close()
    dh.save_dfs('totaldf.pkl')
    

In [4]:
reward_diff_method(    
    collection_eps = 3,
    frac_on = 1/2,
    eval_on_list = [1/2,1/3,1/4],
    collection_ep_time = 600, # in seconds. Must be a multiple of worm_ep_len
    eval_ep_time = 120, # 0 in seconds. Also must be multiple of worm_ep_len
    worm_ep_len = 120, #00 in seconds
)

Collecting randoms 0
19 sec 		

  centers.append(np.array([np.sum(np.arange(im_sz)*sumx) / np.sum(sumx), np.sum(np.arange(im_sz)*sumy) / np.sum(sumy)]))


Finding policy
On model 0
Penalty -0.3514723660769161
Running eval ep 0
Finding policy
On model 0
Penalty -0.24156639483478126
Running eval ep 1
Finding policy
On model 0
Penalty -0.10944554070572998
Running eval ep 2
Collecting randoms 1
Finding policy
On model 0
Penalty -0.013946934423879398
Running eval ep 0
Finding policy
On model 0
Penalty 0.21265360672329292
Running eval ep 1
Finding policy
On model 0
Penalty 0.32225000150634786
Running eval ep 2
Collecting randoms 2
Finding policy
On model 0
Penalty -0.10981718058631884
Running eval ep 0
Finding policy
On model 0
Penalty 0.17212397880456454
Running eval ep 1
Finding policy
On model 0
Penalty 0.3196292657845537
Running eval ep 2
19 sec 			

In [3]:
# Emergency light shut-off
import nidaqmx
task = nidaqmx.Task()
task.ao_channels.add_ao_voltage_chan("Dev1/ao0")
task.write(0)
task.close()