In [1]:
import os
import math
import copy
import logging

from dataclasses import dataclass, field
from transformers import TrainingArguments, HfArgumentParser
from transformers import RobertaForMaskedLM, RobertaTokenizerFast, TextDataset,\
DataCollatorForLanguageModeling, Trainer, LongformerSelfAttention

In [2]:
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

### Setup Model to convert
#### RobertaLong
RobertaLongForMaskedLM represents the "long" version of the RoBERTa model. It replaces BertSelfAttention with RobertaLongSelfAttention, which is a thin wrapper around LongformerSelfAttention.

In [3]:
class RobertaLongSelfAttention(LongformerSelfAttention):
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        output_attentions=False,
    ):
        return super().forward(hidden_states, attention_mask=attention_mask, output_attentions=output_attentions)


class RobertaLongForMaskedLM(RobertaForMaskedLM):
    def __init__(self, config):
        super().__init__(config)
        for i, layer in enumerate(self.roberta.encoder.layer):
            # replace the `modeling_bert.BertSelfAttention` object with `LongformerSelfAttention`
            layer.attention.self = RobertaLongSelfAttention(config, layer_id=i)