<a href="https://colab.research.google.com/github/louiezzang/next-gpt/blob/main/examples/chatgpt_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Overview


What is RLHF? <br>
See [this link](https://gist.github.com/JoaoLages/c6f2dfd13d2484aa8bb0b2d567fbf093).

<br>

**Example of RLHF dataset**:

Total 3 datasets are needed for training the 3 steps(SFT, RM and PPO)
- [Example of dataset](https://github.com/nebuly-ai/nebullvm/tree/main/apps/accelerate/chatllama#dataset-preparation)
- [Example of dataset 1](https://huggingface.co/datasets/stanfordnlp/SHP)
- [Example of dataset 2](https://huggingface.co/datasets/Anthropic/hh-rlhf)

step1) Dataset for SFT(Supervised Fine-tuning training)
```json
[
    {
        "prompt": "",
        "completion": ""        
    }, ...
]
```

step2) Dataset for RM(Reward Model) training: There are multiple completetions with human rated ranking score for one prompt.
```json
[
    {
        "prompt": "",
        "completion_1": "",
        "completion_2": "",
        "completion_3": "",            
        "ranking": [1, 0, 2]
    }, ...
]
```
    
step3) Dataset for PPO(RLHF) training: It only consists of prompt.
```json
[
    {
        "prompt": ""
    }, ...
]
```

# Environment setup

#### Installation (python>=3.8)

In [None]:
# Install next-gpt lib.
!rm -rf ./next-gpt/
!git clone https://github.com/louiezzang/next-gpt.git
%cd next-gpt/
!pip install .
%cd ../

# Step 1) SFT: Surpervised Fine-tuning
Build a Supervised Fine-tuning model to answer well to the question.

- Refereneces
  - [fine tuning code_1](https://github.com/philschmid/fine-tune-GPT-2/blob/master/Fine_tune_a_non_English_GPT_2_Model_with_Huggingface.ipynb)
  - [fine tuning code_2](https://github.com/Beomi/KoAlpaca/blob/main/train.py)


- SFT(Supervised Fine Tuning)
- Fine-tune a pretrained LLM on a specific domain or corpus of instructions and human demonstrations

- Dataset example
```json
[
    {
        "prompt": "",
        "completion": ""        
    }, ...
]
```

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import json
import yaml
import argparse

import numpy as np
import pandas as pd

import loralib as lora
import torch
import torch.distributed as dist

import transformers
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from datasets import load_dataset

from nextgpt.dataset import SupervisedDataset, DataCollatorForSupervisedDataset
from nextgpt.trainer import SFTTrainer
from nextgpt.trainer.strategies import DDPStrategy, NaiveStrategy
from nextgpt.models.bloom import BLOOMLM
from nextgpt.models.gpt import GPTLM
from nextgpt.models.opt import OPTLM

In [None]:
PROMPT_TEMPLATE = (
  "Below is an instruction that describes a task, paired with an input that provides further context. "
  "Write a response that appropriately completes the request.\n\n"
  "### Instruction:\n{instruction}\n\n### Response:"
)

In [None]:
# Define arguments.
parser = argparse.ArgumentParser()
parser.add_argument("--strategy",
                    choices=["naive", "ddp"],
                    default="naive")
parser.add_argument("--model", choices=["gpt2", "bloom", "opt"], default="gpt2")
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--max_datasets_size", type=int, default=None)
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--max_len", type=int, default=512)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank. 0 means LoRA is not applied.")
parser.add_argument("--log_interval", type=int, default=100, help="how many steps to log")
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
parser.add_argument("--output_dir", type=str, default="./output_1_sft")

args = parser.parse_args(args=[])

# For testing.
args.pretrain = "gpt2"
args.max_datasets_size = 10000
args.max_epochs = 1

print(args)

In [None]:
# Configure strategy.
if args.strategy == "naive":
    strategy = NaiveStrategy()
elif args.strategy == "ddp":
    strategy = DDPStrategy()
else:
    raise ValueError(f"Unsupported strategy: {args.strategy}")

# Configure model.
with strategy.model_init_context():
    if args.model == "bloom":
        model = BLOOMLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
    elif args.model == "opt":
        model = OPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
    elif args.model == "gpt2":
        model = GPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
    else:
        raise ValueError(f"Unsupported model: {args.model}")

# Configure tokenizer.
if args.model == "gpt2":
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
elif args.model == "bloom":
    tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
    tokenizer.pad_token = tokenizer.eos_token
elif args.model == "opt":
    tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")  
    tokenizer.pad_token = tokenizer.eos_token
else:
    raise ValueError(f"Unsupported model: {args.model}")

In [None]:
# Configure dataset.
dataset_webgpt_comp = load_dataset("openai/webgpt_comparisons", split="train[:20%]")

data_list = []
for row in dataset_webgpt_comp:
    question = row["question"]["full_text"]
    answer_0 = row["answer_0"]
    data_list.append({
        "instruction": question,
        "completion": answer_0
    })

dataset = SupervisedDataset(
    dataset=data_list,
    tokenizer=tokenizer, 
    prompt_template=PROMPT_TEMPLATE,
    completion_field="completion",
    max_datasets_size=args.max_datasets_size,
    max_length=args.max_len,
    verbose=True)

# Split train and eval dataset.
train_size = int(0.8 * len(dataset))
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = torch.utils.data.random_split(dataset, [train_size, eval_size])

# Data collator.
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

In [None]:
# Train!!!
trainer = SFTTrainer(model=model,
                     strategy=strategy,
                     data_collator=data_collator,
                     train_dataset=train_dataset,
                     eval_dataset=eval_dataset,
                     batch_size=args.batch_size,
                     max_epochs=args.max_epochs,
                     gradient_accumulation_steps=args.gradient_accumulation_steps,
                     lr=args.lr)

trainer.fit()

# Save model checkpoint after fitting on only rank0.
trainer.save_model(path=args.output_dir, only_rank0=True, tokenizer=tokenizer)
# Save optimizer checkpoint on all ranks.
if args.need_optim_ckpt:
    strategy.save_optimizer(trainer.optimizer,
                            "sft_optim_checkpoint_%d.pt" % (torch.cuda.current_device()),
                            only_rank0=False)

In [None]:
!ls -la ./output_1_sft

In [None]:
# Inference test.
generator = transformers.pipeline("text-generation", model=args.output_dir, tokenizer=tokenizer)

generation_args = dict(
    num_beams=4,
    repetition_penalty=2.0,
    no_repeat_ngram_size=4,
    max_new_tokens=64,
    do_sample=True,
    top_k=30,
    top_p=0.95,
    temperature=1.9, 
    #max_length=300, 
    #num_return_sequences=20
    early_stopping=True,
)

test_list = data_list[-5:]

test_prompt_list = []
actual_completion_list = []
for row in test_list:
    text_input = row
    prompt = PROMPT_TEMPLATE.format_map(text_input)
    test_prompt_list.append(prompt)
    actual_completion_list.append(text_input["completion"])

result_list = generator(test_prompt_list, **generation_args)
for prompt, result, actual_response in zip(test_prompt_list, result_list, actual_completion_list):
    print("")
    print("-" * 70)
    print(("completion: %s" % (result[0]["generated_text"])))
    print(f"\n### Actual answer:\n{actual_response}")

# Step 2) RM: Reward Model
Train Reward Model to generate the better answer by giving a reward to the better answer.
- Dataset example
```json
[
    {
        "prompt": "",
        "completion_1": "",
        "completion_2": "",
        "completion_3": "",            
        "ranking": [1, 0, 2]
    }, ...
]
```
- Dataset sources
  - [Dahoas/rm-static](https://huggingface.co/datasets/Dahoas/rm-static)
  - [openai/webgpt_comparisons](https://huggingface.co/datasets/openai/webgpt_comparisons)
  - [openai/summarize_from_feedback](https://huggingface.co/datasets/openai/summarize_from_feedback)
  - [Dahoas/instruct-synthetic-prompt-responses](https://huggingface.co/datasets/Dahoas/synthetic-instruct-gptj-pairwise)

- References
    - [train_reward_model.py](https://github.com/hpcaitech/ColossalAI/blob/main/applications/Chat/examples/train_reward_model.py)
    - [train_prompts.py](https://github.com/hpcaitech/ColossalAI/blob/main/applications/Chat/examples/train_prompts.py)

In [None]:
import os
import json
import argparse

import torch
from torch.optim import Adam
from datasets import load_dataset
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
import loralib as lora

from nextgpt.dataset import RewardDataset
from nextgpt.models.base import RewardModel
from nextgpt.models.bloom import BLOOMRM
from nextgpt.models.gpt import GPTRM
from nextgpt.models.opt import OPTRM
from nextgpt.trainer import RewardModelTrainer
from nextgpt.trainer.strategies import DDPStrategy, NaiveStrategy

In [None]:
# Define arguments.
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", type=str, default="./output_2_rm")
parser.add_argument("--strategy",
                    type=str, 
                    choices=["naive", "ddp"],
                    default="naive")
parser.add_argument("--model", 
                    type=str, 
                    choices=["gpt2", "bloom", "opt"], 
                    default="gpt2")
parser.add_argument("--pretrain", type=str, default="gpt2")
parser.add_argument("--model_path", type=str, default=None)
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument("--max_epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--loss_fn", 
                    type=str, 
                    choices=["log_sig", "log_exp"],
                    default="log_sig")
parser.add_argument("--lr", type=float, default=5e-5)
parser.add_argument("--max_len", type=int, default=512)

args = parser.parse_args(args=[])

# For testing.
args.max_epochs = 3
args.pretrain = "gpt2" # pretrained initial model.
args.verbose = True

print(args)
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

In [None]:
# Configure strategy.
if args.strategy == "naive":
    strategy = NaiveStrategy()
elif args.strategy == "ddp":
    strategy = DDPStrategy()
else:
    raise ValueError(f"Unsupported strategy: {args.strategy}")

In [None]:
# Configure model.
with strategy.model_init_context():
    if args.model == "gpt2":
        model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
    elif args.model == "bloom":
        model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
    elif args.model == "opt":
        model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) 
    else:
        raise ValueError(f"Unsupported model: {args.model}")

    # Load the supervised finetuning model state dict if it is specified.
    # However, we will train the reward model from the initial language model instead of supervised finetuning model.
    if args.model_path is not None:
        state_dict = torch.load(args.model_path)
        model.model.load_state_dict(state_dict)

# This float16 or `model.half()` might cause loss NaN issue!!!
# See:
#   https://stackoverflow.com/questions/65332165/loss-is-nan-when-fine-tuning-huggingface-nli-model-both-roberta-bart
#   https://github.com/huggingface/transformers/issues/9160
# model = model.to(torch.float16)

# Configure tokenizer.
if args.model == "gpt2":
    tokenizer = AutoTokenizer.from_pretrained(
        "gpt2", 
        # bos_token="<|startoftext|>",
        # eos_token="<|endoftext|>", 
        # pad_token="<|pad|>",
        # padding_side="right", 
        model_max_length=args.max_len,
        )
    tokenizer.pad_token = tokenizer.eos_token
    print(tokenizer)
    # model.resize_token_embeddings(len(tokenizer)) 
elif args.model == "bloom":
    tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
    tokenizer.pad_token = tokenizer.eos_token
elif args.model == "opt":
    tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")  
    tokenizer.pad_token = tokenizer.eos_token
else:
    raise ValueError(f"Unsupported model: {args.model}")

In [None]:
# Get the dataset.
dataset_webgpt_comp = load_dataset("openai/webgpt_comparisons", split="train[:20%]")

In [None]:
# Convert data into ranking format.
data_list_ranking = []
for row in dataset_webgpt_comp:
    question = row["question"]["full_text"]
    answer_0 = row["answer_0"]
    answer_1 = row["answer_1"]
    score_0 = row["score_0"]
    score_1 = row["score_1"]
    if answer_0 == "" or answer_1 == "" or (score_0 == score_1):
        continue

    ranking = [0 if score_0 > score_1 else 1, 0 if score_0 < score_1 else 1]
    data_list_ranking.append({
        "prompt": PROMPT_TEMPLATE.format_map({"instruction": question}),
        "completion_0": answer_0,
        "completion_1": answer_1,
        "ranking": ranking
    })

data_list_ranking[:2]

In [None]:
# Make ranking data to chosen, rejetced data for reward model dataset.
total_data_ranking2chosen = []
for tmp in data_list_ranking:
    one_data_ranking2chosen = []

    # data 1) 0 VS 1
    data = {}
    data["prompt"] = tmp["prompt"]
    if tmp["ranking"][0] < tmp["ranking"][1]:
        data["chosen"] = tmp["completion_0"]
        data["rejected"] = tmp["completion_1"]
    else:
        data["chosen"] = tmp["completion_1"]
        data["rejected"] = tmp["completion_0"]
    one_data_ranking2chosen.append(data)

    # # data 2) 0 VS 2
    # data = {}
    # data["prompt"] = tmp["prompt"]
    # if tmp["ranking"][0] < tmp["ranking"][2]:
    #     data["chosen"] = tmp["completion_0"]
    #     data["rejected"] = tmp["completion_2"]
    # else:
    #     data["chosen"] = tmp["completion_2"]
    #     data["rejected"] = tmp["completion_0"]
    # one_data_ranking2chosen.append(data)

    # # data 1) 1 VS 2
    # data = {}
    # data["prompt"] = tmp["prompt"]
    # if tmp["ranking"][1] < tmp["ranking"][2]:
    #     data["chosen"] = tmp["completion_1"]
    #     data["rejected"] = tmp["completion_2"]
    # else:
    #     data["chosen"] = tmp["completion_2"]
    #     data["rejected"] = tmp["completion_1"]
    # one_data_ranking2chosen.append(data)


    total_data_ranking2chosen.extend(one_data_ranking2chosen)


print("before data num: %d" % (len(data_list_ranking)))
print("after data num: %d" % (len(total_data_ranking2chosen)))
print("data example: \n%s" % total_data_ranking2chosen[1])

In [None]:
# Prepare for data and dataset.
import random
random.seed(230319)

random.shuffle(total_data_ranking2chosen)
print(total_data_ranking2chosen[1])

# train_data = total_data_ranking2chosen[:-1000]
# eval_data = total_data_ranking2chosen[-1000:0]
# We just select very small set of data for a quicker training.
train_data = total_data_ranking2chosen[:100]
val_data = total_data_ranking2chosen[100:130]
eval_data = total_data_ranking2chosen[130:160]

train_dataset = RewardDataset(train_data, tokenizer, args.max_len)
val_dataset = RewardDataset(val_data, tokenizer, args.max_len)
eval_dataset = RewardDataset(eval_data, tokenizer, args.max_len)

# Check
idx = 10
print("#" * 70)
print("## prompt ##")
print(train_data[idx]["prompt"])
print("#" * 70)
print("## chosen ##")
print(train_data[idx]["chosen"])
print("#" * 70)
print("## rejected ##")
print(train_data[idx]["rejected"])

In [None]:
trainer = RewardModelTrainer(model=model,
                             strategy=strategy,
                             train_dataset=train_dataset,
                             valid_dataset=val_dataset,
                             eval_dataset=eval_dataset,
                             batch_size=args.batch_size,
                             max_epochs=args.max_epochs,
                             loss_fn=args.loss_fn,
                             lr=args.lr)

In [None]:
# Train!!!
trainer.fit()

# Save model checkpoint after fitting on only rank0.
# strategy.save_model(model, os.path.join(args.output_dir, "rm.pt"), only_rank0=True)
trainer.save_model(path=os.path.join(args.output_dir, "rm.pt"), only_rank0=True)

# Save optimizer checkpoint on all ranks.
if args.need_optim_ckpt:
    strategy.save_optimizer(trainer.optimizer,
                            os.path.join(args.output_dir, "rm_optim_checkpoint_%d.pt" % (torch.cuda.current_device())),
                            only_rank0=False)

# Step 3) PPO: Proximal Policy Optimization
Further fine-tune the LLM from step 1 with the reward model and this dataset using RL (eg. PPO).

- References
    - [train_prompts.py](https://github.com/hpcaitech/ColossalAI/blob/main/applications/Chat/examples/train_prompts.py)

In [None]:
import os
import json
import argparse
from copy import deepcopy

import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler

from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer

from nextgpt.models.base import RewardModel
from nextgpt.models.bloom import BLOOMActor, BLOOMCritic
from nextgpt.models.gpt import GPTActor, GPTCritic
from nextgpt.models.opt import OPTActor, OPTCritic
from nextgpt.trainer import PPOTrainer
from nextgpt.trainer.strategies import DDPStrategy, NaiveStrategy
from nextgpt.dataset import PromptDataset, SupervisedDataset, DataCollatorForSupervisedDataset

# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
# Define arguments.
parser = argparse.ArgumentParser()

parser.add_argument("--output_dir", type=str, default="./output_3_ppo")
parser.add_argument("--strategy",
                    type=str,
                    choices=["naive", "ddp"],
                    default="naive")
parser.add_argument("--model", 
                    type=str, 
                    choices=["gpt2", "bloom", "opt"],
                    default="gpt2")
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--rm_model", 
                    type=str, 
                    choices=["gpt2", "bloom", "opt"],
                    default="gpt2")
parser.add_argument("--rm_path", type=str, default=None)
parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument("--need_optim_ckpt", type=bool, default=False)
parser.add_argument("--num_episodes", type=int, default=10)
parser.add_argument("--max_timesteps", type=int, default=3)
parser.add_argument("--update_timesteps", type=int, default=3)
parser.add_argument("--max_epochs", type=int, default=5)
parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--ptx_batch_size", type=int, default=1)
parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--kl_coef", type=float, default=0.1)
parser.add_argument("--ptx_coef", type=float, default=0.9)
args = parser.parse_args(args=[])

# For testing.
# args.pretrain= "gpt2"
args.pretrain= "./output_1_sft"
args.rm_path = "./output_2_rm/rm.pt" # RM model path
args.rm_pretrain= "gpt2"

args.num_episodes = 1
args.max_epochs = 1

print(args)
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

In [None]:
# Configure strategy.
if args.strategy == "naive":
    strategy = NaiveStrategy()
elif args.strategy == "ddp":
    strategy = DDPStrategy()
else:
    raise ValueError(f"Unsupported strategy: {args.strategy}")

In [None]:
if args.rm_path is not None:
    rm_state_dict = torch.load(args.rm_path, map_location="cpu")

# Configure intial model.
if args.model == "gpt2":
    initial_model = GPTActor(pretrained=args.pretrain)
elif args.model == "bloom":
    initial_model = BLOOMActor(pretrained=args.pretrain)
elif args.model == "opt":
    initial_model = OPTActor(pretrained=args.pretrain)
else:
    raise ValueError(f"Unsupported actor model: {args.model}")

# Configure reward model.
if args.rm_model == "gpt2":
    reward_model = GPTRM(pretrained=args.rm_pretrain)
elif args.rm_model == "bloom":
    reward_model = BLOOMRM(pretrained=args.rm_pretrain)
elif args.rm_model == "opt":
    reward_model = OPTRM(pretrained=args.rm_pretrain)
else:
    raise ValueError(f"Unsupported reward model: {args.rm_model}")

if args.rm_path is not None:
    reward_model.load_state_dict(rm_state_dict)

# initial_model.to(torch.float16).to(torch.cuda.current_device())
# reward_model.to(torch.float16).to(torch.cuda.current_device())
initial_model.to(torch.cuda.current_device())
reward_model.to(torch.cuda.current_device())

# Configure actor and critic.
with strategy.model_init_context():
    # Actor
    if args.model == "gpt2":
        actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
    elif args.model == "bloom":
        actor = BLOOMActor(pretrained=args.pretrain_actor, lora_rank=args.lora_rank)
    elif args.model == "opt":
        actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)        
    else:
        raise ValueError(f"Unsupported actor model: {args.model}")

    # Critic
    if args.rm_model == "gpt2":
        critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
    elif args.rm_model == "bloom":
        critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
    elif args.rm_model == "opt":
        critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank)
    else:
        raise ValueError(f"Unsupported reward model: {args.rm_model}")

    if args.rm_path is not None:
        critic.load_state_dict(rm_state_dict)
        del rm_state_dict

# critic.to(torch.float16).to(torch.cuda.current_device())
# actor.to(torch.float16).to(torch.cuda.current_device())
critic.to(torch.cuda.current_device())
actor.to(torch.cuda.current_device())

# Configure tokenizer.
if args.model == "gpt2":
    tokenizer = GPT2Tokenizer.from_pretrained(
        "gpt2", 
        # bos_token="<|startoftext|>",
        # eos_token="<|endoftext|>", 
        # pad_token="<|pad|>",
        # padding_side="right", 
        model_max_length=512,
        )
    tokenizer.pad_token = tokenizer.eos_token
elif args.model == "bloom":
    tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
    tokenizer.pad_token = tokenizer.eos_token            
elif args.model == "opt":
    tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
    tokenizer.pad_token = tokenizer.eos_token   

In [None]:
def tokenize_fn(texts):
    # MUST padding to max length to ensure inputs of all ranks have the same length
    # Different length may lead to hang when using gemini, as different generation steps
    batch = tokenizer(texts, return_tensors="pt", max_length=96, padding="max_length", truncation=True)
    return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}

In [None]:
# Prepare dataset.
dataset_webgpt_comp = load_dataset("openai/webgpt_comparisons", split="train[:20%]")

data_list = []
for row in dataset_webgpt_comp:
    question = row["question"]["full_text"]
    answer_0 = row["answer_0"]
    data_list.append({
        "instruction": question,
        "completion": answer_0
    })

print(data_list[:1])

In [None]:
# Configure dataloader.
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

prompt_dataset = PromptDataset(
    dataset=data_list, 
    tokenizer=tokenizer, 
    prompt_template=PROMPT_TEMPLATE, 
    max_datasets_size=10000)

prompt_sampler = None
if dist.is_initialized() and dist.get_world_size() > 1:
    prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)

prompt_dataloader = DataLoader(
    prompt_dataset,
    shuffle=(prompt_sampler is None),
    sampler=prompt_sampler,
    batch_size=args.train_batch_size)

pretrain_dataset = SupervisedDataset(
    dataset=data_list,
    tokenizer=tokenizer, 
    prompt_template=PROMPT_TEMPLATE,
    completion_field="completion",
    max_datasets_size=10000,
    max_length=512,
    verbose=True)

pretrain_sampler = None
if dist.is_initialized() and dist.get_world_size() > 1:
    pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)

pretrain_dataloader = DataLoader(
    pretrain_dataset,
    shuffle=(pretrain_sampler is None),
    sampler=pretrain_sampler,
    batch_size=args.ptx_batch_size,
    collate_fn=data_collator)

In [None]:
# Configure trainer.
trainer = PPOTrainer(
    strategy,
    actor,
    critic,
    reward_model,
    initial_model,
    kl_coef=args.kl_coef,
    ptx_coef=args.ptx_coef,
    max_epochs=args.max_epochs,
    train_batch_size=args.train_batch_size,
    experience_batch_size=args.experience_batch_size,
    tokenizer=tokenize_fn,
    max_length=128,
    do_sample=True,
    temperature=1.0,
    top_k=50,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

trainer.fit(
    prompt_dataloader=prompt_dataloader,
    pretrain_dataloader=pretrain_dataloader,
    num_episodes=args.num_episodes,
    max_timesteps=args.max_timesteps,
    update_timesteps=args.update_timesteps)

# Save model checkpoint after fitting on only rank0.
trainer.save_model(os.path.join(args.output_dir, "actor.pt"), only_rank0=True, tokenizer=tokenizer)
# Save optimizer checkpoint on all ranks.
strategy.save_optimizer(trainer.actor_optim,
                        os.path.join(args.output_dir, "actor_optim_checkpoint_%d.pt" % (torch.cuda.current_device())),
                        only_rank0=False)

In [None]:
#  Inference test.
def generation(input_text):
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(torch.cuda.current_device())
    outputs = actor.generate(input_ids,
                             max_length=100,
                             do_sample=True,
                             top_k=50,
                             top_p=0.95,
                             num_return_sequences=1)
    output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)[0]
    print("#" * 70)
    print(output)
    return output


test_isntruction_list = [
    "Heterophobia is the irrational fear of what",
    ]

test_prompt_list = [PROMPT_TEMPLATE.format_map({"instruction": tmp}) for tmp in test_isntruction_list]

for input_text in test_prompt_list:
    output = generation(input_text)

# Inference by PPO actor

In [None]:
import argparse

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--model", 
                    type=str, 
                    choices=["gpt2", "bloom", "opt"],
                    default="gpt2")
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--model_path", type=str, default=None)
parser.add_argument("--input", type=str, default="Question: How are you ? Answer:")
parser.add_argument("--max_length", type=int, default=100)
args = parser.parse_args([])

# args.pretrain= "gpt2"
args.pretrain= "./output_1_sft"
args.model_path = "./output_3_ppo/actor.pt"

In [None]:
def eval(args):
    # Configure model.
    if args.model == "gpt2":
        actor = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
    elif args.model == "bloom":
        actor = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device())
    elif args.model == "opt":
        actor = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
    else:
        raise ValueError(f"Unsupported model: {args.model}")

    state_dict = torch.load(args.model_path)
    # actor.model.load_state_dict(state_dict)
    actor.load_state_dict(state_dict)

    # Configure tokenizer.
    if args.model == "gpt2":
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        tokenizer.pad_token = tokenizer.eos_token
    elif args.model == "bloom":
        tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
        tokenizer.pad_token = tokenizer.eos_token
    elif args.model == "opt":
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
    else:
        raise ValueError(f"Unsupported model: {args.model}")

    actor.eval()
    input = args.input
    input_ids = tokenizer.encode(input, return_tensors="pt").to(torch.cuda.current_device())
    outputs = actor.generate(input_ids,
                             max_length=args.max_length,
                             do_sample=True,
                             top_k=10,
                             top_p=0.95,
                             num_return_sequences=1)
    output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)[0]
    print(output)

In [None]:
input_text = "Heterophobia is the irrational fear of what?"
args.input = PROMPT_TEMPLATE.format_map({"instruction": input_text})
eval(args)