In [11]:
import matplotlib.pyplot as plt
import torch
import gym

class PolicyNet(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(input_size, 64)
        self.fc2 = torch.nn.Linear(64, output_size)
        self.softmax = torch.nn.Softmax(dim=0)
    def forward(self, x):
        x = torch.from_numpy(x).float()
        x = torch.nn.functional.relu(self.fc1(x))
        return self.softmax(self.fc2(x))
    def draw_action(self, x):
        action_prob = self.forward(x)
        m = torch.distributions.Categorical(action_prob)
        action = m.sample()
        log_probs = m.log_prob(action)
        return action.item(), log_probs
    
class ValueNet(torch.nn.Module):
    def __init__(self, input_size):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(input_size, 64)
        self.fc2 = torch.nn.Linear(64, 1)
    def forward(self,x):
        return self.fc2(torch.nn.functional.relu(self.fc1(x)))
    
def collect_trajectory():
    state_list = []; action_list = []; reward_list = []; log_prob_list = []
    state = env.reset()
    done = False
    steps = 0
    while not done and steps < max_num_steps_per_ep:
        action, log_prob = policy.draw_action(state)
        newstate, reward, done, _ = env.step(action)
        state_list.append(state); action_list.append(action)
        reward_list.append(reward); log_prob_list.append(log_prob)
        steps += 1  
        state = newstate
    return state_list, action_list, reward_list, log_prob_list
    
env = gym.make('CartPole-v0')
state = env.reset()
input_size = env.observation_space.shape[0]
output_size = env.action_space.n
policy = PolicyNet(input_size,output_size)
value = ValueNet(input_size)
num_iter = 1; num_ep_per_iter = 20; num_traj_per_iter = 10; 
max_num_steps_per_ep = 200; gamma = 0.99

for i in range(num_iter):
    state_list, action_list, reward_list = collect_trajectory()
    # Compute rewards-to-go Rt
    # Compute advantage A
    # Estimate policy gradient
    # GD step on policy
    # Estimate value function gradient 
    # GD step on value function
    

# manual testing

In [10]:
collect_trajectory()

([array([0.00642266, 0.00365945, 0.01782137, 0.04382058]),
  array([ 0.00649585, -0.19171345,  0.01869778,  0.34207261]),
  array([0.00266158, 0.00313756, 0.02553924, 0.05534401]),
  array([ 0.00272433, -0.1923411 ,  0.02664612,  0.3559741 ]),
  array([-0.00112249, -0.38783157,  0.0337656 ,  0.65693874]),
  array([-0.00887912, -0.19319551,  0.04690437,  0.37507627]),
  array([-0.01274303,  0.00122991,  0.0544059 ,  0.09754372]),
  array([-0.01271844,  0.1955316 ,  0.05635677, -0.17749014]),
  array([-0.0088078 ,  0.38980365,  0.05280697, -0.45187545]),
  array([-0.00101173,  0.58414058,  0.04376946, -0.72745646]),
  array([ 0.01067108,  0.38844172,  0.02922033, -0.42132519]),
  array([ 0.01843991,  0.19291821,  0.02079383, -0.11957558]),
  array([ 0.02229828, -0.00249539,  0.01840232,  0.17959441]),
  array([ 0.02224837,  0.19235845,  0.0219942 , -0.1072269 ]),
  array([ 0.02609554,  0.38715843,  0.01984967, -0.39289042]),
  array([ 0.03383871,  0.1917605 ,  0.01199186, -0.09401585]),
