In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gym
import numpy as np
import random

EPISODE_COUNT = 10000
EPISODE_LENGTH = 200
OBS_SIZE = 4
TARGET_NET_REFRESH_RATE = 200

In [None]:
class Q(nn.Module):
    
    def __init__(self):
        super(Q, self).__init__()
        self.fc1 = nn.Linear(5, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
def transfer_state(src, dst):
    dst.load_state_dict(src.state_dict())   
    
def argmax_action(net, action_space, state):
    best_i = [0]
    best_out = net.forward(torch.cat((action_space[0], state)))
    for i in range(1, len(action_space)):
        out = net.forward(torch.cat((action_space[i], state)))
        if out > best_out:
            best_i = [i]
            best_out = out
    return best_i[0]

def max_action(net, action_space, state):
    best_i = [0]
    best_out = net.forward(torch.cat((action_space[0], state)))
    for i in range(1, len(action_space)):
        out = net.forward(torch.cat((action_space[i], state)))
        if out > best_out:
            best_i = [i]
            best_out = out
    return best_out


def empty_transition_block():
    return torch.zeros(OBS_SIZE*2 + 3)

In [None]:
# Networks, loss fn, optimizer
q = Q()
q_ = Q()
transfer_state(q, q_)

criterion = nn.MSELoss()
optimizer = optim.SGD(q.parameters(), lr=0.001)

# Environment
env = gym.make('CartPole-v0')

# Discount factor
gamma = 0.99

# Replay memory
transitions = torch.zeros(EPISODE_COUNT*EPISODE_LENGTH, OBS_SIZE*2 + 3)

# List of possible actions
action_values = [torch.tensor([i], dtype=torch.float32) for i in range(env.action_space.n)]


# Replay memory counter, counter for target net update
next_i = 0
c = 0

# Train, test, visualise
for i_episode in range(EPISODE_COUNT):
    observation = torch.Tensor(env.reset())
    eps = max(0.01,((EPISODE_COUNT - i_episode)/EPISODE_COUNT)*0.99)
    for t in range(EPISODE_LENGTH):


        # With porbability epsilon select a random action a_t
        # otherwise select a_t = argmax_a Q(x_t, a, net parameters)
        if random.random() < eps:
            action = env.action_space.sample()
        else:
            action = argmax_action(q, action_values, observation)
        
        # Execute action a_t in emulator and observe reward r_t and observation x_t+1
        new_observation, reward, done, info = env.step(action)
        new_observation = torch.Tensor(new_observation)
        # Store transition (x_t, a_t, r_t, x_t+1)
        transition_block = empty_transition_block()
        transition_block[0] = action
        transition_block[1:OBS_SIZE+1] = observation
        transition_block[OBS_SIZE + 1: 2*OBS_SIZE + 1] = new_observation
        transition_block[2*OBS_SIZE + 1] = reward
        transition_block[2*OBS_SIZE + 2] = 0 if (done or t == EPISODE_LENGTH -1) else 1
        transitions[next_i] = transition_block
        observation = new_observation
        next_i += 1
        c += 1
        
        # If enough memory stored
        if next_i > 500:
            
            # Sample random minibatch of transitions (x_j, a_j, r_j, x_j+1)
            batch = transitions[torch.randperm(next_i)[:16]]

            # Set targets
            ys = torch.zeros(batch.size()[0],1)
            for i in range(batch.size()[0]):
                d = batch[i][2*OBS_SIZE + 2]
                r = batch[i][2*OBS_SIZE + 1]
                if d == 0:
                    ys[i][0] = r
                else:
                    ys[i][0] = r + gamma*max_action(q_, action_values, batch[i][OBS_SIZE + 1: 2*OBS_SIZE + 1])        


            # Perform a gradient descent step on (y_j - Q(x_j, a_j))^2
            optimizer.zero_grad()
            outputs = q.forward(batch[:,:OBS_SIZE + 1])
            loss = criterion(outputs, ys)
            loss.backward()
            optimizer.step()

            # Reset q_ = q
            if c > 200:
                transfer_state(q, q_)
                c = 0
        
        if done:
            break
        
    if i_episode % 50 == 0:
        allrewards = 0
        count = 100
        for j in range(count):
            observation = torch.Tensor(env.reset())
            rewards = 0
            for t in range(EPISODE_LENGTH):
                # Uncomment for rendering:
                # env.render()
                
                action = argmax_action(q, action_values, observation)
                observation, reward, done, info = env.step(action)
                rewards += reward
                observation = torch.Tensor(observation)
                if done:
                    break
            allrewards  += rewards/count
        print("trained for %d episodes, current avg test reward: %d" % (i_episode, allrewards))
        if allrewards > 195:
            torch.save(q.state_dict(), "CPDQNSolved")
            print("Cartpole solved, saved net")
            break
    
env.close()