# 📚 Data

This notebook contains code for the data in this experiment suite.

## Setup 

In [None]:
import autorootcwd

In [None]:
from typing import List, Dict, Any

import torch
import pandas as pd
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from datasets import disable_progress_bar
from src.utils import get_dataset, get_tokenizer, non_empty_text, collate_fn

In [None]:
disable_progress_bar()

## Load dataset

For now, we will use a tiny dataset `Salesforce/wikitext/wikitext-2-raw-v1`, which consists of 37K training, 3.7K validation and 4.3K test examples.

In [None]:
# Load WikiText 2
wiki = get_dataset( "Salesforce/wikitext",  "wikitext-2-raw-v1")
train_wiki, val_wiki, test_wiki = wiki["train"], wiki["validation"], wiki["test"]

print(f"Loaded {len(train_wiki)/1e3:.1f}K training, {len(val_wiki)/1e3:.1f}K validation and {len(test_wiki)/1e3:.1f}K test examples.")

In [None]:
# Examples
for example in train_wiki.take(5):
    print(example)

A single example just has a `text` field, which contains a single line of text. They are parsed from high quality Wikipedia articles. We can already see that there are loads of empty lines and other artiffacts like headlines.

## Preprocess dataset

We are going to remove empty lines, headlines and finally tokenize the dataset.

In [None]:
def non_empty_text(examples: Dict[str, Any]) -> bool:
    return examples["text"] != ""

def non_headline(examples: Dict[str, Any]) -> bool:
    return not examples["text"].startswith(" = ")

def tokenize(examples: Dict[str, Any], tokenizer: AutoTokenizer, max_length: int) -> Dict[str, Any]:
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length+1)

In [None]:
tokenizer = get_tokenizer("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token
batch_size = 32
fn_kwargs = {"tokenizer": tokenizer, "max_length": 128}

In [None]:
train_wiki_processed = train_wiki.filter(non_empty_text).filter(non_headline).map(tokenize, batched=True, batch_size=32, fn_kwargs=fn_kwargs).remove_columns(["text"])
val_wiki_processed = val_wiki.filter(non_empty_text).filter(non_headline).map(tokenize, batched=True, batch_size=32, fn_kwargs=fn_kwargs).remove_columns(["text"])
test_wiki_processed = test_wiki.filter(non_empty_text).filter(non_headline).map(tokenize, batched=True, batch_size=32, fn_kwargs=fn_kwargs).remove_columns(["text"])

print(f"Processed {len(train_wiki_processed)/1e3:.1f}K training, {len(val_wiki_processed)/1e3:.1f}K validation and {len(test_wiki_processed)/1e3:.1f}K test examples.")

In [None]:
for example in train_wiki_processed.take(5):
    print(example)

## Data Loader




In [None]:
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
    return {
        'input_ids': torch.stack([torch.tensor(example['input_ids'][:-1]) for example in batch]),
        'attention_mask': torch.stack([torch.tensor(example['attention_mask'][:-1]) for example in batch]),
        'labels': torch.stack([torch.tensor(example['input_ids'][1:]) for example in batch]),
    }

In [None]:
wiki_train_loader = DataLoader(train_wiki_processed, batch_size=batch_size, collate_fn=collate_fn)
wiki_val_loader = DataLoader(val_wiki_processed, batch_size=batch_size, collate_fn=collate_fn)
wiki_test_loader = DataLoader(test_wiki_processed, batch_size=batch_size, collate_fn=collate_fn)

print(f"Loaded {len(wiki_train_loader)} training, {len(wiki_val_loader)} validation and {len(wiki_test_loader)} test batches.")

In [None]:
batch = next(iter(wiki_train_loader))
for key, value in batch.items():
    print(key, value.shape)