In [75]:
# Imports
import tiktoken

from datasets import load_dataset # huggingface datasets
from transformers import BartTokenizer
from transformers import pipeline

tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")

Device set to use mps:0


In [76]:
# Import dataset
split_dataset = load_dataset("abisee/cnn_dailymail", "3.0.0")
split_dataset.pop("test", None)
split_dataset['val'] = split_dataset.pop('validation', None)

In [77]:
# Summarize summaries to make them even smaller to fit into context window
def modify_summary(example):
    long_summary = example["highlights"]
    tokens = tokenizer.tokenize(long_summary)

    if len(tokens) >= 512:
        compressed_summary = summarizer(long_summary, max_length=512, min_length=256, do_sample=False)
        example["highlights"] = compressed_summary[0]["summary_text"]

    return example

# Apply the function to modify the dataset
split_dataset['train'] = split_dataset['train'].map(modify_summary)
split_dataset["val"] = split_dataset["val"].map(modify_summary)

Map:   0%|          | 0/287113 [00:00<?, ? examples/s]

Map:   0%|          | 0/13368 [00:00<?, ? examples/s]

In [78]:
# Encoder with summary tokens
enc = tiktoken.get_encoding("gpt2")
sum_tokens = {"<|sum|>" : 50257, "<|core|>" : 50258}
custom_enc = tiktoken.Encoding(name="gpt2_custom", 
                               pat_str=enc._pat_str, 
                               mergeable_ranks=enc._mergeable_ranks, 
                               special_tokens={**enc._special_tokens, **sum_tokens})
sum_token = sum_tokens["<|sum|>"]
core_token = sum_tokens["<|core|>"]
pad_token = custom_enc.eot_token
eot_token = custom_enc.eot_token

In [79]:
max_len_in_train, argmax_in_train = max(
    (len(custom_enc.encode(item['highlights'])), idx) for idx, item in enumerate(split_dataset['train'])
)
max_len_in_val, argmax_in_val = max(
    (len(custom_enc.encode(item['highlights'])), idx) for idx, item in enumerate(split_dataset['val'])
)
print(f"Maximum length of encoded summary in train: {max_len_in_train}") 
print(f"Maximum length of encoded summary in val: {max_len_in_val}")

Maximum length of encoded summary in train: 504
Maximum length of encoded summary in val: 498


In [80]:
# Encoding function (gpt2 bpe)
def process(example, block_size=1024):
    # Tokenize summary and article
    summary_tokens = enc.encode_ordinary(example["highlights"])
    article_tokens = enc.encode_ordinary(example["article"])
    
    # Split into blocks with summary first
    summary_target_len = block_size // 2 - 1  # Reserve 1 token for <|sum|>
    article_target_len = block_size // 2 - 2  # Reserve 2 tokens for <|core|> and <|endoftext|>

    processed_summary = (
        [sum_token] + 
        summary_tokens[:summary_target_len] + 
        [pad_token] * (summary_target_len - len(summary_tokens[:summary_target_len]))
    )
    
    # Split article into chunks (reserving 2 tokens for <|core|> and <|endoftext|>)
    article_chunks = [
        [core_token] + article_tokens[i:i+article_target_len] + [eot_token]
        for i in range(0, len(article_tokens), article_target_len)
    ]
    
    # Pad each article chunk to half-block size
    processed_blocks = []
    for chunk in article_chunks:
        chunk = chunk[:article_target_len + 2]  # Truncate if too long
        if len(chunk) < article_target_len + 2:
            chunk += [pad_token] * (article_target_len + 2 - len(chunk))
        # Combine summary and article chunk
        block = processed_summary + chunk
        processed_blocks.append({"ids": block, "len": len(block)})
    
    return {"blocks": processed_blocks}

In [89]:
# Testing custom decoding
custom_enc.decode(process(split_dataset['train'][0])['blocks'][0]['ids'])

'<|sum|>Harry Potter star Daniel Radcliffe gets £20M fortune as he turns 18 Monday .\nYoung actor says he has no plans to fritter his cash away .\nRadcliffe\'s earnings from first five Potter films have been held in trust fund .<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|end

In [None]:
# Tokenizing dataset
tokenized = split_dataset.map(
    process,
    remove_columns=['article', 'highlights', 'id'],
    desc="tokenizing the splits",
    num_proc = 4,
)

In [91]:
tokenized

DatasetDict({
    train: Dataset({
        features: ['blocks'],
        num_rows: 287113
    })
    val: Dataset({
        features: ['blocks'],
        num_rows: 13368
    })
})

In [93]:
def flatten_blocks(batch):
    # Collect all 'ids' and 'len' from nested blocks
    all_ids = []
    all_lens = []
    for example_blocks in batch["blocks"]:
        for block in example_blocks:
            all_ids.append(block["ids"])
            all_lens.append(block["len"])
    return {"ids": all_ids, "len": all_lens}

# Apply to the dataset
flattened_dataset = tokenized.map(
    flatten_blocks,
    batched=True,
    remove_columns=["blocks"],  
    desc="Flattening blocks",
    num_proc=4,  
)

In [None]:
flattened_dataset, flattened_dataset['val'][0]

In [100]:
flattened_dataset.items()

dict_items([('train', Dataset({
    features: ['ids', 'len'],
    num_rows: 630393
})), ('val', Dataset({
    features: ['ids', 'len'],
    num_rows: 28680
}))])

In [101]:
flattened_dataset.save_to_disk("../datasets/test.hf")

Saving the dataset (0/11 shards):   0%|          | 0/630393 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/28680 [00:00<?, ? examples/s]