In [0]:
import torch
import torch.nn as nn
import torch.optim as opt
import torch.nn.functional as F

import math

import gym

import random

import numpy as np

import collections
from collections import namedtuple

IN_SIZE = 4
HIDDEN = 64
OUT_SIZE = 2
BATCH = 256

LR = 0.0005
GAMMA = 0.95

MEM_SIZE = 3000

EPS_START = 0.9
EPS_END = 0.01
EPS_DECAY = 200

EPOCH = 1000

TARGET_UPDATE = 10

NEG_REW = -1000

In [0]:
Replay = namedtuple("Replay", ("obs", "action", "next_obs", "rew"))

class Memory():
    def __init__(self, size):
        self.saved = collections.deque(maxlen=size)
        
    def save(self, data):
        self.saved.append(data)
    
    def sample(self, size):
        if len(self.saved) < size:
            return None
        return random.sample(self.saved, size)
    
    def __len__(self):
        return len(self.saved)

In [0]:
def get_action(env, net, obs, x):
    eps = EPS_END + (EPS_START-EPS_END) * math.exp(-1*x/EPS_DECAY)
    
    if random.random() < eps:
        return env.action_space.sample()
    else:
        with torch.no_grad():
            output = net(torch.FloatTensor(obs))
            return output.numpy().argmax()

In [0]:
class QNet(nn.Module):
    def __init__(self, i= IN_SIZE, h = HIDDEN, o = OUT_SIZE):
        super(QNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(i,h),
            nn.ReLU(),
            nn.Linear(h,o)
        )
        
    def forward(self, x):
        return self.net(x)

env = gym.make("CartPole-v1")
env._max_episode_steps = 5000

policy_net = QNet()
target_net = QNet()
target_net.eval()
target_net.load_state_dict(policy_net.state_dict())
criterion = nn.MSELoss()
optim = opt.Adam(policy_net.parameters(), lr = LR)
memory = Memory(MEM_SIZE)

In [0]:
if __name__ == "__main__":
    for i in range(EPOCH):
        obs = env.reset()
        count = 0
        loss_sum = 0
        while True:
            act = get_action(env, policy_net, obs, i)
            next_obs, rew, done, _ = env.step(act)
            
            if done:
                rew = NEG_REW
            memory.save(Replay(obs, act, next_obs, rew))
            
            sample = memory.sample(BATCH)
            if sample:
                sample = Replay(*zip(*sample))
                act_ = torch.LongTensor(sample.action).reshape(-1,1)
                Qpred = policy_net(torch.FloatTensor(sample.obs)).gather(1, act_)
                Qnext = target_net(torch.FloatTensor(sample.next_obs)).max(1)[0].detach()
                
                Q = Qnext * GAMMA + torch.FloatTensor(sample.rew)
                
                #loss = F.smooth_l1_loss(Qpred, Q)
                loss = criterion(Qpred, Q.unsqueeze(1))

                optim.zero_grad()
                loss.backward()
                optim.step()
                
                loss_sum+=loss
            
            count += 1
            obs = next_obs
            if done:
                print("epoch %d, %d %.4f"%(i, count, loss_sum/count))
                break
            
        if i%TARGET_UPDATE == 0:
            target_net.load_state_dict(policy_net.state_dict())

epoch 0, 15 0.0000
epoch 1, 23 0.0000
epoch 2, 42 0.0000
epoch 3, 19 0.0000
epoch 4, 12 0.0000
epoch 5, 31 0.0000
epoch 6, 45 0.0000
epoch 7, 11 0.0000
epoch 8, 20 0.0000
epoch 9, 10 0.0000
epoch 10, 61 20982.2363
epoch 11, 34 35791.7422
epoch 12, 11 39694.9102
epoch 13, 18 39402.1406
epoch 14, 28 39212.9805
epoch 15, 39 38773.6406
epoch 16, 31 37314.0781
epoch 17, 11 39128.4531
epoch 18, 38 37796.7461
epoch 19, 29 37046.2656
epoch 20, 15 37556.4375
epoch 21, 32 36887.5391
epoch 22, 21 36285.8867
epoch 23, 18 38825.0547
epoch 24, 21 39314.9141
epoch 25, 48 37761.4727
epoch 26, 16 37574.5742
epoch 27, 40 37302.5234
epoch 28, 39 36624.1523
epoch 29, 16 30846.3379
epoch 30, 39 33870.9883
epoch 31, 22 37978.4648
epoch 32, 44 34125.8047
epoch 33, 54 36268.5391
epoch 34, 20 33727.2109
epoch 35, 78 33932.4180
epoch 36, 32 31299.1934
epoch 37, 28 30820.0781
epoch 38, 15 28271.1211
epoch 39, 46 31965.5957
epoch 40, 54 30842.0020
epoch 41, 16 31243.0352
epoch 42, 38 31384.6250
epoch 43, 72 33242