In [1]:
import gym

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np

import random
import math

import collections
from collections import namedtuple

In [2]:
stepInfo = namedtuple('StepInfo',('cur_state', 'action','next_state', 'reward',))

class ReplayList():
    def __init__(self, size):
        self.memory = collections.deque(maxlen = size)
        
    def append(self, data):
        type(self.memory)
        self.memory.append(data)
        
    def sample(self, size):        
        if len(self.memory) < size:
            return None
        else:
            return random.sample(self.memory, size)
    
    def __len__(self):
        return len(self.memory)
    
    def __repr__(self):
        return print(self.memory)

In [3]:
class QNet(nn.Module):
    def __init__(self, in_size, hidden_size, out_size):
        super(QNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.ReLU(),
#             nn.Linear(hidden_size, hidden_size),
#             nn.ReLU(),
            nn.Linear(hidden_size, out_size)
        )
        
    def forward(self,x):
        return self.net(x)

In [4]:
def getOneHotTensor(pos, size):
    ret = np.zeros((1,size))
    ret[0][pos] = 1
    return torch.FloatTensor(ret)#.cuda()

In [5]:
HIDDEN_SIZE = 128
LR = 0.0002
EPOCH = 2000

MEM_SIZE = 3000

GAMMA = 0.95
EPS_START = .9
EPS_END = 0.005
EPS_DECAY = 200

BATCH = 128
TARGET_UPDATE = 10

env = gym.make("FrozenLake-v0")

policyNet = QNet(env.observation_space.n, HIDDEN_SIZE, env.action_space.n)#.cuda()
targetNet = QNet(env.observation_space.n, HIDDEN_SIZE, env.action_space.n)#.cuda()
targetNet.load_state_dict(policyNet.state_dict())
targetNet.eval()

memory = ReplayList(MEM_SIZE)

done_size = 0

optimizer = optim.Adam(policyNet.parameters(), lr = LR)
loss = nn.MSELoss(reduction="sum")

In [6]:
if __name__ == "__main__":
    obs = env.reset()
    loss_list = []
    step_count = 0
    cur_target = 0
    target_list = []
    for i in range(EPOCH):
        if i%TARGET_UPDATE == 0:
            if cur_target/TARGET_UPDATE >= 0.5:
                target_list.append((i,cur_target,targetNet.state_dict()))
            else:
                targetNet.load_state_dict(policyNet.state_dict())
            cur_target = 0
            
            print(i)
        
        while True:
            eps = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * step_count / EPS_DECAY)
            if random.random()< eps:
                act = env.action_space.sample()
            else:
                with torch.no_grad():
                    Q = policyNet(getOneHotTensor(obs, env.observation_space.n))
                    act = Q.data.numpy().argmax()
            
            next_obs, rew, done, _ = env.step(act)
            step_count += 1
            if done:
                rew = rew*4 - 2
            info = stepInfo(obs, act, next_obs, rew)
            
            memory.append(info)
            replay = memory.sample(BATCH)
            
            if replay:
                replay = stepInfo(*zip(*replay))
                
                states = torch.cat(list(map(lambda s:getOneHotTensor(s,env.observation_space.n),replay.cur_state)))
                next_states = torch.cat(list(map(lambda s:getOneHotTensor(s,env.observation_space.n),replay.next_state)))
                actions = torch.cat(list(map(lambda s:getOneHotTensor(s,env.action_space.n),replay.action)))
                rewards = torch.FloatTensor(replay.reward).view(-1,1)#.cuda()
                
                optimizer.zero_grad()
                Qpred = policyNet(states)
                Qtarget = (actions * rewards + GAMMA*targetNet(next_states)).detach()
                
                loss_ = F.smooth_l1_loss(Qpred, Qtarget)
                loss_.backward()
                loss_list.append(loss_.data)
                optimizer.step()
            
            obs = next_obs
            if done:
                if rew > 0:
                    done_size += 1
                    cur_target += 1
                    print("%d, steps = %d, loss = %.4f"%(i,step_count, loss_.data), "Cleared")
                obs = env.reset()
                break       

0
10
20
30
40
50
60
64, steps = 668, loss = 0.0301 Cleared
69, steps = 850, loss = 0.0409 Cleared
70
80
90
92, steps = 2672, loss = 0.0063 Cleared
100
110
120
121, steps = 4192, loss = 0.0086 Cleared
130
140
146, steps = 4483, loss = 0.0142 Cleared
147, steps = 4490, loss = 0.0154 Cleared
150
156, steps = 4558, loss = 0.0028 Cleared
157, steps = 4565, loss = 0.0032 Cleared
160
165, steps = 4638, loss = 0.0119 Cleared
169, steps = 4675, loss = 0.0031 Cleared
170
180
190
195, steps = 4812, loss = 0.0167 Cleared
200
201, steps = 4905, loss = 0.0058 Cleared
210
216, steps = 5009, loss = 0.0281 Cleared
220
225, steps = 5139, loss = 0.0171 Cleared
230
230, steps = 5268, loss = 0.0132 Cleared
234, steps = 5330, loss = 0.0293 Cleared
240
250
252, steps = 5628, loss = 0.0146 Cleared
260
270
271, steps = 6119, loss = 0.0268 Cleared
274, steps = 6193, loss = 0.0189 Cleared
280
290
294, steps = 6700, loss = 0.0388 Cleared
296, steps = 6821, loss = 0.0332 Cleared
298, steps = 6882, loss = 0.0242 Cl

1400, steps = 29178, loss = 0.0178 Cleared
1410
1410, steps = 29626, loss = 0.0225 Cleared
1420
1430
1440
1450
1457, steps = 30938, loss = 0.0123 Cleared
1460
1470
1470, steps = 31098, loss = 0.0171 Cleared
1474, steps = 31170, loss = 0.0073 Cleared
1477, steps = 31253, loss = 0.0183 Cleared
1478, steps = 31269, loss = 0.0241 Cleared
1479, steps = 31281, loss = 0.0187 Cleared
1480
1482, steps = 31369, loss = 0.0174 Cleared
1490
1500
1503, steps = 31755, loss = 0.0172 Cleared
1506, steps = 31931, loss = 0.0128 Cleared
1507, steps = 31970, loss = 0.0070 Cleared
1509, steps = 32004, loss = 0.0177 Cleared
1510
1520
1521, steps = 32220, loss = 0.0142 Cleared
1530
1535, steps = 32448, loss = 0.0237 Cleared
1539, steps = 32511, loss = 0.0145 Cleared
1540
1546, steps = 32618, loss = 0.0198 Cleared
1550
1552, steps = 32674, loss = 0.0139 Cleared
1554, steps = 32699, loss = 0.0200 Cleared
1560
1570
1570, steps = 32917, loss = 0.0170 Cleared
1580
1590
1600
1601, steps = 33292, loss = 0.0210 Clear

In [7]:
for i in range(16):
    with torch.no_grad():
        print(policyNet(getOneHotTensor(i,16)).max(1)[1])

tensor([0])
tensor([3])
tensor([3])
tensor([3])
tensor([3])
tensor([3])
tensor([3])
tensor([1])
tensor([3])
tensor([3])
tensor([2])
tensor([2])
tensor([0])
tensor([2])
tensor([2])
tensor([3])


In [8]:
for i in range(16):
    with torch.no_grad():
        print(targetNet(getOneHotTensor(i,16)).max(1)[1])

tensor([0])
tensor([3])
tensor([3])
tensor([3])
tensor([0])
tensor([3])
tensor([3])
tensor([1])
tensor([3])
tensor([2])
tensor([2])
tensor([2])
tensor([0])
tensor([2])
tensor([2])
tensor([3])


In [9]:
done_size

270