In [0]:
import torch
import torch.nn as nn
import torch.optim as opt
import torch.nn.functional as F

import math

import gym

import random

import numpy as np

import collections
from collections import namedtuple

IN_SIZE = 4
HIDDEN = 64
OUT_SIZE = 2
BATCH = 256

LR = 0.0005
GAMMA = 0.95

MEM_SIZE = 3000

EPS_START = 0.9
EPS_END = 0.01
EPS_DECAY = 200

EPOCH = 1000

TARGET_UPDATE = 10

NEG_REW = -1000

In [0]:
Replay = namedtuple("Replay", ("obs", "action", "next_obs", "rew"))

class Memory():
    def __init__(self, size):
        self.saved = collections.deque(maxlen=size)
        
    def save(self, data):
        self.saved.append(data)
    
    def sample(self, size):
        if len(self.saved) < size:
            return None
        return random.sample(self.saved, size)
    
    def __len__(self):
        return len(self.saved)

In [0]:
def get_action(env, net, obs, x):
    eps = EPS_END + (EPS_START-EPS_END) * math.exp(-1*x/EPS_DECAY)
    
    if random.random() < eps:
        return env.action_space.sample()
    else:
        with torch.no_grad():
            output = net(torch.FloatTensor(obs))
            return output.numpy().argmax()

In [0]:
class QNet(nn.Module):
    def __init__(self, i= IN_SIZE, h = HIDDEN, o = OUT_SIZE):
        super(QNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(i,h),
            nn.ReLU(),
            nn.Linear(h,o)
        )
        
    def forward(self, x):
        return self.net(x)

env = gym.make("CartPole-v1")
env._max_episode_steps = 50000

policy_net = QNet()
target_net = QNet()
target_net.eval()
target_net.load_state_dict(policy_net.state_dict())
criterion = nn.MSELoss()
optim = opt.Adam(policy_net.parameters(), lr = LR)
memory = Memory(MEM_SIZE)

In [0]:
if __name__ == "__main__":
    for i in range(EPOCH):
        obs = env.reset()
        count = 0
        loss_sum = 0
        while True:
            act = get_action(env, policy_net, obs, i)
            next_obs, rew, done, _ = env.step(act)
            
            if done:
                rew = NEG_REW
            memory.save(Replay(obs, act, next_obs, rew))
            
            sample = memory.sample(BATCH)
            if sample:
                sample = Replay(*zip(*sample))
                act_ = torch.LongTensor(sample.action).reshape(-1,1)
                Qpred = policy_net(torch.FloatTensor(sample.obs)).gather(1, act_)
                Qnext = target_net(torch.FloatTensor(sample.next_obs)).max(1)[0].detach()
                
                Q = Qnext * GAMMA + torch.FloatTensor(sample.rew)
                
                #loss = F.smooth_l1_loss(Qpred, Q)
                loss = criterion(Qpred, Q.unsqueeze(1))

                optim.zero_grad()
                loss.backward()
                optim.step()
                
                loss_sum+=loss
            
            count += 1
            obs = next_obs
            if done:
                print("epoch %d, %d %.4f"%(i, count, loss_sum/count))
                break
            
        if i%TARGET_UPDATE == 0:
            target_net.load_state_dict(policy_net.state_dict())

epoch 0, 17 0.0000
epoch 1, 18 0.0000
epoch 2, 14 0.0000
epoch 3, 51 0.0000
epoch 4, 13 0.0000
epoch 5, 20 0.0000
epoch 6, 26 0.0000
epoch 7, 23 0.0000
epoch 8, 16 0.0000
epoch 9, 16 0.0000
epoch 10, 14 0.0000
epoch 11, 15 0.0000
epoch 12, 16 11962.3604
epoch 13, 18 48816.2656
epoch 14, 52 45933.3828
epoch 15, 26 44247.6445
epoch 16, 18 42008.8281
epoch 17, 37 42003.0508
epoch 18, 76 38676.8633
epoch 19, 51 33913.4062
epoch 20, 49 34990.9180
epoch 21, 21 38442.7070
epoch 22, 26 33697.0469
epoch 23, 13 33213.0273
epoch 24, 20 33036.2656
epoch 25, 16 37341.9922
epoch 26, 12 32303.9707
epoch 27, 20 41058.7461
epoch 28, 28 38246.6133
epoch 29, 16 34987.1250
epoch 30, 21 37050.9180
epoch 31, 49 37313.2578
epoch 32, 57 37073.7578
epoch 33, 40 39163.0391
epoch 34, 39 33028.1406
epoch 35, 33 36691.3906
epoch 36, 45 35540.9219
epoch 37, 32 34135.6016
epoch 38, 23 32905.1289
epoch 39, 62 33673.9414
epoch 40, 62 31697.1914
epoch 41, 13 35517.0430
epoch 42, 22 36595.8594
epoch 43, 43 33362.9609
ep