In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import gym

In [2]:
env = gym.make("CartPole-v0")

In [3]:
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n 

In [17]:
class ActorCritic(tf.keras.Model):
    def __init__(self, action_dim):
        super().__init__()
        self.fc1 = tf.keras.layers.Dense(512, activation='relu')
        self.fc2 = tf.keras.layers.Dense(128, activation='relu')
        self.critic = tf.keras.layers.Dense(1, activation='linear')
        self.actor = tf.keras.layers.Dense(action_dim, activation='linear')
    
    def call(self, input_data):
        x = self.fc1(input_data)
        x1 = self.fc2(x)

        actor = self.actor(x1)
        critic = self.critic(x1)
        return actor, critic

In [22]:
observation = env.reset()
observation = observation.reshape([-1, 1])
observation

array([[-0.0348073 ,  0.02271736, -0.00114993, -0.04963978]],
      dtype=float32)

In [23]:
actor_critic = ActorCritic(action_dim=action_dim)

In [31]:
gamma = 0.99


### function: `get_action` 

* Arg: 
    * state (`np.ndarray`) : observation or state from the environment at current time step


* Return:
    * action (`int`) : action following the current policy at the time step

In [89]:
# Gets action from actor network

def get_action(state):
    """
    Gets action following the policy at current time step
    Arg: 
        state (np.ndarray) : observation or state from environment
    return: 
        action (int) : action following the policy 
    """
    action = np.random.randint(action_dim)
    state = np.array(state).reshape([1, -1])

    action_probs, _ = actor_critic(state)
    action_probs = tf.nn.softmax(action_probs)
    action_probs = action_probs.numpy()

    dist = tfp.distributions.Categorical(probs=action_probs, dtype=tf.float32)
    action = int(dist.sample())

    return action 

In [88]:
get_action(observation)

1