In [2]:
import numpy as np
import typing
import abc
import copy

In [3]:
# define some utils

# softmax along last dimension
def softmax(arr: np.ndarray) -> np.ndarray:
  exp = np.exp(arr)
  norm = np.sum(exp, axis=-1)
  norm = np.reshape(norm, (*arr.shape[0:-1], 1))
  norm = np.repeat(norm, arr.shape[-1], axis=-1)
  return exp / norm

Reward = np.ndarray

def l1_norm(reward: Reward) -> float:
  return np.abs(reward).sum()

def l2_norm(reward: Reward) -> float:
  out = np.power(reward, 2).sum()
  return np.sqrt(out)

def l_infty_norm(reward: Reward) -> float:
  return np.abs(reward).max()

def random_dist(n):
  logits = np.random.random(n)
  exp = np.exp(logits)
  return exp / sum(exp)

def random_reward(n_s, n_a):
  return np.random.random((n_s, n_a, n_s))

def random_transition_dist(n_s, n_a):
  return softmax(np.random.random((n_s, n_a, n_s)))

In [4]:
class RewardDistance():
  def __init__(self, discount=0.9) -> None:
    self.discount = discount

  @abc.abstractmethod
  def canonicalize(self, reward: Reward) -> Reward:
    pass

  def normalize(self, reward: Reward) -> float:
    return l2_norm(reward)

  def distance(self, r1: Reward, r2: Reward) -> float:
    return l2_norm(r1 - r2)

  def __call__(self, r1: Reward, r2: Reward) -> float:
    can1 = self.canonicalize(r1)
    can2 = self.canonicalize(r2)

    standard1 = can1 / self.normalize(can1)
    standard2 = can2 / self.normalize(can2)

    return self.distance(standard1, standard2)

In [5]:
class Epic(RewardDistance):
  def __init__(self, state_dist: np.ndarray, action_dist: np.ndarray):
    super().__init__()
    self.state_dist = state_dist
    self.action_dist = action_dist
    
  def canonicalize(self, reward: Reward) -> Reward:
    result = copy.deepcopy(reward)

    # memoize all the expected values
    expected_term1 = [] # depends on s_prime
    for s_prime in range(len(self.state_dist)):
      ev1 = 0

      for A, A_prob in enumerate(self.action_dist):
        for S_prime, S_prime_prob in enumerate(self.state_dist):
          prob = A_prob * S_prime_prob
          ev1 += prob * self.discount * reward[s_prime, A, S_prime]

      expected_term1.append(ev1)
  
    expected_term2 = [] # depends on s
    for s in range(len(self.state_dist)):
      ev2 = 0

      for A, A_prob in enumerate(self.action_dist):
        for S_prime, S_prime_prob in enumerate(self.state_dist):
          prob = A_prob * S_prime_prob
          ev2 += prob * self.discount * reward[s, A, S_prime]

      expected_term2.append(ev2)

    expected_term3 = 0 # constant across all state values
    for S, S_prob in enumerate(self.state_dist):
      for A, A_prob in enumerate(self.action_dist):
        for S_prime, S_prime_prob in enumerate(self.state_dist):
          prob = S_prob * A_prob * S_prime_prob
          expected_term3 += prob * self.discount * reward[S, A, S_prime]


    for s in range(len(self.state_dist)):
      for a in range(len(self.action_dist)):
        for s_prime in range(len(self.state_dist)):
          terms = expected_term1[s_prime] - expected_term2[s] - expected_term3
          result[s, a, s_prime] += terms
        
    return result
  

# test EPIC
n_states = 200
n_actions = 10
state_dist = random_dist(n_states)
action_dist = random_dist(n_actions)
r1 = random_reward(n_states, n_actions)
r2 = random_reward(n_states, n_actions)
Epic(state_dist, action_dist)(r1, r2)

1.3935490871336487

In [7]:
class Dard(RewardDistance):
  def __init__(
    self,
    state_dist: np.ndarray,
    action_dist: np.ndarray,
    transition_dist: np.ndarray,
  ):
    super().__init__()
    self.state_dist = state_dist
    self.action_dist = action_dist
    self.transition_dist = transition_dist
    
  def canonicalize(self, reward: Reward) -> Reward:
    result = copy.deepcopy(reward)

    # memoize all the expected values
    expected_term1 = [] # depends on s_prime
    for s_prime in range(len(self.state_dist)):
      ev1 = 0

      for A, A_prob in enumerate(self.action_dist):
        for S_double, S_double_prob in enumerate(self.transition_dist[s_prime, A, :]):
          prob = A_prob * S_double_prob
          ev1 += prob * self.discount * reward[s_prime, A, S_double]

      expected_term1.append(ev1)
  
    expected_term2 = [] # depends on s
    for s in range(len(self.state_dist)):
      ev2 = 0

      for A, A_prob in enumerate(self.action_dist):
        for S_prime, S_prime_prob in enumerate(self.state_dist):
          prob = A_prob * S_prime_prob
          ev2 += prob * self.discount * reward[s, A, S_prime]

      expected_term2.append(ev2)

    expected_term3 = 0 # constant across all state values
    for A, A_prob in enumerate(self.action_dist):
      for S_prime, S_prime_prob in enumerate(self.transition_dist[s, A, :]):
        for S_double, S_double_prob in enumerate(self.transition_dist[s_prime, A, :]):
          prob = S_prime_prob * A_prob * S_double_prob
          expected_term3 += prob * self.discount * reward[S_prime, A, S_double]


    for s in range(len(self.state_dist)):
      for a in range(len(self.action_dist)):
        for s_prime in range(len(self.state_dist)):
          terms = expected_term1[s_prime] - expected_term2[s] - expected_term3
          result[s, a, s_prime] += terms
        
    return result

    
# test DARD
n_states = 200
n_actions = 10
state_dist = random_dist(n_states)
action_dist = random_dist(n_actions)
transition_dist = random_transition_dist(n_states, n_actions)
r1 = random_reward(n_states, n_actions)
r2 = random_reward(n_states, n_actions)
Dard(state_dist, action_dist, transition_dist)(r1, r2)

1.393991041812696

In [12]:
# compare them
n_states = 100
n_actions = 10
state_dist = random_dist(n_states)
action_dist = random_dist(n_actions)
transition_dist = random_transition_dist(n_states, n_actions)

epic = Epic(state_dist, action_dist) 
dard = Dard(state_dist, action_dist, transition_dist)

for _ in range(10):
  r1 = random_reward(n_states, n_actions)
  r2 = random_reward(n_states, n_actions)
  epic_result = epic(r1, r2)
  dard_result = dard(r1, r2)
  diff = abs(epic_result-dard_result)
  print(f"EPIC={epic_result:.5f}, DARD={dard_result:.5f}, difference={diff:.5f}")

EPIC=1.39686, DARD=1.39719, difference=0.00033
EPIC=1.39438, DARD=1.39478, difference=0.00040
EPIC=1.39713, DARD=1.39715, difference=0.00001
EPIC=1.39500, DARD=1.39482, difference=0.00018
EPIC=1.38952, DARD=1.38961, difference=0.00009
EPIC=1.39349, DARD=1.39341, difference=0.00008
EPIC=1.39231, DARD=1.39207, difference=0.00024
EPIC=1.39334, DARD=1.39352, difference=0.00018
EPIC=1.39521, DARD=1.39541, difference=0.00019
EPIC=1.39460, DARD=1.39438, difference=0.00022
