# Initialisation

In [None]:
import random
import gym
import torch
from torch import nn
from torch.utils.data import DataLoader as TorchDataLoader
from torch.distributions.categorical import Categorical
import numpy as np
from tqdm.notebook import tqdm
from model import *
from utils import *
from replay_buffer import *

Please note that this notebook uses Comet.ml for experiment tracking. If you don't have an account, please go here and create one for free - https://www.comet.ml/site/. After you create one, please input your details in the code cell below:

In [None]:
# Add the following code anywhere in your machine learning file
comet_api_key = None
comet_project_name = "udrl"
comet_workspace = None
experiment = Experiment(api_key=comet_api_key, project_name=comet_project_name, workspace=comet_workspace)


More details about Comet quickstart - https://www.comet.ml/docs/quick-start/#quick-start-for-python

# Hyperparameters

In [None]:
# general hyperparams
NUM_WARMUP_EPISODES   = 10    # No of warm-up episodes at the beginning
REPLAY_SIZE           = 300   # Max size of the replay buffer (in episodes)
RETURN_SCALE          = 0.01  # Scaling factor for desired horizon input (reward)
HORIZON_SCALE         = 0.01  # Scaling factor for desired horizon input (steps)

# training hyperparams
BATCH_SIZE            = 512   # No of (input, target) pairs/batch for training 
                              # the behavior function
NUM_UPDATES_PER_ITER  = 100   # No of gradient-based updates of the behavior 
                              # function per step of UDRL training
LEARNING_RATE         = 1e-3  # LR for ADAM optimizer
# generating episodes hyperparams

NUM_EPISODES_PER_ITER = 10    # No of exploratory episodes generated per step of
                              # UDRL training
LAST_FEW              = 25    # No of episodes from the end of the replay buffer 
                              # used for sampling exploratory commands

# Solving `Sparse Lunar Lander`

## Step 1 - Initialize replay buffer and warm-up using random policy

In [None]:
training_step = 0; playing_step = 0

replay_buffer = ReplayBuffer(REPLAY_SIZE)   # init replay buffer
env = gym.make("LunarLander-v2")            # init gym env

for _ in tqdm(range(NUM_WARMUP_EPISODES)):

    episode = {
        'states': [],
        'actions': [],
        'rewards': [],
        'next_states': []
    }
    episode_reward = 0

    state = env.reset()
    done = False
    while not done:
        episode['states'].append(state)
        action = env.action_space.sample()
        state, reward, done, info = env.step(action)
        episode_reward += reward
        episode['actions'].append(action)
        episode['next_states'].append(state)
        if not done: 
            episode['rewards'].append(0)         # because 'sparse' lunar lander
    episode['rewards'].append(episode_reward)    # finally add total episode reward
    
    playing_step += 1
    experiment.log_metric("episode_reward", episode_reward, step=playing_step)

    # add episode data to the replay buffer
    replay_buffer.add_episode(
        np.array(episode['states'], dtype=np.float),
        np.array(episode['actions'], dtype=np.int),
        np.array(episode['rewards'], dtype=np.float),
        np.array(episode['next_states'], dtype=np.float),
    )

## Step 2 - Initialize policy network (BehaviorNet) and optimizer

In [None]:
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

policy = BehaviorNet(state_dim, action_dim)

loss_func = nn.NLLLoss()
optimizer = torch.optim.Adam(policy.parameters(), lr=LEARNING_RATE)

## Step 3 - Main learning loop

    - First, train policy network by sampling behavior segments from buffer.
    
    - Second, sample exploratory commands for future exploration.
    
    - Third, use the latest policy network & sampled commands to generate new trajectories & add them to the replay buffer
    
    - Continue looping.

In [None]:
try:
    while(1): # keep cycling indefinitely

        # 1 - Train Policy Network
        
        episodes_to_train = replay_buffer.sample_episodes(5)
        train_dset = BehaviorDataset(episodes_to_train, 
                                    size=BATCH_SIZE*NUM_UPDATES_PER_ITER, 
                                    horizon_scale=HORIZON_SCALE, 
                                    return_scale=RETURN_SCALE)
        training_behaviors = TorchDataLoader(train_dset, 
                                            batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
        if not policy.training: policy.train();
        for behavior_batch in training_behaviors: # this runs for NUM_UPDATES_PER_ITER rounds
            policy.zero_grad()
            logprobs = policy(behavior_batch['features'])
            loss = loss_func(logprobs, behavior_batch['label'])
            loss.backward()
            optimizer.step()

            training_step += 1
            experiment.log_metric("batch_loss", loss.cpu().detach(), step=training_step)


        # 2 - Sample exploratory target commands

        top_episodes = replay_buffer.top_episodes(LAST_FEW) # [(S,A,R,S_), ... ]
        tgt_horizon = int(np.mean([x[0].shape[0] for x in top_episodes]))
        tgt_reward_mean = np.mean([np.sum(x[2]) for x in top_episodes])
        tgt_reward_std = np.std([np.sum(x[2]) for x in top_episodes])

        def generate_command(tgt_horizon, 
                            tgt_reward_mean, 
                            tgt_reward_std):
            tgt_horizon = min(tgt_horizon, 200)
            tgt_reward = round(np.random.random_sample()*tgt_reward_std + tgt_reward_mean, 0)
            return tgt_horizon, tgt_reward

        experiment.log_metric("tgt_reward_mean", tgt_reward_mean, step=playing_step)
        
        
        # 3 - Generate new trajectories using latest policy network and generated commands
        
        for _ in range(NUM_EPISODES_PER_ITER):
            episode = {
                'states': [],
                'actions': [],
                'rewards': [],
                'next_states': []
            }
            episode_reward = 0
            # start interactions
            state = env.reset()
            done = False
            command_horizon, command_reward = generate_command(tgt_horizon, 
                                                  tgt_reward_mean, 
                                                  tgt_reward_std)
            
            experiment.log_metric("command_horizon", command_horizon, step=playing_step)
            experiment.log_metric("command_reward", command_reward, step=playing_step)
            while not done:
                episode['states'].append(state)
                state_ = augment_state(state, 
                                    command=(command_horizon, command_reward), 
                                    command_scale=(HORIZON_SCALE, RETURN_SCALE))
                state_ = torch.tensor(state_, dtype=torch.float)
                with torch.no_grad():
                    action_logprobs = policy(state_)
                    action_distribution = Categorical(logits=action_logprobs)
                    action = action_distribution.sample().item()
                state, reward, done, info = env.step(action)
                episode_reward += reward
                episode['actions'].append(action)
                episode['next_states'].append(state)
                command_horizon = max(1, command_horizon-1)
                if not done: 
                    episode['rewards'].append(0) # sparse lunar lander
                    command_reward -= 0
                else:
                    episode['rewards'].append(episode_reward)     # sparse lunar lander 
                    command_reward -= episode_reward
            
            playing_step += 1
            experiment.log_metric("episode_reward", episode_reward, step=playing_step)
            

            replay_buffer.add_episode(
                np.array(episode['states'], dtype=np.float),
                np.array(episode['actions'], dtype=np.int),
                np.array(episode['rewards'], dtype=np.float),
                np.array(episode['next_states'], dtype=np.float),
            )
except:
    env.close()
    experiment.end()
    print("Terminated.")