In [None]:
#!pip install sentencepiece wandb

In [None]:
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset

In [None]:
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('t5-base')
model = AutoModelForSeq2SeqLM.from_pretrained('t5-base', max_length=512, pad_token_id=tokenizer.pad_token_id)

In [None]:
# Prepare the training arguments
training_args = TrainingArguments(
    output_dir='./results',     # Directory to save the trained model
    num_train_epochs=3,         # Number of training epochs
    per_device_train_batch_size=4,  # Batch size for training
    per_device_eval_batch_size=8,   # Batch size for evaluation
    warmup_steps=500,           # Number of warmup steps
    weight_decay=0.01,          # Weight decay
    logging_dir='./logs',       # Directory to save the training logs
    logging_steps=100,          # Log training loss every N steps
    evaluation_strategy='epoch',  # Evaluate after each epoch
    save_strategy='epoch',      # Save checkpoint after each epoch
)

In [None]:
# Load and preprocess the dataset
dataset = load_dataset('cnn_dailymail', '3.0.0', split='train[:100]')

In [None]:
dataset[0]

In [None]:
def preprocess_function(examples):
    inputs = examples["article"]
    targets = examples["highlights"]
    inputs = tokenizer(inputs, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
    targets = tokenizer(targets, padding='max_length', truncation=True, max_length=150, return_tensors='pt')
    inputs['labels'] = targets['input_ids']
    inputs['decoder_attention_mask'] = targets['attention_mask']
    return inputs

In [None]:
dataset = dataset.map(preprocess_function, batched=True)

In [None]:
# Instantiate the Trainer class
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset
)

In [None]:
# Start the training
trainer.train()

In [None]:
# Save the trained model
trainer.save_model('./trained_model')

print('Training complete. Model saved.')