In [1]:
pip install -U transformers datasets evaluate nltk accelerate trl wandb peft rouge_score

Collecting transformers
  Obtaining dependency information for transformers from https://files.pythonhosted.org/packages/d0/a7/7eedcf6a359e1e1eff3bc204ad022485aa5d88c08e1e3e0e0aee8a2e2235/transformers-4.47.0-py3-none-any.whl.metadata
  Using cached transformers-4.47.0-py3-none-any.whl.metadata (43 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Obtaining dependency information for tokenizers<0.22,>=0.21 from https://files.pythonhosted.org/packages/22/06/69d7ce374747edaf1695a4f61b83570d91cc8bbfc51ccfecf76f56ab4aac/tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Using cached tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Note: you may need to restart the kernel to use updated packages.


In [2]:
update_intervals = [20, 40]
num_tokens = [1, 10]
memory_len = 3

In [3]:
from transformers import AutoTokenizer, T5ForConditionalGeneration, T5Config, logging
import torch
from custom_model import RTModel, CustomTrainer
import os
import logging as sys_logging

sys_logging.captureWarnings(True)
logger = sys_logging.getLogger("transformers.trainer")
class LogFilter(sys_logging.Filter):
    def filter(self, record):
        message = record.getMessage()
        return "deprecated" not in message and "hang" not in message
logger.addFilter(LogFilter())

os.environ["TOKENIZERS_PARALLELISM"] = "true"  # Dealing with HF logging bugs

model_name = "google/flan-t5-base"

embed_dim = T5Config.from_pretrained(model_name).d_model
tokenizer = AutoTokenizer.from_pretrained(model_name)

2024-12-07 03:56:14.131734: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
from datasets import load_dataset
import numpy as np

train_dataset = load_dataset("Jise/ruletaker", split="train")
test_dataset = load_dataset("Jise/ruletaker", split="test")
ood_dataset = load_dataset("Jise/ruletaker", split="ood_test")

prompt = "Based on the facts and rules, first think step by step and give simple reasoning steps citing the rules, and then output whether the assertion is true by true or false."

temp = tokenizer(train_dataset["reasoning"])
print("Max output token number:", max([len(s) for s in temp["input_ids"]]))

def preprocess(examples):
    inputs = [prompt + "Assertion:" + x for x in examples["context"]]
    model_inputs = tokenizer(inputs, max_length=511, truncation=True, padding=True, return_tensors="pt")
    labels = [x + "\nThe answer is " + y for x, y in zip(examples["reasoning"], examples["flag"])]
    labels = tokenizer(text_target=labels, max_length=262, truncation=True, padding=True, return_tensors="pt")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

train_dataset = train_dataset.map(preprocess, batched=True)
test_dataset = test_dataset.map(preprocess, batched=True)
ood_dataset = ood_dataset.map(preprocess, batched=True)

print(train_dataset)

Max output token number: 257
Dataset({
    features: ['context', 'statement', 'reasoning', 'depth', 'flag', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 1000
})


In [5]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq
import requests
import pickle
import torch
import wandb
import nltk
import evaluate

nltk.download("punkt", quiet=True)
nltk.download("punkt_tab", quiet=True)
metric = evaluate.load("rouge")

with open("TOKENS.pkl", "rb") as f:
    TOKENS = pickle.load(f)

HF_TOKEN = TOKENS["HF_TOKEN"]

def compute_metrics(eval_preds):
    preds, labels = eval_preds

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    result["acc"] = np.mean([("true" in label.lower() and "true" in pred.lower()) or ("false" in label.lower() and "false" in pred.lower()) for pred, label in zip(preds, labels)])

    return result

In [6]:
for update_interval in update_intervals:

    num_token = 5

    model = RTModel.from_pretrained(model_name, num_token=num_token)
    model.config.max_length = 256

    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

    sft_training_args = Seq2SeqTrainingArguments(
        output_dir=f"./flan-t5-rt-ruletaker-{num_token}-{update_interval}",
        eval_strategy="no",
        save_strategy="no",
        save_total_limit=1,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=16,
        gradient_accumulation_steps=4,
        gradient_checkpointing=True,
        label_smoothing_factor=1e-5,
        learning_rate=1e-4,
        num_train_epochs=5,
        logging_strategy="no",
        weight_decay=0.01,
        bf16=True,
        predict_with_generate=True,
        generation_max_length=256,
        push_to_hub=False,
        report_to="none",
        save_safetensors=False,
    )

    rt_trainer = CustomTrainer(
        update_interval=update_interval,
        memory_len=memory_len,
        model=model,
        args=sft_training_args,
        train_dataset=train_dataset,
        eval_dataset={"test": test_dataset, "ood": ood_dataset},
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    rt_trainer.train()
    print(f"Test Results for update interval: {update_interval}")
    print(rt_trainer.evaluate())

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Test Results for update interval: 20


{'eval_test_loss': 61.051517486572266, 'eval_test_rouge1': 0.07391676186889429, 'eval_test_rouge2': 0.040734904287047596, 'eval_test_rougeL': 0.06712375123187533, 'eval_test_rougeLsum': 0.0728435298485233, 'eval_test_acc': 0.0, 'eval_test_runtime': 508.515, 'eval_test_samples_per_second': 0.492, 'eval_test_steps_per_second': 0.031, 'epoch': 5.0, 'eval_ood_loss': 42.98136901855469, 'eval_ood_rouge1': 0.10203888697244479, 'eval_ood_rouge2': 0.05689134788080519, 'eval_ood_rougeL': 0.09227450903251991, 'eval_ood_rougeLsum': 0.10094993481445658, 'eval_ood_acc': 0.0, 'eval_ood_runtime': 740.5145, 'eval_ood_samples_per_second': 0.506, 'eval_ood_steps_per_second': 0.032}


  super().__init__(*args, **kwargs)



Step,Training Loss


Test Results for update interval: 40


{'eval_test_loss': 32.86512756347656, 'eval_test_rouge1': 0.14756606049117407, 'eval_test_rouge2': 0.09454600518784967, 'eval_test_rougeL': 0.13247346467543225, 'eval_test_rougeLsum': 0.1389711238588532, 'eval_test_acc': 0.036, 'eval_test_runtime': 508.5258, 'eval_test_samples_per_second': 0.492, 'eval_test_steps_per_second': 0.031, 'epoch': 5.0, 'eval_ood_loss': 28.138641357421875, 'eval_ood_rouge1': 0.17954083862827658, 'eval_ood_rouge2': 0.10817440215581078, 'eval_ood_rougeL': 0.15745894128450233, 'eval_ood_rougeLsum': 0.1658277861664561, 'eval_ood_acc': 0.016, 'eval_ood_runtime': 738.8661, 'eval_ood_samples_per_second': 0.508, 'eval_ood_steps_per_second': 0.032}


In [6]:
for num_token in num_tokens:

    update_interval = 10

    model = RTModel.from_pretrained(model_name, num_token=num_token)
    model.config.max_length = 256

    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

    sft_training_args = Seq2SeqTrainingArguments(
        output_dir=f"./flan-t5-rt-ruletaker-{num_token}-{update_interval}",
        eval_strategy="no",
        save_strategy="no",
        save_total_limit=1,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=16,
        gradient_accumulation_steps=4,
        gradient_checkpointing=True,
        label_smoothing_factor=1e-5,
        learning_rate=1e-4,
        num_train_epochs=5,
        logging_strategy="no",
        weight_decay=0.01,
        bf16=True,
        predict_with_generate=True,
        generation_max_length=256,
        push_to_hub=False,
        report_to="none",
        save_safetensors=False,
    )

    rt_trainer = CustomTrainer(
        update_interval=update_interval,
        memory_len=memory_len,
        model=model,
        args=sft_training_args,
        train_dataset=train_dataset,
        eval_dataset={"test": test_dataset, "ood": ood_dataset},
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    rt_trainer.train()
    print(f"Test Results for num token: {num_token}")
    print(rt_trainer.evaluate())

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Test Results for num token: 1


{'eval_test_loss': 39.05982208251953, 'eval_test_rouge1': 0.22105174131491598, 'eval_test_rouge2': 0.15474633104325503, 'eval_test_rougeL': 0.20263790745498275, 'eval_test_rougeLsum': 0.2139468104566816, 'eval_test_acc': 0.244, 'eval_test_runtime': 495.6717, 'eval_test_samples_per_second': 0.504, 'eval_test_steps_per_second': 0.032, 'epoch': 5.0, 'eval_ood_loss': 48.232879638671875, 'eval_ood_rouge1': 0.22399039439266244, 'eval_ood_rouge2': 0.1538141368139594, 'eval_ood_rougeL': 0.19464367854394313, 'eval_ood_rougeLsum': 0.21754566709298906, 'eval_ood_acc': 0.31466666666666665, 'eval_ood_runtime': 720.6752, 'eval_ood_samples_per_second': 0.52, 'eval_ood_steps_per_second': 0.033}


  super().__init__(*args, **kwargs)



Step,Training Loss


KeyboardInterrupt: 