# Refine txt2img Prompts with Human Feedback


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/CarperAI/trlx/blob/main/examples/notebooks/trlx_simulacra.ipynb)


#### Optimize a gpt2-based txt2img prompt generator to produce aesthetic prompts using https://github.com/JD-P/simulacra-aesthetic-captions

Notebook by [@smellslikeml](https://github.com/smellslikeml)

---

Execute the cells below to install [TRLX](https://github.com/CarperAI/trlx) for a colab environment.

In [1]:
!ls

RL_misko_trl.ipynb  trlx  trlx_simulacra.ipynb


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# !pip install torchtyping
# !python --version
# import deepspeed

In [4]:
# !git clone https://github.com/CarperAI/trlx.git
# !git config --global --add safe.directory /content/trlx && cd /content/trlx && pip install -e .

In [5]:
# uninstall scikit_learn + jax to avoid numpy issues
# !conda uninstall -y scikit_learn jax

In [2]:
import os

# run within repo
os.chdir('./trlx')
print(os.getcwd())

/mnt/storage-brno2/home/ahajek/Spektro/MassGenie/RLHF/trlx


In [5]:
import sqlite3
from urllib.request import urlretrieve

import trlx

url = "https://raw.githubusercontent.com/JD-P/simulacra-aesthetic-captions/main/sac_public_2022_06_29.sqlite"
dbpath = "sac_public_2022_06_29.sqlite"

if not os.path.exists(dbpath):
    print(f"fetching {dbpath}")
    urlretrieve(url, dbpath)

conn = sqlite3.connect(dbpath)
c = conn.cursor()
c.execute(
    "SELECT prompt, rating FROM ratings "
    "JOIN images ON images.id=ratings.iid "
    "JOIN generations ON images.gid=generations.id "
    "WHERE rating IS NOT NULL;"
)

prompts, ratings = tuple(map(list, zip(*c.fetchall())))

Trlx uses [wandb](https://wandb.ai/) to log results. Make sure to set up an account and use your token to authenticate when prompted after executing the cell below.

In [6]:
import trlx
trlx.train(
    "gpt2",
    samples=prompts,
    rewards=ratings,
    eval_prompts=["Hatsune Miku, Red Dress"] * 64,
)

[RANK 0] Initializing model: gpt2


[RANK 0] Collecting rollouts
[RANK 0] Logging sample example


[RANK 0] Logging experience string statistics


[RANK 0] Starting training
[RANK 0] Evaluating model
[generation sweep 0/1 | eval batch 0/2]:   0%|          | 0/2 [00:00<?, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
[generation sweep 1/1 | eval batch 2/2]: 100%|██████████| 2/2 [00:24<00:00, 12.36s/it]
[RANK 0] Summarizing evaluation


[losses/loss: 5.75 | losses/loss_q: 0.28 | losses/loss_v: 0.03 | losses/loss_cql: 18.98 | losses/loss_awac: 3.54]:  10%|▉         | 99/1000 [00:12<01:42,  8.79it/s][RANK 0] Evaluating model
[generation sweep 1/1 | eval batch 2/2]: 100%|██████████| 2/2 [00:01<00:00,  1.74it/s]
[RANK 0] Summarizing evaluation


[losses/loss: 5.21 | losses/loss_q: 0.34 | losses/loss_v: 0.03 | losses/loss_cql: 17.50 | losses/loss_awac: 3.09]:  20%|█▉        | 199/1000 [00:25<01:33,  8.59it/s][RANK 0] Evaluating model
[generation sweep 1/1 | eval batch 2/2]: 100%|██████████| 2/2 [00:01<00:00,  1.74it/s]
[RANK 0] Summarizing evaluation


[losses/loss: 5.35 | losses/loss_q: 0.31 | losses/loss_v: 0.03 | losses/loss_cql: 17.29 | losses/loss_awac: 3.29]:  30%|██▉       | 299/1000 [00:38<01:20,  8.71it/s][RANK 0] Evaluating model
[generation sweep 1/1 | eval batch 2/2]: 100%|██████████| 2/2 [00:01<00:00,  1.75it/s]
[RANK 0] Summarizing evaluation


[losses/loss: 4.43 | losses/loss_q: 0.24 | losses/loss_v: 0.03 | losses/loss_cql: 15.91 | losses/loss_awac: 2.57]:  40%|███▉      | 399/1000 [00:51<01:09,  8.71it/s][RANK 0] Evaluating model
[generation sweep 1/1 | eval batch 2/2]: 100%|██████████| 2/2 [00:01<00:00,  1.75it/s]
[RANK 0] Summarizing evaluation


[losses/loss: 4.18 | losses/loss_q: 0.18 | losses/loss_v: 0.03 | losses/loss_cql: 15.92 | losses/loss_awac: 2.39]:  50%|████▉     | 499/1000 [01:04<00:56,  8.85it/s][RANK 0] Evaluating model
[generation sweep 1/1 | eval batch 2/2]: 100%|██████████| 2/2 [00:01<00:00,  1.74it/s]
[RANK 0] Summarizing evaluation


[losses/loss: 5.16 | losses/loss_q: 0.29 | losses/loss_v: 0.03 | losses/loss_cql: 16.87 | losses/loss_awac: 3.14]:  60%|█████▉    | 599/1000 [01:17<00:47,  8.47it/s][RANK 0] Evaluating model
[generation sweep 1/1 | eval batch 2/2]: 100%|██████████| 2/2 [00:01<00:00,  1.74it/s]
[RANK 0] Summarizing evaluation


[losses/loss: 4.61 | losses/loss_q: 0.47 | losses/loss_v: 0.04 | losses/loss_cql: 14.56 | losses/loss_awac: 2.65]:  70%|██████▉   | 699/1000 [01:30<00:33,  8.88it/s][RANK 0] Evaluating model
[generation sweep 1/1 | eval batch 2/2]: 100%|██████████| 2/2 [00:01<00:00,  1.75it/s]
[RANK 0] Summarizing evaluation


[losses/loss: 4.02 | losses/loss_q: 0.19 | losses/loss_v: 0.03 | losses/loss_cql: 13.80 | losses/loss_awac: 2.42]:  80%|███████▉  | 799/1000 [01:43<00:23,  8.52it/s][RANK 0] Evaluating model
[generation sweep 1/1 | eval batch 2/2]: 100%|██████████| 2/2 [00:01<00:00,  1.75it/s]
[RANK 0] Summarizing evaluation


[losses/loss: 4.74 | losses/loss_q: 0.30 | losses/loss_v: 0.03 | losses/loss_cql: 14.99 | losses/loss_awac: 2.91]:  90%|████████▉ | 899/1000 [01:56<00:11,  8.84it/s][RANK 0] Evaluating model
[generation sweep 1/1 | eval batch 2/2]: 100%|██████████| 2/2 [00:01<00:00,  1.73it/s]
[RANK 0] Summarizing evaluation


[losses/loss: 4.04 | losses/loss_q: 0.32 | losses/loss_v: 0.05 | losses/loss_cql: 13.64 | losses/loss_awac: 2.30]: 100%|█████████▉| 999/1000 [02:09<00:00,  8.48it/s][RANK 0] Evaluating model
[generation sweep 1/1 | eval batch 2/2]: 100%|██████████| 2/2 [00:01<00:00,  1.69it/s]
[RANK 0] Summarizing evaluation


[losses/loss: 4.68 | losses/loss_q: 0.34 | losses/loss_v: 0.05 | losses/loss_cql: 15.29 | losses/loss_awac: 2.76]: 100%|██████████| 1000/1000 [02:20<00:00,  3.36s/it][RANK 0] Evaluating model
[generation sweep 1/1 | eval batch 2/2]: 100%|██████████| 2/2 [00:01<00:00,  1.73it/s]
[RANK 0] Summarizing evaluation


[losses/loss: 4.68 | losses/loss_q: 0.34 | losses/loss_v: 0.05 | losses/loss_cql: 15.29 | losses/loss_awac: 2.76]: 100%|██████████| 1000/1000 [02:32<00:00,  6.57it/s]


<trlx.trainer.accelerate_ilql_trainer.AccelerateILQLTrainer at 0x14e3c2db7c40>

In [7]:
["fwvver"] * 3

['fwvver', 'fwvver', 'fwvver']