# FM FT API: Data Validation and \$Token Estimation

#### Usage Scenario:
This notebook goes hand-in-hand with Databricks-Mosaicml's FT API. Our customers may find it useful in scenarios where there is a risk of data being malformed. It acts as a preventive measure to ensure data integrity and helps in cost assessment for the fine-tuning process.

#### Script Purpose:
- **Not for Training**: This script is not utilized during the training process.
- **Ad-Hoc Validation**: It serves as an ad-hoc utility for users to run independently prior to starting fine-tuning.
- **Data Verification**: Its primary function is to validate the user's data before they invoke the Fine-Tuning (FT) API.
- **Cost Estimation**: Users can estimate the cost implications with this script.

#### Note on Long-Term Solution:
- **Future Development**: We are in the process of developing a long-term data preparation service, which will eventually replace this script.

#### User Defines:
- The inputs to this validation script is assumed to be the same or a subset of the FT API arguments, i.e., a configuration like below. Is this a valid assumption?
- For the reference, FT API expects following
```
cfg = {
    model: str,
    train_data_path: str,
    save_folder: str,
    *,
    task_type: Optional[str] = "INSTRUCTION_FINETUNE",
    eval_data_path: Optional[str] = None,
    eval_prompts: Optional[List[str]] = None,
    custom_weights_path: Optional[str] = None,
    training_duration: Optional[str] = None,
    learning_rate: Optional[float] = None,
    context_length: Optional[int] = None,
    experiment_trackers: Optional[List[Dict]] = None,
    disable_credentials_check: Optional[bool] = None,
    timeout: Optional[float] = 10,
    future: Literal[False] = False,
}
``` 

# Installation

In [0]:
%pip uninstall -y llm-foundry

In [0]:
dbutils.library.restartPython()

In [0]:
# %pip install git+https://github.com/mosaicml/llm-foundry.git@byod/data_validation
%pip install --upgrade --no-deps git+https://github.com/XiaohanZhangCMU/llm-foundryX.git@validation 
%pip install "mosaicml>=0.17.2,<0.18"
%pip install "transformers>=4.36,<4.37"
%pip install "mosaicml-streaming>=0.7.2,<0.8"
%pip install -U datasets
%pip install omegaconf
%pip install einops
%pip install sentencepiece

In [0]:
dbutils.library.restartPython()

In [0]:
import os
import re
import json
import tempfile
import numpy as np
import pandas as pd 
from collections import defaultdict
from argparse import ArgumentParser, Namespace

import datasets 

from llmfoundry.utils import (create_om_cfg, token_counts_and_validation, token_counts, 
        check_HF_datasets, is_hf_dataset_path, is_uc_delta_table,
        pandas_processing_fn, integrity_check, convert_text_to_mds, parse_args, 
        _args_str, plot_hist, dataframe_to_mds)

import transformers
transformers.logging.set_verbosity_error()

# Instruction Fine Tuning

### Fine-Tuning API Arguments Configuration

This section of the notebook is dedicated to setting up the parameters for the validation notebook. These parameters should be identical to what you specify in Finetuning API. 

**Fine-Tuning API Arguments (FT_API_args):**

- model: Specifies the model to be used for fine-tuning. E.g., 'EleutherAI/gpt-neox-20b'
- train_data_path: The path to the training data. It can be either a huggingface dataset, a path to a jsonl file or a delta table.
- task_type: Defines the type of task for which the training strategy will be applied. It is either 'INSTRUCTION_FINETUNE' or 'CONTINUED_PRETRAIN'.
- training_duration: The duration of the training process, expressed in numerical terms (e.g., 3) with units of training epochs.
- context_length: Specifies the context length of the model, set to 2048. This determines how many tokens the model considers for each training example.

**Temporary Data Path Configuration:**

- temporary_jsonl_data_path: Defines a filesystem path where temporary data related to the training process will be stored.
- Environment variables for Hugging Face caches (HF_DATASETS_CACHE) are set to '/tmp/', directing dataset caching to a temporary directory.

In [0]:
FT_API_args = Namespace(
    model='EleutherAI/gpt-neox-20b',
    train_data_path= 'main.streaming.random_large_table', # 'tatsu-lab/alpaca/train', # '/Volumes/main/mosaic_hackathon/managed-volume/IFT/train.jsonl',  'tatsu-lab/alpaca/train',  # 'mosaicml/dolly_hhrlhf/train', # tatsu-lab/alpaca/train',
    task_type='INSTRUCTION_FINETUNE',
    training_duration=3,
    context_length=2048,
)

temporary_jsonl_data_path = '/Volumes/main/mosaic_hackathon/managed-volume/IFT/ft_data_11Jan24_3/train'
# os.environ['HF_ASSETS_CACHE'] = '/tmp/'
# os.environ['HF_HOME'] = '/tmp/'
# os.environ['HF_HUB_CACHE'] = '/tmp/'
os.environ['HF_DATASETS_CACHE'] = '/tmp/'
os.makedirs(temporary_jsonl_data_path, exist_ok=True)

#### Data Loading

The IFT data needs to stay with a format 
```
prompt: xxx
response or completion: yyy
```

Based on FT_API_args.train_data_path, we will select an ingestion method from one of the three options below:

- Option-1. data is a JSONL file which stores in an object store supported by Composer.
- Option-2. data is a Huggingface dataset ID. Note you need to provide a split as well. 
- Option-3. data is a delta table. 

In [0]:
raw_dataset = None

if is_hf_dataset_path(FT_API_args.train_data_path):
    check_HF_datasets(FT_API_args.train_data_path)
    dataset_id, split = '/'.join(FT_API_args.train_data_path.split('/')[:2]), FT_API_args.train_data_path.split('/')[-1]    
    raw_dataset = datasets.load_dataset(dataset_id, split=split)       
else:
    if is_uc_delta_table(FT_API_args.train_data_path):    
        df = spark.read.table(FT_API_args.train_data_path).toPandas()
        df.to_json(os.path.join(temporary_jsonl_data_path, 'data.jsonl'), orient='records', lines=True)
        raw_dataset = datasets.Dataset.from_pandas(df) 
        FT_API_args.train_data_path = temporary_jsonl_data_path
    else: 
        # train_data_path is a jonsl file (local/remote)
        from composer.utils import dist, get_file, parse_uri 
        data_path = FT_API_args.train_data_path 
        backend, _, _ = parse_uri(data_path)
        if backend not in ['', None]: # It's a remote path, download before loading it
            with tempfile.TemporaryDirectory() as tmp_dir:
                destination = os.path.join(tmp_dir, 'data.jsonl')
                get_file(data_path, destination)
                df = pd.read_json(destination, orient='records', lines=True)    
        else: 
            df = pd.read_json(data_path, orient='records', lines=True)    

        raw_dataset = datasets.Dataset.from_pandas(df)
        FT_API_args.train_data_path = os.path.dirname(data_path)

if raw_dataset is None: 
    raise RuntimeError("Can't find a proper ingestion method")

#### Data Quality Checks on the Dataset


This section of the notebook performs a series of checks on the initial dataset to ensure its quality and expected format. This process ensures that the dataset adheres to the expected structure and contains the necessary keys for further processing. The checks are outlined below.

1. The total number of examples in the dataset is printed.
2. The first example from the dataset is displayed. This provides a quick glimpse into the data structure and format.
3. Data Format Validation:
- The dataset is expected to consist of dictionary-like objects (key-value pairs). A check is performed to validate this structure.
Each example in the dataset is examined for its compliance with the expected format.
4. Key Presence Validation:
- Two sets of allowed keys are defined: _ALLOWED_RESPONSE_KEYS and _ALLOWED_PROMPT_KEYS.
- The script checks for the presence of at least one prompt key and one response key in each example.
  - Prompt Validation: Each example is checked for the presence of keys defined in _ALLOWED_PROMPT_KEYS. If no valid prompt key is found, it is counted as a format error. 
  - Response Validation: Similarly, each example is checked for the presence of keys defined in _ALLOWED_RESPONSE_KEYS. An absence of a valid response key is also counted as a format error.
Error Reporting:

If any format errors are found during the checks, they are reported.
A summary of errors is printed, categorizing them into types like data_type (non-dictionary data), missing_prompt, and missing_response.
If no errors are found, a congratulatory message is displayed, indicating that all checks have passed successfully.

In [0]:
# Initial dataset stats
print("Num examples:", len(raw_dataset))
print("First example:")
for ex in raw_dataset: 
    print(ex)
    print() 
    break 

_ALLOWED_RESPONSE_KEYS = {'response', 'completion'}
_ALLOWED_PROMPT_KEYS = {'prompt'}
format_errors = defaultdict(int)

for ex in raw_dataset:
    if not isinstance(ex, dict):
        format_errors["data_type"] += 1 
        continue 
    
    found = False 
    for key in _ALLOWED_PROMPT_KEYS:
        prompts = ex.get(key, None)
        if prompts:
            found = True 
    if not found: 
        format_errors["missing_prompt"] += 1

    found = False
    for key in _ALLOWED_RESPONSE_KEYS:        
        responses = ex.get("response", None)
        if responses: 
            found = True 
    if not found:
        format_errors["missing_response"] += 1
        
if format_errors:
    print("Oops! Found errors:")
    for k, v in format_errors.items():
        print(f"{k}: {v}")
else:
    print("Congratulations! No errors found")    

#### Token Estimation

Tokenize the raw dataset and let's some statistics of the tokens and estimate the overall cost based on default trainining duration

In [0]:
n_epochs = FT_API_args.training_duration if FT_API_args.training_duration is not None else 1 
batch_tokens = token_counts(FT_API_args)
n_billing_tokens_in_dataset = sum(batch_tokens['ntokens'])

Finetuning API will internally ingest the dataset and run tokenization with the selected tokenizer. 
The output dataset will be a collection of samples. Each sample is a collection of token ids represented as integers.  
We generate a histogram that visualizes the distribution of frequency of token counts in samples in the dataset. 
The visualization aids in identifying patterns, outliers, and central tendencies in the token distribution.

In [0]:
print(f"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be used by the model during training")
print(f"Assume you'll train for {n_epochs} epochs on this dataset")
print(f"Then ~{n_epochs * n_billing_tokens_in_dataset} tokens will be running through the model during training")
plot_hist(pd.Series(batch_tokens['ntokens']))

# Continued Pretrain

#### User Defines

In [0]:
FT_API_args = Namespace(
    model='EleutherAI/gpt-neox-20b',
    train_data_path= '/Volumes/main/mosaic_hackathon/managed-volume/ABT',
    task_type='CONTINUED_PRETRAIN',
    training_duration=3,
    context_length=2048,
)
temporary_mds_output_path = '/Volumes/main/mosaic_hackathon/managed-volume/mds_data_11Jan24_5'
# temporary_mds_output_path = '/tmp/CPT/mds_data_11Jan24_4'

In [0]:
!rm -rf {temporary_mds_output_path}

#### Ingestion, Tokenization and Materialization

CPT takes a folder of txt files as input. It tokenize the text fields and materialize as a streaming dataset of MDS format. 

FT API uses [llmfoundry/scripts/data_prep/convert_text_to_mds.py](https://github.com/mosaicml/llm-foundry/blob/main/scripts/data_prep/convert_text_to_mds.py) to download all the txt files and convert them to MDS. 

In this notebook, we provide two additional approaches via Spark and Dask. 

**Warning** CPT datasets are normally much larger than IFT, so the tokenization and materialization can be very time consuming. 

In [0]:
cfg, tokenizer = create_om_cfg(FT_API_args)

input_folder = FT_API_args.train_data_path
output_folder = temporary_mds_output_path
concat_tokens = FT_API_args.context_length
tokenizer_name = FT_API_args.model

# Run convert_text_to_mds.py and dump MDS dataset to "save_folder"
args = parse_args(tokenizer_name, concat_tokens, output_folder, input_folder)

args.processes = 2
args.reprocess = True

n_samples = convert_text_to_mds(tokenizer_name=args.tokenizer,
                                                  output_folder=args.output_folder,
                                                  input_folder=args.input_folder,
                                                  concat_tokens=args.concat_tokens,
                                                  eos_text=args.eos_text,
                                                  bos_text=args.bos_text,
                                                  no_wrap=args.no_wrap,
                                                  compression=args.compression,
                                                  processes=args.processes,
                                                  reprocess=args.reprocess,
                                                  args_str=_args_str(args))

n_billing_tokens_in_dataset = n_samples * concat_tokens

#### Token Estimation

In [0]:
MAX_TOKENS_PER_EXAMPLE = FT_API_args.context_length if FT_API_args.context_length is not None else 4096
TARGET_EPOCHS = FT_API_args.training_duration if FT_API_args.training_duration is not None else 1 
n_epochs = TARGET_EPOCHS

print(f"Dataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training")
print(f"By default, you'll train for {n_epochs} epochs on this dataset")
print(f"By default, ~{n_epochs * n_billing_tokens_in_dataset} tokens will be used in training")