In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import gym

class PolicyNetwork(nn.Module):
    def __init__(self, n_inputs, n_outputs, learning_rate=1e-2):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(n_inputs, 128)
        self.fc2 = nn.Linear(128, n_outputs)
        self.softmax = nn.Softmax(dim=-1)

        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    def forward(self, x): # L ->  relu -> L -> softmax
        x = torch.relu(self.fc1(x))
        x = self.softmax(self.fc2(x))
        return x

    def update(self, rewards, log_probs):
        loss = (-torch.stack(log_probs) * rewards).sum()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()


# Create environment
env = gym.make('CartPole-v0')
# Input is the obs space
n_inputs = env.observation_space.shape[0]
# Output is the action space
n_outputs = env.action_space.n


# Policy is a NN
policy = PolicyNetwork(n_inputs, n_outputs)
n_episodes = 2000
max_timesteps = 500
gamma = 0.99

In [2]:
#Initial Testing and visualization phase
import time

n_test_episodes = 10
for episode in range(n_test_episodes):
    state = env.reset()
    total_reward = 0

    for t in range(max_timesteps):
        env.render()  # Render the environment
        time.sleep(0.02)  # Slow down the visualization

        state = torch.FloatTensor(state).unsqueeze(0)
        action_probs = policy(state)
        action = torch.argmax(action_probs, 1).item()

        state, reward, done, _ = env.step(action)
        total_reward += reward

        if done:
            break

    print("Test Episode: {}, Reward: {}".format(episode + 1, total_reward))

env.close()

Test Episode: 1, Reward: 9.0
Test Episode: 2, Reward: 10.0
Test Episode: 3, Reward: 10.0
Test Episode: 4, Reward: 9.0
Test Episode: 5, Reward: 9.0
Test Episode: 6, Reward: 9.0
Test Episode: 7, Reward: 10.0
Test Episode: 8, Reward: 8.0
Test Episode: 9, Reward: 9.0
Test Episode: 10, Reward: 9.0


In [3]:


for episode in range(n_episodes):
    # Each episode resets the environment - Start from scratch
    state = env.reset()
    log_probs = []
    rewards = []

    # in each state, we run the model N timesteps
    for t in range(max_timesteps):
        state = torch.FloatTensor(state).unsqueeze(0)
        action_probs = policy(state)
        action = torch.multinomial(action_probs, 1).item()
        log_probs.append(torch.log(action_probs[0, action]))

        state, reward, done, _ = env.step(action)
        rewards.append(reward)

        if done:
            break

    returns = []
    Gt = 0
    for r in reversed(rewards):
        Gt = r + gamma * Gt   # gt = r + gamma * g_(t+1)
        returns.insert(0, Gt)

    returns = torch.FloatTensor(returns)
    returns = (returns - returns.mean()) / (returns.std() + 1e-9)

    policy.update(returns, log_probs)

    if episode % 100 == 0:
        print("Episode: {}, Reward: {}".format(episode, sum(rewards)))



Episode: 0, Reward: 14.0
Episode: 100, Reward: 200.0
Episode: 200, Reward: 84.0
Episode: 300, Reward: 176.0
Episode: 400, Reward: 200.0
Episode: 500, Reward: 118.0
Episode: 600, Reward: 200.0
Episode: 700, Reward: 21.0
Episode: 800, Reward: 23.0
Episode: 900, Reward: 200.0
Episode: 1000, Reward: 200.0
Episode: 1100, Reward: 111.0
Episode: 1200, Reward: 124.0
Episode: 1300, Reward: 73.0
Episode: 1400, Reward: 61.0
Episode: 1500, Reward: 127.0
Episode: 1600, Reward: 200.0
Episode: 1700, Reward: 200.0
Episode: 1800, Reward: 200.0
Episode: 1900, Reward: 174.0


In [4]:
# Testing and visualization phase
import time

n_test_episodes = 10
for episode in range(n_test_episodes):
    state = env.reset()
    total_reward = 0

    for t in range(max_timesteps):
        env.render()  # Render the environment
        time.sleep(0.02)  # Slow down the visualization

        state = torch.FloatTensor(state).unsqueeze(0)
        action_probs = policy(state)
        action = torch.argmax(action_probs, 1).item()

        state, reward, done, _ = env.step(action)
        total_reward += reward

        if done:
            break

    print("Test Episode: {}, Reward: {}".format(episode + 1, total_reward))

env.close()

Test Episode: 1, Reward: 200.0
Test Episode: 2, Reward: 200.0
Test Episode: 3, Reward: 200.0
Test Episode: 4, Reward: 200.0
Test Episode: 5, Reward: 200.0
Test Episode: 6, Reward: 200.0
Test Episode: 7, Reward: 200.0
Test Episode: 8, Reward: 200.0
Test Episode: 9, Reward: 200.0
Test Episode: 10, Reward: 200.0
