In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM


device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base").to(device)

In [None]:
from datasets import load_dataset
from torch.utils.data import DataLoader


def preprocess_data(data, max_input_length=512, max_target_length=128):
    input_text = (
        "summarize: " + data["article"]
        if type(data["article"]) == str
        else ["Summarize: " + article for article in data["article"]]
    )
    target_text = data["highlights"]
    inputs = tokenizer(
        input_text,
        max_length=max_input_length,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )
    targets = tokenizer(
        target_text,
        max_length=max_target_length,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )

    return {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "labels": targets["input_ids"],
    }


dataset = load_dataset("cnn_dailymail", "3.0.0")

# Fewer samples
num_samples = len(dataset['train']) // 10
dataset['train'] = dataset['train'].select(range(num_samples))

tokenized = dataset.map(preprocess_data, batched=True, remove_columns=dataset["train"].column_names)

In [None]:
batch_size = 8

train_dataloader = DataLoader(tokenized['train'], batch_size=batch_size, shuffle=True)
validation_dataloader = DataLoader(tokenized['validation'], batch_size=batch_size, shuffle=False)
train_dataloader = DataLoader(tokenized['train'], batch_size=batch_size, shuffle=False)

In [None]:
from tqdm import tqdm
from torch.optim import AdamW


num_epochs = 10


optimizer = AdamW(model.parameters(), lr=5e-5)

model.train()
for epoch in range(num_epochs):
    total_loss = 0
    for batch in tqdm(train_dataloader):
        input_ids = torch.stack(batch["input_ids"], dim=1).to(device)
        attention_mask = torch.stack(batch["attention_mask"], dim=1).to(device)
        labels = torch.stack(batch["labels"], dim=1).to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1} Loss: {total_loss / len(train_dataloader):.4f}")