<a href="https://colab.research.google.com/github/mnaylor5/pretrain-mega-lm/blob/main/PretrainMegaLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Pretrain BERT-style LM with Mega Architecture

In this notebook, I'll quickly pretrain a masked language model with the encoder layers built using Moving-average Equipped Gated Attention (Mega, see [paper](https://arxiv.org/abs/2209.10655) and [GitHub repo](https://github.com/facebookresearch/mega)). The purpose of this effort is to produce a pretrained language model that can be used to help with the implementation of Mega into Hugging Face's transformers library. I'll reuse the same BPE tokenizer used by GPT2 and RoBERTa and pretrain for a certain number of gradient steps (ideally at least 100k, but depending on GPU limits).

Here's my plan:
* Tokenizer: reuse the RoBERTa tokenizer
* Architecture: the same specs as the `Text` LRA task in the Mega paper (depth of 4)
* Data:
  * Wikitext 103 (`wikitext` with subset `wikitext-103-v1`)
  * Preserve dataset splits where provided, otherwise custom splits
* Training details:
  * Truncate to 256 tokens, return overflowing sequences, and drop observations with fewer than 10 tokens
  * Train for 5 epochs
  * Batch observations based on sequence length to speed up training
  * Linear warmup and decay for learning rate

## Install Dependencies


First, uncomment if you want to run on a TPU:

In [4]:
# !pip install cloud-tpu-client==0.10 torch==1.13.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.13-cp38-cp38-linux_x86_64.whl

Next, the Hugging Face libraries and PyTorch Lightning

In [5]:
!pip install transformers datasets pytorch-lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


Now for Mega -- clone and `pip install`

In [6]:
!git clone https://github.com/facebookresearch/mega.git && cd mega && pip install -e .

fatal: destination path 'mega' already exists and is not an empty directory.


## Environment Setup
Load libraries and define details of our process

In [7]:
import torch
from torch import nn
import torch.nn.functional as F
import sys
import numpy as np
from transformers import AutoTokenizer
from datasets import load_dataset, Dataset, DatasetDict
from torch.utils.data import DataLoader, BatchSampler, SequentialSampler, RandomSampler, Sampler
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
from transformers.optimization import get_linear_schedule_with_warmup
import pytorch_lightning as pl
import random

In [8]:
# import mega code
sys.path.append('./mega')
from argparse import Namespace
from mega.fairseq.modules import MegaEncoderLayer

In [9]:
RUN_ON_TPU = False # set to False if not running on TPU

if RUN_ON_TPU:
  import torch_xla
  import torch_xla.core.xla_model as xm
  NUM_TPU_CORES = len(xm.get_xla_supported_devices())

We'll also silence a tokenizer warning that would be repeated many times during training:

```You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.```

In [10]:
import os
# silence tokenizer warning in huggingface
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'

In [11]:
if torch.cuda.is_available():
  device = 'gpu'
  print("Training on GPU")
  !nvidia-smi
elif RUN_ON_TPU:
  device = 'tpu'
  print(f"Attempting to run on TPU with {NUM_TPU_CORES} cores identified")
else:
  device = None
  print("No GPU available")

Training on GPU
Wed Dec 21 14:37:25 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   56C    P0    27W /  70W |      3MiB / 15109MiB |      4%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------

Set seeds

In [12]:
pl.seed_everything(865)

INFO:lightning_lite.utilities.seed:Global seed set to 865


865

Define the tokenizer to be used and pull from the Hugging Face hub

In [13]:
TOKENIZER_NAME = "roberta-base"

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)

Downloading:   0%|          | 0.00/481 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Last thing for now: let's define some constants we'll need to pass to the tokenizer when we're prepping the text, along with our actual batch sizes to be used in the data loaders

In [14]:
MAX_SEQUENCE_LENGTH = 256
RETURN_OVERFLOWING_TOKENS = True

TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 64

## Load and Prepare Data

In [15]:
wikitext = load_dataset('wikitext', 'wikitext-103-v1')
wikitext

Downloading builder script:   0%|          | 0.00/8.48k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/6.84k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.25k [00:00<?, ?B/s]

Downloading and preparing dataset wikitext/wikitext-103-v1 to /root/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126...


Downloading data:   0%|          | 0.00/190M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/1801350 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Dataset wikitext downloaded and prepared to /root/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 1801350
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})

Let's see how much of the Wikitext data is empty

In [16]:
word_counts = list(map(len, wikitext['train']['text']))

print("Percentiles of character count in Wikitext observations")
length_quantiles = [0, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99, 1.0]
for percentile, length in zip(length_quantiles, np.quantile(word_counts, length_quantiles)):
  print(f" * {percentile:.1%}: {int(length):,} characters")

Percentiles of character count in Wikitext observations
 * 0.0%: 0 characters
 * 10.0%: 0 characters
 * 25.0%: 0 characters
 * 50.0%: 33 characters
 * 75.0%: 556 characters
 * 90.0%: 921 characters
 * 99.0%: 1,601 characters
 * 100.0%: 7,064 characters


At least 25% of the observations are altogether empty strings. Let's get rid of those and set a minimum of 10 characters to get rid of tiny strings

In [17]:
wikitext_nonempty = wikitext.filter(lambda x: len(x['text']) >= 10)
wikitext_nonempty

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/1802 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 2881
    })
    train: Dataset({
        features: ['text'],
        num_rows: 1161735
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 2461
    })
})

Now map the tokenizer to our text -- we don't need to return tensors here because the `DataCollatorForLanguageModeling` will handle that for us :) 

If `RETURN_OVERFLOWING_TOKENS` is set to `True`, this `.map` call will return a dataset with more observations than the original dataset due to 

In [18]:
tokenize_text = lambda x: tokenizer(x['text'], 
                                    max_length=MAX_SEQUENCE_LENGTH,
                                    return_overflowing_tokens=RETURN_OVERFLOWING_TOKENS,
                                    return_length=True,
                                    truncation=True)

wikitext_tokenized = wikitext.map(tokenize_text,
                                  batched=True,
                                  remove_columns=['text'])

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/1802 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

In [19]:
print("Percentiles of token count in Wikitext observations")
length_quantiles = [0, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99, 1.0]
for percentile, length in zip(length_quantiles, np.quantile(wikitext_tokenized['train']['length'], length_quantiles)):
  print(f" * {percentile:.1%}: {int(length):,} tokens")


Percentiles of token count in Wikitext observations
 * 0.0%: 2 tokens
 * 10.0%: 2 tokens
 * 25.0%: 2 tokens
 * 50.0%: 14 tokens
 * 75.0%: 119 tokens
 * 90.0%: 198 tokens
 * 99.0%: 256 tokens
 * 100.0%: 256 tokens


We now want to get rid of any observations with fewer than, say, 10 tokens. We'll also sort by sequence length and drop the extra column so it's ready for our collator.

In [20]:
wikitext_tokenized_final = wikitext_tokenized\
  .filter(lambda x: x['length'] >= 10)\
  .shuffle()\
  .sort('length')\
  .remove_columns(['length'])
  
wikitext_tokenized_final

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/1883 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

DatasetDict({
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'overflow_to_sample_mapping'],
        num_rows: 2800
    })
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'overflow_to_sample_mapping'],
        num_rows: 1158900
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'overflow_to_sample_mapping'],
        num_rows: 2428
    })
})

Now that the data is sorted in ascending sequence length, we can create a `Sampler` subclass that handles batch selection by grouping observations with similar lengths. This will assign static batch IDs which will be sampled in random order for each epoch. 

Uniform-length batching minimizes the amount of "empty" padded space in each batch, and in this case it speeds up our training loop by 40-50%

In [21]:
class SequentialBatchRandomSampler(Sampler):
    def __init__(self, dataset, batch_size, drop_last=False):
        # not actually saving dataset as an attribute to save memory
        # this creates a list of batched IDs in order of appearance
        # e.g. for a batch size of 3: [[0, 1, 2], [3, 4, 5], ...]
        unordered_sampler = BatchSampler(SequentialSampler(dataset), batch_size=batch_size, drop_last=drop_last)
        self.batches = list(unordered_sampler)
        
        # save some additional attributes
        self.batch_size = batch_size
        self.dataset_length = len(dataset)
        self.drop_last=drop_last

    def __iter__(self):
        '''
        Every time we loop over the sampler, shuffle the batches in place (once)
        and yield each set of IDs as we iterate.
        '''
        random.shuffle(self.batches)
        for batch in self.batches:
            yield batch 
            
    def __len__(self) -> int:
        '''
        Return the number of batches to be returned by this sampler, dependent upon
        whether we want to drop the last (incomplete) batch
        '''
        if self.drop_last:
            return self.dataset_length // self.batch_size
        else:
            return (self.dataset_length + self.batch_size - 1) // self.batch_size

Create dataloaders and define our data collator function -- the `DataCollatorForLanguageModeling` class from the HF Transformers library prepares the data for masked LM using the default settings.

We'll want to set `num_workers` according to how many CPUs we have available (2 in the free Colab instance).

In [22]:
collator = DataCollatorForLanguageModeling(tokenizer)

train_dl = DataLoader(
    wikitext_tokenized_final['train'],
    batch_sampler=SequentialBatchRandomSampler(wikitext_tokenized_final['train'], TRAIN_BATCH_SIZE, drop_last=False),
    collate_fn=collator,
    num_workers=2
)

valid_dl = DataLoader(
    wikitext_tokenized_final['validation'],
    batch_sampler=SequentialBatchRandomSampler(wikitext_tokenized_final['validation'], VALID_BATCH_SIZE, drop_last=False),
    collate_fn=collator,
    num_workers=2
)

## Model Classes

These are standard subclasses of `nn.Module` that define the process of encoding tokens and predicting for the LM task. We could have done all of this in the PyTorch Lightning `pl.LightningModule` class, but I find that it's easier to save/load with only PyTorch, so we'll do it this way to avoid requiring Lightning for loading in.

In [23]:
class MegaLM(nn.Module):
  'The base class for our Mega encoder - given input IDs, embed text and return encoder output'
  def __init__(self, mega_args, depth, vocab_size):
    super().__init__()
    self.mega_args = mega_args
    self.embedding_layer = nn.Embedding(vocab_size, self.mega_args.encoder_embed_dim)
    self.encoders = nn.ModuleList(
      [MegaEncoderLayer(self.mega_args) for _ in range(depth)
    ])
    self.depth = depth
        
  def forward(self, input_ids, attention_mask, batch_first=True, ignore_mask_value=0):
    '''
    Code for a forward pass - expects input_ids and attention_mask to come
    from a Hugging Face tokenizer as PyTorch tensors, and returns a tensor
    of size (batch, n_classes) containing classification logits
    
    Other options:
      - batch_first: boolean indicating whether the batch dimension is first 
        in input_ids (default: True, which aligns with the HF tokenizer behavior)
      - ignore_mask_value: the value in attention_mask that identifies tokens 
        that should be ignored (default: 0, which aligns with HF tokenizer)
    '''

    # Mega expects embeddings to be (time, batch, embedding size), but 
    # Hugging Face returns tokens as (batch, time)
    if batch_first:
        input_ids = input_ids.T

    # to make things more confusing, Mega expects the attention mask to
    # be (batch, time), but with values of 0 (normal token) and 1 (ignore token)
    # which is the opposite of what HF returns
    if ignore_mask_value == 0:
        attention_mask = 1 - attention_mask

    # get token embeddings from IDs
    embeds = self.embedding_layer(input_ids)

    # pass through the Mega layers
    # input is (time, batch, encoder dim) and output is the same
    for encoder in self.encoders:
        embeds = encoder(embeds, attention_mask)
        
    # return according to the shape specified
    if batch_first:
        # (T, B, H) --> (B, T, H)
        return torch.transpose(embeds, 0, 1)
    else:
        return embeds

class MegaForMaskedLM(nn.Module):
  'A wrapper class for doing masked language modeling with Mega'
  def __init__(self, mega_args, depth, vocab_size):
    super().__init__()
    self.mega = MegaLM(mega_args, depth, vocab_size)
    self.mlm_head = nn.Linear(mega_args.encoder_embed_dim, vocab_size)
    self.dropout = nn.Dropout(p=0.1)

  def forward(self, input_ids, attention_mask, batch_first=True, ignore_mask_value=0):
    """
    Perform a forward pass through the Mega encoder and the masked LM head. Returns
    logits for each vocabulary entry.

    If `batch_first` (default to align with Hugging Face tokenizer behavior), 
    output will have the shape (Batch size, Sequence length, Vocab size);
    otherwise (S, B, V)
    """
    encoder_output = self.mega(input_ids, attention_mask, batch_first, ignore_mask_value)
    return self.mlm_head(self.dropout(encoder_output))

## Training Loop Setup

Define the information about the model (depth, encoder specifics) as well as the training loop (number of epochs, maximum learning rate, weight decay, etc.)

In [24]:
# first, a couple of args that don't go in the encoder config
EPOCHS = 5
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.01

# learning rate warmup proportion
# pytorch lightning steps at the end of each epoch
LR_WARMUP_PCT = 0.1

# depth: we'll create this many versions of the same encoder layer to stack in our model
ENCODER_DEPTH = 4

# gradient accumulation - collect gradients over this many batches before doing backprop
ACCUMULATE_TRAIN_BATCHES = 16
print(f"Effective batch size = {ACCUMULATE_TRAIN_BATCHES * TRAIN_BATCH_SIZE}")

Effective batch size = 512


One last setup detail: let's define our encoder architecture. The `MegaEncoder` class expects a `Namespace` object, and I've taken most of the arguments from the architecture specs used in the `Text` LRA task in the Mega paper.

In [25]:
# now all the encoder arguments
mega_encoder_args = Namespace(
    encoder_embed_dim=128,              # size of embeddings and encoder output (d_model in the paper)
    encoder_z_dim=64,                   # z in the paper
    encoder_hidden_dim=256,             # size of the attention values (v in the paper)
    encoder_n_dim=16,                   # size of the EMA projection (h in the paper) 
    bidirectional=True, 
    rel_pos_bias='rotary',
    encoder_chunk_size=-1,              # size of the chunks - lower is more efficient/smaller chunks, -1 for no chunking
    normalization_type='scalenorm',     # how to normalize - 'scalenorm' used for text classification in paper
    dropout=0.1,
    attention_dropout=0.1,
    hidden_dropout=0.1,
    truncation_length=None,
    max_source_positions=10_000,        
    activation_fn='silu',               # again what's used in the paper
    attention_activation_fn='softmax',  # also what's used in the paper - could explore 'laplace' as well
    
    # args for the normalized feed-forward component
    normalize_before=True,
    feature_dropout=False,
    encoder_ffn_embed_dim=256,          # size of the hidden dimension in the NFFN (d_FFN in the paper)
    activation_dropout=False
)

Print out the total gradient steps we'll run 

In [26]:
TOTAL_GRADIENT_STEPS = int(EPOCHS * np.ceil(len(train_dl) / ACCUMULATE_TRAIN_BATCHES))
WARMUP_STEPS = round(LR_WARMUP_PCT * TOTAL_GRADIENT_STEPS)
TOTAL_GRADIENT_STEPS, WARMUP_STEPS

(11320, 1132)

Define our PyTorch Lightning module, which I'm essentially just using to create the training loop and handle devices, etc.

In [27]:
class LitMegaMLM(pl.LightningModule):
    def __init__(self, mega_args, depth, vocab_size):
        super().__init__()
        self.vocab_size=vocab_size
        self.mega_mlm = MegaForMaskedLM(mega_args, depth, vocab_size)
        
        # other necessary components
        self.loss_fn = nn.CrossEntropyLoss()

        # we'll do our own optimizer.step() and scheduler.step()
        self.automatic_optimization = False
        
    def forward(self, input_ids, attention_mask, batch_first=True, ignore_mask_value=0):
        return self.mega_mlm(input_ids, attention_mask, batch_first, ignore_mask_value)
    
    def training_step(self, batch, batch_idx):
        '''
        Define the training step, assuming `batch` is the batched output
        of a tokenized dataset containing the following:
          - input_ids: integer with shape (batch, time)
          - attention_mask: 0/1 with shape (batch, time); 0 = ignore, 1 = attend
          - labels: 0-indexed integer vector with length (batch)

        Returns the loss value for this batch
        '''
        logits = self(batch['input_ids'], batch['attention_mask'])
        loss = self.loss_fn(logits.view(-1, self.vocab_size), batch['labels'].view(-1))
        loss.backward()
        self.log('loss', loss, prog_bar=True, sync_dist=True)

        # manually do the backwards pass with gradient accumulation + LR schedule step here as well
        opt = self.optimizers()
        sched = self.lr_schedulers()
        if (batch_idx + 1) % ACCUMULATE_TRAIN_BATCHES == 0:
          opt.step()
          opt.zero_grad()
          sched.step()
        return loss
    
    def validation_step(self, batch, batch_idx):
        '''
        Same code for performing a forward pass and calculating loss as the training step
        '''
        logits = self(batch['input_ids'], batch['attention_mask'])
        loss = self.loss_fn(logits.view(-1, self.vocab_size), batch['labels'].view(-1))
        self.log('valid_loss', loss, prog_bar=True, sync_dist=True)
        return loss
    
    def configure_optimizers(self):
        optim = AdamW(self.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        scheduler = get_linear_schedule_with_warmup(optim, num_warmup_steps=WARMUP_STEPS, num_training_steps=TOTAL_GRADIENT_STEPS)
        return [optim], [scheduler]

In [28]:
model = LitMegaMLM(mega_encoder_args, depth=ENCODER_DEPTH, vocab_size=tokenizer.vocab_size)

Illustrating the inputs and the stages of a forward pass

In [30]:
for batch in train_dl:
    break

print("Batch dictionary entries:")
for k, i in batch.items():
    print(f'  - {k}: {list(i.shape)}')
    
with torch.no_grad():
    mega_embeds = model.mega_mlm.mega.embedding_layer(batch['input_ids'])
    print(f'\nMega embeddings: {list(mega_embeds.shape)}')
    
    mega_output = model.mega_mlm.mega(batch['input_ids'], batch['attention_mask'])
    print(f'\nMega output: {list(mega_output.shape)}')

    mlm_output = model(batch['input_ids'], batch['attention_mask'])
    print(f'\nMLM output: {list(mlm_output.shape)}')

Batch dictionary entries:
  - input_ids: [32, 107]
  - attention_mask: [32, 107]
  - overflow_to_sample_mapping: [32]
  - labels: [32, 107]

Mega embeddings: [32, 107, 128]

Mega output: [32, 107, 128]

MLM output: [32, 107, 50265]


## Pretrain Mega

Start by reinstantiating the model to avoid any issues from CUDA devices, etc.

In [31]:
model = LitMegaMLM(mega_encoder_args, depth=ENCODER_DEPTH, vocab_size=tokenizer.vocab_size)

Define the Trainer class -- this would ordinarily contain the gradient accumulation definition, but we're doing that manually in the `training_step` instead

In [32]:
trainer = pl.Trainer(
    max_epochs=EPOCHS,
    accelerator=device
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Now run the training loop

In [33]:
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=valid_dl)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name     | Type             | Params
----------------------------------------------
0 | mega_mlm | MegaForMaskedLM  | 13.8 M
1 | loss_fn  | CrossEntropyLoss | 0     
----------------------------------------------
13.8 M    Trainable params
0         Non-trainable params
13.8 M    Total params
55.241    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


In [34]:
trainer.validate(model, dataloaders=[valid_dl])

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       valid_loss            3.358445882797241
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'valid_loss': 3.358445882797241}]

Let's print how many gradient steps this model trained for:

In [35]:
model.global_step

11315

## Save Model Components

Now that we have the pretrained the model, let's save its weights, config, and tokenizer for reuse.

I'm going to save the LM encoder weights separately from the MLM task head for easier task-agnostic reuse. We can always load the MLM head weights separately if we want to continue pretraining :) 

In [36]:
import pickle as pkl 
import os

In [37]:
MODEL_PATH = "mega-wikitext-103"

# first the tokenizer
tokenizer.save_pretrained(MODEL_PATH)

# now the encoder weights
torch.save(model.mega_mlm.mega.state_dict(), os.path.join(MODEL_PATH, "encoder_weights.pt"))

# next the MLM head
torch.save(model.mega_mlm.mlm_head.state_dict(), os.path.join(MODEL_PATH, "mlm_head_weights.pt"))

We'll also want a config object to pass the arguments needed to reinstantiate this model class. Saving as a pickle object 

In [38]:
model_config = {
    'vocab_size':tokenizer.vocab_size,
    'depth':ENCODER_DEPTH,
    'mega_args':mega_encoder_args
}

with open(os.path.join(MODEL_PATH, "model_args.pkl"), 'wb') as f:
  pkl.dump(model_config, f)

## Usage Example

Here's a quick example of loading and prediction using the classes above and the saved objects

In [40]:
# first load the config
with open(os.path.join(MODEL_PATH, 'model_args.pkl'), 'rb') as f:
    loaded_config = pkl.load(f)

# then instantiate the LM class using the arguments in the config
loaded_encoder = MegaLM(**loaded_config)

# load the weights with torch.load and import them into the model architecture
print(loaded_encoder.load_state_dict(torch.load(os.path.join(MODEL_PATH, 'encoder_weights.pt'), map_location='cpu')))
loaded_encoder.eval() # switch to eval mode for prediction

# load the tokenizer with Hugging Face's auto function
loaded_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

<All keys matched successfully>


Now let's recreate the MLM object as well

In [43]:
loaded_mlm = MegaForMaskedLM(**loaded_config)

# first load the encoder weights
print(loaded_mlm.mega.load_state_dict(torch.load(os.path.join(MODEL_PATH, 'encoder_weights.pt'), map_location='cpu')))

# then the MLM head weights
print(loaded_mlm.mlm_head.load_state_dict(torch.load(os.path.join(MODEL_PATH, 'mlm_head_weights.pt'), map_location='cpu')))

# set to eval mode
loaded_mlm = loaded_mlm.eval()

<All keys matched successfully>
<All keys matched successfully>


The last thing for now: let's do a full example of this model filling in the masked token

In [56]:
test_input = f"John was born in 1954 in the United {loaded_tokenizer.mask_token} of America."
print(test_input)

John was born in 1954 in the United <mask> of America.


In [57]:
test_tokens = loaded_tokenizer(test_input, return_tensors="pt")
mask_index = test_tokens['input_ids'] == loaded_tokenizer.mask_token_id
mask_index.sum()

tensor(1)

In [58]:
test_outputs = loaded_mlm(test_tokens['input_ids'], test_tokens['attention_mask']).squeeze(0)
pred_logit = test_outputs[mask_index.squeeze()].argmax().item()
pred_logit

532

How does Mega fill in the blank?

In [59]:
loaded_tokenizer.convert_ids_to_tokens([pred_logit])

['ĠStates']

That was admittedly a silly example, but it validates that the loading and prediction code works!

I've downloaded the files from the output directory and am uploading them here:

https://huggingface.co/mnaylor/mega-wikitext-103 