# ✎ Datasets

## Overview

This tutorial demonstrates how to interact with pre-defined datasets in fairseq2.
We use the `gsm8k_sft` (generic instruction finetuning) dataset as an example.

> Make sure that you have followed the End to End Fine-Tuning and the basics assets tutorial.
> For example, you should have the following lines (change the path to the actual path on your machine) in your asset yaml file:

```yaml
name: gsm8k_sft
dataset_family: generic_instruction

name: gsm8k_sft@user
data: "/data/gsm8k_data/sft"
```

## Import all necessary modules

In [1]:
from fairseq2 import setup_fairseq2
from fairseq2.data.text import load_text_tokenizer
from fairseq2.datasets import Batching, LengthBatching, StaticBatching
from fairseq2.datasets.instruction import (
    GenericInstructionDataset,
    load_instruction_dataset,
)
from fairseq2.gang import FakeGang
from fairseq2.recipes.lm.instruction_finetune import _llama3_1_instruct
from fairseq2.recipes.utils.asset import retrieve_asset_card

## Initialization

We first need to initialize fairseq2 -- `setup_fairseq2()`.
This will load the configuration and register the assets, which allows us to interact with pre-defined datasets and models.

In [2]:
# Setup fairseq2
setup_fairseq2()

# Load the configuration
config = _llama3_1_instruct()

# pin the dataset to what we added in `src/fairseq2_ext/cards/datasets/gsm8k.yaml`
config.dataset = "gsm8k_sft"


## Prepare the assets

We will load both the dataset and the model card. The `retrieve_asset_card` function is used to load the asset card from the asset store.

In [3]:
# prepare the dataset
dataset_card = retrieve_asset_card(config.dataset)
for k, v in dataset_card.__dict__.items():
    print(f"{k}: {v}")

_name: gsm8k_sft
_metadata: {'dataset_family': 'generic_instruction', '__base_path__': PosixPath('/opt/hpcaas/.mounts/fs-0565f60d669b6a2d3/home/yaoj/projects/fair/fairseq2-ext/src/fairseq2_ext/cards/datasets'), '__source__': 'package:fairseq2_ext.cards', 'data': '/fsx-ram/shared/fair_conference_2024/gsm8k_data/sft', 'name': 'gsm8k_sft'}
_base_card: None
_base_path: /opt/hpcaas/.mounts/fs-0565f60d669b6a2d3/home/yaoj/projects/fair/fairseq2-ext/src/fairseq2_ext/cards/datasets


In [4]:
# prepare the model
model_card = retrieve_asset_card(config.model)
for k, v in model_card.__dict__.items():
    print(f"{k}: {v}")

_name: llama3_1_8b_instruct
_metadata: {'base': 'llama3_instruct', 'model_arch': 'llama3_1_8b', '__base_path__': PosixPath('/opt/hpcaas/.mounts/fs-0565f60d669b6a2d3/home/yaoj/projects/fair/fairseq2-ext/src/fairseq2_ext/cards/models'), '__source__': 'package:fairseq2_ext.cards', 'checkpoint': '/fsx-ram/shared/Meta-Llama-3.1-8B-Instruct/original/consolidated.00.pth', 'name': 'llama3_1_8b_instruct'}
_base_card: {'base': 'llama3', 'model_config': {'vocab_info': {'eos_idx': 128009}}, '__base_path__': PosixPath('/opt/hpcaas/.mounts/fs-0565f60d669b6a2d3/home/yaoj/projects/fair/fairseq2-ext/.venv/lib/python3.10/site-packages/fairseq2/assets/cards/models'), '__source__': 'package:fairseq2.assets.cards', 'name': 'llama3_instruct'}
_base_path: /opt/hpcaas/.mounts/fs-0565f60d669b6a2d3/home/yaoj/projects/fair/fairseq2-ext/src/fairseq2_ext/cards/models


Then we can load the actual dataset and model from the asset cards, by calling `load_instruction_dataset` and `load_text_tokenizer` respectively.

In [5]:
dataset = load_instruction_dataset(dataset_card)
print(dataset)
print(dataset.__dict__)

<fairseq2.datasets.instruction.GenericInstructionDataset object at 0x7f70a94619f0>
{'_splits': {'default': ([PosixPath('/opt/hpcaas/.mounts/fs-0e3f1457c6d924fc0/shared/fair_conference_2024/gsm8k_data/sft/train.jsonl')], [1.0])}}


In [6]:
# Load the tokenizer.
print(f"Loading {model_card.name} tokenizer.")
tokenizer = load_text_tokenizer(model_card)
print("Tokenizer loaded.")

Loading llama3_1_8b_instruct tokenizer.
Tokenizer loaded.


## Create Data Reader

To create a data reader, we need to prepare the gang and the batching options as well.
If you dig into the `create_reader` method, you will see that it implements the data pipeline that is covered in `notebooks/data/datapipeline.ipynb`.

In [7]:
# prepare the seed
seed = 42

# prepare the gang
gang = FakeGang(rank=2, size=5)

try:
    batching: Batching

    if config.batch_size is not None:
        batching = StaticBatching(config.batch_size)
    else:
        batching = LengthBatching(config.max_num_tokens)

    data_reader = dataset.create_reader(
        config.train_split,
        tokenizer,
        gang,
        config.max_seq_len,
        batching=batching,
        example_shuffle_window=config.example_shuffle_window,
        batch_shuffle_window=config.batch_shuffle_window,
        num_accumulate=config.gradient_accumulation,
        num_prefetch=config.num_prefetch,
        src_encode_mode=config.src_encode_mode,
        tgt_encode_mode=config.tgt_encode_mode,
        seed=seed,
    )
except ValueError as ex:
    raise ValueError(
        "The data reader cannot be initialized. See nested exception for details."
    ) from ex


## Iterate over the batches

Now that we have the data reader, we can iterate over the batches.

In [8]:
try:
    batches = next(data_reader)
except StopIteration:
    batches = None

if batches is not None:
    for batch_nr, batch in enumerate(batches):
        print(f"===batch_nr==={batch_nr}===")
        print(batch)
        print("")
else:
    print("No more batches")
    data_reader.reset()

===batch_nr===0===
SequenceBatch(seqs=tensor([[128000, 128006,    882,  ...,    220,  10132, 128009],
        [128000, 128006,    882,  ...,      0,      0,      0],
        [128000, 128006,    882,  ...,      0,      0,      0],
        ...,
        [128000, 128006,    882,  ...,      0,      0,      0],
        [128000, 128006,    882,  ...,      0,      0,      0],
        [128000, 128006,    882,  ...,      0,      0,      0]]), padding_mask=<fairseq2.nn.padding.PaddingMask object at 0x7f6f630faf20>, target_mask=tensor([[False, False, False,  ...,  True,  True,  True],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]]), example={'id': [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, Non