Setup Model and Tokenizer

In [None]:
import torch
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM, BitsAndBytesConfig, Trainer, TrainingArguments, Seq2SeqTrainer
)
from datasets import load_dataset
from peft import LoraConfig, get_peft_model

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q", "v"],  # only to query and value layers
    lora_dropout=0.1,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)

tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained(
    "google/flan-t5-base",
    quantization_config=quantization_config,
    device_map="auto"
)
model = get_peft_model(model, lora_config)
model.to(device)

Setup data and Training

In [None]:
from datasets import load_dataset

dataset = load_dataset("boolq")

# for reference
print(dataset['train'].to_pandas().head())

In [None]:
def preprocess_function(examples):
  inputs = [f"Question: {question}  Passage: {passage}" for question, passage in zip(examples['question'], examples['passage'])]
  targets = ['true' if answer else 'false' for answer in examples['answer']]
  
  model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding='max_length')
  labels = tokenizer(targets, max_length=10, truncation=True, padding='max_length')
  model_inputs["labels"] = labels["input_ids"]
  
  return model_inputs

# Preprocess
tokenized_dataset = dataset.map(preprocess_function, batched=True)

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="epoch",
    fp16=True,  # for efficiency
    label_names=["labels"]
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
)

In [None]:
trainer.train()

In [None]:
eval_results = trainer.evaluate()

print(f"Eval results: {eval_results}")

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("./results/checkpoint-3537")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")

input_text = "Question: Is the sky blue? Passage: The sky is not blue on a clear day."

input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

output_ids = model.generate(input_ids)

output_ids = model.generate(input_ids)
predicted_answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)

print(f"Predicted answer: {predicted_answer}") # should be no