In [1]:
!pip install transformers==4.6.1



In [2]:
from transformers import AutoModelWithLMHead, AutoTokenizer
import torch

In [3]:
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k >0: keep only top k tokens with highest probability (top-k filtering).
            top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        Source:
            https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

In [4]:
def sample_token(output):
    logits = output[..., -1, :].squeeze(0)
    logits = top_k_top_p_filtering(logits, top_k=10)
    log_probs = torch.softmax(logits, dim=-1)
    token = torch.multinomial(log_probs, num_samples=1)[0]

    return token

## Transformer-XL

In [5]:
tokenizer = AutoTokenizer.from_pretrained('transfo-xl-wt103')
model = AutoModelWithLMHead.from_pretrained('transfo-xl-wt103')



In [6]:
generated = tokenizer.encode("On our way to the beach")
context = torch.tensor([generated])
past = None

In [7]:
for i in range(100):
    output = model(context, mems=past)
    token = sample_token(output.prediction_scores)

    generated.append(token.item())
    context = token.view(1, -1)
    past = output.mems

In [8]:
print(tokenizer.decode(generated))

On our way to the beach on the morning of the weekend of his, <eos> the morning of the Friday the 13th is a very quiet day, and the afternoon of Saturday of the 13th is quite good. "<eos> The day after Monday morning, Saturday afternoon, the afternoon of the 14th, a very quiet day of the 13th, is celebrated with a fireworks display and fireworks. <eos> The day after Saturday, Thursday afternoon, the night of the 14th, the afternoon of Friday is devoted to the celebration. The day is also dedicated to


## GPT-2

In [9]:
tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
model = AutoModelWithLMHead.from_pretrained('gpt2-large')

In [10]:
generated = tokenizer.encode("On our way to the beach")
context = torch.tensor([generated])
past = None

In [11]:
for i in range(100):
    output = model(context, past_key_values=past)
    token = sample_token(output.logits)

    generated.append(token.item())
    context = token.unsqueeze(0)
    past = output.past_key_values

In [12]:
print(tokenizer.decode(generated))

On our way to the beach we found a few more people who were looking at the same thing that we were.

I think we got about a thousand people in one night, and the first night was about a million.

When we arrived back at the house, a few of those same people had joined our group.

I was pretty sure we'd had the first big crowd.

I was pretty sure that the first person I'd ever been close to was someone I'd met in the past.


## XLM

In [13]:
tokenizer = AutoTokenizer.from_pretrained('xlm-clm-enfr-1024')
model = AutoModelWithLMHead.from_pretrained('xlm-clm-enfr-1024')

Some weights of XLMWithLMHeadModel were not initialized from the model checkpoint at xlm-clm-enfr-1024 and are newly initialized: ['transformer.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
generated = [0] # start with just <s>
context = torch.tensor([generated])
lang = 0 # English

In [15]:
for i in range(100):
    langs = torch.zeros_like(context).fill_(lang)
    output = model(context, langs=langs)
    token = sample_token(output.logits)

    generated.append(token.item())
    context = torch.tensor([generated])

In [16]:
print(tokenizer.decode(generated))

<s>, and is a key driver of our financial results. " </s>" But we have made the decision to take it very carefully and that will be based on the best available evidence, " Mr Hunt said. " </s>" We are looking at it. </s>It's been a difficult decision. </s>It's not about a lack of resources, " he said. </s>The fact is that we're going to take the right decision. </s>He's the right time in his situation. </s>" He's


In [17]:
generated = [0] # start with just <s>
context = torch.tensor([generated])
lang = 1 # French

In [18]:
for i in range(100):
    langs = torch.zeros_like(context).fill_(lang)
    output = model(context, langs=langs)
    token = sample_token(output.logits)

    generated.append(token.item())
    context = torch.tensor([generated])

In [19]:
print(tokenizer.decode(generated))

<s>for the U.S. market. </s>C' est un peu comme les années précédentes, il y a des années. </s>" Il fallait faire du théâtre. </s>Les prix. </s>La situation est différente. </s>Il a également été un peu plus difficile pour les banques ". </s>Le comité : pour les banques. </s>La commune, le maire, le syndicat des commerçants... </s>A la maison de retraite, les services municipaux... le conseil municipal, la mairie, c' étaient les services de la mairie, qui
