<a href="https://colab.research.google.com/github/dcaled/convert_bert_to_long/blob/main/convert_bert_to_long.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# `BERT` --> `Longformer`: build a "long" version of pretrained models

This notebook replicates the procedure descriped in the [Longformer paper](https://arxiv.org/abs/2004.05150) to train a Longformer model starting from the BERT checkpoint. The same procedure can be applied to build the "long" version of other pretrained models as well. It was inspired by the notebook provided by Allenai to convert RoBERTa to Longformer: [convert_model_to_long.ipynb](https://colab.research.google.com/github/allenai/longformer/blob/master/scripts/convert_model_to_long.ipynb).


### Libraries, and imports


In [None]:
!pip install transformers==3.0.2



In [None]:
import logging
import os
import math
import torch
import tensorflow as tf
from dataclasses import dataclass, field
from transformers import AutoModel, AutoTokenizer, BertTokenizerFast, BertForMaskedLM, BertModel
from transformers import TrainingArguments, HfArgumentParser
from transformers.modeling_longformer import LongformerSelfAttention

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

### BertLong

`BertLongForMaskedLM` represents the "long" version of the `BERT` model. It replaces `BertSelfAttention` with `BertLongSelfAttention`, which is a thin wrapper around `LongformerSelfAttention`.


In [None]:
class BertLongSelfAttention(LongformerSelfAttention):
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=False,
    ):
        return super().forward(hidden_states, attention_mask=attention_mask, output_attentions=output_attentions)


class BertLong(BertModel):
    def __init__(self, config):
        super().__init__(config)
        for i, layer in enumerate(self.encoder.layer):
            # replace the `modeling_bert.BertSelfAttention` object with `LongformerSelfAttention`
            layer.attention.self = BertLongSelfAttention(config, layer_id=i)

Starting from the `bert-base` checkpoint, the following function converts it into an instance of `BertLong`. It makes the following changes:

- extend the position embeddings from `512` positions to `max_pos`. In Longformer, we set `max_pos=4096`

- initialize the additional position embeddings by copying the embeddings of the first `512` positions. This initialization is crucial for the model performance (check table 6 in [the paper](https://arxiv.org/pdf/2004.05150.pdf) for performance without this initialization)

- replaces `modeling_bert.BertSelfAttention` objects with `modeling_longformer.LongformerSelfAttention` with a attention window size `attention_window`

The output of this function works for long documents even without pretraining. Check tables 6 and 11 in [the paper](https://arxiv.org/pdf/2004.05150.pdf) to get a sense of the expected performance of this model before pretraining.

In [None]:
def create_long_model(save_model_to, attention_window, max_pos):
    model = BertModel.from_pretrained('bert-base-multilingual-cased')
    tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased', model_max_length=max_pos)
    config = model.config

    print(max_pos)
    # extend position embeddings
    tokenizer.model_max_length = max_pos
    tokenizer.init_kwargs['model_max_length'] = max_pos
    current_max_pos, embed_size = model.embeddings.position_embeddings.weight.shape
    config.max_position_embeddings = max_pos
    assert max_pos > current_max_pos
    # allocate a larger position embedding matrix
    new_pos_embed = model.embeddings.position_embeddings.weight.new_empty(max_pos, embed_size)
    print(new_pos_embed.shape)
    print(model.embeddings.position_embeddings)
    # copy position embeddings over and over to initialize the new position embeddings
    k = 0
    step = current_max_pos
    while k < max_pos - 1:
        new_pos_embed[k:(k + step)] = model.embeddings.position_embeddings.weight
        k += step
    print(new_pos_embed.shape)
    model.embeddings.position_ids = torch.from_numpy(tf.range(new_pos_embed.shape[0], dtype=tf.int32).numpy()[tf.newaxis, :])
    model.embeddings.position_embeddings = torch.nn.Embedding.from_pretrained(new_pos_embed)
    
    # replace the `modeling_bert.BertSelfAttention` object with `LongformerSelfAttention`
    config.attention_window = [attention_window] * config.num_hidden_layers
    for i, layer in enumerate(model.encoder.layer):
        longformer_self_attn = LongformerSelfAttention(config, layer_id=i)
        longformer_self_attn.query = layer.attention.self.query
        longformer_self_attn.key = layer.attention.self.key
        longformer_self_attn.value = layer.attention.self.value

        longformer_self_attn.query_global = layer.attention.self.query
        longformer_self_attn.key_global = layer.attention.self.key
        longformer_self_attn.value_global = layer.attention.self.value

        layer.attention.self = longformer_self_attn
    print(model.embeddings.position_ids.shape)
    logger.info(f'saving model to {save_model_to}')
    model.save_pretrained(save_model_to)
    tokenizer.save_pretrained(save_model_to)
    return model, tokenizer, new_pos_embed

**Training hyperparameters**

- Following BERT pretraining setting, we set number of tokens per batch to be `2^18` tokens. Changing this number might require changes in the lr, lr-scheudler, #steps and #warmup steps. Therefor, it is a good idea to keep this number constant.

- Note that: `#tokens/batch = batch_size x #gpus x gradient_accumulation x seqlen`
   
- In [the paper](https://arxiv.org/pdf/2004.05150.pdf), we train for 65k steps, but 3k is probably enough (check table 6)

- **Important note**: The lr-scheduler in [the paper](https://arxiv.org/pdf/2004.05150.pdf) is polynomial_decay with power 3 over 65k steps. To train for 3k steps, use a constant lr-scheduler (after warmup). Both lr-scheduler are not supported in HF trainer, and at least **constant lr-scheduler** will need to be added. 

- Pretraining will take 2 days on 1 x 32GB GPU with fp32. Consider using fp16 and using more gpus to train faster (if you increase `#gpus`, reduce `gradient_accumulation` to maintain `#tokens/batch` as mentioned earlier).

- As a demonstration, this notebook is training on wikitext103 but wikitext103 is rather small that it takes 7 epochs to train for 3k steps Consider doing a single epoch on a larger dataset (800M tokens) instead.

- Set #gpus using `CUDA_VISIBLE_DEVICES`

In [None]:
@dataclass
class ModelArgs:
    attention_window: int = field(default=512, metadata={"help": "Size of attention window"})
    max_pos: int = field(default=4096, metadata={"help": "Maximum position"})

parser = HfArgumentParser((TrainingArguments, ModelArgs,))


training_args, model_args = parser.parse_args_into_dataclasses(look_for_args_file=False, args=[
    '--output_dir', 'tmp',
    '--warmup_steps', '500',
    '--learning_rate', '0.00003',
    '--weight_decay', '0.01',
    '--adam_epsilon', '1e-6',
    '--max_steps', '3000',
    '--logging_steps', '500',
    '--save_steps', '500',
    '--max_grad_norm', '5.0',
    '--per_gpu_eval_batch_size', '8',
    '--per_gpu_train_batch_size', '2',  # 32GB gpu with fp32
    '--gradient_accumulation_steps', '32',
    '--evaluate_during_training',
    '--do_train',
    '--do_eval',
])

# Choose GPU
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

1) As descriped in `create_long_model`, convert a `bert-base` model into `bert-base-4096` which is an instance of `BertLong`, then save it to the disk.

In [None]:
model_path = f'{training_args.output_dir}/bert-base-{model_args.max_pos}'
if not os.path.exists(model_path):
    os.makedirs(model_path)

logger.info(f'Converting bert-base into bert-base-{model_args.max_pos}')
model, tokenizer, new_pos_embed = create_long_model(
    save_model_to=model_path, attention_window=model_args.attention_window, max_pos=model_args.max_pos)
#create_long_model(save_model_to, attention_window, max_pos)

INFO:__main__:Converting bert-base into bert-base-4096
INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json from cache at /root/.cache/torch/transformers/45629519f3117b89d89fd9c740073d8e4c1f0a70f9842476185100a8afe715d1.65df3cef028a0c91a7b059e4c404a975ebe6843c71267b67019c0e9cfa8a88f0
INFO:transformers.configuration_utils:Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_s

4096
torch.Size([4096, 768])
Embedding(512, 768)
torch.Size([4096, 768])


INFO:__main__:saving model to tmp/bert-base-4096
INFO:transformers.configuration_utils:Configuration saved in tmp/bert-base-4096/config.json


torch.Size([1, 4096])


INFO:transformers.modeling_utils:Model weights saved in tmp/bert-base-4096/pytorch_model.bin


2) Load `bert-base-4096` from the disk. This model works for long sequences even without pretraining. If you don't want to pretrain, you can stop here and start finetuning your `bert-base-4096` on downstream tasks 🎉🎉🎉

In [None]:
logger.info(f'Loading the model from {model_path}')
tokenizer = BertTokenizerFast.from_pretrained(model_path)
model = BertLong.from_pretrained(model_path)

INFO:__main__:Loading the model from tmp/bert-base-4096
INFO:transformers.tokenization_utils_base:Model name 'tmp/bert-base-4096' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, TurkuNLP/bert-base-finnish-cased-v1, TurkuNLP/bert-base-finnish-uncased-v1, wietsedv/bert-base-dutch-cased). Assuming 'tmp/bert-base-4096' is a path, a model identifier, or url to a directory containing tokenizer files.
INFO:transformers.tokenization_utils_base:Didn't find file tmp/bert-base-4096/added_tokens.json. We won't load it.
INFO:transformers.tokenization_utils_ba