In [1]:
import numpy as np
from scipy.stats import expon

from envs.vases_grid import VasesGrid, VasesEnvState, print_state, str_to_state, state_to_str
from envs.utils import unique_perm, zeros_with_ones, printoptions
from envs.vases_spec import VasesEnvState2x3V2D3, VasesEnvSpec2x3V2D3, VasesEnvState2x3Broken, VasesEnvSpec2x3Broken

from value_iter_and_policy import vi_boltzmann, vi_boltzmann_deterministic, vi_rational_deterministic


class neg_exp_distr(object):
    '''
    a wrapper to not get confused with the negative exp distribuion,
    as scipy doesn't have a good way to do *negative* exp
    '''
    def __init__(self, mode, scale=1):
        self.mode = mode
        self.scale = scale
        self.distribution = expon(loc=-mode, scale=scale)
        
    def sample(self):
        return -self.distribution.rvs()
    
    def pdf(self, x):
        return self.distribution.pdf(-x)


def birl_one_s_likelihood(s_current, V, temp=1):
    return np.exp(V[np.where(s_current)]/temp)
    

def policy_walk(env, r_spec, s_current, step_size, n_samples):
    '''
    Algo inspired by one in Appendix A from https://arxiv.org/pdf/1807.05037.pdf,
    should be equivalent to the original BIRL (but for one state)
    '''
    
    i=0
    
    p = .5
    samples = []
    
    r_prior = neg_exp_distr(mode=r_spec, scale=.1)
    r = r_prior.sample()
    V, Q, pi = vi_rational_deterministic(env, gamma=.99, r=env.f_matrix @ r)
    
    while True:
        
        r_prime = np.random.normal(r, step_size)        
        
        V_prime, Q_prime, pi_prime = vi_rational_deterministic(env, gamma=.99, 
                                                               r=env.f_matrix @ r_prime,
                                                               init_V = V)
        
        p_1 = birl_one_s_likelihood(s_current, V_prime) * np.prod(r_prior.pdf(r_prime))
        
        if np.random.uniform()<np.amin(np.array([1, p_1/p])):
            samples.append(r_prime)
            r = np.copy(r_prime)
            V = np.copy(V_prime)
            p = np.copy(p_1)
        
        if len(samples)%200==0:
            if i!=len(samples):
                i = len(samples)
                print('samples generated: ', len(samples))
        
        if len(samples)==n_samples:
            return samples


In [2]:
env2x3v2d3 = VasesGrid(VasesEnvSpec2x3V2D3(), VasesEnvState2x3V2D3())
r_spec = np.array([0, 0, 1, 0, 0, 0])

s_current = np.zeros(env2x3v2d3.nS)
s_current[1582] = 1

r_samples = policy_walk(env2x3v2d3, r_spec, s_current, step_size=.01, n_samples=10000)

samples generated:  200
samples generated:  400
samples generated:  600
samples generated:  800
samples generated:  1000
samples generated:  1200
samples generated:  1400
samples generated:  1600
samples generated:  1800
samples generated:  2000
samples generated:  2200
samples generated:  2400
samples generated:  2600
samples generated:  2800
samples generated:  3000
samples generated:  3200
samples generated:  3400
samples generated:  3600
samples generated:  3800
samples generated:  4000
samples generated:  4200
samples generated:  4400
samples generated:  4600
samples generated:  4800
samples generated:  5000
samples generated:  5200
samples generated:  5400
samples generated:  5600
samples generated:  5800
samples generated:  6000
samples generated:  6200
samples generated:  6400
samples generated:  6600
samples generated:  6800
samples generated:  7000
samples generated:  7200
samples generated:  7400
samples generated:  7600
samples generated:  7800
samples generated:  8000
samp

In [3]:
r_samples

[array([-0.11679481, -0.36425451,  0.8403076 , -0.0362376 , -0.06122026,
        -0.02046684]),
 array([-0.13139183, -0.36374731,  0.84825147, -0.04977498, -0.06093634,
        -0.01971753]),
 array([-0.1273514 , -0.36488781,  0.84827043, -0.0452935 , -0.07179797,
        -0.02886349]),
 array([-0.12343134, -0.35727599,  0.85353124, -0.03855832, -0.07646993,
        -0.0307025 ]),
 array([-0.12375063, -0.33134782,  0.85717204, -0.02425707, -0.07102669,
        -0.01195067]),
 array([-0.11342904, -0.33281376,  0.85975374, -0.00422914, -0.06678436,
        -0.02133648]),
 array([-0.1024723 , -0.3311146 ,  0.85455271, -0.00383398, -0.0771481 ,
        -0.02006129]),
 array([-0.10684174, -0.35086954,  0.86577297, -0.00416159, -0.08337938,
        -0.02164675]),
 array([-0.09292462, -0.33958681,  0.87797038, -0.00690627, -0.08879825,
        -0.02246249]),
 array([-9.03524045e-02, -3.57626382e-01,  8.74583719e-01, -8.39903492e-04,
        -9.30436731e-02, -2.69655105e-02]),
 array([-0.09837

In [7]:
'''
Order of features:
- Number of broken vases
- Number of vases on tables
- Number of tablecloths on tables
- Number of tablecloths on floors
- Number of vases on desks
- Number of tablecloths on desks
'''
np.mean(r_samples[2000::], axis=0)

array([-0.00852826, -0.06227133,  0.98692916, -0.08686534, -0.10129312,
       -0.05260909])