In [6]:
from fairseq.tasks.masked_lm import MaskedLMTask
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
import torch
import os

# load pretrained roberta large model

checkpoint_path = "/data/home/xwhan/fairseq-py/checkpoints/roberta.large"
dictionary = MaskedLMTask.load_dictionary(os.path.join(checkpoint_path, 'dict.txt'))
state = torch.load(os.path.join(checkpoint_path, 'model.pt'), map_location=torch.device('cpu'))
roberta_cfg = convert_namespace_to_omegaconf(state['args'])
task = MaskedLMTask(state['args'], dictionary)
roberta = task.build_model(roberta_cfg.model)
roberta.load_state_dict(state['model'], strict=True, model_cfg=roberta_cfg.model)

print(roberta.encoder.sentence_encoder.layers[0])

TransformerEncoderLayerBase(
  (self_attn): MultiheadAttention(
    (dropout_module): FairseqDropout()
    (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
    (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
    (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
    (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (dropout_module): FairseqDropout()
  (activation_dropout_module): FairseqDropout()
  (fc1): Linear(in_features=1024, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=1024, bias=True)
  (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)


In [9]:
print(roberta.encoder)

RobertaEncoder(
  (sentence_encoder): TransformerEncoder(
    (dropout_module): FairseqDropout()
    (embed_tokens): Embedding(50265, 1024, padding_idx=1)
    (embed_positions): LearnedPositionalEmbedding(514, 1024, padding_idx=1)
    (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (layers): ModuleList(
      (0): TransformerEncoderLayerBase(
        (self_attn): MultiheadAttention(
          (dropout_module): FairseqDropout()
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (dropout_module): FairseqDropout()
        (activation_dropout_module): FairseqDropout()
        (fc1): Linear(in_features=1024,

In [11]:
# build longshort transformer
model_args = state['args']

model_args.use_xformers = True
model_args.attention_name = 'block'
model_args.max_positions = 4096
model_args.xformer_config = '{"window_size": 512}'
task = MaskedLMTask(model_args, dictionary)
long_cfg = convert_namespace_to_omegaconf(model_args)
long_model = task.build_model(long_cfg.model)
print(long_model)

# copy stuff outside the attention blocks
long_model.encoder.lm_head.load_state_dict(roberta.encoder.lm_head.state_dict())
long_model.encoder.sentence_encoder.embed_tokens.load_state_dict(roberta.encoder.sentence_encoder.embed_tokens.state_dict())
long_model.encoder.sentence_encoder.layernorm_embedding.load_state_dict(roberta.encoder.sentence_encoder.layernorm_embedding.state_dict())

# copy attention layers
long_model.encoder.sentence_encoder.layers.load_state_dict(roberta.encoder.sentence_encoder.layers.state_dict(), strict=False)
# print(long_model.encoder.sentence_encoder.layers[0])

# copy position embeddings
import copy
pos_limit, _ = roberta.encoder.sentence_encoder.embed_positions.weight.shape

new_pos_limit, embed_dim = long_model.encoder.sentence_encoder.embed_positions.weight.shape
new_pos_embed = roberta.encoder.sentence_encoder.embed_positions.weight.new_empty(new_pos_limit, embed_dim)
step = pos_limit - 2
for start in range(2, new_pos_limit, step):
    new_pos_embed[start:start+step] = roberta.encoder.sentence_encoder.embed_positions.weight[2:]
long_model.encoder.sentence_encoder.embed_positions.weight.data = new_pos_embed

save_path = '/data/home/xwhan/fairseq-py/checkpoints/roberta.large.block-512'
dictionary.save(os.path.join(save_path, 'dict.txt'))
state['args'] = model_args
state['model'] = long_model.state_dict()
torch.save(state, os.path.join(save_path, 'model.pt'))

RobertaModel(
  (encoder): RobertaEncoder(
    (sentence_encoder): TransformerEncoder(
      (dropout_module): FairseqDropout()
      (embed_tokens): Embedding(50265, 1024, padding_idx=1)
      (embed_positions): LearnedPositionalEmbedding(4098, 1024, padding_idx=1)
      (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0): TransformerEncoderLayerBase(
          (self_attn): MultiheadAttention(
            (dropout_module): FairseqDropout()
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (attention): BlockAttention(
              (drop_attn): Dropout(p=0.1, inplace=False)
            )
            (multi_head): MultiHeadDispatch(
              (attenti

In [3]:
# build sliding-window attention model and initialize from roberta-large
long_args = state['args']
long_args.attention_window = [128]
long_args.arch = 'sliding_window_large'
long_args.max_positions = 4096
task = MaskedLMTask(long_args, dictionary)
long_cfg = convert_namespace_to_omegaconf(long_args)
long_model = task.build_model(long_cfg.model)
print(long_model)

# copy stuff outside the attention blocks
long_model.encoder.lm_head.load_state_dict(roberta.encoder.lm_head.state_dict())
long_model.encoder.sentence_encoder.embed_tokens.load_state_dict(roberta.encoder.sentence_encoder.embed_tokens.state_dict())
long_model.encoder.sentence_encoder.layernorm_embedding.load_state_dict(roberta.encoder.sentence_encoder.layernorm_embedding.state_dict())

# copy attention layers
long_model.encoder.sentence_encoder.layers.load_state_dict(roberta.encoder.sentence_encoder.layers.state_dict(), strict=False)
print(long_model.encoder.sentence_encoder.layers[0])

# copy position embeddings
import copy
pos_limit, _ = roberta.encoder.sentence_encoder.embed_positions.weight.shape

new_pos_limit, embed_dim = long_model.encoder.sentence_encoder.embed_positions.weight.shape
new_pos_embed = roberta.encoder.sentence_encoder.embed_positions.weight.new_empty(new_pos_limit, embed_dim)
step = pos_limit - 2
for start in range(2, new_pos_limit, step):
    new_pos_embed[start:start+step] = roberta.encoder.sentence_encoder.embed_positions.weight[2:]
long_model.encoder.sentence_encoder.embed_positions.weight.data = new_pos_embed

# initialize global attention parameters
for layer in long_model.encoder.sentence_encoder.layers:
    layer.self_attn.q_proj_global = copy.deepcopy(layer.self_attn.q_proj)
    layer.self_attn.k_proj_global = copy.deepcopy(layer.self_attn.k_proj)
    layer.self_attn.v_proj_global = copy.deepcopy(layer.self_attn.v_proj)

# save the model 

save_path = '/data/home/xwhan/fairseq-py/checkpoints/roberta.large.extended'
dictionary.save(os.path.join(save_path, 'dict.txt'))
state['args'] = long_args
state['model'] = long_model.state_dict()
torch.save(state, os.path.join(save_path, 'model.pt'))

[128]
SlidingWindownModel(
  (encoder): SlidingWindowEncoder(
    (sentence_encoder): SWTransformerEncoder(
      (dropout_module): FairseqDropout()
      (embed_tokens): Embedding(50265, 1024, padding_idx=1)
      (embed_positions): LearnedPositionalEmbedding(4098, 1024, padding_idx=1)
      (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0): SWTransformerEncoderLayer(
          (self_attn): SWSelfAttention(
            (dropout_module): FairseqDropout()
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (k_proj_global): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj_global): Linear(in_features=1024, out_features=1024, bias=T