In this notebook, all the necessary offline steps are executed for MobileBERT for a masked language modeling task. The necessary training artifacts are generated, as well as the processed data.

## Library Imports

In [None]:
import torch
import transformers
import onnx
import onnxruntime.training.onnxblock as onnxblock
from datasets import load_dataset
import json
import random
import re
from transformers import MobileBertConfig

## Generating artifacts

This section generates the necessary training artifacts: the training version of the ONNX model, the evaluation version of the ONNX model, and the optimizer.

These are exported as ONNX files and later imported in the C# app using the C# ONNX Runtime Training API.

In order to generate the training ONNX model, a loss node needs to be added onto the model. But MobileBERT for Masked LM will also calculate the losses if labels are provided, so the FlatModel is a work-around to add labels to the input in the forward pass of the model. This is then exported, and ORT generate_artifacts is used to generate the training artifacts.

The ORT generate_artifacts method must be passed a model with a loss node & the original torch model is referenced to determine what model parameters should be frozen + which model parameters should have the requires_grad option toggled. 

In [None]:
config = MobileBertConfig(num_hidden_layers=2)
model = transformers.MobileBertForMaskedLM.from_pretrained('google/mobilebert-uncased', config=config)
# model = transformers.AutoModel.from_pretrained('google/mobilebert-uncased')
model_name = 'mobilebert-uncased'

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained("google/mobilebert-uncased")
inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt")
labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)

In [None]:
class FlatModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, *local_inputs):
        return self.model(inputs.input_ids, inputs.attention_mask, inputs.token_type_ids, labels=labels)

model = FlatModel(model)

In [None]:
torch.onnx.export(
    model,
    (inputs["input_ids"], 
      inputs["attention_mask"],
      inputs["token_type_ids"],
      labels),
    f"training_artifacts/{model_name}.onnx",
    input_names=["input_ids", "attention_mask", "token_type_ids", "labels"],
    output_names=["loss", "logits"],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "attention_mask": {0: "batch_size", 1: "sequence_length"},
        "token_type_ids": {0: "batch_size", 1: "sequence_length"},
        "labels": {0: "batch_size", 1: "sequence_length"},
        "logits ": {0: "batch_size", 1: "sequence_length"}
    },
    export_params=True,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
)

In [None]:
from onnxruntime.training import artifacts
import onnx

requires_grad = []
frozen_params = []
for name, param in model.named_parameters():
    if param.requires_grad:
        requires_grad.append(name)
    else:
        frozen_params.append(name)

for name, param in model.named_buffers():
    frozen_params.append(name)

model = onnx.load(f"training_artifacts/{model_name}.onnx")


artifacts.generate_artifacts(
    model,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    optimizer=artifacts.OptimType.AdamW,
    artifact_directory="training_artifacts/"
)


## Generating tokens

This section tokenizes the dataset and then writes it into JSON files.

Since a C# tokenizer for this task doesn't exist yet in ORT extensions, this step is done offline, and the data is then loaded into the C# program.

Since the model requires a masked input + a "label" (a corresponding unmasked input), these are artificially generated by randomly masking a word in a sequence and keeping the original sequence as the label. 

A JSON format was chosen because it can easily be parsed with C# libraries, but you may use any data format.

In [None]:
def tokenize_function(examples, pad_to_len):
    """
    Takes in a Dataset with a "text" feature, as well as an int for what to pad the sequences to.

    The sequences are both padded and truncated so they are all the same length.

    Outputs a Dataset with the following features: text, input_ids, token_type_ids, attention_mask, labels
    """
    tokenizer = transformers.AutoTokenizer.from_pretrained("google/mobilebert-uncased")
    # filter out empty strings to remove unnecessary processing
    examples["text"] = [sent for sent in examples["text"] if len(sent) > 0]
    labels = tokenizer(examples["text"], padding="max_length", max_length=pad_to_len, truncation=True, return_tensors="pt")
    masked_examples = [mask(sent, pad_to_len) for sent in examples["text"]]
    inputs = tokenizer(masked_examples, padding="max_length", max_length=pad_to_len, truncation=True, return_tensors="pt")
    labels = torch.where(inputs["input_ids"] == tokenizer.mask_token_id, labels["input_ids"], -100)
    inputs["labels"] = labels
    return inputs

def mask(sent, pad_to_len):
    """ 
    Randomly replaces a word in the sentence with "[MASK]", ignoring punctuation
    """
    sent_words = sent.split()
    mask_index = random.randint(0, min(len(sent_words), pad_to_len) - 1)
    # replace random index with mask word, leaving punctuation as is
    # ... this preprocessing means that the token masked might be the <unk> word
    masked_words = [sent_words[ind] if ind != mask_index else re.sub("[a-zA-Z']+", "[MASK]", sent_words[ind]) for ind in range(len(sent_words))]
    return ' '.join(masked_words)

def generate_tokens(corpus):
    """
    Takes in a Dataset with a "text" feature.

    Returns a Dataset with the following features: text, input_ids, token_type_ids, attention_mask, special_tokens_mask
    """
    # pad_to_len must be calculated before the batching happens to create consistent sizes in the resulting tensor
    # pad_to_len = max([len(sent) for sent in corpus["text"]])
    pad_to_len = 80 # shortened for demonstration purposes
    return corpus.map(tokenize_function, batched=True, fn_kwargs={"pad_to_len": pad_to_len})

def generate_json_dict(token_dataset):
    """
    Takes in a Dataset with the following features: text, input_ids, token_type_ids, attention_mask, labels

    Basically changes the 2d Python lists into two fields: a shape & a flattened list, for easier conversion to OnnxValues

    Returns a dictionary with the following keys: input_ids, input_size, token_type_ids, token_type_size, attention_mask, attention_mask_size, special_tokens_mask, special_tokens_size
    """
    json_dict = {}
    keys_to_convert = ["input_ids", "token_type_ids", "attention_mask", "labels"]

    for key_name in keys_to_convert:
        # add field for the shape of the tensor
        json_dict[key_name + "_shape"] = [len(token_dataset[key_name]), len(token_dataset[key_name][0])]
        # flatten list
        json_dict[key_name] = [num for sent in token_dataset[key_name] for num in sent]
    
    return json_dict


In [None]:
dataset_name = "wikitext" 
dataset_config = "wikitext-2-v1"
# corpus = type DatasetDict with three Datasets: test, train, validation
corpus = load_dataset(dataset_name, dataset_config)

In [None]:
test_tokens_dataset = generate_tokens(corpus["test"])
test_tokens = generate_json_dict(test_tokens_dataset)
train_tokens_dataset = generate_tokens(corpus["train"])
train_tokens = generate_json_dict(train_tokens_dataset)
validation_tokens_dataset = generate_tokens(corpus["validation"])
validation_tokens = generate_json_dict(validation_tokens_dataset)

In [None]:
# write all the tokens to a json file
file_names = ["test_tokens.json", "train_tokens.json", "validation_tokens.json"]
token_dicts = [test_tokens, train_tokens, validation_tokens]

def write_dicts_to_files(file_names, dicts):
    # assumes file_names and dicts are 2 lists w/ the same lengths
    for i in range(len(file_names)):
        with open(file_names[i], "w") as json_file:
            json.dump(dicts[i], json_file)

write_dicts_to_files(file_names, token_dicts)