In [1]:
from fairseq.tasks.denoising import DenoisingTask
from fairseq.tasks.translation import TranslationTask
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
import torch
import os
from fairseq import checkpoint_utils

# load pretrained bart-large model
checkpoint_path = "/data/home/xwhan/fairseq-py/checkpoints/bart.large"

# bart = models[0]
dictionary = DenoisingTask.load_dictionary(os.path.join(checkpoint_path, 'dict.txt'))
state = torch.load(os.path.join(checkpoint_path, 'model.pt'), map_location=torch.device('cpu'))
bart_cfg = convert_namespace_to_omegaconf(state['args'])
task = DenoisingTask(state['args'], dictionary)
bart = task.build_model(bart_cfg.model)
bart.load_state_dict(state['model'], strict=True, model_cfg=bart_cfg.model)

<All keys matched successfully>

In [4]:
# build sparse transformer
from fairseq.tasks.long_denoising import LongDenoisingTask

model_args = state['args']
model_args.use_xformers = True
model_args.attention_name = 'block_noglobal'
model_args.xformer_config = '{"window_size": 1024}'
model_args.max_source_positions = 1024 * 16
model_args.max_target_positions = 1024
model_args.mean_noise_span_length = 10
model_args.noise_density = 0.05


## need these steps such that the models can add sentinel tokens to its vocab
def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):

    def _tokens_length_to_inputs_length_targets_length(tokens_length):
        num_noise_tokens = int(round(tokens_length * noise_density))
        num_nonnoise_tokens = tokens_length - num_noise_tokens
        num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))

        # @xwhan leave aside EOS token at this point
        _input_length = num_nonnoise_tokens + num_noise_spans
        _output_length = num_noise_tokens + num_noise_spans
        return _input_length, _output_length

    tokens_length = inputs_length

    while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
        tokens_length += 1

    inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)

    # minor hack to get the targets length to be equal to inputs length
    # which is more likely to have been set to a nice round number.
    if noise_density == 0.5 and targets_length > inputs_length:
        tokens_length -= 1
        targets_length -= 1
    return tokens_length, targets_length


# these parameters are needed for the task to tell how many sentinel tokens are needed
tokens_per_sample, _ = compute_input_and_target_lengths(model_args.max_source_positions - 2, model_args.noise_density, model_args.mean_noise_span_length)
model_args.tokens_per_sample = tokens_per_sample + 2
print(model_args.tokens_per_sample)



17156


In [12]:
dictionary.index('<mask>')

50264

In [5]:
task = LongDenoisingTask(model_args, dictionary)

long_cfg = convert_namespace_to_omegaconf(model_args)
long_model = task.build_model(long_cfg.model)

##### encoder staff #####

# 1. embed_tokens and layernorm_embedding
vocab_size, _ = bart.encoder.embed_tokens.weight.shape
new_vocab_size, embed_dim = long_model.encoder.embed_tokens.weight.shape
print('old embedding matrix size', vocab_size)
print('new embedding matrix size', new_vocab_size)
# how should we initialize these sentinel embeddings
new_embed_tokens = bart.encoder.embed_tokens.weight.new_empty(new_vocab_size, embed_dim)
new_embed_tokens[:vocab_size] = bart.encoder.embed_tokens.weight
for idx in range(vocab_size, new_vocab_size):
    new_embed_tokens[idx] = bart.encoder.embed_tokens.weight[-1] # initialize with <mask>
long_model.encoder.embed_tokens.weight.data = new_embed_tokens
long_model.encoder.layernorm_embedding.load_state_dict(bart.encoder.layernorm_embedding.state_dict())

# 2. attention layers
long_model.encoder.layers.load_state_dict(bart.encoder.layers.state_dict(), strict=False)

# 3. embed_positions, longer
pos_limit, _ = bart.encoder.embed_positions.weight.shape
new_pos_limit, embed_dim = long_model.encoder.embed_positions.weight.shape
new_pos_embed = bart.encoder.embed_positions.weight.new_empty(new_pos_limit, embed_dim)
step = pos_limit - 2
for start in range(2, new_pos_limit, step):
    new_pos_embed[start:start+step] = bart.encoder.embed_positions.weight[2:]
long_model.encoder.embed_positions.weight.data = new_pos_embed

##### decoder staff #####
long_model.decoder.layernorm_embedding.load_state_dict(bart.decoder.layernorm_embedding.state_dict())

# 2. embed_positions, longer
long_model.decoder.embed_positions.load_state_dict(bart.decoder.embed_positions.state_dict())

# 3. attention layers
long_model.decoder.layers.load_state_dict(bart.decoder.layers.state_dict(), strict=True)

# 4. output_projection
# no need to copy as they are tied with encoder's embeds

save_path = '/data/home/xwhan/fairseq-py/checkpoints/bart.long.pretrain.block'
dictionary.save(os.path.join(save_path, 'dict.txt'))
state['args'] = model_args
state['model'] = long_model.state_dict()
torch.save(state, os.path.join(save_path, 'model.pt'))

old embedding matrix size 50265
50351
