First of all, make sure your environment has installed the latest version of [🤗 Optimum Graphcore](https://github.com/huggingface/optimum-graphcore).

In [None]:
%pip install "optimum-graphcore>=0.5, <0.6"

Let's print out the versions of Transformers and Optimum Graphcore:

In [None]:
import transformers
import optimum.graphcore

print(transformers.__version__)
print(optimum.graphcore.__version__)

Values for machine size and cache directories can be configured through environment variables or directly in the notebook:

In [None]:
import os

pod_type = os.getenv("GRAPHCORE_POD_TYPE", "pod4")
executable_cache_dir = os.getenv("POPLAR_EXECUTABLE_CACHE_DIR", "/tmp/exe_cache/") + "/external_model"

# Train an external language model

In this notebook, we'll see how to train a model that is not supported by Optimum Graphcore and not even in [🤗 Transformers](https://github.com/huggingface/transformers) on a language modeling task.

We will see how to easily load and preprocess the dataset for each one of those tasks, and how to use the `IPUTrainer` API to train a model on it.

This notebook assumes you have trained a tokenizer on the corpus you are using, see the [How to train a tokenizer](https://github.com/huggingface/notebooks/blob/master/examples/tokenizer_training.ipynb) notebook.

## Preparing the dataset

For each of those tasks, we will use the [Wikitext 2]() dataset as an example. You can load it very easily with the 🤗 Datasets library.

In [None]:
from datasets import load_dataset
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')

## Causal Language modeling

To tokenize all our texts with the same vocabulary that was used when training the model, we could download a pretrained tokenizer. Though we plan to define our own model, here we borrow GPT2's tokenizer. This is all done by the `AutoTokenizer` class:

In [None]:
from transformers import AutoTokenizer
    
tokenizer_checkpoint = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint)

We can now call the tokenizer on all our texts. This is very simple, using the [`map`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map) method from the Datasets library. First we define a function that calls the tokenizer on our texts:

In [None]:
def tokenize_function(examples):
    return tokenizer(examples["text"])

Then we apply it to all the splits in our `datasets` object.

In [None]:
tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])

Then we grab the maximum length our model was pretrained with.

In [None]:
block_size = 128

Then we write the preprocessing function that will group our texts:

In [None]:
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

Again we apply it to all the splits in our `datasets` object.

In [None]:
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
)

Let's define a customized model, which is just a simple implementation GPT2. Note that there is nothing IPU-specific or 🤗 Transformers-related in this model.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class TransformerModel(nn.Module):

    def __init__(self, block_size, vocab_size, d_model, nhead, dim_feedforward, nlayers, dropout=0.1, embd_pdrop=0.1):
        super(TransformerModel, self).__init__()
        self.block_size = block_size
        self.word_embeddings = nn.Embedding(vocab_size, d_model)
        self.position_embeddings = nn.Embedding(block_size, d_model)
        self.drop = nn.Dropout(embd_pdrop)
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.transformer_encoder = TransformerEncoder(encoder_layer, nlayers)
        self.lm_head = nn.Linear(d_model, vocab_size)

        self.init_weights()
        self.tie_weights(self.lm_head, self.word_embeddings)


    def tie_weights(self, output_embeddings, input_embeddings):
        output_embeddings.weight = input_embeddings.weight
        output_embeddings.bias.data = nn.functional.pad(
            output_embeddings.bias.data,
            (
                0,
                output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
            ),
            "constant",
            0,
        )
        output_embeddings.out_features = input_embeddings.num_embeddings

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, -10000.0).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.word_embeddings.weight, -initrange, initrange)
        nn.init.uniform_(self.position_embeddings.weight, -initrange, initrange)

    def forward(self, input_ids, attention_mask=None, labels=None):
        device = input_ids.device
        input_shape = input_ids.size()

        mask = self._generate_square_subsequent_mask(self.block_size).to(device)

        inputs_embeds = self.word_embeddings(input_ids)
        position_ids = torch.arange(0, input_shape[-1], dtype=torch.long, device=device)
        position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
        position_embeds = self.position_embeddings(position_ids)
        hidden_states = inputs_embeds + position_embeds
        hidden_states = self.drop(hidden_states)

        hidden_states = self.transformer_encoder(hidden_states, mask)
        lm_logits = self.lm_head(hidden_states)

        return lm_logits

We then subclass the model to inherit from `PipelineMixin`, so that the model will have the `parallelize` and `deparallelize` methods. Here we override the `parallelize` method to customize the optimization. Note that if the model is small and no customized optimization is needed for the model, there is even no need to override `parallelize`. The optimizations we apply here and later are just for demonstration, so some of them are actually not necessary for such a relatively small model with `block_size` of 128.

Another change we do here is to override the `forward` method. This is because an external model usually just returns logits, but we need to respect the return format of 🤗 Transformers.

In [None]:
import poptorch
from optimum.graphcore.modeling_utils import PipelineMixin, get_layer_ipu, recomputation_checkpoint, register, tied_weight_model
from optimum.utils import logging
logger = logging.get_logger(__name__)


@tied_weight_model()
class IPUTransformerModel(TransformerModel, PipelineMixin):
    def parallelize(self):
        super().parallelize()
        logger.info("---------- Device Allocation -----------")
        logger.info("Embedding  --> IPU 0")
        self.word_embeddings = poptorch.BeginBlock(self.word_embeddings, "word_embeddings", ipu_id=0)
        self.position_embeddings = poptorch.BeginBlock(self.position_embeddings, "position_embeddings", ipu_id=0)

        layer_ipu = get_layer_ipu(self.ipu_config.layers_per_ipu, self.transformer_encoder.layers)
        for index, layer in enumerate(self.transformer_encoder.layers):
            if self.ipu_config.recompute_checkpoint_every_layer:
                # Put checkpoints on every encoder layer
                h = recomputation_checkpoint(layer)
                self._hooks.append(h)
            ipu = layer_ipu[index]
            logger.info(f"Encoder {index:<2} --> IPU {ipu}")
            self.transformer_encoder.layers[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)

        logger.info(f"Head       --> IPU 0")
        logger.info("---------------------------------------")
        self.lm_head = poptorch.BeginBlock(self.lm_head, "lm_head", ipu_id=0)
        return self

    def forward(self, input_ids, attention_mask=None, labels=None):
        lm_logits = super().forward(input_ids, attention_mask=attention_mask, labels=labels)

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n. Use roll() + ignore_index instead of slicing for better efficiency on IPUs.
            labels = torch.roll(labels, -1, 1)
            # By default the ignore_index of CrossEntropyLoss is -100
            labels[:, -1] = -100
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))

        output = (lm_logits,)
        return (loss,) if loss is not None else output

Let's instantiate the model.

In [None]:
model = IPUTransformerModel(
    block_size=block_size,
    vocab_size=tokenizer.vocab_size,
    d_model=768,
    nhead=12,
    dim_feedforward=768 * 4,
    nlayers=12,
)

To instantiate an `IPUTrainer`, we first define the `IPUConfig`, which is a class that specifies attributes and configuration parameters to compile and put the model on the device. We usually initialize it with one config name or a path to a JSON file. We could also initialize it from a dict as we are doing here:

In [None]:
from optimum.graphcore import IPUConfig, IPUTrainer, IPUTrainingArguments

# ipu_config = IPUConfig.from_pretrained("ipu_config.json")
ipu_config_dict = {
    "embedding_serialization_factor": 2,
    "recompute_checkpoint_every_layer": True,
    "optimizer_state_offchip": True,
    "replicated_tensor_sharding": True,
    "enable_half_partials": True,
    "device_iterations": 1,      
    "inference_device_iterations": 5,
    "replication_factor": {"pod4": 1, "pod8": 2, "pod16": 4, "pod32": 8, "pod64": 16, "default": 1},
    "inference_replication_factor": {"pod4": 1, "pod8": 2, "pod16": 4, "pod32": 8, "pod64": 16, "default": 1},
    "gradient_accumulation_steps": 512,
    "executable_cache_dir": executable_cache_dir,
    "ipus_per_replica": 4,
    "layers_per_ipu": [0, 4, 4, 4],
    "matmul_proportion": [0.25, 0.25, 0.25, 0.25],
 }
ipu_config = IPUConfig.from_dict(ipu_config_dict)

The other thing we need to define is the `IPUTrainingArguments`, which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model, and all other arguments are optional:

In [None]:
micro_batch_size = 1
gradient_accumulation_steps = 64

training_args = IPUTrainingArguments(
    "mymodel-wikitext2",
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=micro_batch_size,
    per_device_eval_batch_size=micro_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    pod_type=pod_type,
    num_train_epochs=10,
    loss_scaling=16384,
    warmup_ratio=0.1,
    dataloader_drop_last=True,
    dataloader_num_workers=64,
    logging_steps=10,
)

Finally, we pass along all of those to the `IPUTrainer` class:

In [None]:
trainer = IPUTrainer(
    model=model,
    ipu_config=ipu_config,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["validation"],
)

And we can train our model:

In [None]:
trainer.train()

Once the training is completed, we can evaluate our model and get its perplexity on the validation set like this:

In [None]:
import math
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

The perplexity is still quite high since for this demo we trained on a small dataset for a small number of epochs. For real LM training, you  would need a larger dataset and more epochs.

If you want to resume training from a checkpoint, you could do this.

In [None]:
trainer.train(resume_from_checkpoint='mymodel-wikitext2/checkpoint-500')