In [1]:
import torch
import torch.nn as nn
import torch.optim as opt
import torch.nn.functional as F

import math

import gym

import random

import numpy as np

import collections
from collections import namedtuple

IN_SIZE = 4
HIDDEN = 64
OUT_SIZE = 2
BATCH = 64

LR = 0.0005
GAMMA = 0.9

MEM_SIZE = 1000

EPS_START = 0.9
EPS_END = 0.01
EPS_DECAY = 200

EPOCH = 5000

TARGET_UPDATE = 5

NEG_REW = -1000

In [2]:
Replay = namedtuple("Replay", ("obs", "action", "next_obs", "rew"))

class Memory():
    def __init__(self, size):
        self.saved = collections.deque(maxlen=size)
        
    def save(self, data):
        self.saved.append(data)
    
    def sample(self, size):
        if len(self.saved) < size:
            return None
        return random.sample(self.saved, size)
    
    def __len__(self):
        return len(self.saved)

In [3]:
def get_action(env, net, obs, x):
    eps = EPS_END + (EPS_START-EPS_END) * math.exp(-1*x/EPS_DECAY)
    
    if random.random() < eps:
        return env.action_space.sample()
    else:
        with torch.no_grad():
            output = net(torch.FloatTensor(obs))
            return output.numpy().argmax()

In [4]:
class QNet(nn.Module):
    def __init__(self, i= IN_SIZE, h = HIDDEN, o = OUT_SIZE):
        super(QNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(i,h),
            nn.ReLU(),
            nn.Linear(h,o)
        )
        
    def forward(self, x):
        return self.net(x)

env = gym.make("CartPole-v1")
    
policy_net = QNet()
target_net = QNet()
target_net.eval()
target_net.load_state_dict(policy_net.state_dict())
#criterion = nn.MSELoss()
optim = opt.Adam(policy_net.parameters(), lr = LR)
memory = Memory(MEM_SIZE)

In [None]:
if __name__ == "__main__":
    for i in range(EPOCH):
        obs = env.reset()
        count = 0
        loss_sum = 0
        while True:
            act = get_action(env, policy_net, obs, i)
            next_obs, rew, done, _ = env.step(act)
            
            if done:
                rew = NEG_REW
            memory.save(Replay(obs, act, next_obs, rew))
            
            sample = memory.sample(BATCH)
            if sample:
                sample = Replay(*zip(*sample))
                act_ = torch.LongTensor(sample.action).reshape(-1,1)
                Qpred = policy_net(torch.FloatTensor(sample.obs)).gather(1, act_)
                Qnext = target_net(torch.FloatTensor(sample.next_obs)).max(1)[0].detach()
                
                Q = Qnext * GAMMA + torch.FloatTensor(sample.rew)
                
                loss = F.smooth_l1_loss(Qpred, Q)
                optim.zero_grad()
                loss.backward()
                optim.step()
                
                loss_sum+=loss
            
            count += 1
            obs = next_obs
            if done:
                print("epoch %d, %d %.4f"%(i, count, loss_sum/count))
                break
            
        if i%TARGET_UPDATE == 0:
            print("!")
            target_net.load_state_dict(policy_net.state_dict())

epoch 0, 16 0.0000
!
epoch 1, 27 0.0000
epoch 2, 13 0.0000




epoch 3, 22 30.8944
epoch 4, 26 43.6862
epoch 5, 13 44.8072
!
epoch 6, 25 44.8379
epoch 7, 23 38.4017
epoch 8, 17 46.2282
epoch 9, 28 47.0809
epoch 10, 21 51.4917
!
epoch 11, 12 53.8229
epoch 12, 16 48.2204
epoch 13, 13 42.3655
epoch 14, 15 58.5683
epoch 15, 14 49.3104
!
epoch 16, 14 50.6348
epoch 17, 10 64.3915
epoch 18, 20 51.0601
epoch 19, 9 64.4626
epoch 20, 12 57.5154
!
epoch 21, 11 51.5423
epoch 22, 17 57.3137
epoch 23, 30 63.7796
epoch 24, 22 54.8643
epoch 25, 10 48.5864
!
epoch 26, 31 53.1747
epoch 27, 16 41.1626
epoch 28, 17 53.4150
epoch 29, 14 50.3158
epoch 30, 14 55.8776
!
epoch 31, 23 60.6416
epoch 32, 19 60.1247
epoch 33, 25 60.6907
epoch 34, 19 50.2067
epoch 35, 12 43.0069
!
epoch 36, 33 56.4482
epoch 37, 22 48.3389
epoch 38, 15 62.5345
epoch 39, 15 55.2420
epoch 40, 12 45.6023
!
epoch 41, 24 47.6206
epoch 42, 14 51.3791
epoch 43, 12 54.7172
epoch 44, 19 49.3688
epoch 45, 26 57.1192
!
epoch 46, 40 54.3534
epoch 47, 33 53.5291
epoch 48, 22 63.2381
epoch 49, 13 52.9100
epo

epoch 374, 10 62.5284
epoch 375, 13 46.8975
!
epoch 376, 31 52.9487
epoch 377, 9 55.5821
epoch 378, 19 53.4787
epoch 379, 14 66.9957
epoch 380, 11 48.3180
!
epoch 381, 44 55.0675
epoch 382, 20 59.4012
epoch 383, 23 57.0909
epoch 384, 27 52.6859
epoch 385, 20 56.2753
!
epoch 386, 22 46.8962
epoch 387, 27 47.4755
epoch 388, 38 46.8962
epoch 389, 18 46.8962
epoch 390, 10 59.4021
!
epoch 391, 9 59.0537
epoch 392, 16 60.5737
epoch 393, 11 59.6858
epoch 394, 16 47.8731
epoch 395, 9 55.5802
!
epoch 396, 14 53.5961
epoch 397, 14 41.3144
epoch 398, 28 54.1545
epoch 399, 24 53.4094
epoch 400, 21 58.0617
!
epoch 401, 26 52.9087
epoch 402, 51 51.1874
epoch 403, 9 60.7910
epoch 404, 17 34.0230
epoch 405, 26 48.7003
!
epoch 406, 19 55.9464
epoch 407, 20 50.0229
epoch 408, 26 45.0930
epoch 409, 38 46.8966
epoch 410, 23 47.5762
!
epoch 411, 14 48.0128
epoch 412, 17 48.7350
epoch 413, 17 46.8965
epoch 414, 35 50.9160
epoch 415, 26 45.6940
!
epoch 416, 10 50.0224
epoch 417, 37 53.6557
epoch 418, 23 46.2

epoch 741, 20 64.8726
epoch 742, 24 54.0608
epoch 743, 14 66.9942
epoch 744, 11 61.1067
epoch 745, 39 60.9248
!
epoch 746, 13 52.9095
epoch 747, 9 57.3181
epoch 748, 17 60.6899
epoch 749, 28 65.3200


In [None]:
memory.sample(10)