In [1]:
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
env = gym.make('CartPole-v1')

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if gpu is to be used
device = torch.device("cpu")

  deprecation(
  deprecation(


In [3]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'terminated'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [4]:
class DQN(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(DQN, self).__init__()
        self.layers = nn.ParameterList()
        self.layers.append( nn.Linear(input_dim, hidden_dim, bias = False))
        
        for i in range(num_layers -1): 
            self.layers.append(nn.Linear(hidden_dim,hidden_dim,False))
        self.layers.append(nn.Linear(hidden_dim, output_dim, False ))

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = torch.tensor(x)
        x = x.to(device)
        for layer in self.layers:
            x = layer(x)
        return x

In [5]:
batchSize = 256
gamma = torch.tensor(0.99)
gamma.to(device)
epsilon = 1
EPS_END = 0.1
EPS_DECAY = 0.999
TARGET_UPDATE = 10

In [6]:
n_actions = env.action_space.n

policy_net = DQN(4, 16, 2, 2).to(device)
target_net = DQN(4, 16, 2, 2).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

DQN(
  (layers): ParameterList(
      (0): Object of type: Linear
      (1): Object of type: Linear
      (2): Object of type: Linear
    (0): Linear(in_features=4, out_features=16, bias=False)
    (1): Linear(in_features=16, out_features=16, bias=False)
    (2): Linear(in_features=16, out_features=2, bias=False)
  )
)

In [7]:
optimizer = optim.RMSprop(policy_net.parameters(), lr=0.0001)
memory = ReplayMemory(1000)

In [8]:
def policy(x, epsilon):
    r = random.random()
    if r < epsilon:
        return torch.tensor(random.choice([0,1]))
    else:
        return x.max(0)[-1]

In [9]:
observation = env.reset()
policy_net(observation)

tensor([0.0048, 0.0088], grad_fn=<SqueezeBackward3>)

In [10]:
def estimatePerformance():  
    counters = []
    epsilon = -1
    trials = 10
    for i in range(trials):
        counter = 0
        terminated = False
        observation = env.reset()
        while not terminated:
            action = policy_net(observation).max(0)[-1].item()
            observation, reward, terminated,  info = env.step(action)
            counter += reward
        counters.append(counter)

    summ =  0
    for i in range(trials):
        summ += counters[i]
    return summ/trials

In [11]:
from copy import deepcopy


def updateTargetNet():
    target_net.load_state_dict(deepcopy(policy_net.state_dict()))
    for param in target_net.parameters():
        param.requires_grad = False

In [12]:
counter = 0
while True:
    observation = env.reset()
    memory.memory = deque([])
    updateTargetNet()
    for i in range(1000):
        
        x = policy_net(observation)
        action = policy(x, epsilon)

        next_state, reward, terminated, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)

        memory.push(observation, action, next_state, reward, terminated)
        observation = next_state

        if terminated:
            observation = env.reset()
    
    transitions = memory.sample(batchSize)

    lossSum = 0
    for transition in transitions:
        observation, action, next_state, reward, terminated = transition
        Q = policy_net(observation)[action]
        if not terminated:
            reward += gamma * torch.max(target_net(next_state))
        loss = (reward  - Q) **2 

        lossSum += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(lossSum/batchSize)
    counter += 1
    epsilon *= EPS_DECAY
    if counter == 1000:
        counter = 0
        print(estimatePerformance())
        


1.0365410828962922
1.0500424291240051
1.0308894489426166
1.0038711030501872
1.0141302738338709
1.031707827700302
1.0285460576415062
1.0302825577091426
1.011889803223312
1.0192569885402918
0.9971798376645893
1.0251846788451076
1.0428422240074724
1.0484467844944447
1.0768452645279467
1.0714365734020248
1.0640810339245945
1.0782502167858183
1.0566187442746013
1.0390002431813627
1.054258632240817
1.0528594081988558
1.0269658197648823
1.046956132631749
1.040708698797971
1.0343933681142516
1.0770933215972036
1.0846043804194778
1.0718664329033345
1.050376680213958
1.0406228852807544
1.0739874604623765
1.064968428981956
1.064720056252554
1.0066599887795746
1.0312593519920483
1.0776775390841067
1.0907018245106883
1.0420221531530842
1.0649891570792533
1.0525080480001634
1.0654351211924222
1.1235920674807858
1.0964448005252052
1.052441248379182
1.0596391810686328
1.0432926667563152
1.0793741340376073
1.060702912130182
1.1027054497226345
1.1237135650299024
1.0767034485470504
1.0671697163052158
1.0

KeyboardInterrupt: 

In [16]:
estimatePerformance()

265.7

In [None]:
(( expected_state_action_values - state_action_values) ** 2)[0]