In [None]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
num_episodes = 100
train_batches_per_episode = 100
batch_size = 128
target_copy_freq = 10

In [None]:
env = gym.make('CartPole-v0')
state_dim = 4
n_actions = 2

buffer = [] # Replay buffer to hold environment transitions
training_rewards = []
rng = np.random.RandomState(12345)

In [None]:
net = nn.Sequential(nn.Linear(state_dim, 128),
                    nn.ReLU(),
                    nn.Linear(128, n_actions))

target_net = nn.Sequential(nn.Linear(state_dim, 128),
                           nn.ReLU(),
                           nn.Linear(128, n_actions))
target_net.load_state_dict(net.state_dict())

criterion = nn.MSELoss()
#optimizer = optim.Adam(net.parameters(), lr=0.01)
optimizer = optim.SGD(net.parameters(), lr=0.01)

In [None]:
def policy(s):
    action_values = net.forward(torch.Tensor(s.reshape((1, -1))))
    best_action = action_values.max(dim=1)[1].numpy()
    if rng.uniform() < 0.95:
        return best_action[0]
    else:
        # 5% chance of random action to encourage exploration
        return rng.choice(n_actions)

In [None]:
for ep in tqdm(range(num_episodes)):
    # Do a rollout
    s = env.reset()
    ep_reward = 0
    while True:
        a = policy(s)
        s_new, r, d, _ = env.step(a)
        #r = r - 0.05*((s[0]-0.3)**2)  # Optional reward shaping
        buffer.append((s, a, s_new, r, d))
        s = s_new
        ep_reward += r
        if d:
            training_rewards.append(ep_reward)
            break
    # Train
    if len(buffer) >= batch_size: # Ensure we have enough data for a full batch
        for train_step in range(train_batches_per_episode):
            idx = rng.choice(len(buffer), replace=False, size=batch_size) # Sample randomly from the buffer
            # Create pytorch tensors for the data in this batch
            s = torch.Tensor([buffer[i][0] for i in idx])
            a = torch.Tensor([buffer[i][1] for i in idx]).long()
            s_new = torch.Tensor([buffer[i][2] for i in idx])
            r = torch.Tensor([buffer[i][3] for i in idx])
            d = torch.Tensor([buffer[i][4] for i in idx]).float()
            
            Q_values = target_net.forward(s_new)
            max_Qs = Q_values.max(dim=1)[0] # Get max over actions.
            targets = r + (1 - d) * max_Qs # Necessary so that we treat done states as having zero future value
            targets = targets.detach() # Detach prevents back-prop of gradients
            
            lhs_Q_vals = net.forward(s)[torch.arange(batch_size), a] # LHS Q-values for actions that were taken
            
            optimizer.zero_grad()
            loss = criterion(lhs_Q_vals, targets)
            loss.backward()
            optimizer.step()
            
            # Periodically update the target network with the new parameters
            if train_step + ep*train_batches_per_episode % target_copy_freq == 0:
                target_net.load_state_dict(net.state_dict())

In [None]:
plt.plot(training_rewards)
plt.xlabel('Episodes run')
plt.ylabel('Episode length')

In [None]:
# Save a video to a directory
env_to_wrap = gym.make('CartPole-v0')
env = gym.wrappers.Monitor(env_to_wrap, 'video_output', force = True)
s = env.reset()
while True:
    s, _, d, _ = env.step(policy(s))
    print(s)
    if d:
        break
env.close()
env_to_wrap.close()