In [1]:
import t5_encoder
import transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer
import torch
from accelerate.utils import convert_outputs_to_fp32

from alpaca_farm import data_utils
from argparse import Namespace
import pathlib

import numpy as np
from datasets import load_metric
metric = load_metric('accuracy')

import os
os.environ['CUDA_VISIBLE_DEVICES']='1'

data_args = Namespace()
data_args.prompt_dict_path = pathlib.Path('./prompts/v0_inputs_noinputs.json')
data_args.dataset_path = '../seahorse_data/'
data_args.classification_label_key = 'question4'

training_args = Namespace()
training_args.end_sequence_with_eos = False
training_args.reward_model_name_or_path = '/mnt/nfs_csail/models/swhan/alpaca_farm/q_four_flant5/'
training_args.transformer_cache_dir = None
training_args.flash_attn = False

  metric = load_metric('accuracy')


In [2]:
def cast_with_native_amp(func, mixed_precision):
    """Almost like how huggingface accelerate cast `model.forward`."""
    if mixed_precision not in ("fp16", "bf16"):
        logger.warning(f"Unknown mixed precision mode: {mixed_precision}, falling back to fp32.")
        return func

    if mixed_precision == "fp16":
        output_func = torch.cuda.amp.autocast(dtype=torch.float16)(func)
    else:
        device_type = "cuda" if torch.cuda.is_available() else "cpu"
        output_func = torch.autocast(device_type=device_type, dtype=torch.bfloat16)(func)
    output_func = convert_outputs_to_fp32(output_func)
    return output_func
    


In [11]:
from alpaca_farm.models.make_models import make_reward_model
from alpaca_farm.rl.trainer_utils import _make_padded_tokenizer
from alpaca_farm import accelerate_patch

# Load the model and tokenizer
accelerator = accelerate_patch.MyAccelerator(
    gradient_accumulation_steps=1,
    mixed_precision='fp16',
    even_batches=True,  # Make sure the batch size on each device is the same.
    split_batches=False,  # Don't break a batch into smaller chunks.
    step_scheduler_with_optimizer=False,  # Untie optimizer and scheduler step.
    # Value model might not use all parameters (e.g., lm-head) in the forward pass.
)
model = make_reward_model(training_args, accelerator, is_trainable=False)
tokenizer = _make_padded_tokenizer(training_args.reward_model_name_or_path, cache_dir=None, use_fast_tokenizer=False)
# model = AutoModelForSequenceClassification.from_pretrained('/mnt/nfs_csail/models/swhan/alpaca_farm/q_four_flant5/')
model.forward = cast_with_native_amp(model.forward, 'bf16')
# tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-large', model_max_length=1024)

Initializing reward model that is not lora based
loading base model /mnt/nfs_csail/models/swhan/alpaca_farm/q_four_flant5/...
Loading tokenizer from /mnt/nfs_csail/models/swhan/alpaca_farm/q_four_flant5/


In [12]:
from typing import Optional

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    metrics = metric.compute(predictions=predictions, references=labels)
    if np.all(predictions==0) or np.all(predictions==1):
        metrics['pearson'] = 0
    else:
        metrics['pearson'] = np.corrcoef(labels.squeeze(), predictions)[0,1]
    print(metrics)
    return metrics

def format_prompt(example: dict, prompt_dict: dict) -> str:
    """Formats a prompt with a prompt_dict formatter.

    Args:
        example: A dict-like object with required keys "instruction" and "input"
        prompt_dict: Dictionary containing the keys "prompt_noinputs" and "prompt_inputs" which have
            placeholders corresponding to the keys from `example`. E.g. "{instruction}".

    Returns:
        A formatted prompt string.

    Examples
    --------
    >>> format_prompt(dict(instruction="test", input=""), prompt_dict=dict(prompt_noinputs="prompt {instruction} "))
    "prompt test"
    """
    assert "instruction" in example and "input" in example, "Internal error: example missing required keys."

    if example["input"] is None or len(example["input"]) == 0:
        formatted_prompt = prompt_dict["prompt_noinputs"].format_map(example)
    else:
        formatted_prompt = prompt_dict["prompt_inputs"].format_map(example)

    return formatted_prompt


def format_output_word_by_word(example: dict, eos_token: Optional[str] = None, output_key="output") -> str:
    if eos_token is None:
        eos_token = ""
    output = f"{example[output_key]}{eos_token}"
    return output.split()

def format_output(example: dict, eos_token: Optional[str] = None, output_key="output") -> str:
    if eos_token is None:
        eos_token = ""
    output = f"{example[output_key]}{eos_token}"
    return output

def _get_text(example: dict, output_key: str):
    example['instruction'] = INSTRUCTIONS['seahorse_data']
    example['input'] = example['text']
    source = format_prompt(example, prompt_dict=prompt_dict)
    target = format_output(
        example,
        eos_token=tokenizer.eos_token if training_args.end_sequence_with_eos else None,
        output_key=output_key,
    )
    return source + ' ' + target

def _get_text_target_word_by_word(example: dict, output_key: str):
    example['instruction'] = INSTRUCTIONS['seahorse_data']
    example['input'] = example['text']
    source = format_prompt(example, prompt_dict=prompt_dict)
    target = format_output_word_by_word(
        example,
        eos_token=tokenizer.eos_token if training_args.end_sequence_with_eos else None,
        output_key=output_key,
    )
    return [source + ' ' + ' '.join(target[:t]) for t in range(len(target))], target

In [13]:
# data_module = data_utils.make_classification_reward_modeling_data_module(
#         tokenizer=tokenizer,
#         data_args=data_args,
#         training_args=training_args,
#     )
import datasets
from alpaca_farm import utils

prompt_dict = utils.jload(data_args.prompt_dict_path)
data_files = {"train": "train.json", "validation": "validation.json"}
dataset_json = datasets.load_dataset(data_args.dataset_path, data_files=data_files)
dataset_json = dataset_json.filter(lambda example: example['worker_lang'] == 'en-US')
train_dataset = dataset_json['train']
eval_dataset = dataset_json['validation']

In [14]:
import pandas as pd
INSTRUCTIONS = {
    'seahorse_data': "Generate a one-sentence summary of this post.",
}
eval_dict_data = pd.DataFrame(eval_dataset).to_dict(orient="records")

indices_to_remove = []
for i, dict_data in enumerate(eval_dict_data):
    if dict_data['text'] is None:
        indices_to_remove.append(i)
for index in sorted(indices_to_remove, reverse=True):
    del eval_dict_data[index]

print(len(indices_to_remove))

train_dict_data = pd.DataFrame(train_dataset).to_dict(orient="records")

indices_to_remove = []
for i, dict_data in enumerate(train_dict_data):
    if dict_data['text'] is None:
        indices_to_remove.append(i)
for index in sorted(indices_to_remove, reverse=True):
    del train_dict_data[index]

print(len(indices_to_remove))

0
0


In [None]:
import torch
from tqdm import trange
num_correct = 0.
num_total = 0.
model = model.cuda()
model.eval()
with torch.no_grad():
    for i in trange(len(eval_dataset)):
        sequences = _get_text(eval_dict_data[i], 'summary')
        input_ids = tokenizer(sequences, truncation=True, return_tensors='pt')['input_ids']
        # with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):        
        outputs = model(input_ids.cuda())
        rewards = outputs.rewards
        label = eval_dict_data[i]['question4']
        pred = 1. if rewards > 0 else 0.
        if pred == label:
            num_correct +=1
        num_total +=1

        if i % 100 == 0:
            print(num_correct/num_total)
        if torch.isnan(rewards).any() or torch.isinf(rewards).any():
            print(rewards)

  0%|▍                                                                                                                                                    | 7/2183 [00:00<02:13, 16.36it/s]

1.0


  5%|███████▎                                                                                                                                           | 109/2183 [00:03<00:51, 40.51it/s]

0.7227722772277227


 10%|██████████████                                                                                                                                     | 209/2183 [00:09<01:02, 31.80it/s]

0.7313432835820896


 14%|████████████████████▉                                                                                                                              | 311/2183 [00:14<01:45, 17.66it/s]

0.7475083056478405


 19%|███████████████████████████▌                                                                                                                       | 409/2183 [00:17<00:41, 42.68it/s]

0.7456359102244389


 23%|██████████████████████████████████▎                                                                                                                | 509/2183 [00:23<00:50, 33.20it/s]

0.7385229540918163


 28%|████████████████████████████████████████▉                                                                                                          | 608/2183 [00:28<01:54, 13.70it/s]

0.7371048252911814


 33%|███████████████████████████████████████████████▉                                                                                                   | 711/2183 [00:31<00:30, 47.91it/s]

0.7275320970042796


 37%|██████████████████████████████████████████████████████▍                                                                                            | 808/2183 [00:36<00:44, 31.03it/s]

0.7228464419475655


 42%|█████████████████████████████████████████████████████████████▏                                                                                     | 908/2183 [00:42<01:24, 15.05it/s]

0.7125416204217536


 46%|██████████████████████████████████████████████████████████████████▉                                                                               | 1000/2183 [00:45<00:26, 44.14it/s]

0.7072927072927073


 51%|██████████████████████████████████████████████████████████████████████████▏                                                                       | 1109/2183 [00:50<00:29, 36.57it/s]

0.7057220708446866


 51%|██████████████████████████████████████████████████████████████████████████▉                                                                       | 1121/2183 [00:51<00:26, 40.57it/s]