In [1]:
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
import os
from datetime import datetime
from scipy.stats import norm

# import utils as ut
# from improc import *
# import policy_time as pt

# R file: get_posterior.R should be running simultaneously.

### First worm

In [134]:
def first_pol_collection(
    folder, #'./Data/Pol03_02_0/'
    init_episodes=1,
    ):

    # Fixed parameters
    ep=0
    params = {
        'reward_ahead': 30,
        'timestep_gap': 1,
        'prev_act_window': 3,
        'jump_limit': 100,
    }


    # Collect initial random trajectory
    worm = we.ProcessedWorm(0,ep_len=300) # Each worm episode will be several minutes long
    fnames=[]
    fnames.append(f'{folder}traj{ep}.pkl')
    pt.get_init_traj(fname[0], worm, init_episodes, act_rate=3)


    # Load trajectory in a data handler object as a df. Save for R.
    dh = pt.DataHandler(params=params)
    dh.add_dict_to_df(fnames)
    dh.df = pt.change_reward_ahead(dh.df,params['reward_ahead'],jump_limit=params['jump_limit'])
    dh.save_dfs(f'{folder}df{ep}.pkl')
    pt.save_for_R(dh.df, f'{folder}traj{ep}.npy')


    # Now R script will run and create a file that we will sit and wait for.
    bart_f = f'{folder}sbart{ep}.npy'
    while not os.path.exists(bart_f):
        time.sleep(1)
    # Now that the file exists, we get the highest entropy states and decide to explore them more
    # in the future.
    post = np.load(bart_f,allow_pickle=True)
    probs, ents = pt.bart2pols(post)
    np.save(f'{folder}probs{ep}.npy',probs,allow_pickle=True)
    np.save(f'{folder}ents{ep}.npy',ents,allow_pickle=True)
    
def gen_pol_collection(
    worm_id, # Should not be 0
    
    # All parameters below this should be same as in first_pol_collection() function.
    folder, #'./Data/Pol03_02_0/'
    init_episodes=1,
    tot_eps=5,
    base_sampling=.1, # All states will have a base sampling rate of 10% (5% light on)
    samp_states_decay=5, # tau such that perc light will be on is 1/2*[(1-base)e^(-ep/tau)+base]
    ):
    
    # Fixed parameters
    params = {
        'reward_ahead': 30,
        'timestep_gap': 1,
        'prev_act_window': 3,
        'jump_limit': 100,
    }
    
    # Load entropies and dataframe from previous run    
    curr_ep = 1+worm_id*tot_eps
    dh = pt.DataHandler()
    dh.load_df(f'{folder}df{curr_ep-1}.pkl')
    ents = np.load(f'{folder}ents{curr_ep-1}.npy',allow_pickle=True)

    for ep in np.arange(tot_eps)+curr_ep:
        # Create a list of states to search based on entropy and cutoff.
        # alpha is the proportion of time we are sampling. alpha*.5 is proportion of time light is on.
        alpha = (1-base_sampling)*np.exp(-ep/samp_states_decay)+base_sampling
        counts = pt.get_counts(dh.df)
        sampling_probs = pt.make_sprobs_cutoff(ents,alpha,counts,base_sampling)


        # Get new data
        fnames[0] = f'{folder}traj{ep}.pkl'
        pt.do_sampling_traj(sampling_probs, fnames[0], worm, episodes, act_rate=3)


        # Save for R in addition to old data 
        dh_n = pt.DataHandler(params=params)
        dh_n.add_dict_to_df(fnames)
        dh_n.df = pt.change_reward_ahead(dh_n.df,params['reward_ahead'],jump_limit=params['jump_limit'])
        dh_n.df = dh_n.df.append(dh.df, ignore_index=True)
        dh_n.save_dfs(f'{folder}df{ep}.pkl')
        pt.save_for_R(dh_n.df, f'{folder}traj{ep}.npy')    
        dh = dh_n


        # Wait for R script to run
        bart_f = f'{folder}sbart{ep}.npy'
        while not os.path.exists(bart_f):
            time.sleep(1)
        # Get the probabilities and entropies 
        post = np.load(bart_f,allow_pickle=True)
        probs, ents = pt.bart2pols(post)
        np.save(f'{folder}probs{ep}.npy',probs,allow_pickle=True)
        np.save(f'{folder}ents{ep}.npy',ents,allow_pickle=True)