In [None]:
# first play with examples from the pytorch-transformers repo, try to 
#  prototype a "step" decoding method

In [1]:
import transformers
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig

from transformers import modeling_utils
from torch.nn import functional as F

In [2]:
# summarization

# see ``examples/summarization/bart/evaluate_cnn.py`` for a longer example
model = BartForConditionalGeneration.from_pretrained('bart-large-cnn')
tokenizer = BartTokenizer.from_pretrained('bart-large-cnn')

#ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
#inputs = tokenizer.batch_encode_plus([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
# Generate Summary
#summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
#print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])


In [3]:
model.config

BartConfig {
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_final_layer_norm": false,
  "architectures": null,
  "attention_dropout": 0.0,
  "bad_words_ids": null,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "do_sample": false,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "finetuning_task": null,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_decoder": false,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "length_penalty": 2.0,
  "max_length": 142,
  "max_position_embeddings": 1024,
  "min_length": 56,
  "model_type": "bart",
  "no_r

In [4]:
def summarize(text, max_length=50):
    inputs = tokenizer.batch_encode_plus([text], max_length=1024, return_tensors='pt')
    # Generate Summary
#     summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=max_length, early_stopping=True)
    summary_ids = model.generate_2(inputs['input_ids'], num_beams=4, max_length=max_length, early_stopping=True)
    return [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]

In [5]:
from collections import OrderedDict
import torch


def initialize_generation(
    model,
    input_ids=None,
    max_length=None,
    min_length=None,
    do_sample=None,
    early_stopping=None,
    num_beams=None,
    temperature=None,
    top_k=None,
    top_p=None,
    repetition_penalty=None,
    bad_words_ids=None,
    bos_token_id=None,
    pad_token_id=None,
    eos_token_id=None,
    length_penalty=None,
    no_repeat_ngram_size=None,
    num_return_sequences=None,
    attention_mask=None,
    decoder_start_token_id=None,
):
    # We cannot generate if the model does not have a LM head
    if model.get_output_embeddings() is None:
        raise AttributeError(
            "You tried to generate sequences with a model that does not have a LM Head."
            "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
        )

    max_length = max_length if max_length is not None else model.config.max_length
    min_length = min_length if min_length is not None else model.config.min_length
    do_sample = do_sample if do_sample is not None else model.config.do_sample
    early_stopping = early_stopping if early_stopping is not None else model.config.early_stopping
    num_beams = num_beams if num_beams is not None else model.config.num_beams
    temperature = temperature if temperature is not None else model.config.temperature
    top_k = top_k if top_k is not None else model.config.top_k
    top_p = top_p if top_p is not None else model.config.top_p
    repetition_penalty = repetition_penalty if repetition_penalty is not None else model.config.repetition_penalty
    bos_token_id = bos_token_id if bos_token_id is not None else model.config.bos_token_id
    pad_token_id = pad_token_id if pad_token_id is not None else model.config.pad_token_id
    eos_token_id = eos_token_id if eos_token_id is not None else model.config.eos_token_id
    length_penalty = length_penalty if length_penalty is not None else model.config.length_penalty
    no_repeat_ngram_size = (
        no_repeat_ngram_size if no_repeat_ngram_size is not None else model.config.no_repeat_ngram_size
    )
    bad_words_ids = bad_words_ids if bad_words_ids is not None else model.config.bad_words_ids
    num_return_sequences = (
        num_return_sequences if num_return_sequences is not None else model.config.num_return_sequences
    )
    decoder_start_token_id = (
        decoder_start_token_id if decoder_start_token_id is not None else model.config.decoder_start_token_id
    )

    if input_ids is not None:
        batch_size = input_ids.shape[0]  # overriden by the input batch_size
    else:
        batch_size = 1

    assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
    assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
    assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
    assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
    assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
    assert temperature > 0, "`temperature` should be strictly positive."
    assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
    assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
    assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
    assert input_ids is not None or (
        isinstance(bos_token_id, int) and bos_token_id >= 0
    ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
    assert pad_token_id is None or (
        isinstance(pad_token_id, int) and (pad_token_id >= 0)
    ), "`pad_token_id` should be a positive integer."
    assert (eos_token_id is None) or (
        isinstance(eos_token_id, int) and (eos_token_id >= 0)
    ), "`eos_token_id` should be a positive integer."
    assert length_penalty > 0, "`length_penalty` should be strictly positive."
    assert (
        isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
    ), "`no_repeat_ngram_size` should be a positive integer."
    assert (
        isinstance(num_return_sequences, int) and num_return_sequences > 0
    ), "`num_return_sequences` should be a strictly positive integer."
    assert (
        bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
    ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"

    if input_ids is None:
        assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
            "you should either supply a context to complete as `input_ids` input "
            "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
        )
        input_ids = torch.full(
            (batch_size, 1), bos_token_id, dtype=torch.long, device=next(model.parameters()).device,
        )
    else:
        assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."

    # not allow to duplicate outputs when greedy decoding
    if do_sample is False:
        if num_beams == 1:
            # no_beam_search greedy generation conditions
            assert (
                num_return_sequences == 1
            ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"

        else:
            # beam_search greedy generation conditions
            assert (
                num_beams >= num_return_sequences
            ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"

    # create attention mask if necessary
    # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
    if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
        attention_mask = input_ids.ne(pad_token_id).long()
    elif attention_mask is None:
        attention_mask = input_ids.new_ones(input_ids.shape)

    # set pad_token_id to eos_token_id if not set. Important that this is done after
    # attention_mask is created
    if pad_token_id is None and eos_token_id is not None:
        logger.warning(
            "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
        )
        pad_token_id = eos_token_id

    # current position and vocab size
    vocab_size = model.config.vocab_size

    # set effective batch size and effective batch multiplier according to do_sample
    if do_sample:
        effective_batch_size = batch_size * num_return_sequences
        effective_batch_mult = num_return_sequences
    else:
        effective_batch_size = batch_size
        effective_batch_mult = 1

    if model.config.is_encoder_decoder:
        if decoder_start_token_id is None:
            decoder_start_token_id = bos_token_id

        assert (
            decoder_start_token_id is not None
        ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
        assert hasattr(model, "get_encoder"), "{} should have a 'get_encoder' function defined".format(model)
        assert callable(model.get_encoder), "{} should be a method".format(model.get_encoder)

        # get encoder and store encoder outputs
        encoder = model.get_encoder()

        encoder_outputs = encoder(input_ids, attention_mask=attention_mask)

    # Expand input ids if num_beams > 1 or num_return_sequences > 1
    if num_return_sequences > 1 or num_beams > 1:
        input_ids_len = input_ids.shape[-1]
        input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
        attention_mask = attention_mask.unsqueeze(1).expand(
            batch_size, effective_batch_mult * num_beams, input_ids_len
        )

        input_ids = input_ids.contiguous().view(
            effective_batch_size * num_beams, input_ids_len
        )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
        attention_mask = attention_mask.contiguous().view(
            effective_batch_size * num_beams, input_ids_len
        )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)

    if model.config.is_encoder_decoder:
        # create empty decoder_input_ids
        input_ids = torch.full(
            (effective_batch_size * num_beams, 1),
            decoder_start_token_id,
            dtype=torch.long,
            device=next(model.parameters()).device,
        )
        cur_len = 1

        assert (
            batch_size == encoder_outputs[0].shape[0]
        ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "

        # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
        expanded_batch_idxs = (
            torch.arange(batch_size)
            .view(-1, 1)
            .repeat(1, num_beams * effective_batch_mult)
            .view(-1)
            .to(input_ids.device)
        )
        # expand encoder_outputs
        encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])

    else:
        encoder_outputs = None
        cur_len = input_ids.shape[-1]

    # Chris: return the outputs needed for model._generate_beam_search
    return OrderedDict([
        ('model', model),
        ('input_ids', input_ids),
        ('cur_len', cur_len),
        ('max_length', max_length),
        ('min_length', min_length),
        ('do_sample', do_sample),
        ('early_stopping', early_stopping),
        ('temperature', temperature),
        ('top_k', top_k),
        ('top_p', top_p),
        ('repetition_penalty', repetition_penalty),
        ('no_repeat_ngram_size', no_repeat_ngram_size),
        ('bad_words_ids', bad_words_ids),
        ('bos_token_id', bos_token_id),
        ('pad_token_id', pad_token_id),
        ('decoder_start_token_id', decoder_start_token_id),
        ('eos_token_id', eos_token_id),
        ('batch_size', effective_batch_size),
        ('num_return_sequences', num_return_sequences),
        ('length_penalty', length_penalty),
        ('num_beams', num_beams),
        ('vocab_size', vocab_size),
        ('encoder_outputs', encoder_outputs),
        ('attention_mask', attention_mask)
    ])
    
    # Chris: return the outputs needed for model._generate_no_beam_search
#     return OrderedDict([
#         ('model', model),
#         ('input_ids', input_ids),
#         ('cur_len', cur_len),
#         ('max_length', max_length),
#         ('min_length', min_length),
#         ('do_sample', do_sample),
#         ('temperature', temperature),
#         ('top_k', top_k),
#         ('top_p', top_p),
#         ('repetition_penalty', repetition_penalty),
#         ('no_repeat_ngram_size', no_repeat_ngram_size),
#         ('bad_words_ids', bad_words_ids),
#         ('bos_token_id', bos_token_id),
#         ('pad_token_id', pad_token_id),
#         ('decoder_start_token_id', decoder_start_token_id),
#         ('eos_token_id', eos_token_id),
#         ('batch_size', effective_batch_size),
#         ('encoder_outputs', encoder_outputs),
#         ('attention_mask', attention_mask)
#     ])


In [6]:
# TODO: remember critical assumption that all models use the same output space, we need to use this during 
#  ensembling
# TODO: does model instance hold any state while decoding? I.e. is model's self.* holding any state while we are 
#  inside the decoding loop?
# TODO: remove decoding length arg 

# WORKING: wrap tokenizer and model together so that we can pass through (text, tokenizer, model)
def get_initial_decoding_state(text, model, tokenizer, decoding_hyperparams):
    """
    Get the state needed to start decoding from an instance
    """
    # convert text to tensor
    inputs = tokenizer.batch_encode_plus(
        [text],
        max_length=decoding_hyperparams['max_length'],
        return_tensors='pt'
    )
    input_ids = inputs['input_ids']
    
    return initialize_generation(
        model, input_ids,
        **decoding_hyperparams
    )


In [7]:

# this is def step() for model._generate_no_beam_search
def greedy_step(state):
    model_inputs = state['model'].prepare_inputs_for_generation(
        state['input_ids'],
        past=state['past'],
        attention_mask=state['attention_mask']
    )

    outputs = state['model'](**model_inputs)
    next_token_logits = outputs[0][:, -1, :]

    # if model has past, then set the past variable to speed up decoding
    if state['model']._do_output_past(outputs):
        state['past'] = outputs[1]
        
    # now update next_token_logits using various heuristics
    
    # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
    if state['repetition_penalty'] != 1.0:
        # Chris: note in-place modification side-effect
        state['model'].enforce_repetition_penalty_(
            next_token_logits,
            state['batch_size'], 1, state['input_ids'], state['repetition_penalty'])
    
    # TODO: WORKING: cur_len will need to be updated
    if state['no_repeat_ngram_size'] > 0:
        # calculate a list of banned tokens to prevent repetitively generating the same ngrams
        # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
        
        banned_tokens = modeling_utils.calc_banned_ngram_tokens(
            state['input_ids'],
            state['batch_size'],
            state['no_repeat_ngram_size'],
            state['cur_len'])
        for batch_idx in range(state['batch_size']):
            next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")

    if state['bad_words_ids'] is not None:
        # calculate a list of banned tokens according to bad words
        banned_tokens = modeling_utils.calc_banned_bad_words_ids(state['input_ids'], state['bad_words_ids'])

        for batch_idx in range(state['batch_size']):
            next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")

    # Chris: WORKING: note any next token logic must live outside of the step function
    # Chris: put this into codebase first before proceeding with TDD 
    
    # set eos token prob to zero if min_length is not reached
    if state['eos_token_id'] is not None and state['cur_len'] < state['min_length']:
        next_token_logits[:, state['eos_token_id']] = -float("inf")

    if state['do_sample']:
        # Temperature (higher temperature => more likely to sample low probability tokens)
        if state['temperature'] != 1.0:
            next_token_logits = next_token_logits / state['temperature']
        # Top-p/top-k filtering
        next_token_logits = \
            modeling_utils.top_k_top_p_filtering(
                next_token_logits,
                top_k=state['top_k'],
                top_p=state['top_p']
            )
        # Sample
        probs = F.softmax(next_token_logits, dim=-1)
        # Chris: TODO: note for ensembling all next token logic 
        #  needs to move outside of this function
        next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
    else:
        # Greedy decoding
        # Chris: TODO: note for ensembling all next token logic needs to move outside of this function
        next_token = torch.argmax(next_token_logits, dim=-1)

    # Chris: TODO: update unfinished_sents in state 
    # update generations and finished sentences
    if state['eos_token_id'] is not None:
        # pad finished sentences if eos_token_id exist
        tokens_to_add = next_token * state['unfinished_sents'] + (state['pad_token_id']) * (1 - state['unfinished_sents'])
    else:
        tokens_to_add = next_token

    # Chris: concat whatever was generated to input ids
    # Chris: TODO: this must happen outside of individual model's step functions
    state['input_ids'] = torch.cat([state['input_ids'], tokens_to_add.unsqueeze(-1)], dim=-1)

    if state['eos_token_id'] is not None:
        eos_in_sents = tokens_to_add == state['eos_token_id']
        # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
        is_sents_unfinished_and_token_to_add_is_eos = state['unfinished_sents'].mul(eos_in_sents.long()).bool()
        state['sent_lengths'].masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, state['cur_len'] + 1)
        # unfinished_sents is set to zero if eos in sentence
        state['unfinished_sents'].mul_((~eos_in_sents).long())

    # stop when there is a </s> in each sentence, or if we exceed the maximal length
    if state['unfinished_sents'].max() == 0:
        return state

    # extend attention_mask for new generated input if only decoder
    if state['model'].config.is_encoder_decoder is False:
        state['attention_mask'] = torch.cat(
            [state['attention_mask'],
             state['attention_mask'].new_ones((state['attention_mask'].shape[0], 1))],
            dim=-1
        )

    state['cur_len'] = state['cur_len'] + 1
    
    return state

In [8]:
def beam_search_step(state):
    model_inputs = state['model'].prepare_inputs_for_generation(
        state['input_ids'],
        past=state['past'],
        attention_mask=state['attention_mask'])
    outputs = state['model'](**model_inputs)  # (batch_size * num_beams, cur_len, vocab_size)
    next_token_logits = outputs[0][:, -1, :]  # (batch_size * num_beams, vocab_size)

    # if model has past, then set the past variable to speed up decoding
    if state['model']._do_output_past(outputs):
        state['past'] = outputs[1]

    # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
    if state['repetition_penalty'] != 1.0:
        state['model'].enforce_repetition_penalty_(
            next_token_logits,
            state['batch_size'],
            state['num_beams'],
            state['input_ids'],
            state['repetition_penalty']
        )

    if state['temperature'] != 1.0:
        next_token_logits = next_token_logits / state['temperature']

    scores = F.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)
    if state['model'].config.is_encoder_decoder and state['do_sample'] is False:
        # TODO (PVP) still a bit hacky here - there might be a better solutino
        scores = state['model'].prepare_scores_for_generation(
            scores,
            cur_len=state['cur_len'],
            max_length=state['max_length'])

    # set eos token prob to zero if min_length is not reached
    if state['eos_token_id'] is not None and state['cur_len'] < state['min_length']:
        scores[:, state['eos_token_id']] = -float("inf")

    if state['no_repeat_ngram_size'] > 0:
        # calculate a list of banned tokens to prevent repetitively generating the same ngrams
        num_batch_hypotheses = state['batch_size'] * state['num_beams']
        # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
        banned_batch_tokens = modeling_utils.calc_banned_ngram_tokens(
            state['input_ids'],
            num_batch_hypotheses,
            state['no_repeat_ngram_size'],
            state['cur_len']
        )
        for i, banned_tokens in enumerate(banned_batch_tokens):
            scores[i, banned_tokens] = -float("inf")

    if state['bad_words_ids'] is not None:
        # calculate a list of banned tokens according to bad words
        banned_tokens = modeling_utils.calc_banned_bad_words_ids(
            state['input_ids'],
            state['bad_words_ids']
        )

        for i, banned_tokens in enumerate(banned_tokens):
            scores[i, banned_tokens] = -float("inf")

    assert scores.shape == (state['batch_size'] * state['num_beams'], state['vocab_size']), "Shapes of scores: {} != {}".format(
        scores.shape, (state['batch_size'] * state['num_beams'], state['vocab_size'])
    )
    
    # Chris: ok, now we have the scores from this (model, text) pair, let's return them and ensemble before 
    #  continuing. 
    # Chris: let's create a wrapper that holds pairs of model, text
    # Chris: let's create a new type of hypothesis which stores additional metadata in the beam
    # Chris: same structure as beam, but stores arbitrary meta-data in each cell -- what is the "timestamp metatdata?"

    if state['do_sample']:
        _scores = scores + state['beam_scores'][:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
        # Top-p/top-k filtering
        # Chris: note hard-coded `min_tokens_to_keep`
        _scores = modeling_utils.top_k_top_p_filtering(
            _scores, top_k=state['top_k'], top_p=state['top_p'], min_tokens_to_keep=2
        )  # (batch_size * num_beams, vocab_size)
        # re-organize to group the beam together to sample from all beam_idxs
        _scores = _scores.contiguous().view(
            state['batch_size'], state['num_beams'] * state['vocab_size']
        )  # (batch_size, num_beams * vocab_size)

        # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
        probs = F.softmax(_scores, dim=-1)
        next_tokens = torch.multinomial(probs, num_samples=2 * state['num_beams'])  # (batch_size, num_beams * 2)
        # Compute next scores
        next_scores = torch.gather(_scores, -1, next_tokens)  # (batch_size, num_beams * 2)
        # sort the sampled vector to make sure that the first num_beams samples are the best
        next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
        next_tokens = torch.gather(next_tokens, -1, next_scores_indices)  # (batch_size, num_beams * 2)

    else:
        next_scores = scores + state['beam_scores'][:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)

        # re-organize to group the beam together (we are keeping top hypotheses across beams)
        next_scores = next_scores.view(
            state['batch_size'], state['num_beams'] * state['vocab_size']
        )  # (batch_size, num_beams * vocab_size)

        next_scores, next_tokens = \
            torch.topk(
                next_scores,
                2 * state['num_beams'],
                dim=1,
                largest=True,
                sorted=True
            )

    assert next_scores.size() == next_tokens.size() == (state['batch_size'], 2 * state['num_beams'])

    # next batch beam content
    next_batch_beam = []

    # for each sentence
    for batch_idx in range(state['batch_size']):

        # if we are done with this sentence
        if state['done'][batch_idx]:
            assert (
                len(state['generated_hyps'][batch_idx]) >= state['num_beams']
            ), "Batch can only be done if at least {} beams have been generated".format(state['num_beams'])
            assert (
                state['eos_token_id'] is not None and state['pad_token_id'] is not None
            ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
            next_batch_beam.extend([(0, state['pad_token_id'], 0)] * state['num_beams'])  # pad the batch
            continue

        # next sentence beam content
        next_sent_beam = []

        # next tokens for this sentence from each beam
        for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
            zip(next_tokens[batch_idx], next_scores[batch_idx])
        ):
            # get beam and token IDs
            beam_id = beam_token_id // state['vocab_size']
            token_id = beam_token_id % state['vocab_size']

            effective_beam_id = batch_idx * state['num_beams'] + beam_id
            # add to generated hypotheses if end of sentence or last iteration
            if (state['eos_token_id'] is not None) and (token_id.item() == state['eos_token_id']):
                # if beam_token does not belong to top num_beams tokens, it should not be added
                is_beam_token_worse_than_top_num_beams = beam_token_rank >= state['num_beams']
                if is_beam_token_worse_than_top_num_beams:
                    continue
                # update beam hypotheses obj with finished hypothesis and score
                state['generated_hyps'][batch_idx].add(
                    state['input_ids'][effective_beam_id].clone(), beam_token_score.item(),
                )
            else:
                # add next predicted token if it is not eos_token
                next_sent_beam.append((beam_token_score, token_id, effective_beam_id))

            # the beam for next step is now full
            if len(next_sent_beam) == state['num_beams']:
                break

        # Check if we're done so that we can save a pad step if all(done)
        state['done'][batch_idx] = state['done'][batch_idx] or state['generated_hyps'][batch_idx].is_done(
            next_scores[batch_idx].max().item(), cur_len=state['cur_len']
        )

        # update next beam content
        assert len(next_sent_beam) == state['num_beams'], "Beam should always be full after loop above"
        next_batch_beam.extend(next_sent_beam)
        assert len(next_batch_beam) == state['num_beams'] * (batch_idx + 1)

    # stop if are done with every sentence
    if all(state['done']):
        return state

    # sanity check / prepare next timestep
    assert len(next_batch_beam) == state['batch_size'] * state['num_beams']
    state['beam_scores'] = state['beam_scores'].new([x[0] for x in next_batch_beam])
    
    # re-order batch
    beam_tokens = state['input_ids'].new([x[1] for x in next_batch_beam])
    beam_idx = state['input_ids'].new([x[2] for x in next_batch_beam])

    state['input_ids'] = state['input_ids'][beam_idx, :]
    state['input_ids'] = torch.cat([state['input_ids'], beam_tokens.unsqueeze(1)], dim=-1)
    # re-order internal states
    if state['past'] is not None:
        state['past'] = state['model']._reorder_cache(state['past'], beam_idx)

    # extend attention_mask for new generated input if only decoder
    if state['model'].config.is_encoder_decoder is False:
        state['attention_mask'] = torch.cat(
            [
                state['attention_mask'],
                state['attention_mask'].new_ones((state['attention_mask'].shape[0], 1))
            ],
            dim=-1
        )

    # update current length
    state['cur_len'] = state['cur_len'] + 1
    return state


In [12]:
# generate yes beam search
# Note for BART summarization in transformers repo, beam search performs much better
#  than no beam search, but even their beam search with num_beams=1 is better, implying that something
#  is broken in the _generate_no_beam_search function

decoding_hyperparams = {
    'max_length': 75,
    'num_beams': 2
}

test_news_article = 'New Zealand says it has stopped community transmission of Covid-19, '\
 'effectively eliminating the virus. With new cases in single figures for several days - one on Sunday ' \
 '- Prime Minister Jacinda Ardern said the virus was "currently" eliminated. But officials have warned ' \
 'against complacency, saying it does not mean a total end to new coronavirus cases. ' \
 'The news comes hours before New Zealand is set to move out of its toughest level of social restrictions. ' \
 'From Tuesday, some non-essential business, healthcare and education activity will be able to resume. ' \
 'Most people will still be required to remain at home at all times and avoid all social interactions.'


# set up state 
decoder_state = get_initial_decoding_state(
    text=test_news_article,
    model=model,
    tokenizer=tokenizer,
    decoding_hyperparams=decoding_hyperparams
)

# generated hypotheses
decoder_state['generated_hyps'] = [
    modeling_utils.BeamHypotheses(
        decoder_state['num_beams'],
        decoder_state['max_length'],
        decoder_state['length_penalty'],
        early_stopping=decoder_state['early_stopping'])
    for _ in range(decoder_state['batch_size'])
]

# scores for each sentence in the beam
decoder_state['beam_scores'] = \
    torch.zeros((decoder_state['batch_size'], decoder_state['num_beams']),
                dtype=torch.float,
                device=decoder_state['input_ids'].device)

# for greedy decoding it is made sure that only tokens of the first beam are considered
#  to avoid sampling the exact same tokens three times
if decoder_state['do_sample'] is False:
    decoder_state['beam_scores'][:, 1:] = -1e9
decoder_state['beam_scores'] = decoder_state['beam_scores'].view(-1)  # shape (batch_size * num_beams,)

# cache compute states
decoder_state['past'] = decoder_state['encoder_outputs']  # defined for encoder-decoder models, None for decoder-only models

# done sentences
decoder_state['done'] = [False for _ in range(decoder_state['batch_size'])]

# then we wish to step through decoding
# TODO: note beam logic needs to be outside of step function 

# PLACEHOLDER -- run beam_search_step function
# ok now we are ready to start stepping
# step and decode with tokenizer at each step to visualize decoding progress
for step_idx in range(decoding_hyperparams['max_length']):
    print(f'STEP: {step_idx}')
    print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in decoder_state['input_ids']])
    decoder_state = beam_search_step(decoder_state)
    print()



STEP: 0
['', '']

STEP: 1
['', '']

STEP: 2
['New', '']

STEP: 3
['New Zealand', 'New']

STEP: 4
['New Zealand says', 'New Zealand has']

STEP: 5
['New Zealand says it', 'New Zealand has stopped']

STEP: 6
['New Zealand says it has', 'New Zealand has stopped community']

STEP: 7
['New Zealand says it has stopped', 'New Zealand has stopped community transmission']

STEP: 8
['New Zealand says it has stopped community', 'New Zealand has stopped community transmission of']

STEP: 9
['New Zealand says it has stopped community transmission', 'New Zealand has stopped community transmission of Cov']

STEP: 10
['New Zealand says it has stopped community transmission of', 'New Zealand has stopped community transmission of Covid']

STEP: 11
['New Zealand says it has stopped community transmission of Cov', 'New Zealand has stopped community transmission of Covid-']

STEP: 12
['New Zealand says it has stopped community transmission of Covid', 'New Zealand has stopped community transmission of Covid


STEP: 42
['New Zealand says it has stopped community transmission of Covid-19. Prime Minister Jacinda Ardern said the virus was "currently" eliminated. But officials have warned against complacency, saying it', 'New Zealand says it has stopped community transmission of Covid-19. Prime Minister Jacinda Ardern said the virus was "currently" eliminated. Officials have warned against complacency, saying it does']

STEP: 43
['New Zealand says it has stopped community transmission of Covid-19. Prime Minister Jacinda Ardern said the virus was "currently" eliminated. But officials have warned against complacency, saying it does', 'New Zealand says it has stopped community transmission of Covid-19. Prime Minister Jacinda Ardern said the virus was "currently" eliminated. Officials have warned against complacency, saying it does not']

STEP: 44
['New Zealand says it has stopped community transmission of Covid-19. Prime Minister Jacinda Ardern said the virus was "currently" eliminated. But offici


STEP: 60
['New Zealand says it has stopped community transmission of Covid-19. Prime Minister Jacinda Ardern said the virus was "currently" eliminated. But officials have warned against complacency, saying it does not mean a total end to new coronavirus cases. The', 'New Zealand says it has stopped community transmission of Covid-19. Prime Minister Jacinda Ardern said the virus was "currently" eliminated. But officials have warned against complacency, saying it does not mean a total end to new coronavirus cases..']

STEP: 61
['New Zealand says it has stopped community transmission of Covid-19. Prime Minister Jacinda Ardern said the virus was "currently" eliminated. But officials have warned against complacency, saying it does not mean a total end to new coronavirus cases. The', 'New Zealand says it has stopped community transmission of Covid-19. Prime Minister Jacinda Ardern said the virus was "currently" eliminated. But officials have warned against complacency, saying it does not me

In [11]:
# TODO: The following logic runs AFTER the beam search while loop finishes
# finalize all open beam hypotheses and end to generated hypotheses
for batch_idx in range(batch_size):
    if done[batch_idx]:
        continue

    # test that beam scores match previously calculated scores if not eos and batch_idx not done
    if eos_token_id is not None and all(
        (token_id % vocab_size).item() is not eos_token_id for token_id in next_tokens[batch_idx]
    ):
        assert torch.all(
            next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
        ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
            next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
        )

    # need to add best num_beams hypotheses to generated hyps
    for beam_id in range(num_beams):
        effective_beam_id = batch_idx * num_beams + beam_id
        final_score = beam_scores[effective_beam_id].item()
        final_tokens = input_ids[effective_beam_id]
        generated_hyps[batch_idx].add(final_tokens, final_score)

# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences

# select the best hypotheses
sent_lengths = input_ids.new(output_batch_size)
best = []

# retrieve best hypotheses
for i, hypotheses in enumerate(generated_hyps):
    sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
    for j in range(output_num_return_sequences_per_batch):
        effective_batch_idx = output_num_return_sequences_per_batch * i + j
        best_hyp = sorted_hyps.pop()[1]
        sent_lengths[effective_batch_idx] = len(best_hyp)
        best.append(best_hyp)

# shorter batches are filled with pad_token
if sent_lengths.min().item() != sent_lengths.max().item():
    assert pad_token_id is not None, "`Pad_token_id` has to be defined"
    sent_max_len = min(sent_lengths.max().item() + 1, max_length)
    decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)

    # fill with hypothesis and eos_token_id if necessary
    for i, hypo in enumerate(best):
        decoded[i, : sent_lengths[i]] = hypo
        if sent_lengths[i] < max_length:
            decoded[i, sent_lengths[i]] = eos_token_id
else:
    # none of the hypotheses have an eos_token
    assert (len(hypo) == max_length for hypo in best)
    decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)

return decoded

NameError: name 'batch_size' is not defined

In [None]:
# generate no beam search

# then we wish to step through decoding
decoder_state = get_initial_decoding_state(
    text=test_news_article,
    model=model,
    tokenizer=tokenizer,
    decoding_hyperparams=decoding_hyperparams
)

# TODO: working here, beam search appears essential to get best performance, greedy is 
#  far worse than num_beams=4
# start decoding, we're going to go step by step, calling a function each time

# length of generated sentences / unfinished sentences
decoder_state['unfinished_sents'] = decoder_state['input_ids'].new(decoder_state['batch_size']).fill_(1)
decoder_state['sent_lengths'] = \
    decoder_state['input_ids'].new(
        decoder_state['batch_size']
    ).fill_(decoder_state['max_length'])

decoder_state['past'] = decoder_state['encoder_outputs']  # defined for encoder-decoder models, None for decoder-only models

# ok now we are ready to start stepping
# step and decode with tokenizer at each step to visualize decoding progress
for step_idx in range(decoding_hyperparams['max_length']):
    print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in decoder_state['input_ids']])
    decoder_state = step(decoder_state)

In [None]:

# once decoding is done, we need to do this:
# TODO: does logic below only matter at batch level?

# if there are different sentences lengths in the batch, some batch items have to be padded
# if sent_lengths.min().item() != sent_lengths.max().item():
#     assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
#     # finished sents are filled with pad_token
#     decoded = input_ids.new(batch_size, sent_lengths.max().item()).fill_(pad_token_id)
# else:
#     decoded = input_ids

# for hypo_idx, hypo in enumerate(input_ids):
#     decoded[hypo_idx, : sent_lengths[hypo_idx]] = hypo[: sent_lengths[hypo_idx]]

# return decoded


In [None]:
def summarize(text, max_length=50, num_beams=4):
    inputs = tokenizer.batch_encode_plus([text], max_length=1024, return_tensors='pt')
    # Generate Summary
    summary_ids = model.generate(inputs['input_ids'], num_beams=num_beams, max_length=max_length, early_stopping=True)
#     summary_ids = model.generate_2(inputs['input_ids'], num_beams=4, max_length=max_length, early_stopping=True)
    return [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]

In [None]:
summarize(test_news_article, num_beams=1)

In [None]:
# WORKING: what subset of these args are state, what subset is global-ish hyperparameter?
# WORKING: note lack of @torch.no_grad on _generate_beam_search -- could we possibly back-prop here?

In [None]:
def _generate_beam_search(
    self,
    input_ids,
    cur_len,
    max_length,
    min_length,
    do_sample,
    early_stopping,
    temperature,
    top_k,
    top_p,
    repetition_penalty,
    no_repeat_ngram_size,
    bad_words_ids,
    bos_token_id,
    pad_token_id,
    eos_token_id,
    decoder_start_token_id,
    batch_size,
    num_return_sequences,
    length_penalty,
    num_beams,
    vocab_size,
    encoder_outputs,
    attention_mask,
):
    """ Generate sequences for each example with beam search.
    """

    # generated hypotheses
    generated_hyps = [
        BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
        for _ in range(batch_size)
    ]

    # scores for each sentence in the beam
    beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)

    # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
    if do_sample is False:
        beam_scores[:, 1:] = -1e9
    beam_scores = beam_scores.view(-1)  # shape (batch_size * num_beams,)

    # cache compute states
    past = encoder_outputs  # defined for encoder-decoder models, None for decoder-only models

    # done sentences
    done = [False for _ in range(batch_size)]

    while cur_len < max_length:
        model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
        outputs = self(**model_inputs)  # (batch_size * num_beams, cur_len, vocab_size)
        next_token_logits = outputs[0][:, -1, :]  # (batch_size * num_beams, vocab_size)

        # if model has past, then set the past variable to speed up decoding
        if self._do_output_past(outputs):
            past = outputs[1]

        # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
        if repetition_penalty != 1.0:
            self.enforce_repetition_penalty_(
                next_token_logits, batch_size, num_beams, input_ids, repetition_penalty,
            )

        if temperature != 1.0:
            next_token_logits = next_token_logits / temperature

        scores = F.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)
        if self.config.is_encoder_decoder and do_sample is False:
            # TODO (PVP) still a bit hacky here - there might be a better solutino
            scores = self.prepare_scores_for_generation(scores, cur_len=cur_len, max_length=max_length)

        # set eos token prob to zero if min_length is not reached
        if eos_token_id is not None and cur_len < min_length:
            scores[:, eos_token_id] = -float("inf")

        if no_repeat_ngram_size > 0:
            # calculate a list of banned tokens to prevent repetitively generating the same ngrams
            num_batch_hypotheses = batch_size * num_beams
            # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
            banned_batch_tokens = calc_banned_ngram_tokens(
                input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
            )
            for i, banned_tokens in enumerate(banned_batch_tokens):
                scores[i, banned_tokens] = -float("inf")

        if bad_words_ids is not None:
            # calculate a list of banned tokens according to bad words
            banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)

            for i, banned_tokens in enumerate(banned_tokens):
                scores[i, banned_tokens] = -float("inf")

        assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
            scores.shape, (batch_size * num_beams, vocab_size)
        )

        if do_sample:
            _scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)
            # Top-p/top-k filtering
            _scores = top_k_top_p_filtering(
                _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
            )  # (batch_size * num_beams, vocab_size)
            # re-organize to group the beam together to sample from all beam_idxs
            _scores = _scores.contiguous().view(
                batch_size, num_beams * vocab_size
            )  # (batch_size, num_beams * vocab_size)

            # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
            probs = F.softmax(_scores, dim=-1)
            next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)  # (batch_size, num_beams * 2)
            # Compute next scores
            next_scores = torch.gather(_scores, -1, next_tokens)  # (batch_size, num_beams * 2)
            # sort the sampled vector to make sure that the first num_beams samples are the best
            next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
            next_tokens = torch.gather(next_tokens, -1, next_scores_indices)  # (batch_size, num_beams * 2)

        else:
            next_scores = scores + beam_scores[:, None].expand_as(scores)  # (batch_size * num_beams, vocab_size)

            # re-organize to group the beam together (we are keeping top hypothesis accross beams)
            next_scores = next_scores.view(
                batch_size, num_beams * vocab_size
            )  # (batch_size, num_beams * vocab_size)

            next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)

        assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)

        # next batch beam content
        next_batch_beam = []

        # for each sentence
        for batch_idx in range(batch_size):

            # if we are done with this sentence
            if done[batch_idx]:
                assert (
                    len(generated_hyps[batch_idx]) >= num_beams
                ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
                assert (
                    eos_token_id is not None and pad_token_id is not None
                ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
                next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams)  # pad the batch
                continue

            # next sentence beam content
            next_sent_beam = []

            # next tokens for this sentence
            for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
                zip(next_tokens[batch_idx], next_scores[batch_idx])
            ):
                # get beam and token IDs
                beam_id = beam_token_id // vocab_size
                token_id = beam_token_id % vocab_size

                effective_beam_id = batch_idx * num_beams + beam_id
                # add to generated hypotheses if end of sentence or last iteration
                if (eos_token_id is not None) and (token_id.item() == eos_token_id):
                    # if beam_token does not belong to top num_beams tokens, it should not be added
                    is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
                    if is_beam_token_worse_than_top_num_beams:
                        continue
                    generated_hyps[batch_idx].add(
                        input_ids[effective_beam_id].clone(), beam_token_score.item(),
                    )
                else:
                    # add next predicted token if it is not eos_token
                    next_sent_beam.append((beam_token_score, token_id, effective_beam_id))

                # the beam for next step is full
                if len(next_sent_beam) == num_beams:
                    break

            # Check if were done so that we can save a pad step if all(done)
            done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
                next_scores[batch_idx].max().item(), cur_len=cur_len
            )

            # update next beam content
            assert len(next_sent_beam) == num_beams, "Beam should always be full"
            next_batch_beam.extend(next_sent_beam)
            assert len(next_batch_beam) == num_beams * (batch_idx + 1)

        # stop when we are done with each sentence
        if all(done):
            break

        # sanity check / prepare next batch
        assert len(next_batch_beam) == batch_size * num_beams
        beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
        beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
        beam_idx = input_ids.new([x[2] for x in next_batch_beam])

        # re-order batch
        input_ids = input_ids[beam_idx, :]
        input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
        # re-order internal states
        if past is not None:
            past = self._reorder_cache(past, beam_idx)

        # extend attention_mask for new generated input if only decoder
        if self.config.is_encoder_decoder is False:
            attention_mask = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
            )

        # update current length
        cur_len = cur_len + 1

    # finalize all open beam hypotheses and end to generated hypotheses
    for batch_idx in range(batch_size):
        if done[batch_idx]:
            continue

        # test that beam scores match previously calculated scores if not eos and batch_idx not done
        if eos_token_id is not None and all(
            (token_id % vocab_size).item() is not eos_token_id for token_id in next_tokens[batch_idx]
        ):
            assert torch.all(
                next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
            ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
                next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
            )

        # need to add best num_beams hypotheses to generated hyps
        for beam_id in range(num_beams):
            effective_beam_id = batch_idx * num_beams + beam_id
            final_score = beam_scores[effective_beam_id].item()
            final_tokens = input_ids[effective_beam_id]
            generated_hyps[batch_idx].add(final_tokens, final_score)

    # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
    output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
    output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences

    # select the best hypotheses
    sent_lengths = input_ids.new(output_batch_size)
    best = []

    # retrieve best hypotheses
    for i, hypotheses in enumerate(generated_hyps):
        sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
        for j in range(output_num_return_sequences_per_batch):
            effective_batch_idx = output_num_return_sequences_per_batch * i + j
            best_hyp = sorted_hyps.pop()[1]
            sent_lengths[effective_batch_idx] = len(best_hyp)
            best.append(best_hyp)

    # shorter batches are filled with pad_token
    if sent_lengths.min().item() != sent_lengths.max().item():
        assert pad_token_id is not None, "`Pad_token_id` has to be defined"
        sent_max_len = min(sent_lengths.max().item() + 1, max_length)
        decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)

        # fill with hypothesis and eos_token_id if necessary
        for i, hypo in enumerate(best):
            decoded[i, : sent_lengths[i]] = hypo
            if sent_lengths[i] < max_length:
                decoded[i, sent_lengths[i]] = eos_token_id
    else:
        # none of the hypotheses have an eos_token
        assert (len(hypo) == max_length for hypo in best)
        decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)

    return decoded


In [None]:
# prototype step method on model wrapper
# use setattr (i.e. treat as mixin(?))


@torch.no_grad()
def generate(
    self,
    input_ids=None,
    max_length=None,
    min_length=None,
    do_sample=None,
    early_stopping=None,
    num_beams=None,
    temperature=None,
    top_k=None,
    top_p=None,
    repetition_penalty=None,
    bad_words_ids=None,
    bos_token_id=None,
    pad_token_id=None,
    eos_token_id=None,
    length_penalty=None,
    no_repeat_ngram_size=None,
    num_return_sequences=None,
    attention_mask=None,
    decoder_start_token_id=None,
):
    r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, 
    beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.

    Adapted in part from `Facebook's XLM beam search code`_.

    .. _`Facebook's XLM beam search code`:
       https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529


    Parameters:

        input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
            The sequence used as a prompt for the generation. If `None` the method initializes
            it as an empty `torch.LongTensor` of shape `(1,)`.

        max_length: (`optional`) int
            The max length of the sequence to be generated.  Between `min_length` and infinity. Default to 20.

        min_length: (`optional`) int
            The min length of the sequence to be generated.  Between 0 and infinity. Default to 0.

        do_sample: (`optional`) bool
            If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.

        early_stopping: (`optional`) bool
            if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.

        num_beams: (`optional`) int
            Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.

        temperature: (`optional`) float
            The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.

        top_k: (`optional`) int
            The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.

        top_p: (`optional`) float
            The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.

        repetition_penalty: (`optional`) float
            The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.

        pad_token_id: (`optional`) int
            Padding token. Default to specicic model pad_token_id or None if it does not exist.

        bos_token_id: (`optional`) int
            BOS token. Defaults to `bos_token_id` as defined in the models config.

        eos_token_id: (`optional`) int
            EOS token. Defaults to `eos_token_id` as defined in the models config.

        length_penalty: (`optional`) float
            Exponential penalty to the length. Default to 1.

        no_repeat_ngram_size: (`optional`) int
            If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
        bad_words_ids: (`optional`) list of lists of int
            `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.

        num_return_sequences: (`optional`) int
            The number of independently computed returned sequences for each element in the batch. Default to 1.

        attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids`
            Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``:
            ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
            Defaults to `None`.

        `What are attention masks? <../glossary.html#attention-mask>`__

        decoder_start_token_id=None: (`optional`) int
            If an encoder-decoder model starts decoding with a different token than BOS.
            Defaults to `None` and is changed to `BOS` later.

    Return:

        output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
            sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`

    Examples::

        tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
        model = AutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
        outputs = model.generate(max_length=40)  # do greedy decoding
        print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))

        tokenizer = AutoTokenizer.from_pretrained('openai-gpt')   # Initialize tokenizer
        model = AutoModelWithLMHead.from_pretrained('openai-gpt')    # Download model and configuration from S3 and cache.
        input_context = 'The dog'
        input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
        outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5)  # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
        for i in range(3): #  3 output sequences were generated
            print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))

        tokenizer = AutoTokenizer.from_pretrained('distilgpt2')   # Initialize tokenizer
        model = AutoModelWithLMHead.from_pretrained('distilgpt2')    # Download model and configuration from S3 and cache.
        input_context = 'The dog'
        input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
        outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3)  # 3 generate sequences using by sampling
        for i in range(3): #  3 output sequences were generated
            print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))

        tokenizer = AutoTokenizer.from_pretrained('ctrl')   # Initialize tokenizer
        model = AutoModelWithLMHead.from_pretrained('ctrl')    # Download model and configuration from S3 and cache.
        input_context = 'Legal My neighbor is'  # "Legal" is one of the control codes for ctrl
        input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
        outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2)  # generate sequences
        print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))

        tokenizer = AutoTokenizer.from_pretrained('gpt2')   # Initialize tokenizer
        model = AutoModelWithLMHead.from_pretrained('gpt2')    # Download model and configuration from S3 and cache.
        input_context = 'My cute dog'  # "Legal" is one of the control codes for ctrl
        bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
        input_ids = tokenizer.encode(input_context, return_tensors='pt')  # encode input context
        outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)  # generate sequences without allowing bad_words to be generated
    """

    # We cannot generate if the model does not have a LM head
    if self.get_output_embeddings() is None:
        raise AttributeError(
            "You tried to generate sequences with a model that does not have a LM Head."
            "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
        )

    max_length = max_length if max_length is not None else self.config.max_length
    min_length = min_length if min_length is not None else self.config.min_length
    do_sample = do_sample if do_sample is not None else self.config.do_sample
    early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
    num_beams = num_beams if num_beams is not None else self.config.num_beams
    temperature = temperature if temperature is not None else self.config.temperature
    top_k = top_k if top_k is not None else self.config.top_k
    top_p = top_p if top_p is not None else self.config.top_p
    repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
    bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
    pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
    eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
    length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
    no_repeat_ngram_size = (
        no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
    )
    bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
    num_return_sequences = (
        num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
    )
    decoder_start_token_id = (
        decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
    )

    if input_ids is not None:
        batch_size = input_ids.shape[0]  # overriden by the input batch_size
    else:
        batch_size = 1

    assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
    assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
    assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
    assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
    assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
    assert temperature > 0, "`temperature` should be strictly positive."
    assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
    assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
    assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
    assert input_ids is not None or (
        isinstance(bos_token_id, int) and bos_token_id >= 0
    ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
    assert pad_token_id is None or (
        isinstance(pad_token_id, int) and (pad_token_id >= 0)
    ), "`pad_token_id` should be a positive integer."
    assert (eos_token_id is None) or (
        isinstance(eos_token_id, int) and (eos_token_id >= 0)
    ), "`eos_token_id` should be a positive integer."
    assert length_penalty > 0, "`length_penalty` should be strictly positive."
    assert (
        isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
    ), "`no_repeat_ngram_size` should be a positive integer."
    assert (
        isinstance(num_return_sequences, int) and num_return_sequences > 0
    ), "`num_return_sequences` should be a strictly positive integer."
    assert (
        bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
    ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"

    if input_ids is None:
        assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
            "you should either supply a context to complete as `input_ids` input "
            "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
        )
        input_ids = torch.full(
            (batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
        )
    else:
        assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."

    # not allow to duplicate outputs when greedy decoding
    if do_sample is False:
        if num_beams == 1:
            # no_beam_search greedy generation conditions
            assert (
                num_return_sequences == 1
            ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"

        else:
            # beam_search greedy generation conditions
            assert (
                num_beams >= num_return_sequences
            ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"

    # create attention mask if necessary
    # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
    if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
        attention_mask = input_ids.ne(pad_token_id).long()
    elif attention_mask is None:
        attention_mask = input_ids.new_ones(input_ids.shape)

    # set pad_token_id to eos_token_id if not set. Important that this is done after
    # attention_mask is created
    if pad_token_id is None and eos_token_id is not None:
        logger.warning(
            "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
        )
        pad_token_id = eos_token_id

    # current position and vocab size
    vocab_size = self.config.vocab_size

    # set effective batch size and effective batch multiplier according to do_sample
    if do_sample:
        effective_batch_size = batch_size * num_return_sequences
        effective_batch_mult = num_return_sequences
    else:
        effective_batch_size = batch_size
        effective_batch_mult = 1

    if self.config.is_encoder_decoder:
        if decoder_start_token_id is None:
            decoder_start_token_id = bos_token_id

        assert (
            decoder_start_token_id is not None
        ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
        assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
        assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)

        # get encoder and store encoder outputs
        encoder = self.get_encoder()

        encoder_outputs = encoder(input_ids, attention_mask=attention_mask)

    # Expand input ids if num_beams > 1 or num_return_sequences > 1
    if num_return_sequences > 1 or num_beams > 1:
        input_ids_len = input_ids.shape[-1]
        input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
        attention_mask = attention_mask.unsqueeze(1).expand(
            batch_size, effective_batch_mult * num_beams, input_ids_len
        )

        input_ids = input_ids.contiguous().view(
            effective_batch_size * num_beams, input_ids_len
        )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)
        attention_mask = attention_mask.contiguous().view(
            effective_batch_size * num_beams, input_ids_len
        )  # shape: (batch_size * num_return_sequences * num_beams, cur_len)

    if self.config.is_encoder_decoder:
        # create empty decoder_input_ids
        input_ids = torch.full(
            (effective_batch_size * num_beams, 1),
            decoder_start_token_id,
            dtype=torch.long,
            device=next(self.parameters()).device,
        )
        cur_len = 1

        assert (
            batch_size == encoder_outputs[0].shape[0]
        ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "

        # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
        expanded_batch_idxs = (
            torch.arange(batch_size)
            .view(-1, 1)
            .repeat(1, num_beams * effective_batch_mult)
            .view(-1)
            .to(input_ids.device)
        )
        # expand encoder_outputs
        encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])

    else:
        encoder_outputs = None
        cur_len = input_ids.shape[-1]

    if num_beams > 1:
        output = self._generate_beam_search(
            input_ids,
            cur_len=cur_len,
            max_length=max_length,
            min_length=min_length,
            do_sample=do_sample,
            early_stopping=early_stopping,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            bad_words_ids=bad_words_ids,
            bos_token_id=bos_token_id,
            pad_token_id=pad_token_id,
            decoder_start_token_id=decoder_start_token_id,
            eos_token_id=eos_token_id,
            batch_size=effective_batch_size,
            num_return_sequences=num_return_sequences,
            length_penalty=length_penalty,
            num_beams=num_beams,
            vocab_size=vocab_size,
            encoder_outputs=encoder_outputs,
            attention_mask=attention_mask,
        )
    else:
        output = self._generate_no_beam_search(
            input_ids,
            cur_len=cur_len,
            max_length=max_length,
            min_length=min_length,
            do_sample=do_sample,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            no_repeat_ngram_size=no_repeat_ngram_size,
            bad_words_ids=bad_words_ids,
            bos_token_id=bos_token_id,
            pad_token_id=pad_token_id,
            decoder_start_token_id=decoder_start_token_id,
            eos_token_id=eos_token_id,
            batch_size=effective_batch_size,
            encoder_outputs=encoder_outputs,
            attention_mask=attention_mask,
        )

    return output


In [None]:


def _generate_no_beam_search(
    self,
    input_ids,
    cur_len,
    max_length,
    min_length,
    do_sample,
    temperature,
    top_k,
    top_p,
    repetition_penalty,
    no_repeat_ngram_size,
    bad_words_ids,
    bos_token_id,
    pad_token_id,
    eos_token_id,
    decoder_start_token_id,
    batch_size,
    encoder_outputs,
    attention_mask,
):
    """ Generate sequences for each example without beam search (num_beams == 1).
        All returned sequence are generated independantly.
    """
    # length of generated sentences / unfinished sentences
    unfinished_sents = input_ids.new(batch_size).fill_(1)
    sent_lengths = input_ids.new(batch_size).fill_(max_length)

    past = encoder_outputs  # defined for encoder-decoder models, None for decoder-only models

    while cur_len < max_length:
        model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)

        outputs = self(**model_inputs)
        next_token_logits = outputs[0][:, -1, :]

        # if model has past, then set the past variable to speed up decoding
        if self._do_output_past(outputs):
            past = outputs[1]

        # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
        if repetition_penalty != 1.0:
            self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)

        if no_repeat_ngram_size > 0:
            # calculate a list of banned tokens to prevent repetitively generating the same ngrams
            # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
            banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
            for batch_idx in range(batch_size):
                next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")

        if bad_words_ids is not None:
            # calculate a list of banned tokens according to bad words
            banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)

            for batch_idx in range(batch_size):
                next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")

        # set eos token prob to zero if min_length is not reached
        if eos_token_id is not None and cur_len < min_length:
            next_token_logits[:, eos_token_id] = -float("inf")

        if do_sample:
            # Temperature (higher temperature => more likely to sample low probability tokens)
            if temperature != 1.0:
                next_token_logits = next_token_logits / temperature
            # Top-p/top-k filtering
            next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            # Sample
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
        else:
            # Greedy decoding
            next_token = torch.argmax(next_token_logits, dim=-1)

        # update generations and finished sentences
        if eos_token_id is not None:
            # pad finished sentences if eos_token_id exist
            tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
        else:
            tokens_to_add = next_token

        input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)

        if eos_token_id is not None:
            eos_in_sents = tokens_to_add == eos_token_id
            # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
            is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
            sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1)
            # unfinished_sents is set to zero if eos in sentence
            unfinished_sents.mul_((~eos_in_sents).long())

        # stop when there is a </s> in each sentence, or if we exceed the maximul length
        if unfinished_sents.max() == 0:
            break

        # extend attention_mask for new generated input if only decoder
        if self.config.is_encoder_decoder is False:
            attention_mask = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
            )

        cur_len = cur_len + 1

    # if there are different sentences lengths in the batch, some batches have to be padded
    if sent_lengths.min().item() != sent_lengths.max().item():
        assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
        # finished sents are filled with pad_token
        decoded = input_ids.new(batch_size, sent_lengths.max().item()).fill_(pad_token_id)
    else:
        decoded = input_ids

    for hypo_idx, hypo in enumerate(input_ids):
        decoded[hypo_idx, : sent_lengths[hypo_idx]] = hypo[: sent_lengths[hypo_idx]]

    return decoded


In [None]:
# dynamically bind new generate method -- note this is a hack from here:
# https://stackoverflow.com/questions/972/adding-a-method-to-an-existing-object-instance

model.generate_2 = generate.__get__(model)

In [None]:
test_news_article = 'New Zealand says it has stopped community transmission of Covid-19, '\
 'effectively eliminating the virus. With new cases in single figures for several days - one on Sunday ' \
 '- Prime Minister Jacinda Ardern said the virus was "currently" eliminated. But officials have warned ' \
 'against complacency, saying it does not mean a total end to new coronavirus cases. ' \
 'The news comes hours before New Zealand is set to move out of its toughest level of social restrictions. ' \
 'From Tuesday, some non-essential business, healthcare and education activity will be able to resume. ' \
 'Most people will still be required to remain at home at all times and avoid all social interactions.'


max_len = 75
summarize(test_news_article, max_length=max_len)

In [None]:
model.use_cache

In [None]:
model.config

In [None]:
summarize(test_news_article, max_length=max_len)

In [None]:
# slot-filling (very cool feature of BART large)



In [None]:
def _generate_no_beam_search(
    self,
    input_ids,
    cur_len,
    max_length,
    min_length,
    do_sample,
    temperature,
    top_k,
    top_p,
    repetition_penalty,
    no_repeat_ngram_size,
    bad_words_ids,
    bos_token_id,
    pad_token_id,
    eos_token_id,
    decoder_start_token_id,
    batch_size,
    encoder_outputs,
    attention_mask,
    use_cache,
):
    """ Generate sequences for each example without beam search (num_beams == 1).
        All returned sequence are generated independantly.
    """
    # length of generated sentences / unfinished sentences
    unfinished_sents = input_ids.new(batch_size).fill_(1)
    sent_lengths = input_ids.new(batch_size).fill_(max_length)

    past = encoder_outputs  # defined for encoder-decoder models, None for decoder-only models

    while cur_len < max_length:
        model_inputs = self.prepare_inputs_for_generation(
            input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
        )

        outputs = self(**model_inputs)
        next_token_logits = outputs[0][:, -1, :]

        # if model has past, then set the past variable to speed up decoding
        if self._use_cache(outputs, use_cache):
            past = outputs[1]

        # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
        if repetition_penalty != 1.0:
            self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)

        if no_repeat_ngram_size > 0:
            # calculate a list of banned tokens to prevent repetitively generating the same ngrams
            # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
            banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
            for batch_idx in range(batch_size):
                next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")

        if bad_words_ids is not None:
            # calculate a list of banned tokens according to bad words
            banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)

            for batch_idx in range(batch_size):
                next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")

        # set eos token prob to zero if min_length is not reached
        if eos_token_id is not None and cur_len < min_length:
            next_token_logits[:, eos_token_id] = -float("inf")

        if do_sample:
            # Temperature (higher temperature => more likely to sample low probability tokens)
            if temperature != 1.0:
                next_token_logits = next_token_logits / temperature
            # Top-p/top-k filtering
            next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            # Sample
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
        else:
            # Greedy decoding
            next_token = torch.argmax(next_token_logits, dim=-1)

        # update generations and finished sentences
        if eos_token_id is not None:
            # pad finished sentences if eos_token_id exist
            tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
        else:
            tokens_to_add = next_token

        input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)

        if eos_token_id is not None:
            eos_in_sents = tokens_to_add == eos_token_id
            # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
            is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
            sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len + 1)
            # unfinished_sents is set to zero if eos in sentence
            unfinished_sents.mul_((~eos_in_sents).long())

        # stop when there is a </s> in each sentence, or if we exceed the maximul length
        if unfinished_sents.max() == 0:
            break

        # extend attention_mask for new generated input if only decoder
        if self.config.is_encoder_decoder is False:
            attention_mask = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
            )

        cur_len = cur_len + 1

    # if there are different sentences lengths in the batch, some batches have to be padded
    if sent_lengths.min().item() != sent_lengths.max().item():
        assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
        # finished sents are filled with pad_token
        decoded = input_ids.new(batch_size, sent_lengths.max().item()).fill_(pad_token_id)
    else:
        decoded = input_ids

    for hypo_idx, hypo in enumerate(input_ids):
        decoded[hypo_idx, : sent_lengths[hypo_idx]] = hypo[: sent_lengths[hypo_idx]]

    return decoded


