In [1]:
import torch
import numpy as np
import gym

In [2]:
from tensorboardX import SummaryWriter

In [3]:
##cross_method for rl
## get episode ,and learn to map state to action, which ignore accurate reward or gain

In [4]:
env = gym.make('FrozenLake-v0')

In [5]:
obs = env.reset()

In [6]:
def one_hot_obs(obs):
    obs_np = np.zeros([16])
    obs_np[obs] = 1
    return obs_np

In [7]:
one_hot_obs(env.reset())

array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

# 1 nn 

In [8]:
class Net(torch.nn.Module):
    '''
    cross method for map state to action
    '''
    def __init__(self,obs_size,hidden_size,action_size):
        super(Net,self).__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(obs_size,hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_size,action_size)
        )
    
    def forward(self,x):
        return self.net(x)

# 2 交互过程

In [9]:
from collections import namedtuple

In [10]:
Episode = namedtuple('Eposide',field_names=['reward','step']) ## filter to get good eposide
EpisodeStep = namedtuple('EposideStep',field_names=['observation','action']) ## map obs to action

In [11]:
def iterate_batch(env,net,batch_size):
    batch = []
    ##action ->reward 
    episode_reward = 0.0
    episode_step = []
    obs = env.reset()
    op_sm = torch.nn.Softmax(dim=1)
    while True:
        ##get action
        obs_value = torch.FloatTensor([one_hot_obs(obs)])
        action_prob = op_sm(net(obs_value)).data.numpy()[0]
        action = np.random.choice(len(action_prob),p = action_prob)
        ##get obs and reward
        next_obs , reward, is_done,_ = env.step(action)
        episode_reward += reward
        episode_step.append(EpisodeStep(one_hot_obs(obs),action))
        if is_done:
            batch.append(Episode(episode_reward,episode_step))
            episode_reward = 0.0
            episode_step = []
            next_obs = env.reset()
            if len(batch) == batch_size:
                yield batch
                batch = []
        obs = next_obs

# 3 学习过程

In [12]:
def filter_batch(batch,percentile,discount):
    reward = list(map(lambda x:x.reward*(discount**len(x.step)),batch))
    #print(reward)
    reward_bound = np.percentile(reward,percentile)
    #print(reward_bound)
    train_obs = []
    train_act = []
    elite_batch = []#保留good episode
    for one_reward,example in zip(reward,batch):
        if one_reward <= reward_bound:
            continue
        train_obs.extend((map(lambda x:x.observation,example.step)))
        train_act.extend((map(lambda x:x.action,example.step)))
        elite_batch.append(example)
        
    return elite_batch,train_obs,train_act,reward_bound

In [13]:
HIDDEN = 128
PERCERTILE = 30
DISCOUNT = 0.9
BATCHS_SIZE = 100

In [14]:
env.action_space.n

4

In [18]:
num_exp =1
writer = SummaryWriter('runs/cross_entroy_fl/exp%d'%(num_exp))

In [19]:
import random

In [20]:
#env = gym.envs.toy_text.frozen_lake.FrozenLakeEnv(is_slippery=False)
obs = env.reset()
net = Net(16,HIDDEN,env.action_space.n)
loss = torch.nn.CrossEntropyLoss()
optimzer = torch.optim.Adam(params=net.parameters(),lr=1e-3)
full_batch = []
random.seed(12345)
for iter_no,batch in enumerate(iterate_batch(env,net,BATCHS_SIZE)):
    reward_mean = np.mean(list(map(lambda x:x.reward ,batch)))
    full_batch ,train_obs,train_act,reward_bound = filter_batch(full_batch+batch,PERCERTILE,DISCOUNT)
    if 0 == len(full_batch):
        continue
    #print(train_obs)
    obs_v = torch.FloatTensor(train_obs)
    act_v = torch.LongTensor(train_act)
    full_batch = full_batch[-500:]
    
    #print(act_v.shape)
    optimzer.zero_grad()
    action_score = net(obs_v)
    #print(action_score.shape) 
    
    loss_v = loss(action_score,act_v)
    loss_v.backward()
    optimzer.step()
    if (iter_no+1)% 100 == 0:
        print("%d :loss:%f, reward_bound:%f, reward_mean:%f,batch=%d"%(iter_no,loss_v.item(),reward_bound,reward_mean,len(full_batch)))
    writer.add_scalar('loss',loss_v.item(),global_step=iter_no)
    writer.add_scalar('reward_bound',reward_bound)
    writer.add_scalar('reward_mean',reward_mean)
    if 20000 == iter_no:
        print("%d :loss:%f, reward_bound:%f, reward_mean:%f,batch=%d"%(iter_no,loss_v.item(),reward_bound,reward_mean,len(full_batch)))
        writer.close()
        break

99 :loss:1.295987, reward_bound:0.000000, reward_mean:0.040000,batch=204
199 :loss:1.157771, reward_bound:0.313811, reward_mean:0.030000,batch=224
299 :loss:1.034672, reward_bound:0.161254, reward_mean:0.040000,batch=227
399 :loss:1.006526, reward_bound:0.350616, reward_mean:0.080000,batch=228
499 :loss:0.977155, reward_bound:0.366422, reward_mean:0.090000,batch=229
599 :loss:0.951341, reward_bound:0.205891, reward_mean:0.070000,batch=225
699 :loss:0.944008, reward_bound:0.423771, reward_mean:0.050000,batch=230
799 :loss:0.822725, reward_bound:0.291844, reward_mean:0.120000,batch=225
899 :loss:0.794710, reward_bound:0.250233, reward_mean:0.050000,batch=230
999 :loss:0.777797, reward_bound:0.251645, reward_mean:0.100000,batch=220
1099 :loss:0.752025, reward_bound:0.307534, reward_mean:0.070000,batch=229
1199 :loss:0.754134, reward_bound:0.331245, reward_mean:0.100000,batch=221
1299 :loss:0.743693, reward_bound:0.338218, reward_mean:0.130000,batch=231
1399 :loss:0.684448, reward_bound:0.

11199 :loss:0.118165, reward_bound:0.228768, reward_mean:0.270000,batch=210
11299 :loss:0.127864, reward_bound:0.426163, reward_mean:0.410000,batch=227
11399 :loss:0.132657, reward_bound:0.324271, reward_mean:0.280000,batch=225
11499 :loss:0.131567, reward_bound:0.328533, reward_mean:0.290000,batch=227
11599 :loss:0.123440, reward_bound:0.228768, reward_mean:0.390000,batch=206
11699 :loss:0.123041, reward_bound:0.185302, reward_mean:0.320000,batch=216
11799 :loss:0.126013, reward_bound:0.430467, reward_mean:0.350000,batch=231
11899 :loss:0.128785, reward_bound:0.348678, reward_mean:0.250000,batch=225
11999 :loss:0.118884, reward_bound:0.285568, reward_mean:0.350000,batch=229
12099 :loss:0.138468, reward_bound:0.387420, reward_mean:0.260000,batch=231
12299 :loss:0.141626, reward_bound:0.334731, reward_mean:0.350000,batch=226
12399 :loss:0.147683, reward_bound:0.348678, reward_mean:0.300000,batch=227
12499 :loss:0.150032, reward_bound:0.185302, reward_mean:0.360000,batch=202
12599 :loss: