In [2]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

# checkpoint = 'facebook/bart-large-cnn'
checkpoint = 'facebook/bart-base'

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device)

  from .autonotebook import tqdm as notebook_tqdm


# Prepare Dataset

In [3]:
from datasets import load_dataset

# dataset = load_dataset('samsum', trust_remote_code=True)
dataset = load_dataset("cnn_dailymail", '3.0.0', trust_remote_code=True)
dataset

DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 287113
    })
    validation: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 13368
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 11490
    })
})

In [4]:
from tqdm import tqdm
import matplotlib.pyplot as plt

# input_ = 'article'
# output_ = 'highlights'

# input_token_len = []
# output_token_len = []

# for i in tqdm(range(dataset['train'].num_rows), desc="Processing rows"):
#     d = dataset['train'][input_][i]
#     s = dataset['train'][output_][i]
#     input_token_len.append(len(tokenizer.encode(d)))
#     output_token_len.append(len(tokenizer.encode(s)))


# fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)  # Use True for sharey
# axes[0].hist(dialogue_token_len, bins=20, color='C0', edgecolor='black')  # Adjusted edge color for visibility
# axes[0].set_title("Dialogue Token Length")
# axes[0].set_xlabel("Length")
# axes[0].set_ylabel("Count")

# axes[1].hist(summary_token_len, bins=20, color='C0', edgecolor='black')  # Adjusted edge color for visibility
# axes[1].set_title("Summary Token Length")
# axes[1].set_xlabel("Length")

# plt.tight_layout()
# plt.show()


In [5]:
input_ = 'article'
output_ = 'highlights'

def tokenize_inputs(batch):
    start_prompt = "summary: "
    end_prompt = "</s>"
    prompt = [start_prompt + text + end_prompt for text in batch[input_]]

    # Tokenize inputs
    inputs = tokenizer(
        prompt,
        padding="max_length",     # Ensures consistent length
        truncation=True,          # Truncate to max input length
        max_length=512,           # Set maximum input length if needed
    )
    labels = tokenizer(
        batch[output_],
        padding="max_length",     # Ensure consistent length for labels
        truncation=True,          # Truncate summaries as well
        max_length=128,           # Maximum length for summaries
    )

    return {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask'],
        'labels': labels['input_ids'],
    }


tokenized_datasets = dataset.map(tokenize_inputs, batched=True, remove_columns=dataset['train'].column_names)

Map: 100%|███████████████████████████████████████████████████████████████| 13368/13368 [00:15<00:00, 844.57 examples/s]


In [6]:
print(tokenized_datasets['train'].shape)
print(tokenized_datasets['validation'].shape)
print(tokenized_datasets['test'].shape)

(287113, 3)
(13368, 3)
(11490, 3)


In [7]:
tokenized_datasets['train'][0].keys()

dict_keys(['input_ids', 'attention_mask', 'labels'])

In [8]:
from transformers import DataCollatorForSeq2Seq

seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [9]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    'demo-bart-summary',
    warmup_steps=500,
    learning_rate=1e-5,
    num_train_epochs=1,
    weight_decay=0.01,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    logging_steps=10,
    report_to="none",
    fp16=True,  # Enable mixed precision for faster training on GPU
    per_device_train_batch_size=8,  # Adjust batch size based on GPU memory
    per_device_eval_batch_size=8,   # Same for evaluation
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    data_collator=seq2seq_data_collator,
)

  trainer = Trainer(


In [10]:
trainer.train()

Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 