In [1]:
import gym
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

import random

HIDDEN = 128
PERCENTILE= 30
BATCH_SIZE = 100
DISCOUNT = 0.95
DECAY = 0.99

In [2]:
class OneHotObsWrapper(gym.ObservationWrapper):
    def __init__(self,env):
        super(OneHotObsWrapper, self).__init__(env)
        self.obs_size = env.observation_space.n
        self.observation_space = gym.spaces.Box(low = 0.0, high = 1.0, shape=(self.obs_size,),dtype=np.float32)
        
    def observation(self, obs):
        res = np.zeros((self.obs_size,))
        res[obs] = 1.0
        return res

In [3]:
class Net(nn.Module):
    def __init__(self, in_size, hidden, out_size):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_size, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_size)
        )
    def forward(self, x):
        return self.net(x)

In [4]:
def get_batch(env, net, batch_size):
    batch = []
    steps = []
    reward = 0
    sm = nn.Softmax(dim = 1)
    obs = env.reset()
    while True:
        obs_ = torch.FloatTensor([obs])
        action_ = sm(net(obs_)).data.numpy()[0]
        action = np.random.choice(len(action_), p=action_)
        next_obs, rew, done, _ = env.step(action)
        reward+=rew
        steps.append((obs, action))
        if done:
            batch.append((rew*(DISCOUNT**len(steps)), steps))
            steps=[]
            reward=0
            next_obs = env.reset()
            if len(batch) == batch_size:
                yield batch
                batch = []
        obs= next_obs

In [44]:
def filter_batch(batch, saved):
    mean = np.array(list(map(lambda ep: ep[0],batch)))
    mean = mean[mean>0].shape[0]
    
    batch.extend(saved)
    rews= list(map(lambda ep: ep[0],batch))
    threshold = rews[np.argsort(rews)[-PERCENTILE]]
    
    acts = []
    obss = []
    top_batch = []
    
    for episode in batch:
        if episode[0] == 0.0  or episode[0] < threshold:
            continue
        obss.extend(map(lambda step:step[0], episode[1]))
        acts.extend(map(lambda step:step[1], episode[1]))
        ep = (episode[0]*DECAY, episode[1])
        top_batch.append(ep)
    
    return torch.FloatTensor(obss), torch.LongTensor(acts), top_batch, threshold, mean

In [55]:
if __name__ == "__main__":
    random.seed(12345)
    env = gym.envs.toy_text.frozen_lake.FrozenLakeEnv(is_slippery=False)
    env = gym.wrappers.TimeLimit(env, max_episode_steps=100)
    env = OneHotObsWrapper(env)
#     env = OneHotObsWrapper(gym.make("FrozenLake-v0"))
    net = Net(env.observation_space.shape[0], HIDDEN, env.action_space.n)
    CELoss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr= 0.01)
    
    saved = []
    for i, batch in enumerate(get_batch(env, net, BATCH_SIZE)):
        obss, acts, saved, threshold, mean = filter_batch(batch, saved)
        if not obss.shape[0] > 0:
            continue
        optimizer.zero_grad()
        action_pred = net(obss)
        loss = CELoss(action_pred, acts)
        loss.backward()
        optimizer.step()
    
        print("%d saved = %d, loss = %.4f, threshold = %.4f, mean = %.4f"%(i, len(saved), loss, threshold, mean))
        
        if mean > 99:
            print("Solved")
            break

1 saved = 1, loss = 1.3884, threshold = 0.0000, mean = 1.0000
2 saved = 2, loss = 1.3715, threshold = 0.0000, mean = 1.0000
3 saved = 3, loss = 1.3427, threshold = 0.0000, mean = 1.0000
4 saved = 4, loss = 1.3122, threshold = 0.0000, mean = 1.0000
5 saved = 9, loss = 1.2818, threshold = 0.0000, mean = 5.0000
6 saved = 15, loss = 1.2470, threshold = 0.0000, mean = 6.0000
7 saved = 22, loss = 1.2071, threshold = 0.0000, mean = 7.0000
8 saved = 30, loss = 1.1590, threshold = 0.4401, mean = 12.0000
9 saved = 31, loss = 1.0171, threshold = 0.5987, mean = 22.0000
10 saved = 33, loss = 0.8744, threshold = 0.6568, mean = 28.0000
11 saved = 33, loss = 0.7705, threshold = 0.6634, mean = 31.0000
12 saved = 35, loss = 0.6125, threshold = 0.6983, mean = 38.0000
13 saved = 32, loss = 0.4907, threshold = 0.7205, mean = 38.0000
14 saved = 36, loss = 0.4144, threshold = 0.7277, mean = 54.0000
15 saved = 50, loss = 0.3552, threshold = 0.7277, mean = 54.0000
16 saved = 35, loss = 0.2981, threshold = 0.73

In [56]:
obs = env.reset()
sm = nn.Softmax(dim = 1)
while True:
    obs_ = torch.FloatTensor([obs])
    action_ = sm(net(obs_)).data.numpy()[0]
    action = np.random.choice(len(action_), p=action_)
    next_obs, rew, done, _ = env.step(action)
    env.render()
    if done:
        break
    obs = next_obs

  (Down)
SFFF
FHFH
FFFH
HFFG
  (Down)
SFFF
FHFH
FFFH
HFFG
  (Right)
SFFF
FHFH
FFFH
HFFG
  (Down)
SFFF
FHFH
FFFH
HFFG
  (Right)
SFFF
FHFH
FFFH
HFFG
  (Right)
SFFF
FHFH
FFFH
HFFG
