T5 Flan
----
This notebook follows this guide: https://huggingface.co/docs/transformers/tasks/summarization

# Setup

In [1]:
# ! pip install transformers datasets evaluate rouge_score

In [2]:
pt = True

from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForSeq2Seq
if pt:
    from transformers import TFAutoModelForSeq2SeqLM, AdamWeightDecay
else:
    import tensorflow as tf
    from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
    from transformers.keras_callbacks import KerasMetricCallback
import evaluate
import numpy as np
# from google.colab import drive

Defining the model to use.

In [3]:
# google_drive = '/content/drive'
# drive.mount(google_drive)

checkpoint = 'google/flan-t5-small'
base_path = '..'
# base_path = google_drive + '/My Drive/coding/flan_t5_small'
output_path = base_path + '/out'
model_save_path = base_path + '/model'

# Data

In [4]:
billsum = load_dataset('billsum', split='ca_test').train_test_split(test_size=0.2)

# HF data objects can be indexed EITHER by obs or key: the former returns a dict, the latter a list
print(billsum['train'][0])
print(billsum['train']['summary'][:5])

{'text': 'The people of the State of California do enact as follows:\n\n\nSECTION 1.\n(a) It is the intent of the Legislature to clarify that pawnbrokers and other secondhand dealers are to report their acquisition of tangible personal property received in pledge, trade, consignment, or auction or by purchase using plain text, in descriptive language historically used in the pawn and secondhand industries when reporting to the single, statewide, and uniform electronic reporting system operated by the Department of Justice, or if not yet implemented in their respective jurisdictions, on paper forms sent to the local police chief or sheriff of the jurisdiction in which the secondhand dealer is physically located.\n(b) It is further the intent of the Legislature that by specifying this manner of reporting, it will relieve all secondhand dealers and pawnbrokers of the inherent costs and burdens imposed under existing law that requires these businesses to report their daily acquisitions of 

In [5]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [6]:
def preprocess_function(data):
    prefix = "summarize: "
    inputs = [prefix + text for text in data['text']]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)
    
    labels = tokenizer(text_target=data['summary'], max_length=128, truncation=True)
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

In [7]:
tokenized_billsum = billsum.map(preprocess_function, batched=True)
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=checkpoint,
    return_tensors='tf'
)

Map:   0%|          | 0/989 [00:00<?, ? examples/s]

Map:   0%|          | 0/248 [00:00<?, ? examples/s]

# Defining Model and Metrics

In [8]:
rouge = evaluate.load('rouge')
def compute_metrics(eval_pred, evalutor=rouge, tokenizer=tokenizer):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id) # not sure what this is doing
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = evalutor.compute(
        predictions=decoded_preds,
        references=decoded_labels,
        use_stemmer=True
    )
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}


## PyTorch

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
training_args = Seq2SeqTrainingArguments(
    output_dir='./model_temp',
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=4,
    predict_with_generate=True,
    fp16=True
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_billsum["train"],
    eval_dataset=tokenized_billsum["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

## TensorFlow

In [9]:
optimizer = AdamWeightDecay(learning_rate=2e-5, weight_decay_rate=0.01)
model = TFAutoModelForSeq2SeqLM.from_pretrained(checkpoint)
model.compile(optimizer=optimizer)

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at google/flan-t5-small.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


Note that HF models have built-in loss functions, so one does not need to be specified when compiling.

In [None]:
train_set = model.prepare_tf_dataset(
    tokenized_billsum['train'],
    shuffle=True,
    batch_size=16,
    collate_fn=data_collator
)
test_set = model.prepare_tf_dataset(
    tokenized_billsum['test'],
    shuffle=False,
    batch_size=16,
    collate_fn=data_collator
)

metric_callback = KerasMetricCallback(metric_fn=compute_metrics, eval_dataset=test_set)

model.fit(
    x=train_set,
    validation_data=test_set,
    epochs=3,
    callbacks=[metric_callback]
)

In [None]:
tokenizer.save_pretrained(model_save_path)
model.save_pretrained(model_save_path)

In [15]:
from datetime import datetime

model = AutoModelForSeq2SeqLM.from_pretrained(model_save_path)
tokenizer = AutoTokenizer.from_pretrained(model_save_path)
one_test_case = 'summarize: ' + billsum['test'][0]['text']
tokenized_input = tokenizer(one_test_case, return_tensors='pt').input_ids
raw_output = model.generate(tokenized_input, max_new_tokens=128, do_sample=False)
# raw_output = model.generate(**tokenized_input)
text_output = tokenizer.decode(raw_output[0], skip_special_tokens=True)

print(text_output)

save_time = datetime.now().strftime(r'%Y%m%d%H%M%S')
with open(output_path + f'/T5_billsum_example_text_{save_time}.txt', 'w', encoding='utf8') as f:
    f.write(one_test_case)
with open(output_path + f'/T5_billsum_example_output_{save_time}.txt', 'w', encoding='utf8') as f:
    f.write(text_output)

OSError: [Errno 22] Invalid argument: './out/T5_billsum_example_text_2023-07-28 17:16:18.633901.txt'