# trlx framework overview
This example outline the current trlx training framework that uses a defined reward functionn. The objective of this example is to generate positive movie reviews by tuning a pretrained model on IMDB dataset with a sentiment reward function. 

In [1]:
import pathlib
import os
from typing import List

import torch
import yaml
from datasets import load_dataset
from transformers import pipeline

import trlx
from trlx.data.configs import TRLConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_positive_score(scores):
    "Extract value associated with a positive sentiment from pipeline's output"
    return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"]

In [3]:
# config_path = pathlib.Path('.').joinpath("../configs/ppo_config.yml")
with open('./configs/ppo_config.yml') as f:
    default_config = yaml.safe_load(f)


In [4]:
# A dictionary of hparameters for overriding the config file.
hparams = {}
config = TRLConfig.update(default_config, hparams)






In [5]:
if torch.cuda.is_available():
    device = int(os.environ.get("LOCAL_RANK", 0))
else:
    device = -1

  return torch._C._cuda_getDeviceCount() > 0


In [6]:
sentiment_fn = pipeline(
    "sentiment-analysis",
    "lvwerra/distilbert-imdb",
    top_k=2,
    truncation=True,
    batch_size=256,
    device=device,
)


In [7]:
def reward_fn(samples: List[str], **kwargs) -> List[float]:
    sentiments = list(map(get_positive_score, sentiment_fn(samples)))
    return sentiments

In [8]:
# Take few words off of movies reviews as prompts
imdb = load_dataset("imdb", split="train+test")
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

Found cached dataset imdb (/home/fongsu/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)


In [10]:
trlx.train(
    reward_fn=reward_fn,
    prompts=prompts,
    eval_prompts=["I don't know much about Hungarian underground"] * 64,
    config=config,
)

Downloading (…)"pytorch_model.bin";: 100%|██████████| 334M/334M [01:39<00:00, 3.34MB/s]
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33moliversf[0m ([33mcomprehelp[0m). Use [1m`wandb login --relogin`[0m to force relogin


You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


: 

: 

# Observations
1. No abstractions for reward models (Model, Trainer, Config, etc.)
2. Low customizability for training details
