In [110]:
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

In [111]:
env = gym.make('CartPole-v1')

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if gpu is to be used
device = torch.device("cpu")

In [112]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'terminated'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [113]:
class DQN(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(DQN, self).__init__()
        self.layers = []
        self.layers.append(nn.Linear(input_dim, hidden_dim))
        
        for i in range(num_layers -1): 
            self.layers.append(nn.ReLU())
            self.layers.append(nn.Linear(hidden_dim,hidden_dim))
        self.layers.append(nn.ReLU())
        self.layers.append(nn.Linear(hidden_dim, output_dim ))
        self.layers = nn.Sequential(*self.layers)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        return self.layers(torch.tensor(x, device=device, dtype=torch.float))

In [114]:
batchSize = 128
gamma = torch.tensor(0.99)
gamma.to(device)
epsilon = 1
EPS_END = 0.10
EPS_DECAY = 0.99
TARGET_UPDATE = 100

In [115]:
n_actions = env.action_space.n

policy_net = DQN(4, 128, 2, 2).to(device)
target_net = DQN(4, 128, 2, 2).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

DQN(
  (layers): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=2, bias=True)
  )
)

In [116]:
optimizer = optim.Adam(policy_net.parameters(), lr=0.0001)
memory = ReplayMemory(100000)
mse = torch.nn.MSELoss()

In [117]:
def policy(x, epsilon):
    if torch.rand(1) < epsilon:
        return torch.tensor(random.choice([0,1]))
    else:
        return x.max(0)[-1]

In [118]:
observation = env.reset()
policy_net(observation)

tensor([ 0.0371, -0.0433], grad_fn=<AddBackward0>)

In [119]:
def estimatePerformance():  
    counters = []
    epsilon = -1
    trials = 10
    for i in range(trials):
        counter = 0
        terminated = False
        observation = env.reset()
        while not terminated:
            action = policy_net(observation).max(0)[-1].item()
            observation, reward, terminated,  info = env.step(action)
            counter += reward
        counters.append(counter)

    summ =  0
    for i in range(trials):
        summ += counters[i]
    return summ/trials

In [120]:
from copy import deepcopy
def updateTargetNet():
    target_net = deepcopy(policy_net)
    for param in target_net.parameters():
        param.requires_grad = False
    return target_net

In [None]:
counter = 0
epi_count = 0
cum_reward = 0.0
upd_count = 0
loss = "untrained"
terminated = False
env.reset()

In [122]:
while True:
    if counter % TARGET_UPDATE == 0:
        target_net = updateTargetNet()
    while not terminated:
        with torch.no_grad():
            x = policy_net(observation)
        action = policy(x, epsilon)
        next_state, reward, terminated, _ = env.step(action.item())
        reward = 1
        cum_reward += reward
        memory.push(observation, action, next_state, reward, terminated)
        observation = (next_state)
    observation = env.reset()
    terminated = False
    epi_count += 1
    if len(memory) >= 10*batchSize:
        transitions = memory.sample(batchSize)
        tup = np.stack([transition for transition in transitions])
        obs = np.stack(tup[:, 0])
        action = np.stack(tup[:, 1])
        obs_next = np.stack(tup[:, 2])
        rew = torch.tensor(np.stack(tup[:, 3]), dtype=torch.float)
        term = torch.tensor(np.stack(tup[:, 4]), dtype=torch.float)
        
        with torch.no_grad():
            q_prime = target_net(obs_next).max(dim=1)[0]
        target =rew + gamma * q_prime * (1 - term)

        q_pred = policy_net(obs)[np.arange(obs.shape[0]), action]
        loss = mse(q_pred, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        upd_count += 1
    if len(memory) >=batchSize and (counter % 200 == 0):
        epsilon = max(EPS_END, epsilon*EPS_DECAY)
    if counter %100 == 0:
        print(f"Episode={epi_count}, reward={estimatePerformance()}, loss={loss}, epsilon={epsilon} iter={upd_count}, example _ prediction = {q_pred[0]}")
    counter += 1
    cum_reward = 0.0

Episode=23602, reward=482.8, loss=1.6255673170089722, epsilon=0.3024044356690215 iter=23549, example _ prediction = 128.5148468017578
Episode=23702, reward=497.1, loss=1.185368537902832, epsilon=0.3024044356690215 iter=23649, example _ prediction = 137.26695251464844


KeyboardInterrupt: 

In [None]:
transitions = memory.sample(30)
tup = np.stack([transition for transition in transitions])
obs = np.stack(tup[:, 0])
action = np.stack(tup[:, 1])
obs_next = np.stack(tup[:, 2])
rew = torch.tensor(np.stack(tup[:, 3]), dtype=torch.float)
term = torch.tensor(np.stack(tup[:, 4]), dtype=torch.float)

  arrays = [asanyarray(arr) for arr in arrays]


In [None]:
policy_net(obs)[np.arange(obs.shape[0]), action]

tensor([1.1189, 1.0976, 1.0937, 1.0713, 1.1059, 1.1081, 1.1093, 1.1100, 1.1089,
        1.0933, 1.1081, 1.0922, 1.1146, 1.1010, 1.1109, 1.0913, 1.1187, 1.0939,
        0.9757, 1.1005, 1.1076, 1.1043, 1.0712, 1.1040, 1.1058, 1.1105, 1.1048,
        1.0744, 1.0857, 1.1042, 1.1158, 1.1101, 1.0931, 1.1075, 1.0972, 1.1184,
        1.0836, 1.1012, 1.0906, 1.0899, 1.1055, 1.1034, 1.1147, 1.0958, 1.1247,
        1.1118, 1.1213, 1.1173, 1.1009, 1.1073, 1.0944, 1.1074, 1.0753, 1.0902,
        1.1072, 1.1060, 1.1200, 1.0952, 1.1061, 1.0688, 1.1205, 1.1081, 1.1003,
        1.0912, 1.1035, 1.1201, 1.1243, 1.1024, 1.0882, 1.0935, 1.1133, 1.0798,
        1.1008, 1.1278, 1.0944, 1.1035, 1.1089, 1.1053, 1.1141, 1.0422, 1.0914,
        1.1218, 1.0951, 1.1047, 1.0939, 1.0955, 1.1088, 1.1183, 1.0966, 1.1084,
        1.0968, 1.1060, 1.0891, 1.0951, 1.0887, 1.0677, 1.1053, 1.1069, 1.1306,
        1.1346, 1.0942, 1.1171, 1.1053, 1.0948, 1.1015, 1.0895, 1.1037, 1.0915,
        1.1215, 1.0675, 1.0847, 1.1172, 

In [None]:
action

1

In [None]:
 policy_net(obs[4])

tensor([1.0935, 1.1059], grad_fn=<AddBackward0>)

In [None]:
action= np.stack(tup[:, 1])[4]