In [1]:
!pip install datasets evaluate rouge_score



In [2]:
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
    TrainerCallback,
)
from datasets import load_dataset
from tqdm import tqdm
import evaluate


In [14]:
model_name = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

train_data = load_dataset("cnn_dailymail", "3.0.0", split="train[:5000]")
val_data = load_dataset("cnn_dailymail", "3.0.0", split="validation[:1000]")


In [15]:
def preprocess_function(batch):
    inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=1024)
    outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    # batch["labels"] = [
    #     [(token if token != tokenizer.pad_token_id else -100) for token in label]
    #     for label in outputs.input_ids
    # ]
    batch["labels"] = outputs.input_ids
    return batch

    # return {
    #     'input_ids': inputs['input_ids'],
    #     'attention_mask': inputs['attention_mask'],
    #     'labels': outputs['input_ids']
    # }


In [16]:
train_data = train_data.map(preprocess_function, batched=True, remove_columns=["article", "highlights", "id"])
val_data = val_data.map(preprocess_function, batched=True, remove_columns=["article", "highlights", "id"])

# train_data = train_data.map(preprocess_function, batched=True)
# val_data = val_data.map(preprocess_function, batched=True)


Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [17]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model,padding="longest")


In [None]:
train_data[0]

In [19]:
training_args = Seq2SeqTrainingArguments(
    output_dir="flan-t5_checkpoints",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs = 1,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=3,
    load_best_model_at_end=True,
    eval_steps=100,
    generation_max_length=64,
    predict_with_generate=True,
    logging_steps=100,
    gradient_accumulation_steps=1,
    fp16=True,
    remove_unused_columns=False,
    report_to="none",
)




In [20]:
metric = evaluate.load("rouge")

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions

    if len(preds.shape) > 2:
      preds = preds.argmax(axis=-1)

    preds = preds.clip(0, len(tokenizer) - 1)

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    scores = metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=["rouge1","rouge2","rougeL"])
    return scores


In [21]:
class TQDMProgressBar(TrainerCallback):
    def __init__(self):
        self.pbar = None


    def on_train_begin(self, args, state, control, **kwargs):
        self.pbar = tqdm(total=args.max_steps, desc="Training Progress", unit="step")

    def on_step_end(self, args, state, control, **kwargs):
        self.pbar.update(1)
        if state.log_history:
            latest_log = state.log_history[-1]
            desc = ", ".join([f"{k}: {v:.4f}" for k, v in latest_log.items() if isinstance(v, (int, float))])
            self.pbar.set_postfix_str(desc)

    def on_train_end(self, args, state, control, **kwargs):
        self.pbar.close()

progress_bar = TQDMProgressBar()

In [22]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[progress_bar],
)


  trainer = Seq2SeqTrainer(


In [23]:
trainer.train()

# results = trainer.evaluate()
# print(results)

Training Progress: 1step [00:00,  3.84step/s]

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel
1,1.1006,0.821295,0.125445,0.05553,0.092938


Training Progress: 1250step [05:19,  3.91step/s, loss: 1.1006, grad_norm: 2.4016, learning_rate: 0.0000, epoch: 0.9600, step: 1200.0000]There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].
Training Progress: 1250step [10:55,  1.91step/s, loss: 1.1006, grad_norm: 2.4016, learning_rate: 0.0000, epoch: 0.9600, step: 1200.0000]


TrainOutput(global_step=1250, training_loss=1.5673149383544922, metrics={'train_runtime': 655.4845, 'train_samples_per_second': 7.628, 'train_steps_per_second': 1.907, 'total_flos': 1353418014720000.0, 'train_loss': 1.5673149383544922, 'epoch': 1.0})

In [24]:
tokenizer.decode(0)

'<pad>'