In [2]:
pip install -U transformers==4.45.2 datasets evaluate nltk accelerate wandb rouge_score

Note: you may need to restart the kernel to use updated packages.


In [4]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_name = "google/flan-t5-base"

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model.config.max_length = 256

In [5]:
from datasets import load_dataset
import numpy as np

train_dataset = load_dataset("Jise/ruletaker", split="train")
test_dataset = load_dataset("Jise/ruletaker", split="test")
ood_dataset = load_dataset("Jise/ruletaker", split="ood_test")

prompt = "Based on the facts and rules, first think step by step and give simple reasoning steps citing the rules, and then output whether the assertion is true by true or false."

temp = tokenizer(train_dataset["reasoning"])
print("Max output token number:", max([len(s) for s in temp["input_ids"]]))

def preprocess(examples):
    inputs = [prompt + "Assertion:" + x for x in examples["context"]]
    model_inputs = tokenizer(inputs, max_length=511, truncation=True, padding=True, return_tensors="pt")
    labels = [x + "\nThe answer is " + y for x, y in zip(examples["reasoning"], examples["flag"])]
    labels = tokenizer(text_target=labels, max_length=262, truncation=True, padding=True, return_tensors="pt")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

train_dataset = train_dataset.map(preprocess, batched=True)
test_dataset = test_dataset.map(preprocess, batched=True)
ood_dataset = ood_dataset.map(preprocess, batched=True)

print(train_dataset)

Max output token number: 257
Dataset({
    features: ['context', 'statement', 'reasoning', 'depth', 'flag', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 1000
})


In [4]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
from trl import DPOTrainer, DPOConfig
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
import requests
import pickle
from transformers.optimization import Adafactor, AdafactorSchedule
import wandb
import nltk
import evaluate

nltk.download("punkt", quiet=True)
nltk.download("punkt_tab", quiet=True)
metric = evaluate.load("rouge")

run_name = "Flan-T5_RuleTaker_SFT"

with open("TOKENS.pkl", "rb") as f:
    TOKENS = pickle.load(f)

WANDB_TOKEN = TOKENS["WANDB_TOKEN"]
HF_TOKEN = TOKENS["HF_TOKEN"]

wandb.login(key=WANDB_TOKEN)

def compute_metrics(eval_preds):
    preds, labels = eval_preds

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in labels]
 
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    result["acc"] = np.mean([("true" in label.lower() and "true" in pred.lower()) or ("false" in label.lower() and "false" in pred.lower()) for pred, label in zip(preds, labels)])
    
    return result

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

sft_training_args = Seq2SeqTrainingArguments(
    output_dir="./flan-t5-sft-ruletaker",
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=1,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    label_smoothing_factor=1e-5,
    learning_rate=1e-4,
    num_train_epochs=5,
    logging_dir="./logs",
    logging_steps=100,
    weight_decay=0.01,
    bf16=True,
    predict_with_generate=True,
    push_to_hub=True,
    report_to="wandb",
    run_name=run_name,
    hub_token=HF_TOKEN,
    hub_model_id="Jise/flan-t5-ruletaker-sft",
    save_safetensors=False,
)

sft_trainer = Seq2SeqTrainer(
    model=model,
    args=sft_training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

2024-12-05 21:48:21.767950: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelis

In [5]:
print("Zero-shot Test result:")
sft_trainer.evaluate()

Zero-shot Test result:




{'eval_loss': 36.28911209106445,
 'eval_model_preparation_time': 0.0048,
 'eval_rouge1': 0.018774714781507364,
 'eval_rouge2': 0.0,
 'eval_rougeL': 0.018716265665975895,
 'eval_rougeLsum': 0.018794678105219717,
 'eval_acc': 0.492,
 'eval_runtime': 14.8611,
 'eval_samples_per_second': 16.822,
 'eval_steps_per_second': 1.077}

In [6]:
print("Zero-shot OOD result:")
sft_trainer.evaluate(eval_dataset=ood_dataset)

Zero-shot OOD result:


{'eval_loss': 22.32184410095215,
 'eval_model_preparation_time': 0.0048,
 'eval_rouge1': 0.007528396075322917,
 'eval_rouge2': 0.0,
 'eval_rougeL': 0.007524591822213361,
 'eval_rougeLsum': 0.007532277095522594,
 'eval_acc': 0.42933333333333334,
 'eval_runtime': 20.6375,
 'eval_samples_per_second': 18.171,
 'eval_steps_per_second': 1.163}

In [None]:
sft_trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss,Validation Loss,Model Preparation Time,Rouge1,Rouge2,Rougel,Rougelsum,Acc
100,3.593,0.114167,0.0048,0.506817,0.394937,0.446455,0.490164,0.624
200,0.1114,0.074364,0.0048,0.572086,0.441646,0.491652,0.562684,0.636
300,0.0872,0.071179,0.0048,0.566286,0.431515,0.480286,0.545394,0.624
400,0.0794,0.071062,0.0048,0.546849,0.4334,0.473044,0.537601,0.624
500,0.0735,0.068899,0.0048,0.561883,0.439878,0.475096,0.543862,0.608
600,0.0728,0.068771,0.0048,0.551749,0.435579,0.470588,0.536157,0.616




In [10]:
print("SFT Test result:")
sft_trainer.evaluate()

SFT Test result:


{'eval_loss': 0.06863230466842651,
 'eval_model_preparation_time': 0.0048,
 'eval_rouge1': 0.5405360565791885,
 'eval_rouge2': 0.4278123695553129,
 'eval_rougeL': 0.4657341716918304,
 'eval_rougeLsum': 0.5253153542621036,
 'eval_acc': 0.624,
 'eval_runtime': 55.6622,
 'eval_samples_per_second': 4.491,
 'eval_steps_per_second': 0.287,
 'epoch': 5.0}

In [11]:
print("SFT OOD result:")
sft_trainer.evaluate(eval_dataset=ood_dataset)

SFT OOD result:


{'eval_loss': 0.10819105058908463,
 'eval_model_preparation_time': 0.0048,
 'eval_rouge1': 0.48410530695527554,
 'eval_rouge2': 0.4000244274186455,
 'eval_rougeL': 0.39304103134106605,
 'eval_rougeLsum': 0.4759229556808772,
 'eval_acc': 0.36533333333333334,
 'eval_runtime': 75.1572,
 'eval_samples_per_second': 4.99,
 'eval_steps_per_second': 0.319,
 'epoch': 5.0}