In [74]:
import gym
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

import random

HIDDEN = 128
PERCENTILE= 30
BATCH_SIZE = 100
DISCOUNT = 0.95
DECAY = 0.99

In [46]:
class OneHotObsWrapper(gym.ObservationWrapper):
    def __init__(self,env):
        super(OneHotObsWrapper, self).__init__(env)
        self.obs_size = env.observation_space.n
        self.observation_space = gym.spaces.Box(low = 0.0, high = 1.0, shape=(self.obs_size,),dtype=np.float32)
        
    def observation(self, obs):
        res = np.zeros((self.obs_size,))
        res[obs] = 1.0
        return res

In [47]:
class Net(nn.Module):
    def __init__(self, in_size, hidden, out_size):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_size, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_size)
        )
    def forward(self, x):
        return self.net(x)

In [48]:
def get_batch(env, net, batch_size):
    batch = []
    steps = []
    reward = 0
    sm = nn.Softmax(dim = 1)
    obs = env.reset()
    while True:
        obs_ = torch.FloatTensor([obs])
        action_ = sm(net(obs_)).data.numpy()[0]
        action = np.random.choice(len(action_), p=action_)
        next_obs, rew, done, _ = env.step(action)
        reward+=rew
        steps.append((obs, action))
        if done:
            batch.append((rew*(DISCOUNT**len(steps)), steps))
            steps=[]
            reward=0
            next_obs = env.reset()
            if len(batch) == batch_size:
                yield batch
                batch = []
        obs= next_obs

In [85]:
def filter_batch(batch, saved):
    mean = np.array(list(map(lambda ep: ep[0],batch)))
    mean = mean[mean>0].shape[0]
    
    batch.extend(saved)
    rews= list(map(lambda ep: ep[0],batch))
    threshold = rews[np.argsort(rews)[-PERCENTILE]]
    
    acts = []
    obss = []
    top_batch = []
    
    for episode in batch:
        if episode[0] == 0.0  or episode[0] < threshold:
            continue
        obss.extend(map(lambda step:step[0], episode[1]))
        acts.extend(map(lambda step:step[1], episode[1]))
        ep = (episode[0]*DECAY, episode[1])
        top_batch.append(ep)
    
    return torch.FloatTensor(obss), torch.LongTensor(acts), top_batch, threshold, mean

In [88]:
if __name__ == "__main__":
#     random.seed(12345)
#     env = gym.envs.toy_text.frozen_lake.FrozenLakeEnv(is_slippery=False)
#     env = gym.wrappers.TimeLimit(env, max_episode_steps=100)
#     env = OneHotObsWrapper(env)
    env = OneHotObsWrapper(gym.make("FrozenLake-v0"))
    net = Net(env.observation_space.shape[0], HIDDEN, env.action_space.n)
    CELoss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr= 0.01)
    
    saved = []
    for i, batch in enumerate(get_batch(env, net, BATCH_SIZE)):
        obss, acts, saved, threshold, mean = filter_batch(batch, saved)
        optimizer.zero_grad()
        action_pred = net(obss)
        loss = CELoss(action_pred, acts)
        loss.backward()
        optimizer.step()
    
        print("%d saved = %d, loss = %.4f, threshold = %.4f, mean = %.4f"%(i, len(saved), loss, threshold, mean))
        
        if mean > 99:
            print("Solved")
            break

0 saved = 1, loss = 1.3778, threshold = 0.0000, mean = 1.0000
1 saved = 4, loss = 1.3353, threshold = 0.0000, mean = 3.0000
2 saved = 4, loss = 1.2860, threshold = 0.0000, mean = 0.0000
3 saved = 9, loss = 1.2782, threshold = 0.0000, mean = 5.0000
4 saved = 18, loss = 1.2349, threshold = 0.0000, mean = 9.0000
5 saved = 27, loss = 1.2026, threshold = 0.0000, mean = 9.0000
6 saved = 30, loss = 1.1173, threshold = 0.4638, mean = 13.0000
7 saved = 30, loss = 1.0544, threshold = 0.5810, mean = 11.0000
8 saved = 32, loss = 0.9869, threshold = 0.5987, mean = 21.0000
9 saved = 30, loss = 0.8948, threshold = 0.6373, mean = 27.0000
10 saved = 32, loss = 0.8266, threshold = 0.6634, mean = 36.0000
11 saved = 31, loss = 0.6856, threshold = 0.6914, mean = 32.0000
12 saved = 38, loss = 0.6057, threshold = 0.6983, mean = 40.0000
13 saved = 30, loss = 0.5243, threshold = 0.7133, mean = 41.0000
14 saved = 36, loss = 0.4664, threshold = 0.7205, mean = 50.0000
15 saved = 38, loss = 0.4137, threshold = 0.7