<a href="https://colab.research.google.com/github/mauriciogtec/llm_policies_for_text_based_rl_tutorial/blob/main/llm_policies_for_text_based_rl_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# MIT License
#
# @title Copyright (c) 2024 TAFM Workshop Authors (c) 2021 CCAI Community Authors { display-mode: "form" }
#
# Template modified from
# https://colab.research.google.com/drive/1_RPUB26HWVk7SuH3OsA5FnyLx3ThaFV-?usp=sharing
# Copyright (c) 2021 CCAI Community Authors
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

# LLM Policies for Text-based Reinforcement Learning: An Interactive Tutorial

Author

[**Mauricio Tec**](mauriciogtec.com), *Harvard University*, [`mauriciogtec@hsph.harvard.edu`](mailto:mauriciogtec@hsph.harvard.edu)

**TL;DR**. This interactive tutorial uses LLMs for text-based reinforcement learning, considering key topics such as quantization, low-rank adaptation, fine-tuning with expert demonstrations, and reinforcement learning via proximal policy optimization.

# Table of Contents


*   [Overview](#overview)
*   [Background & Prerequisites](#background-and-prereqs)
*   [Software Requirements](#software-requirements)
*   [Foundation Model and Task Description](#fm-task-description)
*   [Methodology](#methodology)
*   [Experiments & Discussion](#experiments-and-discussion)
*   [References](#references)


<a name="overview"></a>
# Overview


**Summary**. This interactive notebook demonstrates how to train a reinforcement-learning agent in text-based environments using an LLM-parameterized policy. The ability to understand and generate natural language is essential for AI systems that need to interact with humans in a variety of domains, including customer service, healthcare, and education. By training reinforcement-learning agents in text-based environments, we can develop AI systems that are better able to understand and respond to human language, leading to more effective and versatile AI applications. The challenges faced by text-based agents are shared with other unstructured data environments, such as images and video. Investigating LLMs has implication beyond text, while maintaining the advantage of the emergence of highly capable text-based foundation models (LLMs) that can be finetuned with standard GPUs.

In this tutorial, we will specifically train an agent interacting with the [Textworld (Côté et al. 2019)](https://arxiv.org/abs/1806.11532) environment, which was used in the [2019 Textworld competition](https://www.microsoft.com/en-us/research/project/textworld/competition/) organized by Microsoft Research. Other text-based environments could be considered, such as the classic [Zork game (Anderson et al., 1977)](https://en.wikipedia.org/wiki/Zork) or the Jericho’s suit of games [(Hausknecht et al., 2019)](https://arxiv.org/pdf/1909.05398). Nonetheless, Textworld provides a fast and highly parameterizable environment. We will use an easy-to-solve configuration for the purpose of this tutorial.

**Target Audience**. This tutorial is designed for researchers familiar with reinforcement learning but with possibly limited hands-on experience in training and fine-tuning LLMs.

**Learning Objective**. The tutorial covers key topics such as (1) parameterizing a policy using an LLM for text generation, (2) supervised fine-tuning with expert demonstrations, and (3) reinforcement learning with proximal policy optimization. The tutorial highlights strategies for efficient computation and memory management, including quantization and low-rank adaptation, which are crucial for scenarios with limited computational resources.

### Contributions

* It facilitates barrier to entry to RL folks
* It highlights elements to pay attention and challenges in the area
* Provides a useful starting code

The structure of training a text-based agents bears similarities with reinforcement learning from human feedback (RLHF). However, in constrat to RLHF, the environment dynamics and reward function are externally by the environment [(Carta et al. 2023)](https://arxiv.org/pdf/2302.02662v3). This difference requires an outer loop to control the experience collection.


<a name="background-and-prereqs"></a>
# Background & Prerequisites



### Related Work

### Background

### Other References
Feel free to include additional resources (e.g. research papers, blog posts, textbooks) for the readers to further study.

<a name="software-requirements"></a>
# Software Requirements
Include in this section the software requirements, setup instructions, and library imports.

Example:

This notebook requires Python >= 3.10

In [None]:
# print python version
import sys
print(sys.version)

3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]


Install dependencies. Pip will thrown an error, but it can be safely ignored. The error is due to the latest version of torchrl.

In [None]:
# %pip install -q \
#     "torch==2.2.1" \
#     "torchaudio==2.2.1" \
#     "flash-attn==2.5.8" \
#     "textworld-express==1.0.4" \
#     "gymnasium==0.29.1" \
#     "transformers==4.40.1" \
#     "accelerate==0.29.3" \
#     "peft==0.10.0" \
#     "bitsandbytes==0.43.1" \
#     "tqdm==4.66.2" \
#     "IProgress==0.4" \
#     "matplotlib==3.7.1" \
#     "accelerate" \
#     "trl"

%pip install -q \
  "torch" \
  "torchaudio" \
  "textworld-express==1.0.4" \
  "gymnasium" \
  "tqdm==4.66.2" \
  "IProgress==0.4" \
  "matplotlib==3.7.1" \
  "git+https://github.com/huggingface/trl" \
  "git+https://github.com/huggingface/transformers" \
  "peft" \
  "accelerate" \
  "bitsandbytes" \
  "ipdb" # for debugging

# "optimum"
# "auto-gptq"

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


### Auxiliary functions and libraries

In [None]:
import torch
from tqdm.auto import tqdm
import numpy as np
import transformers
from datasets import Dataset, load_dataset
from trl import (
  ModelConfig,
  SFTTrainer,
  get_kbit_device_map,
  get_peft_config,
  get_quantization_config,
  SFTConfig,
  setup_chat_format,
)
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig, DataCollatorForSeq2Seq
from peft import LoraConfig, TaskType, get_peft_model
from IPython.display import Markdown, display
import time
from textworld_express import TextWorldExpressEnv
from torch.utils.data import DataLoader



def printmd(string):
    display(Markdown(string))


<a name="fm-task-description"></a>
# Foundation Model and Task Description

In this section, kindly provide a brief description of the types of foundation models that will be used (for example ChatGPT-X or Llama-X). Describe also the type of task or benchmark environment, please add a more detailed decription of the task. Feel free to provide external links and resources that discuss the specific details of the tasks or pre-trained models.

### The TextWorld RL Environment

The goal will be to solve a task in a cooking world virtual environment.

In [None]:
STEP_LIMIT = 16

def make_env(
    game_name="cookingworld",
    step_limit=STEP_LIMIT,
    numLocations=3,
    includeDoors=1,
    numDistractorItems=2,
    numIngredients=2,
    **kwargs
):
  """Make a TextWorldExpressEnv RL environment."""

  # initialize game generator
  env = TextWorldExpressEnv(envStepLimit=step_limit)

  # set game default params and update with kwargs
  game_params = {
      "numLocations": numLocations,
      "includeDoors": includeDoors,
      "numDistractorItems": numDistractorItems,
      "numIngredients": numIngredients,
  }
  game_params = ",".join([f"{k}={v}" for k, v in game_params.items()])
#
  # load game
  env.load(
      gameName=game_name,
      gameParams=game_params,
  )

  return env

# test
env = make_env()
obs, info = env.reset(seed=123456, gameFold="train")

for key, value in info.items():
  print(f"**{key}**: {value}")

**observation**: You are in the kitchen. In one part of the room you see a stove. There is also an oven. You also see a fridge that is closed. In another part of the room you see a counter that has a red apple, and a cookbook on it. In one part of the room you see a dining table, that has nothing on it. There is also a cutlery drawer that is closed. You also see a trash can that is closed. In another part of the room you see a dishwasher that is closed. In one part of the room you see a dining chair, that has nothing on it. 
To the North you see a closed sliding patio door. To the South you see a closed frosted-glass door. 
**look**: You are in the kitchen. In one part of the room you see a stove. There is also an oven. You also see a fridge that is closed. In another part of the room you see a counter that has a red apple, and a cookbook on it. In one part of the room you see a dining table, that has nothing on it. There is also a cutlery drawer that is closed. You also see a trash ca

## Collect expert dataset for SFT

In [None]:
obs, info = env.reset(
  seed=123456, gameFold="train", generateGoldPath=True
)
env.getGoldActionSequence()

['look around',
 'take cookbook',
 'read cookbook',
 'open cutlery drawer',
 'take knife',
 'take red apple',
 'open fridge',
 'take yellow bell pepper',
 'close fridge',
 'chop yellow bell pepper',
 'slice red apple',
 'cook yellow bell pepper in stove',
 'prepare meal',
 'eat meal']

Collect random experience for supervised fine-tuning using the generate gold path function.

In [None]:
INSTRUCTION_TEMPLATE = (
  "# TASK: {}\n"
  "Choose exactly one valid action."
)

GAME_MSG_TEMPLATE = "\n# STEP {} \nObs: {}"

def expert_rollout(
  env, max_history=0, seed=None, gameFold="train"
) -> list[dict]:
  "Rollout a game using the provided model."

  # reset the environment and obtain expert (goldPath) sequence
  obs, info = env.reset(seed=seed, gameFold=gameFold, generateGoldPath=True)
  expert_path = env.getGoldActionSequence()

  # obtain instruction from game
  instr = INSTRUCTION_TEMPLATE.format(info["taskDescription"])

  # create buffers
  messages = [instr]

  # rollout loop
  done = False
  step = 0
  transitions = []
  while not done:
    # format observation
    game_msg = GAME_MSG_TEMPLATE.format(step, obs) # , valid)

    # make prompt for the LLM
    valid = ", ".join(info["validActions"])
    question = f"\nValid actions: {valid}\nAction: "
    prompt_msgs = messages[-max_history:] + [game_msg] + [question]
    if step > max_history:
        prompt_msgs.insert(0, messages[0])
    prompt = ''.join(prompt_msgs)

    # get expert action
    expert_action = expert_path.pop(0)

    # step the environment
    obs, reward, done, info = env.step(expert_action)

    # add to transitions
    transitions.append(
      {"prompt": prompt, "action": expert_action, "reward": reward, "done": done}
    )
    messages.append(game_msg + "\nAction: " + expert_action)
    step += 1

  return transitions


# outer loop / create dataset
rollouts = []
num_rollouts = 900
max_history = 11
transitions = []
for i in tqdm(range(num_rollouts)):
  transitions = expert_rollout(env, seed=123456 * i, max_history=max_history)
  rollouts.extend(transitions)
dataset = Dataset.from_list(rollouts)

# get train/eval splits
num_prompts = len(rollouts)
num_train = 500
num_eval = 100
sample = np.random.choice(num_prompts, num_train + num_eval, replace=False)
train_sample = sample[:num_train]
eval_sample = sample[num_train:]
train_dataset = dataset.select(train_sample)
eval_dataset = dataset.select(eval_sample)

  0%|          | 0/900 [00:00<?, ?it/s]

In [None]:
printmd("### Prompt")
print(train_dataset["prompt"][1])
printmd("### Expert action")
print(train_dataset["action"][1])

### Prompt

# TASK: You are hungry! Let's cook a delicious meal. Check the cookbook in the kitchen for the recipe. Once done, enjoy your meal!
Choose exactly one valid action.
# STEP 0 
Obs: You are in the kitchen. In one part of the room you see a stove. There is also an oven. You also see a fridge that is closed. In another part of the room you see a counter that has a knife, a banana, and a cookbook on it. In one part of the room you see a dining table, that has nothing on it. There is also a cutlery drawer that is closed. You also see a trash can that is closed. In another part of the room you see a dishwasher that is closed. In one part of the room you see a dining chair, that has nothing on it. 
To the North you see a closed patio door. To the South you see a closed plain door. 
Action: look around
# STEP 1 
Obs: You are in the kitchen. In one part of the room you see a stove. There is also an oven. You also see a fridge that is closed. In another part of the room you see a counter that has 

### Expert action

take banana


## Supervised Fine-Tuning

In [None]:
model_name = "Rocketknight1/falcon-rw-1b"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, use_fast=True)
tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
tokenizer.padding_side = 'left'
tokenizer.pad_token_id

50257

In [None]:
MAX_SEQ_LEN = 2048  # ensures things will run on the free T4 GPU

def preprocess_function(examples):
    inputs = examples['prompt']
    targets = examples['action']
    targets = [x + tokenizer.eos_token for x in targets]

    tok_inputs = tokenizer(
        inputs,
        max_length=MAX_SEQ_LEN,
        truncation=True,
    )
    tok_targets = tokenizer(targets)
    combined_tokens = [
        x + y
        for x, y in zip(tok_inputs.input_ids, tok_targets.input_ids)
    ]
    attention_mask = [
        x + y
        for x, y in zip(tok_inputs.attention_mask, tok_targets.attention_mask)
    ]
    target_lengths = [len(x) for x in tok_targets['input_ids']]

    # make input/target pairs for next token prediction
    input_ids = [x[:-1] for x in combined_tokens]
    labels = [x[1:] for x in combined_tokens]

    # for any other part of examples, include

    return {
        'input_ids': input_ids,
        'labels': labels,
        'target_lengths': target_lengths,
        'reward': examples['reward'],
    }

tokenized_datasets = train_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=train_dataset.column_names
)

# Convert lists to tensors using DataCollator
data_collator = DataCollatorForSeq2Seq(tokenizer)

# Data loader
train_dataloader = torch.utils.data.DataLoader(
    tokenized_datasets,
    batch_size=1,
    shuffle=True,
    collate_fn=data_collator,
)

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [None]:
for k, v in next(iter(train_dataloader)).items():
  print(f"{k}: {v.shape}")

reward: torch.Size([1])
input_ids: torch.Size([1, 724])
target_lengths: torch.Size([1])
attention_mask: torch.Size([1, 724])
labels: torch.Size([1, 724])


In [None]:
import bitsandbytes as bnb
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler

# dtype = torch.bfloat16. # ~ for some reason is super slow in T4 Colab GPU
dtype = torch.float16

quantization_config = BitsAndBytesConfig(
  load_in_4bit=True,
  bnb_4bit_compute_dtype=dtype,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    torch_dtype=dtype,
    low_cpu_mem_usage=True,
    use_cache = False
)

peft_config = LoraConfig(
  task_type=TaskType.CAUSAL_LM,
  r=4,
  lora_alpha=32,
  target_modules="all-linear",
  lora_dropout=0.25,
  bias="none",
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 3,145,728 || all params: 1,314,770,944 || trainable%: 0.2393


Make an initial test of the model generation capabilities

In [None]:
def generate_answer(prompt, model, tokenizer, sample=False):
    model.eval()

    # Tokenize the prompt
    tokens = tokenizer(prompt, return_tensors="pt", truncation=True)
    tokens = tokens.to(model.device)
    input_len = tokens.input_ids.shape[1]

    # Generate the output
    output = model.generate(
        input_ids=tokens.input_ids,
        attention_mask=tokens.attention_mask,
        do_sample=sample,
        num_return_sequences=1,
        max_new_tokens=15,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

    # Decode the generated answer
    answer = output[0][input_len:]
    answer = tokenizer.decode(answer, skip_special_tokens=True)

    return answer

prompt = eval_dataset[1]['prompt']
printmd("### Prompt")
print(prompt)

answer = generate_answer(prompt, model, tokenizer)
printmd("### Answer without finetuning")
print(answer)

### Prompt

# TASK: You are hungry! Let's cook a delicious meal. Check the cookbook in the kitchen for the recipe. Once done, enjoy your meal!
Choose exactly one valid action.
# STEP 0 
Obs: You are in the kitchen. In one part of the room you see a stove. There is also an oven. You also see a fridge that is closed. In another part of the room you see a counter that has a knife, and a cookbook on it. In one part of the room you see a dining table, that has nothing on it. There is also a cutlery drawer that is closed. You also see a trash can that is closed. In another part of the room you see a dishwasher that is closed. In one part of the room you see a dining chair, that has nothing on it. 
To the North you see a closed sliding patio door. To the East you see a closed plain door. 
Action: look around
# STEP 1 
Obs: You are in the kitchen. In one part of the room you see a stove. There is also an oven. You also see a fridge that is closed. In another part of the room you see a counter that has a k

### Answer without finetuning


Valid actions: look around, look at knife, look at counter,


Next cell takes ~10 minutes for 3 epochs.

In [None]:
from collections import defaultdict
from torch.nn import functional as F
import numpy as np
from accelerate import Accelerator
import gc

# Num epochs
num_epochs = 1
gradient_accumulation_steps = 8

# Optimizer (8 bit variant)
opt = bnb.optim.PagedAdamW8bit(model.parameters(), lr=5e-5, weight_decay=0.01)

# Scheduler
num_training_steps = len(train_dataloader) * num_epochs  # 3 epochs
scheduler = get_scheduler(
    name="linear",
    optimizer=opt,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

# Loss fn
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# Setup the accelerator to handle devices and gradient accumulation
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
model, opt, train_dataloader, scheduler = accelerator.prepare(
  model, opt, train_dataloader, scheduler
)

# Training step function
def train_step(batch, model, opt, scheduler, accelerator):
    with accelerator.accumulate(model):
        # Forward pass
        outputs = model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask']
        )

        # Compute cross-entropy loss. Only count losses for predicting
        # The completion, not the original text.
        losses = []
        batch_size = len(batch['input_ids'])
        for b in range(batch_size):
            L = batch['target_lengths'][b]
            loss = loss_fn(outputs.logits[b][-L:], batch['labels'][b][-L:])
            losses.append(loss)
        loss = sum(losses) / len(losses)

        # Backward pass
        accelerator.backward(loss)
        opt.step()
        scheduler.step()
        opt.zero_grad()

        return loss.item()

# Training loop
for epoch in range(num_epochs):
    printmd(f"### Epoch {epoch}")

    model.train()
    epoch_losses = []

    for batch in tqdm(train_dataloader, total=len(train_dataloader)):
        loss = train_step(batch, model, opt, scheduler, accelerator)
        epoch_losses.append(loss)

    printmd(f"**Loss**: {np.mean(epoch_losses):.4f}")


print("Training completed")
model.save_pretrained("ckpt_sft.pt")

# Clean up
torch.cuda.empty_cache()
gc.collect()

### Epoch 0

  0%|          | 0/500 [00:00<?, ?it/s]

**Loss**: 0.9739

Training completed


265

In [None]:
prompt = train_dataset[0]['prompt']
printmd("### Prompt")
print(prompt)

answer = generate_answer(prompt, model, tokenizer)
printmd("### Answer with finetuning")
print(answer)

### Prompt

# TASK: You are hungry! Let's cook a delicious meal. Check the cookbook in the kitchen for the recipe. Once done, enjoy your meal!
Choose exactly one valid action.
# STEP 4 
Obs: You open the cutlery drawer. The cutlery drawer contains a knife.
Action: take knife
# STEP 5 
Obs: You take the knife.
Action: take red apple
# STEP 6 
Obs: You take the red apple.
Action: open fridge
# STEP 7 
Obs: You open the fridge. It's empty inside.
Action: close fridge
# STEP 8 
Obs: You close the fridge.
Action: open kitchen cupboard
# STEP 9 
Obs: You open the kitchen cupboard. It's empty inside.
Action: close kitchen cupboard
# STEP 10 
Obs: You close the kitchen cupboard.
Action: open door to north
# STEP 11 
Obs: You open the plain door, revealing the pantry. 
Action: move north
# STEP 12 
Obs: You are in the pantry. In one part of the room you see a folding chair, that has nothing on it. There is also a shelf that has some salt on it. 
Through an open plain door, to the South you see the kitchen.

### Answer with finetuning

move south


# ^ Training loop working!!!
1B LLM generated sensible responses!

# Reinforcement Learning

<a name="methodology"></a>
# Methodology


### This section's content:

1. [Initial evaluation of pure LLM agents](#initial-eval)
  - 1.1. [Prompt design and test](#prompt-design)
  - 1.2. [Rollout evaluation](#rollout-eval)
2. [Improvement via RL](#rl-improvement)
  - 2.1 [Policy/Value networks and LORA](#arch)
  - 2.2 [Gym wrapper](#gym-wrapper)
  - 2.3 [PPO Training](#ppo)
3. Combining RL with RAG

<a name="initial-eval"></a>
## 1. Initial evaluation of pure LLM agents


<a name="prompt-design"></a>
### 1.1 Simple prompt design


THIS SECTION IS DEPRECATED.
PROMPT IS ABOVE

<!-- Now test using an LLM to select one action. We need to **create a prompt** do do so. The ingredients for the prompt are.
1. The overall instruction.
2. A function `get_state` to transform the agent's past history an actions into a textual representation. (We will initially set this to concatenation)
3. A function `make_prompt` that takes the state and the environment's allowed actions to create a prompt for the LLM.

We provide an initial basic implementation of these functions that will be refined throughout the tutorial. -->

<a name="rollout-eval"><a/>
### 2.2 Rollout Evaluation

We will write a quick script for evaluating the models in a number of games

### 2. RL Improvement
<a name="rl-improvement"></a>

We will now use PPO and LORA to optimize the underperforming LLM.

<!-- We will need the following ingredients:

1. A definition of policy and value networks.
2. A wrapper of the TextWorld environment for compatibility with the `gynnasium` and `stable-baselines` framewrok. -->

### Policy and Value Networks and LORA
<a name="arch"></a>

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class ActorCritic(nn.Module):
  """Agent class modeled after transformers.AutoModelForCausalLMWithValueHead"""
  transformers_parent_class = AutoModelForCausalLM

  def __init__(self, model, tokenizer):
    super().__init__()
    self.model = model
    self.tokenizer = tokenizer

    # value function related stuff
    self.hidden_size = self.model.config.hidden_size

    # define and initialize critic head
    self.v_head = nn.Linear(
      self.hidden_size, 1, dtype=self.model.dtype
    ).to(self.model.device)

  def forward(self, input_ids, attention_mask):
    outputs = self.model(
      input_ids=input_ids,
      attention_mask=attention_mask,
      output_hidden_states=True,
    )

    # here's a small modification from PPO for RLHF since we want the
    # last_token_hidden_state not the last_hidden_state for all tokens
    last_token_hidden_state = outputs.hidden_states[-1]
    lm_logits = outputs.logits
    loss = outputs.loss
    value = self.v_head(last_token_hidden_state).squeeze(-1)

    return (lm_logits, loss, value)

We will now perform the main RL training loop. The expected running time is about 15 mins. Consider that this time is for demonstration, as more training is needed for good results.

In [None]:
#  epoch breakdown ~  5 min per epoch. 20 epochs ~ 100 minutes
from argparse import Namespace
import ipdb

# Hyperparameters
cfg = {
  # Optimization parameters
  "epochs": 20,
  "lr": 1.41e-5,  # from hugging face PPO trainer
  "clip_grad_norm": None,
  "clip_value": 0.2,
  # PPO hyperparams
  "gamma": 1.0,
  "clip_ratio": 0.2,  # clip ratio
  "wd": 1e-8,
  "whiten_rewards": False,
  "entropy_coef": 0.0,
  "vf_coef": 0.1,
  "gae_lambda": 0.95,
  "kl_coef": 0.2,
  # Data colleciton loop setup
  "num_rollouts": 20,
  "max_history": STEP_LIMIT - 1,
  "num_training_samples": 320,
  # Optimization params
  "batch_size" : 1,
  "gradient_accumulation_steps": 8,
  "inner_optimization_epochs": 1,
}
cfg = Namespace(**cfg)


# =============================
# ==== Auxiliary functions ====
# =============================

def logprobs_from_logits(logits, labels):
  logp = F.log_softmax(logits, dim=-1)
  logpy = torch.gather(logp, -1, labels.unsqueeze(-1)).squeeze(-1)
  return logpy


def scale_rewards(rewards, whiten_rewards=True):
  if whiten_rewards:
      r_mu, r_sig = np.mean(rewards), np.std(rewards) + 1e-2
      scaled_rewards = [(r - r_mu) / r_sig for r in rewards]
  else:
      scaled_rewards = rewards
  return scaled_rewards

# we will define model_rollout, rollout_step, training_step, evaluation_step

def model_rollout(
  env,
  model,
  tokenizer,
  max_history=0,
  seed=None,
  gameFold="test",
  # eps: float = 0.0,
  sample=False,
) -> list[dict]:
  "Rollout a game using the provided model."

  # reset the environment and obtain expert (goldPath) sequence
  obs, info = env.reset(seed=seed, gameFold=gameFold)

  # obtain instruction from game
  instr = INSTRUCTION_TEMPLATE.format(info["taskDescription"])

  # create buffers
  messages = [instr]

  # rollout loop
  done = False
  step = 0
  transitions = []
  while not done:
    # format observation
    game_msg = GAME_MSG_TEMPLATE.format(step, obs)

    # make prompt for the LLM
    valid = ", ".join(info["validActions"])
    question = f"\nValid actions: {valid}\nAction: "
    prompt_msgs = messages[-max_history:] + [game_msg] + [question]
    if step > max_history:
        prompt_msgs.insert(0, messages[0])
    prompt = ''.join(prompt_msgs)

    # Get answer from LLM
    action = generate_answer(prompt, model, tokenizer, sample=sample)

    # step the environment
    obs, reward, done, info = env.step(action)

    # add to transitions
    transitions.append(
      {"prompt": prompt, "action": action, "reward": reward, "done": done}
    )
    messages.append(game_msg + "\nAction: " + action)
    step += 1

  return transitions


@torch.no_grad()
def rollout_step(env, agent, tokenizer, ref_model, cfg):
  agent.eval()
  transitions = []
  returns = []
  for i in tqdm(range(cfg.num_rollouts)):
    # == Collect data ===
    out = model_rollout(
      env,
      agent.model,
      tokenizer,
      max_history=cfg.max_history,
      gameFold="train",
      sample=True,
    )
    transitions.extend(out)
    returns.append(sum([t["reward"] for t in out]))

  # Precompute scaling contants
  r_mu, r_sig = 0, 1
  rewards = [t["reward"] for t in transitions]
  if cfg.whiten_rewards:
    r_mu, r_sig = np.mean(rewards), np.std(rewards) + 1e-2

  # Compute discounted returns, values, and advantages, etc.
  disc_return = 0
  next_value = 0
  advantage = 0
  buffer = []
  for i in reversed(range(len(transitions))):
    tr = transitions[i]

    # compute discounted return
    scaled_reward = (tr['reward'] - r_mu) / r_sig
    disc_return = scaled_reward + cfg.gamma * disc_return * (1 - tr["done"])

    # Prepare tokens for model evaluation, needed for values and log probs
    prompt_tokens = tokenizer(
      tr["prompt"],
      return_tensors="pt",
      truncation=True,
      max_length=MAX_SEQ_LEN
    ).input_ids
    target_tokens = tokenizer(
        tr["action"] + tokenizer.eos_token, return_tensors="pt"
    ).input_ids
    cat = torch.cat([prompt_tokens, target_tokens], dim=-1)
    input_ids = cat[0, :-1]
    labels = cat[0, 1:]
    attention_mask = torch.ones_like(input_ids)
    target_length = len(target_tokens[0])

    # == Obtain value, log probs and advantage ==

    # Prepare inputs
    input_ids = input_ids.to(accelerator.device)
    attention_mask = attention_mask.to(accelerator.device)

    # Obtain value
    curr_value = value if i < len(transitions) - 1 else 0
    logits, _, value = agent(input_ids, attention_mask)

    # Value output is for each token, take value at last token of prompt
    # before the completion/action by offseting by target length
    value = value[:, -target_length].item()

    # Obtain advantage with LamGAE
    next_value = next_value * (1 - tr["done"])
    delta = scaled_reward + cfg.gamma * next_value - value
    advantage = delta + cfg.gamma * cfg.gae_lambda * advantage

    # Obtain log prob of each token of the response
    full_log_prob =
    log_prob = logprobs_from_logits(
      logits[:, -target_length:],
      labels[None, -target_length:].to(accelerator.device)
    ).sum()

    # Obtain reference logits and compute KL penalty
    ref_logits = ref_model(input_ids, attention_mask).logits
    ref_log_prob = logprobs_from_logits(
      ref_logits[:, -target_length:],
      labels[None, -target_length:].to(accelerator.device)
    ).sum()



    buffer.append(
      {
        'input_ids': input_ids,
        'labels': labels,
        'attention_mask': attention_mask,
        'target_lengths': target_length,
        'log_prob': log_prob,
        'value': value,
        'scaled_reward': scaled_reward,
        'advantage': advantage,
        'disc_return': disc_return,
      }
    )

  # Make buffer for training
  buffer = Dataset.from_list(buffer)
  subsample = min(len(buffer), cfg.num_training_samples)
  buffer = buffer.shuffle().select(range(subsample))
  loader = DataLoader(
    buffer, batch_size=cfg.batch_size, collate_fn=data_collator, shuffle=True
  )

  return loader, returns


def training_step(agent, loader, opt, accelerator, cfg):
  # Prepare model, loader, etc with accelerate
  agent.train()
  agent, loader, opt = accelerator.prepare(agent, loader, opt)

  # Buffers for loss from each batch
  policy_losses = []
  value_losses = []
  entropy_losses = []
  kl_losses = []
  total_losses = []

  # Get dtype from model
  dtype = agent.model.dtype

  for epoch in tqdm(range(cfg.inner_optimization_epochs)):
    for batch in loader:
      with accelerator.accumulate(agent):
        new_value = []
        entropy_loss = []
        policy_loss = []
        kl_loss = []

        for j in range(cfg.batch_size):
          # Compute state value from prompt + action tokens
          target_length = batch.target_lengths[j]
          logits, _, value = agent(
            batch.input_ids[j, None], batch.attention_mask[j, None]
          )
          logits = logits[:, -target_length:]
          labels = batch.labels[j, None, -target_length:]
          new_value.append(value[:, -target_length])

          # Entropy loss
          pd = F.softmax(logits, -1)
          neg_ent = torch.logsumexp(logits, -1) - (pd * logits).sum(-1)
          entropy_loss.append(neg_ent.mean())

          # Policy loss
          new_log_prob = logprobs_from_logits(logits, labels).sum()
          old_log_prob = batch.log_prob[j].to(dtype)
          adv = batch.advantage.to(dtype).unsqueeze(-1)
          prob_ratio = (new_log_prob - old_log_prob).clamp(-10, 10).exp()
          surr1 = - prob_ratio * adv
          surr2 = - prob_ratio.clamp(1 - cfg.clip_ratio, 1 + cfg.clip_ratio) * adv
          policy_loss.append(torch.max(surr1, surr2).mean())

          # KL loss
          ipdb.set_trace
          kl_loss.append(0.5 * (new_log_prob - old_log_prob)**2)

        # Critic loss / as in PPO hugging face
        disc_return = batch.disc_return.to(dtype)
        old_value = batch.value.to(dtype)

        new_value = torch.cat(new_value)
        new_value_clipped = new_value.clamp(
          old_value - cfg.clip_value,
          old_value + cfg.clip_value,
        )
        value_loss_1 = (new_value - disc_return) ** 2
        value_loss_2 = (new_value_clipped - disc_return) ** 2
        value_loss = 0.5 * torch.max(value_loss_1, value_loss_2).mean()

        # Total loss
        kl_loss = sum(kl_loss) / cfg.batch_size
        policy_loss = sum(policy_loss) / cfg.batch_size
        entropy_loss = sum(entropy_loss) / cfg.batch_size
        loss = (
          policy_loss
          + cfg.vf_coef * value_loss
          + cfg.entropy_coef * entropy_loss
          + cfg.kl_coef * kl_loss
        )

        # Optimize the model
        accelerator.backward(loss)
        if cfg.clip_grad_norm is not None:
          accelerator.clip_grad_norm_(agent.parameters(), cfg.clip_grad_norm)
        opt.step()
        opt.zero_grad()

        # Update metrics
        policy_losses.append(policy_loss.item())
        value_losses.append(value_loss.item())
        entropy_losses.append(entropy_loss.mean().item())
        kl_losses.append(kl_loss.mean().item())
        total_losses.append(loss.item())

  return {
    "policy_loss": policy_losses,
    "value_loss": value_losses,
    "entropy_loss": entropy_losses,
    "kl_loss": kl_losses,
    "total_loss" : total_losses
  }


@torch.no_grad()
def eval_step(agent, env, tokenizer, cfg):
  returns = []
  agent.eval()
  for i in tqdm(range(cfg.num_rollouts)):
    transitions = model_rollout(
      env,
      agent.model,
      tokenizer,
      max_history=cfg.max_history,
      gameFold="test",
      sample=False
    )
    returns.append(sum([t["reward"] for t in transitions]))

  return returns

# ===================================
# ==== Minimal PPO training loop ====
# ===================================

# Agent
agent = ActorCritic(model, tokenizer)

# Environment
env = make_env(step_limit=STEP_LIMIT)  # short games

# Optimizer
opt = bnb.optim.PagedAdamW8bit(agent.parameters(), lr=cfg.lr, weight_decay=cfg.wd)

# Accelerator
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)

# Clean memory before starting to reduce change of memory errors
gc.collect()
torch.cuda.empty_cache()

# Best model checkpointing
best_eval_return = -np.inf

# Results
eval_mean_returns = []
train_mean_returns = []

# Training loop
for epoch in range(cfg.epochs):
  printmd(f"### Epoch {epoch + 1}")

  print(f"Collecting rollouts...")
  loader, train_returns = rollout_step(env, agent, tokenizer, cfg)

  print(f"Training...")
  losses = training_step(agent, loader, opt, accelerator, cfg)

  print(f'Evaluating...')
  eval_returns = eval_step(agent, env, tokenizer, cfg)

  # Checkpoints
  eval_mean_returns.append(np.mean(eval_returns))
  train_mean_returns.append(np.mean(train_returns))

  if eval_mean_returns[-1] > best_eval_return:
    best_eval_return = eval_mean_returns[-1]
    agent.model.save_pretrained("ckpt_best_rl.pt")

  # Print mebtrics
  losses["mean_train_return"] = train_mean_returns[-1]
  losses["mean_eval_return"] = eval_mean_returns[-1]
  msg = ', '.join([f"**{k}**: {np.mean(v):.3f}" for k, v in losses.items()])
  printmd(msg)

# Clean up memory
gc.collect()
torch.cuda.empty_cache()

SyntaxError: invalid syntax (<ipython-input-1-f03e7f2582a8>, line 178)

In [None]:
# Eval best model
best_model = AutoModelForCausalLM.from_pretrained("ckpt_best_rl.pt")
eval_returns = eval_step(best_model, env, tokenizer, cfg)
print(np.mean(eval_returns))

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

df = pd.DataFrame({'train': train_mean_returns, 'eval': eval_mean_returns})
df.plot()
plt.title('RL results')

Now use a quick loop to evaluate the improved agent's performance

### 😡 Why is our eval policy doing so much worse than training time collection? Tried both `sample=True` and `sample=False`

<a name="experiments-and-discussion"></a>
# Experiments & Discussion

In this section, describe your experiments and results. Briefly describe the performance metrics used and the significance of the results.

In [None]:
# Insert code here. Feel free to break this up into several code
# cells, interleaved with explanatory text.

Finally, include a discussion on the limitations and important takeaways from the exercise.

## Limitations
*  Reflect on the potential biases or problems in the analysis presented in your tutorial, including its potential societal impact, and discuss how you might go about addressing this challenge.

## Next Steps
*   What do you imagine would be the next steps for your readers after finishing your tutorial?

<a name="references"></a>
# References

*   EarthCube Notebook Template: https://github.com/earthcube/NotebookTemplates
*   Earth Engine Community Tutorials Style Guide: https://developers.google.com/earth-engine/tutorials/community/styleguide#colab
*   Google Cloud Community Tutorial Style Guide: https://cloud.google.com/community/tutorials/styleguide
*   Rule A, Birmingham A, Zuniga C, Altintas I, Huang S-C, Knight R, et al. (2019) Ten simple rules for writing and sharing computational analyses in Jupyter Notebooks. PLoS Comput Biol 15(7): e1007007. https://doi.org/10.1371/journal.pcbi.1007007




# Submitting the Tutorial *(Please remove this section from your submission.)*

If you are using Google Colab, make sure to change the permissions by clicking "Share" (upper right corner of the notebook) >> Change permissions to "Anyone on the internet with this link can comment".

See our website for additional instructions:
https://sites.google.com/view/tafm

For additional questions, please feel free to contact:
*   tafm.rlc@gmail.com