# Fine Tuning LLM with Huggingface and WebDataset

This notebook illustrates the use of WebDataset together with Huggingface for fine-tuning large language models.

Some features of note:

- training data is loaded directly from Huggingface
- data is downloaded and stored locally incrementally as needed
- a custom sampler is used in order to make remote data access more efficient

In [None]:
# parameters
base_model = "google/flan-t5-base"
dataset_url = (
    "https://huggingface.co/tmbdev/d-tokens/resolve/main/d-tokens.json?download=true"
)
cache_dir = "./_cache"
batch_size = 1
max_steps = 10000
epochs = 1
learning_rate = 3e-4

In [None]:
# imports
import string
import random
import numpy as np
import regex
import unicodedata
import logging

import torch.utils.data
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import AutoModelForSeq2SeqLM
from transformers.adapters import LoRAConfig
from transformers import TrainingArguments, AdapterTrainer, TrainerCallback

# workaround for running this in the source tree, you usually don't need this
try:
    import wids
except:
    sys.path += [".."]
    import wids

In [None]:
def normalize_string(s):
    """Take a string and normalize it.

    Normalization removes common typographic variants of punctuation characters
    that would otherwise have to be learned explicitly by the model. It also
    simplifies whitespace and removes long strings of punctuations used for
    graphical effect.
    """
    # start with Unicode normalization
    s = unicodedata.normalize("NFKC", s)
    s = regex.sub(r"[*@`]", "", s)
    s = regex.sub(r"[\u0027\u2019\u2018\u201A\u201B]", "'", s)
    # replace all single quotes with '
    s = regex.sub(r"[\u0022\u201C\u201D\u201E\u201F]", '"', s)
    # replace all double quotes with "
    s = regex.sub(r"[\u2013\u2014\u2012\u2015-]", "-", s)  # normalize dashes
    s = regex.sub(r"(\p{P})\p{P}+", r"\1", s)  # remove duplicate punctuation
    s = regex.sub(r"[^\p{L}\p{N}\p{Z}().,?!:;'\"\n-]+", " ", s)
    s = regex.sub(r"[ \t]+", " ", s)
    s = s.strip()
    return s

In [None]:
# Data augmentation. Actually, in this case, we generate a synthetic training sample from
# a clean input string.

replacements = list(
    set(string.ascii_letters + string.digits + " " + "" + string.punctuation)
    - set(["*"])
)


def degrade(s, prange=(0.05, 0.1), seed=None, special="*"):
    """Generate training samples by degrading a string.

    Our model is a sequence-to-sequence model that identifies the location of OCR errors in
    a text string. It is trained on a synthetic dataset that contains pairs of strings, one
    of which is a degraded string, and the other is the degraded string with errors marked
    by asterisks. The model is trained to predict the location of the asterisks.
    """
    seed = random.randint(0, 1000000) if seed is None else seed
    rng = random.Random(seed)
    s = normalize_string(s)
    if len(s) < 2:
        return s, s
    for _ in range(100):
        if rng.random() < 0.5:
            # use regex to delete the first k words, where k is random between 1 and 2
            # we do this because otherwise the model will flag lower case letters at the beginning
            # of a string as errors
            k = rng.randint(1, 4)
            expr = r"^([^\p{Z}]+?\p{Z}+){%d}" % k
            s = regex.sub(expr, "", s, count=1)
        if len(s) > 1:
            break
    result = ""
    target = ""
    p = rng.uniform(*prange)
    for c in s:
        if c == special:
            continue
        if c != "\n" and rng.random() < p:
            r = rng.choice(replacements)
            result += r
            target += special
        else:
            result += c
            target += c
    result = normalize_string(result)
    return result, target


degrade("Hello, world's biggest ball-of-yarn!")

In [None]:
# We use Flan T5 as the base model. Other models might work better.

tokenizer = AutoTokenizer.from_pretrained(base_model)

In [None]:
# This is a helper function that takes a sample, unpacks it, applies the degradation,
# and then returns a dictionary with the input_ids and the labels as required by Huggingface.


def make_sample(sample, *, prange=(0.05, 0.1), seed=None, prefix="ocr-errors: "):
    """Given a sample consisting of a clean text string, generate a training sample.

    Args:
        sample: a sample from the webdataset
        prange: range of error probability
        seed: random seed or None for random seed
        prefix: prefix (prompt) to add to the input stringf
    """
    clean = sample[".txt.gz"]
    clean = normalize_string(clean)
    text, target = degrade(clean, prange=prange, seed=seed)
    text_ids = torch.tensor(
        tokenizer.encode(prefix + text, max_length=512, truncation=True, padding="max_length")
    )
    target_ids = torch.tensor(tokenizer.encode(target, max_length=512, truncation=True, padding="max_length"))
    return dict(input_ids=text_ids, labels=target_ids)

In [None]:
# This is really all that is WebDataset specific:
# - we specify a URL for the JSON index file
# - we specify a local cache directory
# - we instantiate a ShardListDataset with keep=True
# - we add the make_sample transform to the dataset
# - we create a custom sampler that respects shard boundaries



dataset = wids.ShardListDataset(
    dataset_url, cache_dir=cache_dir, cache_size=int(1e10), keep=True
)
dataset.add_transform(make_sample)
dataset[999]

sampler = wids.ShardedSampler(dataset)

In [None]:
# This plot illustrates the behavior of the shard sampler: it generates a sequence
# of samples from each shard in turn, and then moves on to the next shard.

import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline
plt.subplot(121)
plt.plot(list(sampler)[:10000])
plt.subplot(122)
plt.plot(list(sampler)[:500]);

In [None]:
# Standard Hugginface LoRA setup.

# start with the pretrained base model
model = AutoModelForSeq2SeqLM.from_pretrained(base_model)

# set the parameters for LoRA
config = LoRAConfig(
    r=8,
    alpha=16,
    # use it on all of the layers
    intermediate_lora=True,
    output_lora=True,
)

# make a new adapter for the xerr dataset
model.add_adapter("xerr", config=config)
# enable the adapter for training
model.train_adapter("xerr")
model.set_active_adapters(["xerr"])

In [None]:
# Standard Huggingface adapter training, except for the custom sampler.

training_args = TrainingArguments(
    learning_rate=learning_rate,
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    logging_steps=2000,
    save_steps=5000,
    output_dir="./training_output",
    overwrite_output_dir=True,
    remove_unused_columns=False,
    max_steps=max_steps,
)

# create the trainer
trainer = AdapterTrainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=dataset,
    # eval_dataset=OCRDataset("test", maxsize=100),
)

# to set the sampler, we override the get_train_sampler method
# Huggingface doesn't provide a better way to do this

trainer._get_train_sampler = lambda: sampler

In [None]:
# Run the bulk of the training.

trainer.train()

In [None]:
# Show some examples (this isn't really "validation").

num_validation = 10
validation_dataset = dataset

logging.getLogger("transformers").setLevel(logging.ERROR)

for i in range(num_validation):
    # load the input and label (note: we get a different degradation each time)
    sample = validation_dataset[i]
    # convert the input and label to tensors
    input_ids = sample["input_ids"].unsqueeze(0).to(0)
    label_ids = sample["labels"].unsqueeze(0).to(0)
    # use the model to generate the output
    output = model.generate(input_ids, max_length=1024)
    # convert the tokens to text
    input_text = (
        tokenizer.decode(input_ids[0], skip_special_tokens=True)
        .replace("ocr-errors:", "")
        .strip()
    )
    output_text = tokenizer.decode(output[0], skip_special_tokens=True).strip()
    label_text = tokenizer.decode(label_ids[0], skip_special_tokens=True).strip()

    print(f"[{i}]")
    print("Input: ", input_text)
    print("Output:", output_text)
    print("Label: ", label_text)
    print("---")