In [55]:
import numpy as np
from einops import rearrange
from copy import deepcopy
from distance import canon
from _types import Reward
from env import Env, RandomEnv
from coverage_dist import get_state_dist, get_action_dist
from reward import random_reward

In [56]:
def slow_dard(reward: Reward, env: Env):
  D_a = get_action_dist(env)

  term1 = np.zeros((1, 1, env.n_s))
  for s_prime in range(env.n_s):
    for A, A_prob in enumerate(D_a):
      for S_double, S_double_prob in enumerate(env.transition_dist[s_prime, A, :]):
        prob = A_prob * S_double_prob
        term1[0, 0, s_prime] += prob * env.discount * reward[s_prime, A, S_double]
  
  term2 = np.zeros((env.n_s, 1, 1))
  for s in range(env.n_s):
    for A, A_prob in enumerate(D_a):
      for S_prime, S_prime_prob in enumerate(env.transition_dist[s, A, :]):
        prob = A_prob * S_prime_prob
        term2[s, 0, 0] += prob * reward[s, A, S_prime]
  
  term3 = np.zeros((env.n_s, 1, env.n_s))
  for s in range(env.n_s):
    for s_prime in range(env.n_s):
      for A, A_prob in enumerate(D_a):
        for S_prime, S_prime_prob in enumerate(env.transition_dist[s, A, :]):
          for S_double, S_double_prob in enumerate(env.transition_dist[s_prime, A, :]):
            prob = A_prob * S_prime_prob * S_double_prob
            term3[s, 0, s_prime] += prob * env.discount * reward[S_prime, A, S_double]
  
  return term1, term2, term3


In [57]:
def dard_brr(reward: Reward, env: Env):
  A = get_action_dist(env)

  potential = (env.transition_dist * reward).sum(axis=2)
  potential = (potential * A[None, :]).sum(axis=1)

  term1 = env.discount * potential[None, None, :]
  term2 = potential[:, None, None]
  
  joint_probs = ( # [s, s', S', A, S'']; p(S', S'' | s, s', A=A)
    A[None, None, None, :, None] * 
    rearrange(env.transition_dist, 's A Sp -> s 1 Sp A 1') *
    rearrange(env.transition_dist, 'sp A Sd -> 1 sp 1 A Sd')
  )
  r_given_probs = reward[None, None, ...] * joint_probs
  term3 = env.discount * r_given_probs.sum(axis=(2,3,4))[:,None,:]
  
  return term1, term2, term3


In [58]:
for _ in range(10):
  e = RandomEnv(8, 2)
  r = random_reward(e)
  s1, s2, s3 = slow_dard(r, e)
  b1, b2, b3 = dard_brr(r, e)
  assert np.isclose(s1, b1).all()
  assert np.isclose(s2, b2).all()
  assert np.isclose(s3, b3).all()