<a href="https://colab.research.google.com/github/francescomontagna/MaxEnt-IRL/blob/main/MaxEntIRL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import random
import numpy as np
import matplotlib.pyplot as plt

In [20]:
class GridWorld:
  def __init__(self, p_initial, grid_size = 5):
    self.actions = [0, 1, 2, 3] # [up, right, down, left]
    self.grid_size = grid_size
    self.terminal_state = (4, 5)
    self.p_initial = p_initial # must be computed form trajectories

    self.reset()

  def state_coord_to_point(self, state_coord):
    return state_coord[0]*self.grid_size + state_coord[1]

  # TODO: test
  def point_to_state_coord(self, point):
    row = point // self.grid_size
    col = point % grid_size
    return row, col

  #TODO: test
  def simulate_step(self, state, action):
    s_coord = self.point_to_state_coord(state)
    next_s_coord = None

    if action == 0:
      if s_coord[0] == 0:
        next_s_coord = s_coord
      else:
        next_s_coord = (s_coord[0]-1, s_coord[1])

    elif action == 1:
      if s_coord[1] == (self.grid_size-1):
        next_s_coord = s_coord
      else:
        next_s_coord = (s_coord[0], s_coord[1]+1)
      
    elif action == 2:
      if s_coord[0] == (self.grid_size-1):
        next_s_coord = s_coord
      else:
        next_s_coord = (s_coord[0]+1, s_coord[1])

    elif action == 3:
      if s_coord[1] == 0:
        next_s_coord = s_coord
      else:
        next_s_coord = (s_coord[0], s_coord[1]-1)

    return self.state_coord_to_point(next_s_coord)

  def step(self, state, action):
    next_state = self.simulate_step(state, action)
    self.state_coord = self.point_to_state_coord(next_state) 

    return self.state

  # Initialize the gridworld
  def reset(self):
    initial_state = random.choice(range(self.grid_size**2), p=self.p_initial)
    self.state_coord = self.state_coord_to_point(initial_state)
    

In [None]:
class MaxCausalEntropy:
  # We'll consider the case of a deterministic policy
  def __init__(self, trajectories, lr, decay):
    """
    Attributes:
      trajectories: list of the sampeld trajectories
      lr: learning rate for SGA
      decay: weight decay for SGA
      world: The environment modeled as GridWorld
      features: feature to represrent each state. Simply encoded as coordinates in the grid
      theta: learnable parameters, one vector for each possible state. Same dimensionality of features
      V: (n_states x n_states) matrix with value associated to each state
      D: (n_states x n_states) matrix with visitation frequencies
      gamma: discount (Forse dev'essere ritornato da simulate_step, decido poi)
    """
    self.trajectories = trajectories
    self.lr = lr
    self.decay = decay

    self.world = GridWorld()
    self.n_states = self.world.grid_size**2
    self.features = np.zeros((self.n_states, 2))
    for s in range(self.n_states):
      i, j = self.world.point_to_state_coord(s)
      self.features[s, 0] = i
      self.features[s, 1] = j

    self.theta = np.zeros((self.n_states, 2)) # np.matmul(self.theta[i, j], self.features[i, j]) = R(s=[i,j])
    self.V = np.zeros((self.n_states))
    self.D = np.zeros((self.n_states))

    self.gamma = 0.9
    

  def R(self, state):
    """
    Return reward associated to given state (point)
    """
    return np.matmul(self.theta[s], self.features[s])

  # def Q(self, state, action):
  #   """
  #   Return action value associated to given state (point)
  #   Maximize over action to get Q(S, A)
  #   """
  #   next_s = self.world.simulate_step(state, action)
  #   return self.R(state) + self.gamma * self.V[next_s]

  def policy(self, state, action):
    return np.exp(self.Q(state, action) - self.V[state])
    
  def feature_exp_from_trajectories(self):
    # Compute feature_exp, as a representation of expert behaviour
    _, features_dim = self.features.shape
    f_exp = np.zeros((features_dim))

    for trajectory in self.trajectories:
      for step in trajectory:
        state, action = step
        f_exp += self.features[state] # TODO: check if it's point or coord

    return f_exp / len(self.trajectories)


  def pseudo_gpi(self):
    """
    Algorithm to update the value function
    """
    # NOTE: state = point in range(self.num_states)
    eps = 1e-3
    
    # For each state, define set V to -inf
    self.V = np.ones((self.n_states))*np.min

    # Set V_prime to -inf except for termnal state which is set to 0
    V_prime = np.ones((self.n_states))*np.min
    v_prime[self.world.state_coord_to_point(self.terminal_state)] = 0

    delta = np.ones((self.num_states))*np.max # parameter to monitor convergence
    while np.max(delta) > eps:
      # Update V for each state
      for s in range(self.num_states):
        for a in self.world.actions:
          next_s = self.world.simulate_step(s, a)
          V_prime[s] = np.log(np.exp(V_prime[s]) + np.exp(self.R(s) + self.V[next_s]))

        delta[s] = np.abs(self.V[s] - V_prime[s])
        self.V[s] = V_prime[s]

    # return self.V 

  def expected_svf(self):
    eps = 1e-5

    self.D = np.zeros((n_states))
    D_prime = world.p_initial # TODO Handle copy (shallow, deep, .. boh)

    delta = np.ones((self.num_states))*np.max
    while np.max(delta) > eps
      for s in range(self.num_states):
        for a in self.world.actions:
          next_s = self.world.simulate_step(s, a)
          D_prime[s] += self.D[s]*self.policy(next_s, a)

        delta[s] = np.abs(self.D[s] = D_prime[s])
        self.D[s] = D_prime[s]
           
    # return self.D[s]

  def sga(self):
    eps = 1e-4
    delta = np.zeros(self.features.shape)
    while np.max(delta) > eps
    pass