In [1]:
pip install -U transformers datasets evaluate nltk accelerate trl wandb peft rouge_score

Collecting transformers
  Obtaining dependency information for transformers from https://files.pythonhosted.org/packages/d0/a7/7eedcf6a359e1e1eff3bc204ad022485aa5d88c08e1e3e0e0aee8a2e2235/transformers-4.47.0-py3-none-any.whl.metadata
  Using cached transformers-4.47.0-py3-none-any.whl.metadata (43 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Obtaining dependency information for tokenizers<0.22,>=0.21 from https://files.pythonhosted.org/packages/22/06/69d7ce374747edaf1695a4f61b83570d91cc8bbfc51ccfecf76f56ab4aac/tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Using cached tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Note: you may need to restart the kernel to use updated packages.


In [2]:
update_interval = 10
num_token = 5
memory_len = 3

In [3]:
from transformers import AutoTokenizer, T5ForConditionalGeneration, T5Config, logging
import torch
from custom_model import RTModel, CustomTrainer
import os
import logging as sys_logging

sys_logging.captureWarnings(True)
logger = sys_logging.getLogger("transformers.trainer")
class LogFilter(sys_logging.Filter):
    def filter(self, record):
        message = record.getMessage()
        return "deprecated" not in message and "hang" not in message
logger.addFilter(LogFilter())

os.environ["TOKENIZERS_PARALLELISM"] = "true"  # Dealing with HF logging bugs

model_name = "google/flan-t5-base"

embed_dim = T5Config.from_pretrained(model_name).d_model

model = RTModel.from_pretrained(model_name, num_token=num_token)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model.config.max_length = 256

2024-12-07 00:47:43.088275: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
from datasets import load_dataset
import numpy as np

train_dataset = load_dataset("Jise/ruletaker", split="train")
test_dataset = load_dataset("Jise/ruletaker", split="test")
ood_dataset = load_dataset("Jise/ruletaker", split="ood_test")

prompt = "Based on the facts and rules, first think step by step and give simple reasoning steps citing the rules, and then output whether the assertion is true by true or false."

temp = tokenizer(train_dataset["reasoning"])
print("Max output token number:", max([len(s) for s in temp["input_ids"]]))

def preprocess(examples):
    inputs = [prompt + "Assertion:" + x for x in examples["context"]]
    model_inputs = tokenizer(inputs, max_length=511, truncation=True, padding=True, return_tensors="pt")
    labels = [x + "\nThe answer is " + y for x, y in zip(examples["reasoning"], examples["flag"])]
    labels = tokenizer(text_target=labels, max_length=262, truncation=True, padding=True, return_tensors="pt")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

train_dataset = train_dataset.map(preprocess, batched=True)
test_dataset = test_dataset.map(preprocess, batched=True)
ood_dataset = ood_dataset.map(preprocess, batched=True)

print(train_dataset)

Max output token number: 257
Dataset({
    features: ['context', 'statement', 'reasoning', 'depth', 'flag', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 1000
})


In [5]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
import requests
import pickle
import torch
import wandb
import nltk
import evaluate

nltk.download("punkt", quiet=True)
nltk.download("punkt_tab", quiet=True)
metric = evaluate.load("rouge")

run_name = "Flan-T5_RuleTaker_RT"

with open("TOKENS.pkl", "rb") as f:
    TOKENS = pickle.load(f)

WANDB_TOKEN = TOKENS["WANDB_TOKEN"]
HF_TOKEN = TOKENS["HF_TOKEN"]

wandb.login(key=WANDB_TOKEN)

def compute_metrics(eval_preds):
    preds, labels = eval_preds

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in labels]
 
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    result["acc"] = np.mean([("true" in label.lower() and "true" in pred.lower()) or ("false" in label.lower() and "false" in pred.lower()) for pred, label in zip(preds, labels)])
    
    return result
    
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

sft_training_args = Seq2SeqTrainingArguments(
    output_dir="./flan-t5-rt-ruletaker",
    eval_strategy="steps",
    eval_steps=300,
    save_strategy="steps",
    save_steps=300,
    save_total_limit=1,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    label_smoothing_factor=1e-5,
    learning_rate=1e-4,
    num_train_epochs=5,
    logging_dir="./logs",
    logging_steps=10,
    weight_decay=0.01,
    bf16=True,
    predict_with_generate=True,
    generation_max_length=256,
    push_to_hub=True,
    report_to="wandb",
    run_name=run_name,
    hub_token=HF_TOKEN,
    hub_model_id="Jise/flan-t5-ruletaker-rt",
    save_safetensors=False,
)

rt_trainer = CustomTrainer(
    update_interval=update_interval,
    memory_len=memory_len,
    model=model,
    args=sft_training_args,
    train_dataset=train_dataset,
    eval_dataset={"test": test_dataset, "ood": ood_dataset},
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjiseshen[0m ([33mjise[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/idies/.netrc
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [6]:
rt_trainer.train()

Step,Training Loss,Validation Loss,Test Loss,Test Rouge1,Test Rouge2,Test Rougel,Test Rougelsum,Test Acc,Ood Loss,Ood Rouge1,Ood Rouge2,Ood Rougel,Ood Rougelsum,Ood Acc
300,0.0033,No log,50.180489,0.230879,0.144393,0.218862,0.2245,0.052,38.53915,0.29788,0.185867,0.275606,0.288075,0.058667
600,0.0014,No log,39.190765,0.190051,0.128402,0.175926,0.189851,0.176,34.519337,0.220282,0.147778,0.195363,0.218153,0.194667


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.



TrainOutput(global_step=625, training_loss=0.05504026737213135, metrics={'train_runtime': 3899.1849, 'train_samples_per_second': 1.282, 'train_steps_per_second': 0.16, 'total_flos': 1885790039040000.0, 'train_loss': 0.05504026737213135, 'epoch': 5.0})

In [7]:
print("RT Test Results:")
rt_trainer.evaluate()

RT Test Results:


{'eval_test_loss': 39.73509216308594,
 'eval_test_rouge1': 0.1933269292162502,
 'eval_test_rouge2': 0.1294599170243153,
 'eval_test_rougeL': 0.17894140530366087,
 'eval_test_rougeLsum': 0.1925948014545385,
 'eval_test_acc': 0.16,
 'eval_test_runtime': 507.5324,
 'eval_test_samples_per_second': 0.493,
 'eval_test_steps_per_second': 0.032,
 'epoch': 5.0,
 'eval_ood_loss': 35.09615707397461,
 'eval_ood_rouge1': 0.21526193488292422,
 'eval_ood_rouge2': 0.1444443530901435,
 'eval_ood_rougeL': 0.19240451977751177,
 'eval_ood_rougeLsum': 0.2123844494270246,
 'eval_ood_acc': 0.168,
 'eval_ood_runtime': 742.4027,
 'eval_ood_samples_per_second': 0.505,
 'eval_ood_steps_per_second': 0.032}