<a href="https://colab.research.google.com/github/nathanwispinski/meta-rl/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# train.ipynb

This is a Google Colab notebook to demo model training of a recurrent neural network in a two-armed bandit task using reinforcement learning.

This is a single-threaded version of model training, and may take a while depending on your training settings (around 15 minutes in a CPU Colab instance in my experience).

For more details, see the GitHub repository (https://github.com/nathanwispinski/meta-rl).

# Colab setup

In [1]:
#@title Clone GitHub repository.
!git clone https://github.com/nathanwispinski/meta-rl

Cloning into 'meta-rl'...
remote: Enumerating objects: 54, done.[K
remote: Counting objects: 100% (54/54), done.[K
remote: Compressing objects: 100% (39/39), done.[K
remote: Total 54 (delta 20), reused 41 (delta 14), pack-reused 0[K
Unpacking objects: 100% (54/54), 143.75 KiB | 920.00 KiB/s, done.


In [2]:
#@title Change working directory to cloned repository (i.e., /content/meta-rl/).
%cd meta-rl

/content/meta-rl


'/content/meta-rl'

In [3]:
# @title Install dependencies from `requirements.txt`.
!pip install -r requirements.txt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting absl_py==1.3.0
  Downloading absl_py-1.3.0-py3-none-any.whl (124 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m124.6/124.6 KB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting chex==0.1.5
  Downloading chex-0.1.5-py3-none-any.whl (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.3/85.3 KB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dm_haiku==0.0.9
  Downloading dm_haiku-0.0.9-py3-none-any.whl (352 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m352.1/352.1 KB[0m [31m32.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jax==0.3.25
  Downloading jax-0.3.25.tar.gz (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m64.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting jaxlib==0.3.25
  Downloading jax

# Import dependencies

In [4]:
#@title Import dependencies after install.
import json
import numpy as np
import pickle

import modules.envs as envs
import modules.agents as agents
import modules.loggers as loggers

# Set config for training

In [5]:
#@title Import config for training.
from configs.bandit_config_train import get_config

config = get_config()
json_config = json.loads(config.to_json_best_effort())

In [6]:
#@title Print loaded config.
json_config

{'phase': 'train',
 'path': './',
 'params_filename': 'train_test',
 'random_seed': 42,
 'num_workers': 5,
 'num_evaluators': 1,
 'eval_every_steps': 200000,
 'num_eval_episodes': 400,
 'log_every_steps': 20000,
 'environment': {'env_name': 'bandit',
  'steps_per_episode': 100,
  'reward_structure': 'correlated'},
 'eval_environment': {'steps_per_episode': 100,
  'reward_structure': 'correlated'},
 'agent': {'total_training_steps': 2000000,
  'random_seed': 42,
  'num_lstm_units': 48,
  'learning_rate_start': 0.0003,
  'learning_rate_end': 0.0,
  'gamma': 0.9,
  'v_loss_coef': 0.05,
  'e_loss_coef_start': 0.0,
  'e_loss_coef_end': 0.0,
  'e_loss_decay_factor': 3,
  'max_unroll_steps': 200,
  'global_norm_grad_clip': 50.0}}

In [7]:
#@title Modify config (optional).
#@markdown Add as many lines as needed in the code here.
config.update({'random_seed': 100})
config.update({'params_filename': 'my_colab_agent'})

# Print to see changes.
json.loads(config.to_json_best_effort())

{'phase': 'train',
 'path': './',
 'params_filename': 'train_test',
 'random_seed': 100,
 'num_workers': 5,
 'num_evaluators': 1,
 'eval_every_steps': 200000,
 'num_eval_episodes': 400,
 'log_every_steps': 20000,
 'environment': {'env_name': 'bandit',
  'steps_per_episode': 100,
  'reward_structure': 'correlated'},
 'eval_environment': {'steps_per_episode': 100,
  'reward_structure': 'correlated'},
 'agent': {'total_training_steps': 2000000,
  'random_seed': 42,
  'num_lstm_units': 48,
  'learning_rate_start': 0.0003,
  'learning_rate_end': 0.0,
  'gamma': 0.9,
  'v_loss_coef': 0.05,
  'e_loss_coef_start': 0.0,
  'e_loss_coef_end': 0.0,
  'e_loss_decay_factor': 3,
  'max_unroll_steps': 200,
  'global_norm_grad_clip': 50.0}}

# Training setup

In [8]:
#@title Unpack config.
env_config = config.environment
agent_config = config.agent
random_seed = config.random_seed
total_training_steps = config.agent.total_training_steps
log_every_steps = config.log_every_steps
params_filename = config.params_filename

In [9]:
#@title Set random seed in NumPy.
np.random.seed(random_seed)

In [10]:
#@title Initialize environment.
env = envs.create_env(env_config)
observation = env.reset()

In [11]:
#@title Initialize agent.
# Note: Jax might complain if there is no GPU/TPU found. You can run on a CPU,
# or go to Runtime -> Change Runtime Type in Colab to access a GPU or TPU.
agent = agents.create_agent(
    observation=observation,
    num_actions=env.num_actions,
    agent_config=agent_config)



In [12]:
#@title Initialize performance logger.
logger = loggers.create_logger(logger_name='bandit', config=config, log_to_console=True)

In [13]:
#@title Initialize LSTM recurrent state to zeros.
initial_lstm_state = agent.get_initial_lstm_state()
lstm_state = initial_lstm_state

# Training

In [14]:
#@title Main training loop (Note: this might take a while).

step, episode, loss = 0, 0, 0
while step < total_training_steps:

    # Get an action and step the environment with the agent's action
    action, _, v_out, new_lstm_state, _ = agent.get_action(observation, lstm_state)
    next_observation, reward, done, info = env.step(action)

    # Save experience in a buffer
    agent.buffer.append(
        obs=observation,
        action=action,
        reward=reward,
        next_obs=next_observation,
        done=done,
        lstm_state=lstm_state,
    )

    observation = next_observation
    lstm_state = new_lstm_state

    # Log performance
    logger.log_step(
        global_step=step,
        worker_step=step,
        reward=reward,
        info=info,
        loss=loss,
        entropy_coef=agent.e_loss_coef,
    )

    # Update agent parameters if an episode is done, or
    # if the agent experience buffer == max_unroll_steps
    loss, grads, num_steps = agent.update(done, update_params=True)
    step += 1

    # If done, reset the environment and LSTM state
    if done:
        episode += 1
        done = False
        lstm_state = initial_lstm_state
        observation = env.reset()

print('Done training!')

Global step:	0	| Worker step:	0	| T:	151.50	| Mean Reward:	0.0001	| Entropy coef:	0.0000	| Loss:	0.00000	|
Global step:	20000	| Worker step:	20000	| T:	6.80	| Mean Reward:	0.5082	| Entropy coef:	0.0000	| Loss:	0.24415	|
Global step:	40000	| Worker step:	40000	| T:	7.30	| Mean Reward:	0.5055	| Entropy coef:	0.0000	| Loss:	0.08747	|
Global step:	60000	| Worker step:	60000	| T:	5.37	| Mean Reward:	0.4942	| Entropy coef:	0.0000	| Loss:	0.08729	|
Global step:	80000	| Worker step:	80000	| T:	7.30	| Mean Reward:	0.5107	| Entropy coef:	0.0000	| Loss:	0.05006	|
Global step:	100000	| Worker step:	100000	| T:	5.40	| Mean Reward:	0.5024	| Entropy coef:	0.0000	| Loss:	-0.00209	|
Global step:	120000	| Worker step:	120000	| T:	6.57	| Mean Reward:	0.5116	| Entropy coef:	0.0000	| Loss:	0.13439	|
Global step:	140000	| Worker step:	140000	| T:	6.17	| Mean Reward:	0.5090	| Entropy coef:	0.0000	| Loss:	-0.01373	|
Global step:	160000	| Worker step:	160000	| T:	5.55	| Mean Reward:	0.5196	| Entropy coef:	0.00

In [15]:
#@title Save model after training is complete.
results = {
    "params": agent.params,
    "config": config.to_dict(),
}
with open(params_filename + '.pickle', 'wb') as fp:
    pickle.dump([results], fp)
print("Saved parameters.")

Saved parameters.
