In [0]:
import torch
import torch.nn as nn
import torch.optim as opt
import torch.nn.functional as F

import math

import gym

import random

import numpy as np

import collections
from collections import namedtuple

IN_SIZE = 4
HIDDEN = 64
OUT_SIZE = 2
BATCH = 128

LR = 0.01
GAMMA = 0.9

MEM_SIZE = 1000

EPS_START = 0.9
EPS_END = 0.01
EPS_DECAY = 200

EPOCH = 500

TARGET_UPDATE = 10

NEG_REW = -10

In [0]:
Replay = namedtuple("Replay", ("obs", "action", "next_obs", "rew"))

class Memory():
    def __init__(self, size):
        self.saved = collections.deque(maxlen=size)
        
    def save(self, data):
        self.saved.append(data)
    
    def sample(self, size):
        if len(self.saved) < size:
            return None
        return random.sample(self.saved, size)
    
    def __len__(self):
        return len(self.saved)

In [0]:
def get_action(env, net, obs, x):
    eps = EPS_END + (EPS_START-EPS_END) * math.exp(-1*x/EPS_DECAY)
    
    if random.random() < eps:
        return env.action_space.sample()
    else:
        with torch.no_grad():
            output = net(torch.FloatTensor(obs))
            return output.numpy().argmax()

In [0]:
class QNet(nn.Module):
    def __init__(self, i= IN_SIZE, h = HIDDEN, o = OUT_SIZE):
        super(QNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(i,h),
            nn.ReLU(),
            nn.Linear(h,o)
        )
        
    def forward(self, x):
        return self.net(x)

env = gym.make("CartPole-v1")
    
policy_net = QNet()
target_net = QNet()
target_net.eval()
target_net.load_state_dict(policy_net.state_dict())
criterion = nn.MSELoss()
optim = opt.Adam(policy_net.parameters(), lr = LR)
memory = Memory(MEM_SIZE)

In [141]:
if __name__ == "__main__":
    for i in range(EPOCH):
        obs = env.reset()
        count = 0
        loss_sum = 0
        while True:
            act = get_action(env, policy_net, obs, i)
            next_obs, rew, done, _ = env.step(act)
            
            if done:
                rew = NEG_REW
            memory.save(Replay(obs, act, next_obs, rew))
            
            sample = memory.sample(BATCH)
            if sample:
                sample = Replay(*zip(*sample))
                act_ = torch.LongTensor(sample.action).reshape(-1,1)
                Qpred = policy_net(torch.FloatTensor(sample.obs)).gather(1, act_)
                Qnext = target_net(torch.FloatTensor(sample.next_obs)).max(1)[0].detach()
                
                Q = Qnext * GAMMA + torch.FloatTensor(sample.rew)
                
                #loss = F.smooth_l1_loss(Qpred, Q)
                loss = criterion(Qpred, Q)

                optim.zero_grad()
                loss.backward()
                optim.step()

                if i<10 or i>=490:
                  with torch.no_grad():
                    print("bf",loss)
                    Qpred_ = policy_net(torch.FloatTensor(sample.obs)).gather(1, act_)
                    Qnext_ = target_net(torch.FloatTensor(sample.next_obs)).max(1)[0].detach()
                    
                    Q_ = Qnext_ * GAMMA + torch.FloatTensor(sample.rew)
                    
                    loss_ = criterion(Qpred_, Q_)
                    print("af", loss_)
                    print(Qpred_, Q_)
                
                loss_sum+=loss
            
            count += 1
            obs = next_obs
            if done:
                print("epoch %d, %d %.4f"%(i, count, loss_sum/count))
                break
            
        if i%TARGET_UPDATE == 0:
            target_net.load_state_dict(policy_net.state_dict())

epoch 0, 12 0.0000
epoch 1, 13 0.0000
epoch 2, 25 0.0000
epoch 3, 30 0.0000
epoch 4, 30 0.0000
epoch 5, 15 0.0000
bf tensor(5.7154, grad_fn=<MseLossBackward>)
af tensor(5.6002)
tensor([[ 0.2433],
        [ 0.1609],
        [ 0.1745],
        [ 0.1436],
        [-0.0025],
        [ 0.1291],
        [ 0.1413],
        [ 0.1869],
        [ 0.1480],
        [ 0.1985],
        [ 0.2510],
        [ 0.1835],
        [ 0.1794],
        [ 0.2474],
        [ 0.1590],
        [ 0.1563],
        [ 0.3161],
        [ 0.1560],
        [ 0.0188],
        [ 0.1661],
        [ 0.1714],
        [ 0.1236],
        [ 0.1001],
        [ 0.3575],
        [ 0.1646],
        [ 0.1441],
        [ 0.1750],
        [ 0.1483],
        [ 0.1891],
        [ 0.2367],
        [ 0.2893],
        [ 0.1551],
        [ 0.2173],
        [ 0.4481],
        [ 0.1302],
        [ 0.2019],
        [ 0.1537],
        [ 0.2005],
        [ 0.3670],
        [ 0.3437],
        [ 0.0998],
        [ 0.1811],
        [ 0.2217],
      

  return F.mse_loss(input, target, reduction=self.reduction)


af tensor(6.2812)
tensor([[0.5623],
        [0.5525],
        [0.4808],
        [0.5999],
        [0.5693],
        [0.5830],
        [0.5556],
        [0.5612],
        [0.5148],
        [0.5383],
        [0.4795],
        [0.5841],
        [0.5711],
        [0.5652],
        [0.5440],
        [0.5687],
        [0.4888],
        [0.4889],
        [0.5427],
        [0.5766],
        [0.5684],
        [0.5290],
        [0.5660],
        [0.4802],
        [0.4949],
        [0.5574],
        [0.5631],
        [0.5496],
        [0.5437],
        [0.5537],
        [0.4815],
        [0.5024],
        [0.5396],
        [0.5257],
        [0.5943],
        [0.5890],
        [0.5441],
        [0.5769],
        [0.5714],
        [0.5187],
        [0.5415],
        [0.5414],
        [0.5879],
        [0.5811],
        [0.5541],
        [0.5453],
        [0.5658],
        [0.4976],
        [0.4954],
        [0.4788],
        [0.5804],
        [0.4791],
        [0.5334],
        [0.5422],
        [0