In [1]:
import pandas as pd
import torch
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, LEDTokenizer, LEDForConditionalGeneration, LongformerTokenizer, TrainerCallback, set_seed
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import evaluate

In [2]:
train_billsum = load_dataset("billsum", split="train")
test_billsum = load_dataset("billsum", split="test")

#checkpoint = "allenai/longformer-base-4096"
checkpoint = "allenai/led-base-16384"
# tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
tokenizer = LEDTokenizer.from_pretrained(checkpoint)
model = LEDForConditionalGeneration.from_pretrained(checkpoint)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding="longest", return_tensors="pt")

In [3]:
""" def compute_token_lengths(dataset):
    token_lengths = []
    for example in tqdm(dataset):
        tokenized_input = tokenizer(example["text"], return_tensors='pt')
        num_tokens = tokenized_input.input_ids.size(1)
        token_lengths.append(num_tokens)
    return token_lengths

train_token_lengths = compute_token_lengths(train_billsum)
test_token_lengths = compute_token_lengths(test_billsum)

plt.figure(figsize=(10, 6))
plt.hist(train_token_lengths, bins=50, alpha=0.5, label='Train', color='blue')
plt.hist(test_token_lengths, bins=50, alpha=0.5, label='Test', color='red')
plt.title('Distribution of Number of Tokens in BillSum Dataset')
plt.xlabel('Number of Tokens')
plt.ylabel('Frequency')
plt.legend()
plt.show() """

' def compute_token_lengths(dataset):\n    token_lengths = []\n    for example in tqdm(dataset):\n        tokenized_input = tokenizer(example["text"], return_tensors=\'pt\')\n        num_tokens = tokenized_input.input_ids.size(1)\n        token_lengths.append(num_tokens)\n    return token_lengths\n\ntrain_token_lengths = compute_token_lengths(train_billsum)\ntest_token_lengths = compute_token_lengths(test_billsum)\n\nplt.figure(figsize=(10, 6))\nplt.hist(train_token_lengths, bins=50, alpha=0.5, label=\'Train\', color=\'blue\')\nplt.hist(test_token_lengths, bins=50, alpha=0.5, label=\'Test\', color=\'red\')\nplt.title(\'Distribution of Number of Tokens in BillSum Dataset\')\nplt.xlabel(\'Number of Tokens\')\nplt.ylabel(\'Frequency\')\nplt.legend()\nplt.show() '

In [4]:
def tokenize_and_chunk(examples):
    max_length = 2048  # This is quite large; ensure your GPU can handle it
    summary_max_length = 1024  # Adjust as needed
    stride = 512

    input_ids = []
    attention_masks = []
    labels = []

    # Tokenize text
    for text in examples['text']:
        tokenized_text = tokenizer(
            text, 
            add_special_tokens=True, 
            return_overflowing_tokens=True, 
            max_length=max_length, 
            stride=stride, 
            return_tensors="pt", 
            truncation=True,
        )
        
        input_ids.extend(tokenized_text['input_ids'])
        attention_masks.extend(tokenized_text['attention_mask'])

    # Tokenize summaries
    for summary in examples['summary']:
        tokenized_summary = tokenizer(
            summary,
            max_length=summary_max_length, 
            truncation=True,
            return_tensors="pt"
        )
        labels.extend(tokenized_summary['input_ids'])

    return {'input_ids': input_ids, 'attention_mask': attention_masks, 'labels': labels}

# Apply tokenization to the dataset
tokenized_train = train_billsum.map(tokenize_and_chunk, batched=True, remove_columns=train_billsum.column_names)
tokenized_test = test_billsum.map(tokenize_and_chunk, batched=True, remove_columns=test_billsum.column_names)


In [5]:
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    decoded_labels = [label.replace(tokenizer.pad_token, '') for label in decoded_labels]
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels)
    result = {key: value.mid.fmeasure for key, value in result.items()}
    return result

In [6]:
""" def validate_token_ids(examples):
    vocab_size = model.config.vocab_size
    for ids in examples['input_ids']:
        if isinstance(ids, list):
            max_id = max(ids)
        else:
            max_id = ids.max()
        if max_id >= vocab_size:
            print("Out of range token IDs found:", max_id)
    return examples

tokenized_train = tokenized_train.map(validate_token_ids, batched=True)
tokenized_test = tokenized_test.map(validate_token_ids, batched=True)
 """

' def validate_token_ids(examples):\n    vocab_size = model.config.vocab_size\n    for ids in examples[\'input_ids\']:\n        if isinstance(ids, list):\n            max_id = max(ids)\n        else:\n            max_id = ids.max()\n        if max_id >= vocab_size:\n            print("Out of range token IDs found:", max_id)\n    return examples\n\ntokenized_train = tokenized_train.map(validate_token_ids, batched=True)\ntokenized_test = tokenized_test.map(validate_token_ids, batched=True)\n '

In [7]:
def check_labels(examples):
    labels = examples['labels']
    vocab_size = model.config.vocab_size  # Retrieve the vocabulary size from the model's configuration

    # Ensuring all label ids are integers and checking their range
    max_id = max([max(label) if isinstance(label, list) else label for label in labels])
    min_id = min([min(label) if isinstance(label, list) else label for label in labels])

    # Check if any label ID is out of the expected range
    if max_id >= vocab_size or min_id < -100:
        print(f"Invalid label IDs found: max ID = {max_id}, min ID = {min_id}")

    return examples

# Apply this function to validate datasets
tokenized_train = tokenized_train.map(check_labels, batched=True)
tokenized_test = tokenized_test.map(check_labels, batched=True)


Map:   0%|          | 0/18949 [00:00<?, ? examples/s]

Map:   0%|          | 0/3269 [00:00<?, ? examples/s]

In [None]:
class MemoryManagementCallback(TrainerCallback):
    def on_epoch_end(self, args, state, control, **kwargs):
        torch.cuda.empty_cache()
        print(f"End of epoch {state.epoch}. Cleared CUDA cache.")

In [8]:
model.gradient_checkpointing_enable()

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=16,
    per_device_eval_batch_size=4,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="epoch",
    report_to="tensorboard",
    fp16=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[MemoryManagementCallback()]
)

trainer.train()

  0%|          | 0/888 [00:00<?, ?it/s]

{'loss': 3.6215, 'grad_norm': 18.50367546081543, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.03}
{'loss': 3.3184, 'grad_norm': 5.147391319274902, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.07}
{'loss': 3.0102, 'grad_norm': 3.3919925689697266, 'learning_rate': 3e-06, 'epoch': 0.1}
{'loss': 2.7529, 'grad_norm': 2.199863910675049, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.14}
{'loss': 2.6062, 'grad_norm': 1.8474993705749512, 'learning_rate': 5e-06, 'epoch': 0.17}
{'loss': 2.4844, 'grad_norm': 1.6797367334365845, 'learning_rate': 6e-06, 'epoch': 0.2}
{'loss': 2.3787, 'grad_norm': 1.529343605041504, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.24}
{'loss': 2.2679, 'grad_norm': 1.7866604328155518, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.27}
{'loss': 2.2634, 'grad_norm': 1.5836156606674194, 'learning_rate': 9e-06, 'epoch': 0.3}


Input ids are automatically padded from 1959 to 2048 to be a multiple of `config.attention_window`: 1024


{'loss': 2.1898, 'grad_norm': 1.6300872564315796, 'learning_rate': 1e-05, 'epoch': 0.34}
{'loss': 2.125, 'grad_norm': 1.760856032371521, 'learning_rate': 1.1000000000000001e-05, 'epoch': 0.37}
{'loss': 2.0917, 'grad_norm': 1.437061071395874, 'learning_rate': 1.2e-05, 'epoch': 0.41}
{'loss': 2.0275, 'grad_norm': 1.4249582290649414, 'learning_rate': 1.3000000000000001e-05, 'epoch': 0.44}
{'loss': 2.0514, 'grad_norm': 1.4583123922348022, 'learning_rate': 1.4000000000000001e-05, 'epoch': 0.47}
{'loss': 2.0177, 'grad_norm': 1.3966107368469238, 'learning_rate': 1.5e-05, 'epoch': 0.51}
{'loss': 2.0003, 'grad_norm': 1.276999592781067, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.54}
{'loss': 1.955, 'grad_norm': 1.482304573059082, 'learning_rate': 1.7000000000000003e-05, 'epoch': 0.57}
{'loss': 1.962, 'grad_norm': 1.6317179203033447, 'learning_rate': 1.8e-05, 'epoch': 0.61}
{'loss': 1.9458, 'grad_norm': 1.3042287826538086, 'learning_rate': 1.9e-05, 'epoch': 0.64}


Input ids are automatically padded from 1668 to 2048 to be a multiple of `config.attention_window`: 1024


{'loss': 1.9288, 'grad_norm': 1.5091028213500977, 'learning_rate': 2e-05, 'epoch': 0.68}
{'loss': 1.898, 'grad_norm': 1.3693567514419556, 'learning_rate': 2.1e-05, 'epoch': 0.71}


Input ids are automatically padded from 2040 to 2048 to be a multiple of `config.attention_window`: 1024


{'loss': 1.9412, 'grad_norm': 1.4860608577728271, 'learning_rate': 2.2000000000000003e-05, 'epoch': 0.74}
{'loss': 1.9018, 'grad_norm': 1.3913681507110596, 'learning_rate': 2.3000000000000003e-05, 'epoch': 0.78}
{'loss': 1.9141, 'grad_norm': 1.4692658185958862, 'learning_rate': 2.4e-05, 'epoch': 0.81}


Input ids are automatically padded from 1912 to 2048 to be a multiple of `config.attention_window`: 1024


{'loss': 1.8433, 'grad_norm': 1.6365731954574585, 'learning_rate': 2.5e-05, 'epoch': 0.84}
{'loss': 1.8309, 'grad_norm': 1.4654594659805298, 'learning_rate': 2.6000000000000002e-05, 'epoch': 0.88}
{'loss': 1.8096, 'grad_norm': 1.2544760704040527, 'learning_rate': 2.7000000000000002e-05, 'epoch': 0.91}
{'loss': 1.8298, 'grad_norm': 1.3958159685134888, 'learning_rate': 2.8000000000000003e-05, 'epoch': 0.95}
{'loss': 1.8241, 'grad_norm': 1.4531382322311401, 'learning_rate': 2.9e-05, 'epoch': 0.98}
{'loss': 1.8092, 'grad_norm': 1.6742535829544067, 'learning_rate': 3e-05, 'epoch': 1.01}
{'loss': 1.8045, 'grad_norm': 1.3991260528564453, 'learning_rate': 3.1e-05, 'epoch': 1.05}
{'loss': 1.7342, 'grad_norm': 1.3938068151474, 'learning_rate': 3.2000000000000005e-05, 'epoch': 1.08}
{'loss': 1.7371, 'grad_norm': 1.3197624683380127, 'learning_rate': 3.3e-05, 'epoch': 1.11}
{'loss': 1.7934, 'grad_norm': 1.2046458721160889, 'learning_rate': 3.4000000000000007e-05, 'epoch': 1.15}
{'loss': 1.7713, 'gr

Input ids are automatically padded from 1934 to 2048 to be a multiple of `config.attention_window`: 1024


{'loss': 1.7373, 'grad_norm': 1.281718373298645, 'learning_rate': 4.1e-05, 'epoch': 1.38}
{'loss': 1.7248, 'grad_norm': 1.2748881578445435, 'learning_rate': 4.2e-05, 'epoch': 1.42}
{'loss': 1.6676, 'grad_norm': 1.3553487062454224, 'learning_rate': 4.3e-05, 'epoch': 1.45}
{'loss': 1.6658, 'grad_norm': 1.3881195783615112, 'learning_rate': 4.4000000000000006e-05, 'epoch': 1.49}
{'loss': 1.6925, 'grad_norm': 1.3141294717788696, 'learning_rate': 4.5e-05, 'epoch': 1.52}
{'loss': 1.681, 'grad_norm': 1.2988027334213257, 'learning_rate': 4.600000000000001e-05, 'epoch': 1.55}
{'loss': 1.6723, 'grad_norm': 1.302764892578125, 'learning_rate': 4.7e-05, 'epoch': 1.59}
{'loss': 1.7003, 'grad_norm': 1.1893130540847778, 'learning_rate': 4.8e-05, 'epoch': 1.62}
{'loss': 1.673, 'grad_norm': 1.3241840600967407, 'learning_rate': 4.9e-05, 'epoch': 1.65}
{'loss': 1.6604, 'grad_norm': 1.4069883823394775, 'learning_rate': 5e-05, 'epoch': 1.69}


  0%|          | 0/409 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 18.01 GiB. GPU 0 has a total capacity of 15.99 GiB of which 0 bytes is free. Of the allocated memory 19.95 GiB is allocated by PyTorch, and 8.15 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)