In [1]:
import gym

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np

import random
import math

import collections
from collections import namedtuple

In [2]:
stepInfo = namedtuple('StepInfo',('cur_state', 'action','next_state', 'reward',))

class ReplayList():
    def __init__(self, size):
        self.memory = collections.deque(maxlen = size)
        
    def append(self, data):
        type(self.memory)
        self.memory.append(data)
        
    def sample(self, size):        
        if len(self.memory) < size:
            return None
        else:
            return random.sample(self.memory, size)
    
    def __len__(self):
        return len(self.memory)
    
    def __repr__(self):
        return print(self.memory)

In [3]:
class QNet(nn.Module):
    def __init__(self, in_size, hidden_size, out_size):
        super(QNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.ReLU(),
#             nn.Linear(hidden_size, hidden_size),
#             nn.ReLU(),
            nn.Linear(hidden_size, out_size)
        )
        
    def forward(self,x):
        return self.net(x)

In [4]:
def getOneHotTensor(pos, size):
    ret = np.zeros((1,size))
    ret[0][pos] = 1
    return torch.FloatTensor(ret)#.cuda()

In [5]:
HIDDEN_SIZE = 32
LR = 0.0002
EPOCH = 2000

MEM_SIZE = 3000

GAMMA = 0.95
EPS_START = .9
EPS_END = 0.005
EPS_DECAY = 200

BATCH = 128
TARGET_UPDATE = 10

env = gym.make("FrozenLake-v0")

policyNet = QNet(env.observation_space.n, HIDDEN_SIZE, env.action_space.n)#.cuda()
targetNet = QNet(env.observation_space.n, HIDDEN_SIZE, env.action_space.n)#.cuda()
targetNet.load_state_dict(policyNet.state_dict())
targetNet.eval()

memory = ReplayList(MEM_SIZE)

done_size = 0

optimizer = optim.Adam(policyNet.parameters(), lr = LR)
loss = nn.MSELoss(reduction="sum")

In [6]:
if __name__ == "__main__":
    obs = env.reset()
    loss_list = []
    step_count = 0
    cur_target = 0
    target_list = []
    for i in range(EPOCH):
        if i%TARGET_UPDATE == 0:
            if cur_target/TARGET_UPDATE >= 0.5:
                target_list.append((i,cur_target,targetNet.state_dict()))
            else:
                targetNet.load_state_dict(policyNet.state_dict())
            cur_target = 0
            
            print(i)
        
        while True:
            eps = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * step_count / EPS_DECAY)
            if random.random()< eps:
                act = env.action_space.sample()
            else:
                with torch.no_grad():
                    Q = policyNet(getOneHotTensor(obs, env.observation_space.n))
                    act = Q.data.numpy().argmax()
            
            next_obs, rew, done, _ = env.step(act)
            step_count += 1
            if done:
                rew = rew*4 - 2
            info = stepInfo(obs, act, next_obs, rew)
            
            memory.append(info)
            replay = memory.sample(BATCH)
            
            if replay:
                replay = stepInfo(*zip(*replay))
                
                states = torch.cat(list(map(lambda s:getOneHotTensor(s,env.observation_space.n),replay.cur_state)))
                next_states = torch.cat(list(map(lambda s:getOneHotTensor(s,env.observation_space.n),replay.next_state)))
                actions = torch.cat(list(map(lambda s:getOneHotTensor(s,env.action_space.n),replay.action)))
                rewards = torch.FloatTensor(replay.reward).view(-1,1)#.cuda()
                
                optimizer.zero_grad()
                Qpred = policyNet(states)
                Qtarget = (actions * rewards + GAMMA*targetNet(next_states)).detach()
                
                loss_ = F.smooth_l1_loss(Qpred, Qtarget)
                loss_.backward()
                loss_list.append(loss_.data)
                optimizer.step()
            
            obs = next_obs
            if done:
                if rew > 0:
                    done_size += 1
                    cur_target += 1
                    print("%d, steps = %d, loss = %.4f"%(i,step_count, loss_.data), "Cleared")
                obs = env.reset()
                break       

0
10
20
30
39, steps = 617, loss = 0.0273 Cleared
40
50
60
70
80
90
100
110
120
126, steps = 4618, loss = 0.0035 Cleared
130
140
146, steps = 4733, loss = 0.0125 Cleared
150
160
170
180
190
200
210
213, steps = 5347, loss = 0.0137 Cleared
216, steps = 5443, loss = 0.0225 Cleared
220
222, steps = 5648, loss = 0.0211 Cleared
224, steps = 5689, loss = 0.0231 Cleared
230
240
240, steps = 6271, loss = 0.0240 Cleared
248, steps = 6805, loss = 0.0325 Cleared
249, steps = 6842, loss = 0.0257 Cleared
250
252, steps = 7005, loss = 0.0148 Cleared
257, steps = 7222, loss = 0.0257 Cleared
259, steps = 7328, loss = 0.0242 Cleared
260
261, steps = 7439, loss = 0.0151 Cleared
266, steps = 7700, loss = 0.0156 Cleared
268, steps = 7843, loss = 0.0176 Cleared
270
276, steps = 7964, loss = 0.0090 Cleared
280
280, steps = 8115, loss = 0.0089 Cleared
290
300
310
317, steps = 10413, loss = 0.0004 Cleared
320
324, steps = 10870, loss = 0.0032 Cleared
330
338, steps = 11208, loss = 0.0064 Cleared
340
350
355, 

1840, steps = 45781, loss = 0.0118 Cleared
1847, steps = 45865, loss = 0.0126 Cleared
1850
1859, steps = 46025, loss = 0.0135 Cleared
1860
1870
1880
1890
1900
1908, steps = 46348, loss = 0.0105 Cleared
1910
1914, steps = 46408, loss = 0.0215 Cleared
1920
1924, steps = 46484, loss = 0.0221 Cleared
1930
1940
1950
1950, steps = 46795, loss = 0.0153 Cleared
1951, steps = 46809, loss = 0.0327 Cleared
1952, steps = 46820, loss = 0.0206 Cleared
1954, steps = 46839, loss = 0.0124 Cleared
1955, steps = 46853, loss = 0.0205 Cleared
1960
1970
1980
1990


In [7]:
for i in range(16):
    with torch.no_grad():
        print(policyNet(getOneHotTensor(i,16)).max(1)[1])

tensor([3])
tensor([1])
tensor([3])
tensor([3])
tensor([0])
tensor([3])
tensor([3])
tensor([3])
tensor([3])
tensor([3])
tensor([3])
tensor([3])
tensor([3])
tensor([2])
tensor([2])
tensor([3])


In [8]:
for i in range(16):
    with torch.no_grad():
        print(targetNet(getOneHotTensor(i,16)).max(1)[1])

tensor([3])
tensor([3])
tensor([3])
tensor([3])
tensor([0])
tensor([3])
tensor([3])
tensor([3])
tensor([3])
tensor([3])
tensor([3])
tensor([3])
tensor([3])
tensor([2])
tensor([2])
tensor([3])


In [9]:
done_size

185