### Init

In [1]:
import torch
from typing import Any, Dict, Iterable, List, Optional, Tuple

In [2]:
SUMMARY_MIN_LEN = 5

In [3]:
from transformers import BartForConditionalGeneration, BartTokenizer

model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

In [4]:
model.config.is_encoder_decoder

True

In [5]:
input_text = ['I am Marco and I am experimenting.', 'Yes it is easy!']
input_ids = tokenizer.prepare_seq2seq_batch(input_text, return_tensors='pt')['input_ids']
print(input_ids)
print(tokenizer.convert_ids_to_tokens(input_ids[0]))
print(tokenizer.convert_ids_to_tokens(input_ids[1]))

tensor([[    0,   100,   524, 10425,     8,    38,   524, 26038,     4,     2],
        [    0,  9904,    24,    16,  1365,   328,     2,     1,     1,     1]])
['<s>', 'I', 'Ġam', 'ĠMarco', 'Ġand', 'ĠI', 'Ġam', 'Ġexperimenting', '.', '</s>']
['<s>', 'Yes', 'Ġit', 'Ġis', 'Ġeasy', '!', '</s>', '<pad>', '<pad>', '<pad>']


In [6]:
input_text = 'I am Marco and I am experimenting.'
input_ids = tokenizer.encode(input_text, return_tensors='pt')
print(input_ids)
print(tokenizer.convert_ids_to_tokens(input_ids[0]))

tensor([[    0,   100,   524, 10425,     8,    38,   524, 26038,     4,     2]])
['<s>', 'I', 'Ġam', 'ĠMarco', 'Ġand', 'ĠI', 'Ġam', 'Ġexperimenting', '.', '</s>']


### prepare attention mask for encoder

In [7]:
def _prepare_attention_mask_for_generation(
    input_ids: torch.Tensor, pad_token_id: int, eos_token_id: int
) -> torch.LongTensor:
    is_pad_token_in_inputs_ids = (pad_token_id is not None) and (pad_token_id in input_ids)
    is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
        (eos_token_id is not None) and (pad_token_id != eos_token_id)
    )
    if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
        return input_ids.ne(pad_token_id).long()
    return input_ids.new_ones(input_ids.shape)

In [8]:
attention_mask = _prepare_attention_mask_for_generation(
    input_ids,
    model.config.pad_token_id,
    model.config.eos_token_id)

In [9]:
print(input_ids)
print(attention_mask)

tensor([[    0,   100,   524, 10425,     8,    38,   524, 26038,     4,     2]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])


### add encoder_outputs to model_kwargs

In [10]:
def _prepare_encoder_decoder_kwargs_for_generation(
    input_ids: torch.LongTensor, model_kwargs
):
    # retrieve encoder hidden states
    encoder = model.get_encoder()
    encoder_kwargs = {
        argument: value for argument, value in model_kwargs.items() if not argument.startswith("decoder_")
    }
    encoder_outputs = encoder(input_ids, return_dict=True, **encoder_kwargs)
    return encoder_outputs


In [11]:
encoder_outputs = _prepare_encoder_decoder_kwargs_for_generation(
    input_ids,
    {})

In [12]:
print('vocab size:', model.config.vocab_size)
print()
model.get_encoder()

vocab size: 50265



BartEncoder(
  (embed_tokens): Embedding(50265, 768, padding_idx=1)
  (embed_positions): BartLearnedPositionalEmbedding(1026, 768, padding_idx=1)
  (layers): ModuleList(
    (0): BartEncoderLayer(
      (self_attn): BartAttention(
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=768, out_features=3072, bias=True)
      (fc2): Linear(in_features=3072, out_features=768, bias=True)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (1): BartEncoderLayer(
      (self_attn): BartAttention(
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=76

In [13]:
print(encoder_outputs.keys())
print()
print('input_ids shape:',input_ids.size())
print('model dimension:', model.config.d_model)
print('encoder output shape:',encoder_outputs['last_hidden_state'].size())

odict_keys(['last_hidden_state'])

input_ids shape: torch.Size([1, 10])
model dimension: 768
encoder output shape: torch.Size([1, 10, 768])


### get decoder_input_ids

In [14]:
def _get_decoder_start_token_id(decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
    decoder_start_token_id = (
        decoder_start_token_id if decoder_start_token_id is not None else model.config.decoder_start_token_id
    )
    bos_token_id = bos_token_id if bos_token_id is not None else model.config.bos_token_id

    if decoder_start_token_id is not None:
        return decoder_start_token_id
    elif (
        hasattr(model.config, "decoder")
        and hasattr(model.config.decoder, "decoder_start_token_id")
        and model.config.decoder.decoder_start_token_id is not None
    ):
        return model.config.decoder.decoder_start_token_id
    elif bos_token_id is not None:
        return bos_token_id
    elif (
        hasattr(model.config, "decoder")
        and hasattr(model.config.decoder, "bos_token_id")
        and model.config.decoder.bos_token_id is not None
    ):
        return model.config.decoder.bos_token_id
    raise ValueError(
        "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
    )

In [15]:
def _prepare_decoder_input_ids_for_generation(
    input_ids: torch.LongTensor, decoder_start_token_id: int = None, bos_token_id: int = None
) -> torch.LongTensor:

    decoder_start_token_id = _get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
    decoder_input_ids = (
        torch.ones((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device)
        * decoder_start_token_id
    )
    return decoder_input_ids

In [16]:
decoder_input_ids = _prepare_decoder_input_ids_for_generation(
    input_ids,
    None,
    model.config.bos_token_id)

In [17]:
input_ids

tensor([[    0,   100,   524, 10425,     8,    38,   524, 26038,     4,     2]])

In [18]:
print(decoder_input_ids)
print(tokenizer.convert_ids_to_tokens(decoder_input_ids))
decoder_input_ids.size()

tensor([[2]])
['</s>']


torch.Size([1, 1])