In [1]:
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
import os
from datetime import datetime
from scipy.stats import norm

import utils as ut
from improc import *
import policy_time as pt

In [None]:
def bart2pols(posterior):
    # Takes posterior sampling from BART in R (size [n_samples,288]) where
    # actions alternate every entry. Then second variable ascends fastest.
    # Returns 
    #    1. the array of P(Q_delta(state)>0) for each state
    #    2. the array of H(state), entropies of P(Q_delta(state)>0)
    
    def POver0(mu,sig):
        # Returns the probability that the RV is above 0
         return 1-norm.cdf(-mu/sig)
    def entropy(probs):
        return -probs*np.log(probs+1e-6)-(1-probs)*np.log(1-probs+1e-6)
    
    post_diff = posterior[:,1::2] - posterior[:,::2]
    mu = np.mean(post_diff,axis=0)
    sig = np.std(post_diff,axis=0)
    probs = POver0(mu,sig).reshape(12,12)
    ent = entropy(probs).reshape(12,12)
    return probs, ent

def make_sprobs_cutoff(ents,alpha,baseline,counts):
    # TODO
    pass

In [None]:
# R file: get_posterior.R should be running simultaneously.

In [5]:
# Parameters
init_episodes = 1
tot_eps = 5 # Total episodes using efficient exploration policy
folder = './Data/Pol03_02_0/'
fbase = 'traj' # Must match between rstudio and this

# Policy parameters
base_sampling = .1 
    # All states will have a base sampling rate of 10% (5% light on)
samp_states_decay = 10 
    # tau such that perc light will be on is 1/2*[(1-base)e^(-ep/tau)+base]


# Fixed parameters
params = {
    'reward_ahead': 30,
    'timestep_gap': 1,
    'prev_act_window': 3,
    'jump_limit': 100,
}

# Collect initial random trajectory
worm = we.ProcessedWorm(0,ep_len=300) # Each worm episode will be several minutes long
fnames=[]
fnames.append(f'{folder}{fbase}0.pkl')
pt.get_init_traj(fname, worm, init_episodes, act_rate=3)

# Load trajectory in a data handler object as a df. Save for R.
dh = pt.DataHandler(params=params)
dh.add_dict_to_df(fnames)
dh.df = change_reward_ahead(dh.df,params['reward_ahead'],jump_limit=params['jump_limit'])
save_for_R(dh.df, f'{folder}{fbase}0.npy')

ep=0
# Now R script will run and create a file that we will sit and wait for.
bart_f = f'{folder}sbart{ep}.npy'
while not os.path.exists(bart_f):
    time.sleep(1)
# Now that the file exists, we get the highest entropy states and decide to explore them more
# in the future.
post = np.load(bart_f,allow_pickle=True)
probs, ents = bart2pols(post)
np.save(f'{folder}probs{ep}.npy',probs,allow_pickle=True)
np.save(f'{folder}ents{ep}.npy',ents,allow_pickle=True)

for ep in np.arange(tot_eps)+1:
    # Create a list of states to search based on entropy and cutoff.
    # alpha is the proportion of time light is on.
    alpha = 1/2*((1-base_sampling)*np.exp(-ep/samp_states_decay)+base_sampling)
    sampling_probs = make_sprobs_cutoff(ents,alpha)

Penalty -0.3642956292503186
