## Prepare Dataset

In [1]:
#|export
from random import randint
from itertools import chain
from functools import partial
from transformers import AutoTokenizer
from fastcore.script import call_parse

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#|export
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token

In [3]:
#|export
def format_dolly(sample):
    instruction = f"### Instruction\n{sample['instruction']}"
    context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
    response = f"### Answer\n{sample['response']}"
    # join all the parts together
    prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
    return prompt


def template_dataset(sample):
    sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}"
    return sample


def get_data():
    from datasets import load_dataset
    dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
    return dataset.map(template_dataset, remove_columns=list(dataset.features))


def chunk(sample, chunk_length=2048):
    # define global remainder variable to save remainder from batches to use in next batch
    global remainder
    # Concatenate all texts and add remainder from previous batch
    concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()}
    concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()}
    # get total number of tokens for batch
    batch_total_length = len(concatenated_examples[list(sample.keys())[0]])

    # get max number of chunks for batch
    if batch_total_length >= chunk_length:
        batch_chunk_length = (batch_total_length // chunk_length) * chunk_length

    # Split by chunks of max_len.
    result = {
        k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)]
        for k, t in concatenated_examples.items()
    }
    # add remainder to global variable for next batch
    remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()}
    # prepare labels
    result["labels"] = result["input_ids"].copy()
    return result

In [4]:
#|export
remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []}

def generate(seq_len, n=1024, dsname=None):
    dataset = get_data()
    lm_dataset = dataset.map(
        lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(dataset.features)
            ).map(
                partial(chunk, chunk_length=seq_len),
                batched=True).select(range(n))
    if dsname:
        lm_dataset.save_to_disk(dsname)
    return lm_dataset

In [5]:
lm_dataset = generate(2048, n=256)
assert len(lm_dataset[0]['input_ids']) == 2048
assert len(lm_dataset) == 256

In [6]:
#|export
data_configs = [
    {'seq_len': 64, 'n': 3000, 'dsname': 'data_64'},
    {'seq_len': 256, 'n': 1600, 'dsname': 'data_256'},
    {'seq_len': 512, 'n': 800, 'dsname': 'data_512'},
    {'seq_len': 1024, 'n': 400, 'dsname': 'data_1024'},
    {'seq_len': 2048, 'n': 200, 'dsname': 'data_2048'}
]

for d in data_configs:
    generate(**d)

Saving the dataset (1/1 shards): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3000/3000 [00:00<00:00, 457211.29 examples/s]
Saving the dataset (1/1 shards): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1600/1600 [00:00<00:00, 76589.93 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15011/15011 [00:01<00:00, 7568.56 examples/s]
Saving the dataset (1/1 shards): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 800/800 [00:00<00:00, 66829.52 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15011/15011 [00:01<00:00, 7648.20 examples/s]
Saving the data

# Prepare Config

In [7]:
for d in data_configs:
    print(d['seq_len'] * d['n'])

192000
409600
409600
409600
409600
