In [1]:
import gym, random
import time
from gym.envs.registration import register
register(
    id='FrozenLakeNotSlippery4x4-v0',
    entry_point='gym.envs.toy_text:FrozenLakeEnv',
    kwargs={'map_name' : '4x4', 'is_slippery': False},
)
register(
    id='FrozenLakeNotSlippery8x8-v0',
    entry_point='gym.envs.toy_text:FrozenLakeEnv',
    kwargs={'map_name' : '8x8', 'is_slippery': False},
)

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import random
from itertools import combinations

In [26]:
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor

In [31]:
class SNN(nn.Module):    # Shallow neural network
  def __init__(self, input_size, output_size):
    super(SNN, self).__init__()
    self.l1_linear = nn.Linear(input_size, output_size, bias = False)

  def forward(self,x):
    out = F.sigmoid(self.l1_linear(x))
    return out
    

In [32]:
class TransitionHistory():
  def __init__(self, max_count = 1000):
    self.transitions = []
    self.max_count = max_count
    self.loc_pointer = 0
  
  def clear(self):
    self.transitions = []
    self.loc_pointer = 0
  
  def add_transition(self, s0, a0, r, s1):
    if len(self.transitions) <= self.loc_pointer:
      self.transitions.append(None)
    self.transitions[self.loc_pointer] = (s0, a0, r, s1)
    self.loc_pointer += 1
    if self.loc_pointer >= self.max_count:
      self.loc_pointer %= self.max_count
  
  def get_sample_transition(self, batch_size):
    return random.sample(self.transitions, batch_size)

In [33]:
class SQN(): # Shallow q network
  _gamma = 0.8
  _lambda = 0.5
  _epsilon = 0.2
  transition_history = TransitionHistory()

  def __init__(self, size):
    self.Q = SNN(size, 4)
    if use_cuda:
      self.Q.cuda()
    self.size = size
    self.optimizer = torch.optim.Adagrad(self.Q.parameters(),weight_decay=1e-5)
#     self.optimizer = torch.optim.SGD(self.Q.parameters(),lr=2.0)
    self.loss_fn = nn.MSELoss()

  def index_to_onehot(self, index, total=16):
    onehot = [0] * total
    onehot[index] = 1
    return onehot
  
  def epsilon(self):
    return self._epsilon

  def pick_action(self, state):
    if random.random() < self.epsilon():
      return random.randrange(0,4,1)
    else:
      self.predict(state)
      onehot = self.index_to_onehot(state)
      s = Variable(FloatTensor([onehot]))
      action_value = self.Q(s).data.tolist()
      return action_value.index(max(action_value))

  def predict(self, state):
    onehot = self.index_to_onehot(state)
    s = Variable(FloatTensor([onehot]))
    action_value = self.Q(s).data.tolist()[0]
    return action_value

  def update_Q(self, batch):
#     print(batch)
    (state, action, reward, next_state) = tuple(zip(*batch))
    non_final_mask = torch.ByteTensor(tuple(map(lambda s:s is not None, next_state)))
    non_final_next_states = Variable(FloatTensor([self.index_to_onehot(s) for s in next_state if s is not None]), volatile=True)

    state_batch = Variable(FloatTensor([self.index_to_onehot(s) for s in state]))
    action_batch = Variable(LongTensor([a for a in action]))
    reward_batch = Variable(FloatTensor(reward))
#     print('state_batch:',state_batch)
#     print('self.Q(state_batch):',self.Q(state_batch))
#     print('action_batch:',action_batch)
    state_action_values = self.Q(state_batch).gather(1, action_batch.view(-1,1))
    
    next_state_values = Variable(torch.zeros(len(batch)).type(FloatTensor))
    next_state_values[non_final_mask] = self.Q(non_final_next_states).max(1)[0]
    
    expected_state_action_values = (next_state_values * self._gamma) + reward_batch
    expected_state_action_values = Variable(expected_state_action_values.view(-1,1).data)
    
#     print('state_action_values',state_action_values)
#     print('expected_state_action_values',expected_state_action_values)
    loss = self.loss_fn(state_action_values, expected_state_action_values)
#     print('loss', loss.data[0])
    self.optimizer.zero_grad()
    loss.backward()
    
    # gradient clipping
    for param in self.Q.parameters():
        param.grad.data.clamp_(-1, 1)

    self.optimizer.step()


  def show_policy(self):
    # greedy policy
# SFFF
# FHFH
# FFFH
# HFFG
    action_name_map = ['left','down','right','up','none']
    for i in range(4):
      policy = ""
      for j in range(4):
        if (i * 4 + j) in [5,7,11,12,15]:
          policy += "%2d: %-10s" %(state, 'None')
          continue
        state = i*4+j
        action_value = self.predict(state)
        action_index = action_value.index(max(action_value))
        policy += "%2d: %-10s" %(state, action_name_map[action_index])
      print(policy)

  def train(self, env, episode, batch_per_episode = 20, batch_size = 32):
    for i in range(episode):
      s0 = env.reset()
      a0 = self.pick_action(s0)
      episode_ended = False
      while not episode_ended:
        (s1, reward, episode_ended, info) = env.step(a0)
        if reward > 0.0:
          print('reward:', reward)
          
        if not episode_ended:
          if reward <= 0.0:
            reward = -0.1
          a1 = self.pick_action(s1)
        else:
          if reward <= 0.0:
            reward = -1.0
          a1 = None
          s1 = None
        self.transition_history.add_transition(s0,a0,reward,s1)
        s0 = s1
        a0 = a1
      for i in range(batch_per_episode):
        batch = self.transition_history.get_sample_transition(batch_size)
        self.update_Q(batch)

In [34]:
env = gym.make('FrozenLakeNotSlippery4x4-v0')
agent = SQN(16)

In [36]:
# print(agent.Q.state_dict())
# print('='*50)
agent.show_policy()
for i in range(100):
  agent.train(env,500,20)
  if (i + 1) % 10 == 0:
    print('i:',i)
    agent.show_policy()
#     print(agent.Q.state_dict())
#     print('='*50)

 0: right      1: right      2: down       3: right     
 4: down       4: None       6: up         6: None      
 8: right      9: right     10: down      10: None      
10: None      13: right     14: right     14: None      
i: 9
 0: right      1: right      2: down       3: right     
 4: down       4: None       6: down       6: None      
 8: right      9: right     10: down      10: None      
10: None      13: right     14: right     14: None      
i: 19
 0: right      1: right      2: down       3: right     
 4: down       4: None       6: up         6: None      
 8: right      9: right     10: down      10: None      
10: None      13: right     14: right     14: None      
i: 29
 0: right      1: right      2: right      3: right     
 4: down       4: None       6: up         6: None      
 8: right      9: right     10: down      10: None      
10: None      13: right     14: right     14: None      
i: 39
 0: right      1: right      2: right      3: right     
 4: down