In [1]:
import numpy as np
import torch
from rocket import Rocket
from policy import ActorCritic
import matplotlib.pyplot as plt
import utils
import os
import glob

# Decide which device we want to run on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [4]:

if __name__ == '__main__':

    task = 'hover'  # 'hover' or 'landing'

    max_m_episode = 800000
    max_steps = 800

    env = Rocket(task=task, max_steps=max_steps)
    ckpt_folder = os.path.join('./', task + '_ckpt')
    if not os.path.exists(ckpt_folder):
        os.mkdir(ckpt_folder)

    last_episode_id = 0
    REWARDS = []

    net = ActorCritic(input_dim=env.state_dims, output_dim=env.action_dims).to(device)
    if len(glob.glob(os.path.join(ckpt_folder, '*.pt'))) > 0:
        # load the last ckpt
        checkpoint = torch.load(glob.glob(os.path.join(ckpt_folder, '*.pt'))[-1])
        net.load_state_dict(checkpoint['model_G_state_dict'])
        last_episode_id = checkpoint['episode_id']
        REWARDS = checkpoint['REWARDS']

    for episode_id in range(last_episode_id, max_m_episode):

        # training loop
        state = env.reset()
        rewards, log_probs, values, masks = [], [], [], []
        
        for step_id in range(max_steps):
            action, log_prob, value = net.get_action(state)
#             print(log_prob)
#             print(value)
#             print(action)
            state, reward, done, _ = env.step(action)
            rewards.append(reward)
            log_probs.append(log_prob)
            values.append(value)
            masks.append(1-done)
            if episode_id % 100 == 1:
                env.render()

            if done or step_id == max_steps-1:
                _, _, Qval = net.get_action(state)
                net.update_ac(net, rewards, log_probs, values, masks, Qval, gamma=0.999)
                break

        REWARDS.append(np.sum(rewards))
        print('episode id: %d, episode reward: %.3f'
              % (episode_id, np.sum(rewards)))

        if episode_id % 100 == 1:
            plt.figure()
            plt.plot(REWARDS), plt.plot(utils.moving_avg(REWARDS, N=50))
            plt.legend(['episode reward', 'moving avg'], loc=2)
            plt.xlabel('m episode')
            plt.ylabel('reward')
            plt.savefig(os.path.join(ckpt_folder, 'rewards_' + str(episode_id).zfill(8) + '.jpg'))
            plt.close()

            torch.save({'episode_id': episode_id,
                        'REWARDS': REWARDS,
                        'model_G_state_dict': net.state_dict()},
                       os.path.join(ckpt_folder, 'ckpt_' + str(episode_id).zfill(8) + '.pt'))



tensor(-2.1728, grad_fn=<LogBackward0>)
tensor([-0.0064], grad_fn=<SelectBackward0>)
8
tensor(-2.2099, grad_fn=<LogBackward0>)
tensor([-0.0166], grad_fn=<SelectBackward0>)
2
tensor(-2.2076, grad_fn=<LogBackward0>)
tensor([-0.0261], grad_fn=<SelectBackward0>)
1
tensor(-2.2038, grad_fn=<LogBackward0>)
tensor([-0.0188], grad_fn=<SelectBackward0>)
1
tensor(-2.2615, grad_fn=<LogBackward0>)
tensor([-0.0138], grad_fn=<SelectBackward0>)
0
tensor(-2.1265, grad_fn=<LogBackward0>)
tensor([-0.0059], grad_fn=<SelectBackward0>)
5
tensor(-2.2293, grad_fn=<LogBackward0>)
tensor([-0.0151], grad_fn=<SelectBackward0>)
7
tensor(-2.2165, grad_fn=<LogBackward0>)
tensor([-0.0361], grad_fn=<SelectBackward0>)
7
tensor(-2.2215, grad_fn=<LogBackward0>)
tensor([-0.0361], grad_fn=<SelectBackward0>)
2
tensor(-2.2193, grad_fn=<LogBackward0>)
tensor([-0.0433], grad_fn=<SelectBackward0>)
6
tensor(-2.2633, grad_fn=<LogBackward0>)
tensor([-0.0342], grad_fn=<SelectBackward0>)
0
tensor(-2.1224, grad_fn=<LogBackward0>)
ten

tensor(-2.1838, grad_fn=<LogBackward0>)
tensor([-0.0057], grad_fn=<SelectBackward0>)
4
tensor(-2.2724, grad_fn=<LogBackward0>)
tensor([-0.0032], grad_fn=<SelectBackward0>)
7
tensor(-2.1769, grad_fn=<LogBackward0>)
tensor([0.0058], grad_fn=<SelectBackward0>)
4
tensor(-2.1709, grad_fn=<LogBackward0>)
tensor([0.0057], grad_fn=<SelectBackward0>)
4
tensor(-2.1954, grad_fn=<LogBackward0>)
tensor([-0.0193], grad_fn=<SelectBackward0>)
6
tensor(-2.1138, grad_fn=<LogBackward0>)
tensor([-0.0179], grad_fn=<SelectBackward0>)
5
tensor(-2.1498, grad_fn=<LogBackward0>)
tensor([-0.0099], grad_fn=<SelectBackward0>)
8
tensor(-2.2576, grad_fn=<LogBackward0>)
tensor([0.0047], grad_fn=<SelectBackward0>)
0
tensor(-2.2606, grad_fn=<LogBackward0>)
tensor([0.0324], grad_fn=<SelectBackward0>)
7
tensor(-2.2002, grad_fn=<LogBackward0>)
tensor([0.0156], grad_fn=<SelectBackward0>)
1
tensor(-2.1969, grad_fn=<LogBackward0>)
tensor([0.0094], grad_fn=<SelectBackward0>)
6
tensor(-2.1115, grad_fn=<LogBackward0>)
tensor([0

tensor(-2.1598, grad_fn=<LogBackward0>)
tensor([-0.0362], grad_fn=<SelectBackward0>)
4
tensor(-2.2404, grad_fn=<LogBackward0>)
tensor([-0.0228], grad_fn=<SelectBackward0>)
6
tensor(-2.2161, grad_fn=<LogBackward0>)
tensor([0.0022], grad_fn=<SelectBackward0>)
0
tensor(-2.2140, grad_fn=<LogBackward0>)
tensor([-0.0105], grad_fn=<SelectBackward0>)
0
tensor(-2.1077, grad_fn=<LogBackward0>)
tensor([-0.0161], grad_fn=<SelectBackward0>)
5
tensor(-2.1744, grad_fn=<LogBackward0>)
tensor([-0.0178], grad_fn=<SelectBackward0>)
8
tensor(-2.1764, grad_fn=<LogBackward0>)
tensor([-0.0439], grad_fn=<SelectBackward0>)
8
tensor(-2.2617, grad_fn=<LogBackward0>)
tensor([-0.0431], grad_fn=<SelectBackward0>)
2
tensor(-2.1968, grad_fn=<LogBackward0>)
tensor([-0.0114], grad_fn=<SelectBackward0>)
4
tensor(-2.2346, grad_fn=<LogBackward0>)
tensor([0.0023], grad_fn=<SelectBackward0>)
1
tensor(-2.1127, grad_fn=<LogBackward0>)
tensor([-0.0015], grad_fn=<SelectBackward0>)
5
tensor(-2.1917, grad_fn=<LogBackward0>)
tenso

tensor([0.0565], grad_fn=<SelectBackward0>)
2
tensor(-2.2212, grad_fn=<LogBackward0>)
tensor([0.0589], grad_fn=<SelectBackward0>)
0
tensor(-2.1593, grad_fn=<LogBackward0>)
tensor([0.0513], grad_fn=<SelectBackward0>)
4
tensor(-2.2145, grad_fn=<LogBackward0>)
tensor([0.0449], grad_fn=<SelectBackward0>)
8
tensor(-2.1169, grad_fn=<LogBackward0>)
tensor([0.0675], grad_fn=<SelectBackward0>)
5
tensor(-2.1641, grad_fn=<LogBackward0>)
tensor([0.0803], grad_fn=<SelectBackward0>)
3
tensor(-2.2342, grad_fn=<LogBackward0>)
tensor([0.0705], grad_fn=<SelectBackward0>)
7
tensor(-2.2407, grad_fn=<LogBackward0>)
tensor([0.0790], grad_fn=<SelectBackward0>)
7
tensor(-2.1550, grad_fn=<LogBackward0>)
tensor([0.0706], grad_fn=<SelectBackward0>)
3
tensor(-2.1230, grad_fn=<LogBackward0>)
tensor([0.0522], grad_fn=<SelectBackward0>)
5
tensor(-2.2312, grad_fn=<LogBackward0>)
tensor([0.0595], grad_fn=<SelectBackward0>)
6
tensor(-2.2452, grad_fn=<LogBackward0>)
tensor([0.0603], grad_fn=<SelectBackward0>)
6
tensor(-

tensor(-2.2632, grad_fn=<LogBackward0>)
tensor([0.0618], grad_fn=<SelectBackward0>)
0
tensor(-2.1101, grad_fn=<LogBackward0>)
tensor([0.0654], grad_fn=<SelectBackward0>)
5
tensor(-2.2174, grad_fn=<LogBackward0>)
tensor([0.0660], grad_fn=<SelectBackward0>)
7
tensor(-2.1773, grad_fn=<LogBackward0>)
tensor([0.0603], grad_fn=<SelectBackward0>)
8
tensor(-2.2594, grad_fn=<LogBackward0>)
tensor([0.0487], grad_fn=<SelectBackward0>)
0
tensor(-2.2457, grad_fn=<LogBackward0>)
tensor([0.0500], grad_fn=<SelectBackward0>)
1
tensor(-2.2330, grad_fn=<LogBackward0>)
tensor([0.0554], grad_fn=<SelectBackward0>)
7
tensor(-2.2324, grad_fn=<LogBackward0>)
tensor([0.0528], grad_fn=<SelectBackward0>)
1
tensor(-2.1633, grad_fn=<LogBackward0>)
tensor([0.0510], grad_fn=<SelectBackward0>)
8
tensor(-2.1560, grad_fn=<LogBackward0>)
tensor([0.0465], grad_fn=<SelectBackward0>)
3
tensor(-2.1782, grad_fn=<LogBackward0>)
tensor([0.0369], grad_fn=<SelectBackward0>)
4
tensor(-2.2149, grad_fn=<LogBackward0>)
tensor([0.0228

tensor(-2.0997, grad_fn=<LogBackward0>)
tensor([0.0136], grad_fn=<SelectBackward0>)
5
tensor(-2.2371, grad_fn=<LogBackward0>)
tensor([0.0298], grad_fn=<SelectBackward0>)
2
tensor(-2.2131, grad_fn=<LogBackward0>)
tensor([0.0235], grad_fn=<SelectBackward0>)
1
tensor(-2.0959, grad_fn=<LogBackward0>)
tensor([0.0126], grad_fn=<SelectBackward0>)
5
tensor(-2.1730, grad_fn=<LogBackward0>)
tensor([0.0276], grad_fn=<SelectBackward0>)
8
tensor(-2.2475, grad_fn=<LogBackward0>)
tensor([0.0260], grad_fn=<SelectBackward0>)
7
tensor(-2.2587, grad_fn=<LogBackward0>)
tensor([0.0115], grad_fn=<SelectBackward0>)
7
tensor(-2.1846, grad_fn=<LogBackward0>)
tensor([0.0143], grad_fn=<SelectBackward0>)
4
tensor(-2.0976, grad_fn=<LogBackward0>)
tensor([0.0230], grad_fn=<SelectBackward0>)
5
tensor(-2.1049, grad_fn=<LogBackward0>)
tensor([0.0113], grad_fn=<SelectBackward0>)
5
tensor(-2.2572, grad_fn=<LogBackward0>)
tensor([0.0276], grad_fn=<SelectBackward0>)
0
tensor(-2.2050, grad_fn=<LogBackward0>)
tensor([0.0479

tensor(-2.1841, grad_fn=<LogBackward0>)
tensor([0.0404], grad_fn=<SelectBackward0>)
4
tensor(-2.2305, grad_fn=<LogBackward0>)
tensor([0.0463], grad_fn=<SelectBackward0>)
0
tensor(-2.2383, grad_fn=<LogBackward0>)
tensor([0.0678], grad_fn=<SelectBackward0>)
2
tensor(-2.2713, grad_fn=<LogBackward0>)
tensor([0.0652], grad_fn=<SelectBackward0>)
2
tensor(-2.1243, grad_fn=<LogBackward0>)
tensor([0.0574], grad_fn=<SelectBackward0>)
5
tensor(-2.2440, grad_fn=<LogBackward0>)
tensor([0.0670], grad_fn=<SelectBackward0>)
0
tensor(-2.1571, grad_fn=<LogBackward0>)
tensor([0.0627], grad_fn=<SelectBackward0>)
3
tensor(-2.1760, grad_fn=<LogBackward0>)
tensor([0.0460], grad_fn=<SelectBackward0>)
8
tensor(-2.2489, grad_fn=<LogBackward0>)
tensor([0.0661], grad_fn=<SelectBackward0>)
0
tensor(-2.2500, grad_fn=<LogBackward0>)
tensor([0.0719], grad_fn=<SelectBackward0>)
2
tensor(-2.2620, grad_fn=<LogBackward0>)
tensor([0.0605], grad_fn=<SelectBackward0>)
0
tensor(-2.2093, grad_fn=<LogBackward0>)
tensor([0.0475

KeyboardInterrupt: 