In [1]:
%%capture
import matplotlib.pyplot as plt

from IPython import display

from torch.optim.lr_scheduler import StepLR
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
from tqdm.notebook import tqdm

%matplotlib inline


In [2]:
seed = 543 # Do not change this

def fix(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [3]:
%%capture
import gym
import random
seed = 543
fix(seed)
env = gym.make('LunarLander-v2')

In [4]:
print(env.observation_space)

Box([-1.5       -1.5       -5.        -5.        -3.1415927 -5.
 -0.        -0.       ], [1.5       1.5       5.        5.        3.1415927 5.        1.
 1.       ], (8,), float32)


In [5]:
print(env.action_space)

Discrete(4)


In [6]:
initial_state = env.reset()
print(initial_state)

(array([-0.00566216,  1.4077348 , -0.5735394 , -0.14158951,  0.00656792,
        0.12991525,  0.        ,  0.        ], dtype=float32), {})


In [7]:
random_action = env.action_space.sample()
print(random_action)
observation, reward, done, info, _ = env.step(random_action)

0


In [8]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.net1 = nn.Sequential(
            nn.Linear(8,32),
            nn.Tanh(),
            nn.Linear(32,32),
        )
        self.net2 = nn.Sequential(
            nn.Linear(32,16),
            nn.ReLU(),
            nn.Linear(16,4),
        )

    def forward(self, state):
        hid = torch.tanh(self.net1(state))
        return F.softmax(self.net2(hid), dim=-1), hid
    
class RewardNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(32,16),
            nn.Tanh(),
            nn.Linear(16,1)
        )

    def forward(self, hid):
        return self.net(hid)

In [9]:
class PolicyGradientAgent():
    
    def __init__(self, network, rewardnet, epoch):
        self.network = network
        self.optimizerN = optim.SGD(self.network.parameters(), lr=0.001)
        torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1)
        self.rewardnet = rewardnet
        self.optimizerR = optim.SGD(self.rewardnet.parameters(), lr=0.001)
        torch.nn.utils.clip_grad_norm_(self.rewardnet.parameters(), 1)
        self.loss_a = []
        self.loss_r = []
         
    def forward(self, state, Act=True):
        action, hid = self.network(state)
        reward = self.rewardnet(hid.detach())
        return action, hid, reward
        
    def learn(self, log_probs, rewards, pd_rewards, pd_rewards1, epoch):
        At = rewards + pd_rewards1 - pd_rewards
        loss = (-log_probs * At).sum()
        self.loss_a.append(loss.item())        
        self.optimizerN.zero_grad()
        loss.backward(retain_graph=True)
        self.optimizerN.step()

        reward_loss = torch.nn.MSELoss()(rewards.clone(), pd_rewards.clone())
        self.loss_r.append(reward_loss.item())
        self.optimizerR.zero_grad()
        reward_loss.backward()
        self.optimizerR.step()
        
            
        
    def sample(self, state):
        action_prob, _, reward = self.forward(torch.FloatTensor(state))
        action_dist = Categorical(action_prob)
        action = action_dist.sample()
        log_prob = action_dist.log_prob(action)
        return action.item(), log_prob, reward

In [10]:
EPISODE_PER_BATCH = 5  # 每蒐集 5 個 episodes 更新一次 agent
NUM_BATCH = 1000        # 總共更新 400 次
lamb = 0.99

network = Network()
reward_net = RewardNet()
agent = PolicyGradientAgent(network,reward_net,NUM_BATCH)
agent.network.train()
agent.rewardnet.train()

avg_total_rewards, avg_final_rewards = [], []

In [11]:
prg_bar = tqdm(range(NUM_BATCH))
best = 0
for batch in prg_bar:

    log_probs, rewards, pd_rewards, pd_rewards1 = [], [], [], []
    total_rewards, final_rewards = [], []
    
    # 蒐集訓練資料
    for episode in range(EPISODE_PER_BATCH):
        
        state = env.reset()[0]
        total_reward, total_step = 0, 0
        seq_rewards = []
        pd = []
        while True:
            action, log_prob, pd_reward = agent.sample(state) # at , log(at|st)
            
            next_state, reward, done, info, _= env.step(action)
            log_probs.append(log_prob) # [log(a1|s1), log(a2|s2), ...., log(at|st)]
            seq_rewards.append(reward)
            pd.append(pd_reward)
            total_step += 1
            
            state = next_state
            
            prg_bar.set_postfix({"current_step":total_step})
            if done or total_step == 2000:
                seq_rewards = np.array(seq_rewards)

                # 计算幂次项的数组
                l = len(seq_rewards)
                powers = np.power(lamb, np.arange(l))
                # 使用向量化操作计算reward_li
                reward_li = [np.sum(seq_rewards[i:]*powers[:l-i]) for i in range(l)]
                rewards += list(reward_li)
                pd_rewards += pd
                _, _, pd_reward1 = agent.sample(state)
                pd_rewards1 += pd[1:] + [pd_reward1]
                final_rewards.append(seq_rewards[-1])
                total_rewards.append(sum(seq_rewards))
                break

    print(f"rewards looks like ", np.shape(rewards))  
    print(f"log_probs looks like ", len(log_probs))     
    # 紀錄訓練過程
    avg_total_reward = sum(total_rewards) / len(total_rewards)
    avg_final_reward = sum(final_rewards) / len(final_rewards)
    avg_total_rewards.append(avg_total_reward)
    avg_final_rewards.append(avg_final_reward)
    prg_bar.set_description(f"Total: {avg_total_reward: 4.1f}, Final: {avg_final_reward: 4.1f}")
    if avg_total_reward > best:
        print("save_best:",batch)
        torch.save(network.state_dict(),"best_net.bin")
        torch.save(reward_net.state_dict(),"best_reward_net.bin")

    # 更新網路
    # rewards = np.concatenate(rewards, axis=0)
    rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-9)  # 將 reward 正規標準化
    agent.learn(torch.stack(log_probs), torch.from_numpy(rewards).float(), torch.cat(pd_rewards), torch.cat(pd_rewards1), NUM_BATCH)

  0%|          | 0/1000 [00:00<?, ?it/s]

rewards looks like  (440,)
log_probs looks like  440
rewards looks like  (534,)
log_probs looks like  534
rewards looks like  (499,)
log_probs looks like  499
rewards looks like  (546,)
log_probs looks like  546
rewards looks like  (446,)
log_probs looks like  446
rewards looks like  (497,)
log_probs looks like  497
rewards looks like  (574,)
log_probs looks like  574
rewards looks like  (528,)
log_probs looks like  528
rewards looks like  (545,)
log_probs looks like  545
rewards looks like  (528,)
log_probs looks like  528
rewards looks like  (498,)
log_probs looks like  498
rewards looks like  (437,)
log_probs looks like  437
rewards looks like  (492,)
log_probs looks like  492
rewards looks like  (641,)
log_probs looks like  641
rewards looks like  (520,)
log_probs looks like  520
rewards looks like  (500,)
log_probs looks like  500
rewards looks like  (532,)
log_probs looks like  532
rewards looks like  (593,)
log_probs looks like  593
rewards looks like  (449,)
log_probs looks lik

rewards looks like  (459,)
log_probs looks like  459
rewards looks like  (476,)
log_probs looks like  476
rewards looks like  (568,)
log_probs looks like  568
rewards looks like  (486,)
log_probs looks like  486
rewards looks like  (468,)
log_probs looks like  468
rewards looks like  (494,)
log_probs looks like  494
rewards looks like  (540,)
log_probs looks like  540
rewards looks like  (662,)
log_probs looks like  662
rewards looks like  (436,)
log_probs looks like  436
rewards looks like  (495,)
log_probs looks like  495
rewards looks like  (454,)
log_probs looks like  454
rewards looks like  (395,)
log_probs looks like  395
rewards looks like  (497,)
log_probs looks like  497
rewards looks like  (484,)
log_probs looks like  484
rewards looks like  (399,)
log_probs looks like  399
rewards looks like  (425,)
log_probs looks like  425
rewards looks like  (445,)
log_probs looks like  445
rewards looks like  (405,)
log_probs looks like  405
rewards looks like  (525,)
log_probs looks lik

rewards looks like  (557,)
log_probs looks like  557
rewards looks like  (625,)
log_probs looks like  625
rewards looks like  (537,)
log_probs looks like  537
rewards looks like  (547,)
log_probs looks like  547
rewards looks like  (585,)
log_probs looks like  585
rewards looks like  (597,)
log_probs looks like  597
rewards looks like  (652,)
log_probs looks like  652
rewards looks like  (474,)
log_probs looks like  474
rewards looks like  (604,)
log_probs looks like  604
rewards looks like  (487,)
log_probs looks like  487
rewards looks like  (656,)
log_probs looks like  656
rewards looks like  (690,)
log_probs looks like  690
rewards looks like  (666,)
log_probs looks like  666
rewards looks like  (618,)
log_probs looks like  618
rewards looks like  (551,)
log_probs looks like  551
rewards looks like  (724,)
log_probs looks like  724
rewards looks like  (726,)
log_probs looks like  726
rewards looks like  (480,)
log_probs looks like  480
rewards looks like  (619,)
log_probs looks lik

rewards looks like  (383,)
log_probs looks like  383
rewards looks like  (406,)
log_probs looks like  406
rewards looks like  (387,)
log_probs looks like  387
rewards looks like  (436,)
log_probs looks like  436
rewards looks like  (442,)
log_probs looks like  442
rewards looks like  (398,)
log_probs looks like  398
rewards looks like  (423,)
log_probs looks like  423
rewards looks like  (455,)
log_probs looks like  455
rewards looks like  (415,)
log_probs looks like  415
rewards looks like  (415,)
log_probs looks like  415
rewards looks like  (394,)
log_probs looks like  394
rewards looks like  (368,)
log_probs looks like  368
rewards looks like  (410,)
log_probs looks like  410
rewards looks like  (382,)
log_probs looks like  382
rewards looks like  (403,)
log_probs looks like  403
rewards looks like  (481,)
log_probs looks like  481
rewards looks like  (444,)
log_probs looks like  444
rewards looks like  (384,)
log_probs looks like  384
rewards looks like  (363,)
log_probs looks lik

rewards looks like  (602,)
log_probs looks like  602
save_best: 606
rewards looks like  (2555,)
log_probs looks like  2555
save_best: 607
rewards looks like  (4347,)
log_probs looks like  4347
save_best: 608
rewards looks like  (570,)
log_probs looks like  570
rewards looks like  (622,)
log_probs looks like  622
rewards looks like  (570,)
log_probs looks like  570
rewards looks like  (597,)
log_probs looks like  597
rewards looks like  (603,)
log_probs looks like  603
rewards looks like  (476,)
log_probs looks like  476
rewards looks like  (566,)
log_probs looks like  566
save_best: 615
rewards looks like  (502,)
log_probs looks like  502
save_best: 616
rewards looks like  (510,)
log_probs looks like  510
rewards looks like  (561,)
log_probs looks like  561
save_best: 618
rewards looks like  (493,)
log_probs looks like  493
rewards looks like  (573,)
log_probs looks like  573
save_best: 620
rewards looks like  (621,)
log_probs looks like  621
save_best: 621
rewards looks like  (544,)
l

rewards looks like  (393,)
log_probs looks like  393
rewards looks like  (374,)
log_probs looks like  374
rewards looks like  (354,)
log_probs looks like  354
rewards looks like  (370,)
log_probs looks like  370
rewards looks like  (432,)
log_probs looks like  432
rewards looks like  (346,)
log_probs looks like  346
rewards looks like  (406,)
log_probs looks like  406
rewards looks like  (343,)
log_probs looks like  343
rewards looks like  (377,)
log_probs looks like  377
rewards looks like  (408,)
log_probs looks like  408
rewards looks like  (421,)
log_probs looks like  421
rewards looks like  (410,)
log_probs looks like  410
rewards looks like  (367,)
log_probs looks like  367
rewards looks like  (462,)
log_probs looks like  462
rewards looks like  (448,)
log_probs looks like  448
rewards looks like  (436,)
log_probs looks like  436
rewards looks like  (447,)
log_probs looks like  447
rewards looks like  (396,)
log_probs looks like  396
rewards looks like  (463,)
log_probs looks lik

KeyboardInterrupt: 

In [None]:
import time
import math
end = time.time()
plt.plot(avg_total_rewards)
plt.title("Total Rewards")
plt.show()

In [None]:
plt.plot(avg_final_rewards)
plt.title("Final Rewards")
plt.show()

In [None]:
list = [float(i) for i in agent.loss_a]
plt.subplot(2, 1, 1)
plt.plot(list,"r")
plt.title("action_loss")
plt.show()
plt.subplot(2, 1, 2)
list = [float(i) for i in agent.loss_r]
plt.plot(list,"b")
plt.title("reward_loss")
plt.show()

In [None]:
fix(seed)
agent.network.load_state_dict(torch.load("best_net.bin"))
agent.network.eval()  # 測試前先將 network 切換為 evaluation 模式
NUM_OF_TEST = 20 # Do not revise it !!!!!
test_total_reward = []
action_list = []
for i in range(NUM_OF_TEST):
    actions = []
    state = env.reset()[0]

    total_reward = 0

    done = False
    while not done:
        action, _, _ = agent.sample(state)
        actions.append(action)
        state, reward, done, _, _ = env.step(action)

        total_reward += reward

        #img.set_data(env.render(mode='rgb_array'))
        #display.display(plt.gcf())
        #display.clear_output(wait=True)
    print(total_reward)
    test_total_reward.append(total_reward)

    action_list.append(actions) #儲存你測試的結果
    print("length of actions is ", len(actions))

In [None]:
print(f"Your final reward is : %.2f"%np.mean(test_total_reward))

In [None]:
print("Action list looks like ", action_list)
print("Action list's shape looks like ", np.shape(action_list))

In [None]:
distribution = {}
for actions in action_list:
  for action in actions:
    if action not in distribution.keys():
      distribution[action] = 1
    else:
      distribution[action] += 1
print(distribution)

In [None]:
PATH = "Action_List_test.npy" # 可以改成你想取的名字或路徑
np.save(PATH ,np.array(action_list)) 