In [1]:
import sys
sys.path.append('../../')

In [2]:
import os
import hydra
from hydra import compose, initialize
from omegaconf import DictConfig, OmegaConf
import importlib
from rlprompt.models import (LMAdaptorModelConfig, SinglePromptModelConfig,
                             make_lm_adaptor_model, make_single_prompt_model)
from rlprompt.modules import SQLModuleConfig, make_sql_module
import rlprompt.trainers
importlib.reload(rlprompt.trainers)
from rlprompt.trainers import TrainerConfig, make_trainer
from rlprompt.utils.utils import (colorful_print, compose_hydra_config_store,
                                  get_hydra_output_dir)

from fsc_helpers import (PromptedClassificationRewardConfig,
                         FewShotClassificationDatasetConfig,
                         make_prompted_classification_reward,
                         make_few_shot_classification_dataset)

In [3]:
# Compose default config
config_list = [PromptedClassificationRewardConfig,
                FewShotClassificationDatasetConfig, LMAdaptorModelConfig,
                SinglePromptModelConfig, SQLModuleConfig, TrainerConfig]
cs = compose_hydra_config_store('base_fsc', config_list)

In [4]:
initialize(version_base=None, config_path="./", job_name="test_app")
config = compose(config_name="fsc_config", overrides=['dataset=sst-5', 'dataset_seed=0', 'prompt_length=5', 'task_lm=distilroberta-base', 'random_seed=7', 'report_to_wandb=false'])

In [5]:
colorful_print(OmegaConf.to_yaml(config), fg='red')

[31mtask_lm: distilroberta-base
is_mask_lm: null
compute_zscore: true
incorrect_coeff: 180.0
correct_coeff: 200.0
dataset: sst-5
dataset_seed: 0
base_path: ./data
num_shots: 16
policy_lm: distilgpt2
hidden_size: 2048
logit_bias: 0.0
fluent: false
fluent_top_k: 20
max_decoding_length: 5
eos_token_id: null
prompt_length: 5
prompt_train_batch_size: 16
prompt_infer_batch_size: 1
source_str: <|endoftext|>
sql_loss_impl: v2_v2r_v3_v3r
training_mode: sql-onpolicy
mix_strategy: null
target_update_method: polyak
target_update_steps: null
target_learning_rate: 0.001
reward_shaping: true
reward_shaping_old_min: 0.0
reward_shaping_old_max: 1.0
reward_shaping_new_min: 0.0
reward_shaping_new_max: 5.0
top_k: 256
top_p: 1.0
num_beams: 1
train_batch_size: 16
train_shuffle: false
train_drop_last: true
num_train_epochs: 1
max_train_steps: 12000
do_eval: true
eval_batch_size: 16
eval_steps: 10
do_save: true
save_dir: ./outputs
save_steps: 100
learning_rate: 5.0e-05
gradient_clip: true
gradient_clip_norm:

In [6]:
output_dir = './outputs/'

In [7]:
(train_dataset, val_dataset, test_dataset, num_classes, verbalizers, template) = make_few_shot_classification_dataset(config)
print('Train Size:', len(train_dataset))
print('Examples:', train_dataset[:5])
print('Val Size', len(val_dataset))
print('Examples:', val_dataset[:5])

Train Size: 80
Examples: {'source_texts': ["steven soderbergh 's digital video experiment is a clever and cutting , quick and dirty look at modern living and movie life .", 'a vivid , sometimes surreal , glimpse into the mysteries of human behavior .', 'an ingenious and often harrowing look at damaged people and how families can offer either despair or consolation .', 'presents a side of contemporary chinese life that many outsiders will be surprised to know exists , and does so with an artistry that also smacks of revelation .', 'can you bear the laughter ?'], 'class_labels': [3, 3, 3, 3, 3]}
Val Size 80
Examples: {'source_texts': ['not as well-written as sexy beast , not as gloriously flippant as lock , stock and two smoking barrels , but stylish and moody and exceptionally well-acted .', "like kubrick , soderbergh is n't afraid to try any genre and to do it his own way .", 'bring on the sequel .', 'an intense and effective film about loneliness and the chilly anonymity of the enviro

In [8]:
verbalizers

['Ġterrible', 'Ġbad', 'Ġokay', 'Ġgood', 'Ġgreat']

In [9]:
template

In [10]:
policy_model = make_lm_adaptor_model(config)

distilgpt2


In [11]:
policy_model.generator.model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dro

In [12]:
print(policy_model)

LMAdaptorModel(
  (mlp): Sequential(
    (0): Linear(in_features=768, out_features=2048, bias=True)
    (1): ReLU()
    (2): Linear(in_features=2048, out_features=768, bias=True)
  )
)


In [13]:
policy_model.logit_bias

0.0

In [14]:
prompt_model = make_single_prompt_model(policy_model, config)

In [15]:
print(prompt_model)

SinglePromptModel(
  (_model): LMAdaptorModel(
    (mlp): Sequential(
      (0): Linear(in_features=768, out_features=2048, bias=True)
      (1): ReLU()
      (2): Linear(in_features=2048, out_features=768, bias=True)
    )
  )
)


In [16]:
prompt_model.source_str

'<|endoftext|>'

In [17]:
reward = make_prompted_classification_reward(num_classes, verbalizers, template, config)

Task LM: distilroberta-base
Verbalizers: ['Ġterrible', 'Ġbad', 'Ġokay', 'Ġgood', 'Ġgreat']


In [18]:
algo_module = make_sql_module(prompt_model, reward, config)

In [19]:
algo_module

SQLModule(
  (_model): SinglePromptModel(
    (_model): LMAdaptorModel(
      (mlp): Sequential(
        (0): Linear(in_features=768, out_features=2048, bias=True)
        (1): ReLU()
        (2): Linear(in_features=2048, out_features=768, bias=True)
      )
    )
  )
  (_target_model): SinglePromptModel(
    (_model): LMAdaptorModel(
      (mlp): Sequential(
        (0): Linear(in_features=768, out_features=2048, bias=True)
        (1): ReLU()
        (2): Linear(in_features=2048, out_features=768, bias=True)
      )
    )
  )
)

In [20]:
algo_module._forward_modes

[<ForwardMode.SQL_ON: 'SQL_ON'>]

In [21]:
# Hack for few-shot classification - Each batch contains all examples
config.train_batch_size = len(train_dataset)
config.eval_batch_size = len(val_dataset)
config.save_dir = os.path.join(output_dir, config.save_dir)

In [22]:
import rlprompt.trainers
importlib.reload(rlprompt.trainers)
from rlprompt.trainers import TrainerConfig, make_trainer

In [23]:
trainer = make_trainer(algo_module, train_dataset, val_dataset, config)
trainer.train(config=config)

total_train_epochs  12000 num_batches_per_epoch  1
input length 80
source_texts ["steven soderbergh 's digital video experiment is a clever and cutting , quick and dirty look at modern living and movie life .", 'a vivid , sometimes surreal , glimpse into the mysteries of human behavior .', 'an ingenious and often harrowing look at damaged people and how families can offer either despair or consolation .', 'presents a side of contemporary chinese life that many outsiders will be surprised to know exists , and does so with an artistry that also smacks of revelation .', 'can you bear the laughter ?', 'a fascinating , dark thriller that keeps you hooked on the delicious pulpiness of its lurid fiction .', 'both garcia and jagger turn in perfectly executed and wonderfully sympathetic characters , who are alternately touching and funny .', 'great fun both for sports aficionados and for ordinary louts whose idea of exercise is climbing the steps of a stadium-seat megaplex .', 'the hook is the 

  (reward_key, torch.mean(torch.tensor(reward_vals)))
