In [2]:
# imports
import numpy as np
import gym

In [3]:
## Task 0. Simple Policy Function
def policy(matrix, weight):
    """
    Computes a policy using the given weight for the provided matrix

    Args:
        matrix: np.ndarray shape (state, action)
        weight: np.ndarray shape (action, weight)

    Returns:
        The policy computed using the given weight
        np.ndarray shape (state, weight)
    """
    dot_prod = matrix.dot(weight)
    exp = np.exp(dot_prod)
    policy = exp / np.sum(exp)
    return policy

In [4]:
# 0-main
weight = np.ndarray((4, 2), buffer=np.array([
    [4.17022005e-01, 7.20324493e-01], 
    [1.14374817e-04, 3.02332573e-01], 
    [1.46755891e-01, 9.23385948e-02], 
    [1.86260211e-01, 3.45560727e-01]
    ]))
state = np.ndarray((1, 4), buffer=np.array([
    [-0.04428214,  0.01636746,  0.01196594, -0.03095031]
    ]))

res = policy(state, weight)
print(res)

[[0.50351642 0.49648358]]


In [25]:
# Task 1. Compute the Monte-Carlo policy gradient
def policy_gradient(state, weight):
    """
    Function that computes the Monte-Carlo policy gradient based on a state
        and a weight matrix

    Args:
        state: matrix representing the current observation of the environment
        weight: matrix of random weight

    Returns:
        The action and the gradieng(in this order)
    """
    MCPolicy = policy(state, weight)
    action = np.random.choice(len(MCPolicy[0]), p=MCPolicy[0])

    # Need to reshape the policy to build softmax, so we do that here
    s = MCPolicy.reshape(-1, 1)

    softmax = (np.diagflat(s) - np.dot(s, s.T))[action, :]

    log_derivative = softmax / MCPolicy[0, action]

    grad = state.T.dot(log_derivative[None, :])

    return action, grad

In [26]:
# 1-main
env = gym.make('CartPole-v1')
np.random.seed(1)

weight = np.random.rand(4, 2)
state = env.reset()[None,:]
print(weight)
print(state)

action, grad = policy_gradient(state, weight)
print(action)
print(grad)

env.close()

[[4.17022005e-01 7.20324493e-01]
 [1.14374817e-04 3.02332573e-01]
 [1.46755891e-01 9.23385948e-02]
 [1.86260211e-01 3.45560727e-01]]
[[ 0.04124436  0.00458376  0.0449007  -0.04867243]]
0
[[ 0.02066031 -0.02066031]
 [ 0.00229612 -0.00229612]
 [ 0.02249186 -0.02249186]
 [-0.02438121  0.02438121]]
