<a href="https://colab.research.google.com/github/nathanwispinski/meta-rl/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# train.ipynb

This is a Google Colab notebook to demo model training of a recurrent neural network in a two-armed bandit task using reinforcement learning.

This is a single-threaded version of model training, and may take a while depending on your training settings (around XX minutes in my experience).

For more details, see the GitHub repository (https://github.com/nathanwispinski/meta-rl).

# Colab setup

In [None]:
#@title Clone GitHub repository.
!git clone https://github.com/nathanwispinski/meta-rl

In [None]:
#@title Change working directory to cloned repository (i.e., meta-rl/).
%cd meta-rl
%pwd

[Errno 2] No such file or directory: 'meta-rl'
/home/natha/meta-rl


'/home/natha/meta-rl'

In [None]:
# @title Install dependencies from `requirements.txt`.
!pip install -r requirements.txt

# Import dependencies

In [None]:
#@title Import dependencies after install.
import json
import numpy as np
import pickle

import modules.envs as envs
import modules.agents as agents
import modules.loggers as loggers

# Set config for training

In [None]:
#@title Import config for training.
from configs.bandit_config_train import get_config

config = get_config()
json_config = json.loads(config.to_json_best_effort())

In [None]:
#@title Print loaded config.
json_config

In [None]:
#@title Modify config (optional).
#@markdown Add as many lines as needed in the code here.
config.update({'random_seed': 100})

# Print to see changes.
json.loads(config.to_json_best_effort())

# Training setup

In [None]:
#@title Unpack config.
env_config = config.environment
agent_config = config.agent
random_seed = config.random_seed
total_training_steps = config.agent.total_training_steps
log_every_steps = config.log_every_steps
params_filename = config.params_filename

In [None]:
#@title Set random seed in NumPy.
np.random.seed(random_seed)

In [None]:
#@title Initialize environment.
env = envs.create_env(env_config)
observation = env.reset()

In [None]:
#@title Initialize agent.
# Note: Jax might complain if there is no GPU/TPU found. You can run on a CPU,
# or go to Runtime -> Change Runtime Type in Colab to access a GPU or TPU.
agent = agents.create_agent(
    observation=observation,
    num_actions=env.num_actions,
    agent_config=agent_config)



In [None]:
#@title Initialize performance logger.
logger = loggers.create_logger(logger_name='bandit', config=config, log_to_console=True)

In [None]:
#@title Initialize LSTM recurrent state to zeros.
initial_lstm_state = agent.get_initial_lstm_state()
lstm_state = initial_lstm_state

# Training

In [None]:
#@title Main training loop (Note: this might take a while).

step, episode, loss = 0, 0, 0
while step < total_training_steps:

    # Get an action and step the environment with the agent's action
    action, _, v_out, new_lstm_state, _ = agent.get_action(observation, lstm_state)
    next_observation, reward, done, info = env.step(action)

    # Save experience in a buffer
    agent.buffer.append(
        obs=observation,
        action=action,
        reward=reward,
        next_obs=next_observation,
        done=done,
        lstm_state=lstm_state,
    )

    observation = next_observation
    lstm_state = new_lstm_state

    # Log performance
    logger.log_step(
        global_step=step,
        worker_step=step,
        reward=reward,
        info=info,
        loss=loss,
        entropy_coef=agent.e_loss_coef,
    )

    # Update agent parameters if an episode is done, or
    # if the agent experience buffer == max_unroll_steps
    loss, grads, num_steps = agent.update(done, update_params=True)
    step += 1

    # If done, reset the environment and LSTM state
    if done:
        episode += 1
        done = False
        lstm_state = initial_lstm_state
        observation = env.reset()

print('Done training!')

In [None]:
#@title Save model after training is complete.
results = {
    "params": agent.params,
    "config": config.to_dict(),
}
with open(params_filename + '.pickle', 'wb') as fp:
    pickle.dump([results], fp)
print("Saved parameters.")