In [40]:
import numpy as np
from scipy.stats import expon
from scipy.stats import norm

from envs.vases_grid import VasesGrid, VasesEnvState, print_state, str_to_state, state_to_str
from envs.utils import unique_perm, zeros_with_ones, printoptions
from envs.vases_spec import VasesEnvState2x3V2D3, VasesEnvSpec2x3V2D3, VasesEnvState2x3Broken, VasesEnvSpec2x3Broken

from value_iter_and_policy import vi_boltzmann, vi_boltzmann_deterministic, vi_rational_deterministic


TEMP=0


class neg_exp_distr(object):
    '''
    a wrapper to not get confused with the negative exp distribuion,
    as scipy doesn't have a good way to do *negative* exp
    '''
    def __init__(self, mode, scale=1):
        self.mode = mode
        self.scale = scale
        self.distribution = expon(loc=-mode, scale=scale)
        
    def rvs(self):
        '''sample'''
        return -self.distribution.rvs()
    
    def pdf(self, x):
        return self.distribution.pdf(-x)
    



def birl_one_s_likelihood(s_current, V, temp=1):
    return np.exp((V[np.where(s_current)])/temp)
    

def policy_walk(env, r_spec, s_current, step_size, n_samples):
    '''
    Algo inspired by one in Appendix A from https://arxiv.org/pdf/1807.05037.pdf,
    should be equivalent to the original BIRL (but for one state)
    '''
    
    i=0
    
    p = .5
    samples = []
    
    #r_prior = neg_exp_distr(mode=r_spec, scale=2)
    r_prior = norm(loc=r_spec, scale=.1)
    
    r = r_prior.rvs()
    V, Q, pi = vi_rational_deterministic(env, gamma=.97, r=env.f_matrix @ r)
    
    while True:
        
        r_prime = np.random.normal(r, step_size)        
        #r_prime[r_prime>r_spec] = r_spec[r_prime>r_spec]
        
        V_prime, Q_prime, pi_prime = vi_rational_deterministic(env, gamma=.97,
                                                               r=env.f_matrix @ r_prime,
                                                               init_V = V)
        
        p_1 = birl_one_s_likelihood(s_current, V_prime) * np.prod(r_prior.pdf(r_prime))
        
        
        # There are sometimes numerical errors since birl_one_s_likelihood is
        # unnormalized and gets very large. This is why r_spec is 10 times smaller
        # than normal.
        # TODO: Metropolis-Hastings with logprobs instead of probs
        if not np.isfinite(p_1) or np.isnan(p_1):
            print(p)
            print(p_1)

        
        if np.random.uniform()<np.amin(np.array([1, p_1/p])) and p>0:
            samples.append(r_prime)
            r = np.copy(r_prime)
            V = np.copy(V_prime)
            p = np.copy(p_1)

        if len(samples)%100==0:
            if i!=len(samples):
                i = len(samples)
                print('samples generated: ', len(samples))
        
        if len(samples)==n_samples:
            return samples


In [None]:
import warnings
warnings.filterwarnings('error')

env2x3v2d3 = VasesGrid(VasesEnvSpec2x3V2D3(), VasesEnvState2x3V2D3())
r_spec = np.array([0, 0, .1, 0, 0, 0])

s_current = np.zeros(env2x3v2d3.nS)
s_current[1582] = 1
print_state(str_to_state(env2x3v2d3.num_state[1582]))

r_samples = policy_walk(env2x3v2d3, r_spec, s_current, step_size=.02, n_samples=1000)

In [None]:
#r_samples

In [39]:
'''
Order of features:
- Number of broken vases
- Number of vases on tables
- Number of tablecloths on tables
- Number of tablecloths on floors
- Number of vases on desks
- Number of tablecloths on desks
'''
np.mean(r_samples[200::], axis=0)

array([-0.00556688,  0.02111717,  0.09687882, -0.00501262,  0.60154344,
        0.32314526])