# Model Generation step by step: No rock left unturned

The generation of text is one of the most interesting visible parts of a model like GPT.  For the most part I have been using the OpenAI generation API, or the Huggingface APIs, but I had never dived deep into the 'behind the scenes' of the generation.

In this notebook I do that: First I recreate the current Huggingface classes mainly used for generation, and then I do my own generation, step by step.

**By the end of this notebook you will have a perfect understanding on how the Generation works, and will have created your own *Huggingface* classes to do text generation with you models!**

## Let's review first the Huggingface classes

Huggingface makes it very easy to use the models. Their classes create a good level of abstraction which allow us to do anything with the call of a single function.

For me, though, this is sometimes a cause of concern, because I can get what I want but I don't know how this is working behind the scenes.

In this case, I want to review the classes the Huggingface provides to generate text from an initial prompt. I will review 2 classes, one that has a super high level of abstraction, to the point that I only need two sentences, and then the next one which requires a bit more of code from my part, but still has a 'magic' function called 'generate' that does all the generation behind the scenes.

### High level of abstraction

In [28]:
from transformers import pipeline, set_seed
set_seed(42)

generator = pipeline('text-generation', model='gpt2')
generator("Once upon a time,", max_length=50, num_return_sequences=1)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': "Once upon a time, my father had been in that role. He was not a typical member of this kind of club. The manager was there and asked me: 'Can I join?\n\n'I said: 'It looks like you are"}]

### Medium level of abstraction

In [29]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel

# Load the pre-trained tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained('gpt2')

# Input text
input_text = "Once upon a time,"

# Tokenize the input text
input_ids = tokenizer.encode(input_text, return_tensors='pt')

# Generate text based on the input
output = model.generate(input_ids, max_length=50, num_return_sequences=1, no_repeat_ngram_size=2)

# Decode the generated text
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

# Print the generated text
print(generated_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Once upon a time, the world was a place of great beauty and great danger. The world of the gods was the place where the great gods were born, and where they were to live.

The world that was created was not the same


## BUT: What's going on behind the scenes?

The two classes above are great, and save us a lot of time. But, what is going on behind the scenes? 

Lets uncover this magic, leaving no rock unturned.

# Text generation Step by Step

#### Step 1: Instantiate the model and its tokenizer

We will start by instantiating the Tokenizer and a model, the GPT2, as a starting point.

The Tokenizer has its own magic and we can discuss the details of it in a different notebook.

As for the model, what we are doing here is taking the GPT2 model from Huggingface. Fortunately in Part 1 we already learned what is behind this class and we know now what 'model' includes, and how it is done. In case you have any doubt about this, [please check Part 1](https://github.com/jcolano/transformer_step_by_step.git) where this is treated in detail.

In [30]:
import torch
import torch.nn.functional as F

from transformers import GPT2Tokenizer, GPT2LMHeadModel

# Load the pre-trained tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

model = GPT2LMHeadModel.from_pretrained('gpt2')

#### Step 2: Define the prompt 
We will define the following prompt for this exercise

In [31]:
input_text = "Once upon a time,"

#### Step 3: Convert the prompt into token ids
Models can only take numbers, so the tokenizer takes the text, converts it into tokens, and then returns the integers that represent these tokens.

In [32]:
input_ids = tokenizer.encode(input_text, return_tensors='pt')
input_ids 

tensor([[7454, 2402,  257,  640,   11]])

#### Step 4: Forward pass
In [Part 1](https://github.com/jcolano/transformer_step_by_step.git) and [Part 2](https://github.com/jcolano/transformer_step_by_step.git) we learned how the model is built, what is the 'forward' function in the model, and how it is invoked. While doing generation we use the very same function to get the next token.

In [33]:
output = model(input_ids)

The output of the forward pass includes several components, typically: the logits, the hidden states, attention weights, etc. 

In [None]:
output 

#### Step 5: Get the logits
For the case of the generation we only care about the logits, so we get them out of the output.

In [35]:
logits = output.logits

The logits are the 'raw scores' or the 'raw probabilities' assigned to eack token in the entire vocabulary of the model for being the next position in the input sequence.

In [36]:
logits 

tensor([[[ -34.5645,  -34.4081,  -38.3079,  ...,  -41.6996,  -39.7801,
           -35.0521],
         [ -84.7256,  -82.9326,  -87.0165,  ...,  -91.6667,  -86.2354,
           -84.7094],
         [-109.0798, -105.7258, -109.9115,  ..., -114.2847, -107.6933,
          -105.3613],
         [ -57.8935,  -58.5540,  -64.7374,  ...,  -64.9437,  -62.9294,
           -60.0625],
         [ -97.2383,  -98.3015, -100.6650,  ...,  -99.0754,  -98.9562,
           -95.6314]]], grad_fn=<UnsafeViewBackward0>)

In [37]:
logits.shape 

torch.Size([1, 5, 50257])

#### Step 6: Apply softmax to the logits
We apply softmax to the raw probabilities to convert them into normalized probabilities (probabilities between 0 and 1 - all sum up to 1)

In [9]:
token_probs = F.softmax(logits[:, -1, :], dim=-1)

These are the normalized probabilities for all tokens. Each number has a value between 0 and 1, and the sum of all is 1.

In [11]:
token_probs 

tensor([[8.0201e-06, 2.7697e-06, 2.6060e-07,  ..., 1.2774e-06, 1.4391e-06,
         3.9997e-05]], grad_fn=<SoftmaxBackward0>)

In [12]:
token_probs.shape

torch.Size([1, 50257])

In [21]:
sum_of_elements = torch.sum(token_probs)
sum_of_elements 

tensor(1.0000, grad_fn=<SumBackward0>)

In [22]:
# Count the number of elements in the tensor
num_elements = token_probs.numel()
num_elements 

50257

#### Step 7: Out of the token_probs pick the one with the highest probability

In [23]:
next_token = torch.multinomial(token_probs, 1)

In [38]:
# The value of this next_token is an integer that can be converted into a token.
next_token 

tensor([[339]])

Here we convert the "next_token" tensor into a string. Every time we run the last few lines, the str_next_token will most probably be a different word.

In [25]:
str_next_token = tokenizer.decode(next_token[0])
print(str_next_token)

 he


#### Step 8: Concatenate the new token to the input_ids, and start this cycle again

Now that we have the next token, we concatenate it to the previous string. For example:

If previous string is: "Once upon a time," 
... and new token is: "there"
... then the new input string will be: "Once upon a time, there"

And with this new input string we start again.

Here we don't concatenate the next_token as a string but instead we concatenate it to the already converted to input_ids object:

In [26]:
input_ids = torch.cat([input_ids, next_token], dim=-1)

In [27]:
input_ids 

#tensor([[7454, 2402,  257,  640,   11]])

tensor([[7454, 2402,  257,  640,   11,  262,  339]])

### Putting it all together in a loop

Now that we have done a manual, step-by-step, generation of one token, all that we have to do is put it inside a loop and repeat it while the input is less than 'max_length':

In [43]:


# Initialize the input_ids with your starting text
input_text = "Once upon a time,"
input_ids = tokenizer.encode(input_text, return_tensors='pt')

# Define the maximum length of the generated text
max_length = 50

# Generate text
while len(input_ids[0]) < max_length:
    # Get logits from the model
    output = model(input_ids)

    logits = output.logits

    # Apply softmax to obtain token probabilities
    token_probs = F.softmax(logits[:, -1, :], dim=-1)

    # Sample the next token based on probabilities
    next_token = torch.multinomial(token_probs, 1)

    # Append the sampled token to input_ids
    input_ids = torch.cat([input_ids, next_token], dim=-1)

# Decode the generated text
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Print the generated text
print(generated_text)


Once upon a time, in a land far, far away, they took fire, with a tiny bullet penetrating their form.

Upon a very distant thread, almost a hundred and fifty thousand years ago, in the forbidding night of years when


## AND THAT'S IT!!!!

That is how the "GENERATE" works behind the scenes. 

There are other parts to it that we will see now:

- Temperature
- TOP_K
- TOP_P
- REPETITION_PENALTY

### Lets apply temperature!
To apply temperature when sampling from the model, you can scale the token probabilities before sampling. The temperature parameter controls the level of randomness in the sampling process. A higher temperature (e.g., 1.0) makes the sampling more random, while a lower temperature (e.g., 0.5) makes it more deterministic.

*Higher Temperature (e.g., > 1.0)*: When you increase the temperature, it has a smoothing effect on the token probabilities. It makes the distribution of probabilities flatter, giving tokens with lower probabilities a better chance of being selected. This increases the randomness of the generated text. It's like injecting more randomness into the generation process.

*Lower Temperature (e.g., < 1.0)*: Conversely, when you decrease the temperature, it sharpens the token probabilities. It makes the distribution peakier, with higher probabilities assigned to the most likely tokens. This results in more deterministic and focused text generation. Lower temperature values make the model more conservative and less prone to generating diverse, unexpected text.

In [44]:
import torch
import torch.nn.functional as F

# Initialize the input_ids with your starting text
input_text = "Once upon a time, in a land far, far away,"
input_ids = tokenizer.encode(input_text, return_tensors='pt')

# Define the maximum length of the generated text
max_length = 50

# Define the temperature
temperature = 0.7  # Adjust the temperature value as desired

# Generate text
while len(input_ids[0]) < max_length:
    # Get logits from the model
    output = model(input_ids)

    logits = output.logits

    # Apply softmax with temperature to obtain token probabilities
    token_probs = F.softmax(logits[:, -1, :] / temperature, dim=-1)

    # Sample the next token based on probabilities
    next_token = torch.multinomial(token_probs, 1)

    # Append the sampled token to input_ids
    input_ids = torch.cat([input_ids, next_token], dim=-1)

# Decode the generated text
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Print the generated text
print(generated_text)


Once upon a time, in a land far, far away, there was a goddess, one that was truly known to mankind feathers. It was called the Mina Mina, and she was goddess of the sun, the moon, and earth.


### And now lets apply TOP_K

Setting top_k to a positive integer value will ensure that only the top-k most likely tokens are considered for sampling. 

In [45]:
import torch
import torch.nn.functional as F

# Initialize the input_ids with your starting text
input_text = "Once upon a time, in a land far, far away,"
input_ids = tokenizer.encode(input_text, return_tensors='pt')

# Define the maximum length of the generated text
max_length = 50

# Define the temperature
temperature = 0.7  # Adjust the temperature value as desired

# Define the top-k value
top_k = 50  # Adjust the top-k value as desired

# Generate text
while len(input_ids[0]) < max_length:
    # Get logits from the model
    output = model(input_ids)

    logits = output.logits

    # Apply softmax with temperature to obtain token probabilities
    token_probs = F.softmax(logits[:, -1, :] / temperature, dim=-1)

    # Apply top-k filtering to the token probabilities
    # After softmax, you have token probabilities for each token in the model's vocabulary.
    # top_k specifies how many of the highest probability tokens you want to consider. It selects the top k tokens with the highest probabilities.
    # filtered_token_probs contains the probabilities of the top k tokens, and top_indices contains the corresponding indices of these top tokens.
    filtered_token_probs, top_indices = token_probs.topk(top_k, dim=-1)

    # Sample the next token based on the filtered probabilities
    # torch.multinomial is used to randomly sample one token based on the probabilities in filtered_token_probs. It's essentially drawing a token from the filtered distribution.
    # The 1 passed as the second argument specifies that you want to draw one token.
    next_token = torch.multinomial(filtered_token_probs, 1)

    # Map the sampled token back to the original token space
    # Once you've sampled an index from the filtered distribution, you use top_indices to map it back to the index in the original token vocabulary.
    # This step ensures that the generated token corresponds to a valid token in the model's vocabulary.
    next_token = top_indices.gather(dim=-1, index=next_token)

    # Append the sampled token to input_ids
    input_ids = torch.cat([input_ids, next_token], dim=-1)

# Decode the generated text
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Print the generated text
print(generated_text)


Once upon a time, in a land far, far away, with its ruins, its ruins, and its ruins, in the mountains of the great lakes of this country, the gods of the land, and of the world, and of the world


### Lets apply TOP_P
You can apply the top_p parameter (also known as nucleus sampling) to control the cumulative probability mass of the tokens considered during sampling. Setting top_p to a value between 0 and 1 ensures that only the most probable tokens whose cumulative probability mass exceeds top_p are considered for sampling.

In [50]:
import torch
import torch.nn.functional as F

# Initialize the input_ids with your starting text
input_text = "Once upon a time, in a land far, far away,"
input_ids = tokenizer.encode(input_text, return_tensors='pt')

# Define the maximum length of the generated text
max_length = 50

# Define the temperature
temperature = 0.7  # Adjust the temperature value as desired

# Define the top-k value
top_k = 50  # Adjust the top-k value as desired

# Define the top-p value
top_p = 0.9  # Adjust the top-p value as desired

# Generate text
while len(input_ids[0]) < max_length:
    # Get logits from the model
    output = model(input_ids)

    logits = output.logits

    # Apply softmax with temperature to obtain token probabilities
    token_probs = F.softmax(logits[:, -1, :] / temperature, dim=-1)

    # Apply top-k filtering to the token probabilities
    filtered_token_probs, top_indices = token_probs.topk(top_k, dim=-1)

    # Apply top-p filtering to the filtered token probabilities
    sorted_filtered_probs, sorted_filtered_indices = torch.sort(filtered_token_probs, descending=True, dim=-1)
    cumulative_filtered_probs = torch.cumsum(sorted_filtered_probs, dim=-1)
    exceed_top_p = cumulative_filtered_probs > top_p
    min_filtered_exceed_index = torch.min(sorted_filtered_indices.masked_fill(exceed_top_p, sorted_filtered_indices.size(-1) - 1), dim=-1).values
    final_token_probs = torch.zeros_like(filtered_token_probs)
    final_token_probs.scatter_(dim=-1, index=min_filtered_exceed_index.unsqueeze(-1), src=filtered_token_probs.gather(dim=-1, index=min_filtered_exceed_index.unsqueeze(-1)))

    # Sample the next token based on the final filtered probabilities
    next_token = torch.multinomial(final_token_probs, 1)

    # Map the sampled token back to the original token space
    next_token = top_indices.gather(dim=-1, index=next_token)

    # Append the sampled token to input_ids
    input_ids = torch.cat([input_ids, next_token], dim=-1)

# Decode the generated text
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Print the generated text
print(generated_text)


Once upon a time, in a land far, far away, the world was a land forever changed.
Then came we.
The world was a land forever forgotten. After all here was a land forever changed.
And now, here is


### Adding repetition_penalty

Lets add a repetition penalty to the code to discourage the model from generating repeated tokens. 

When applying a repetition penalty, it's a common practice to consider not just the last token but a window of the last n tokens to prevent repetitive patterns.

In [57]:
import torch
import torch.nn.functional as F

# Initialize the input_ids with your starting text
input_text = "Once upon a time, in a land far, far away,"
input_ids = tokenizer.encode(input_text, return_tensors='pt')

# Define the maximum length of the generated text
max_length = 50

# Define the temperature
temperature = 0.7  # Adjust the temperature value as desired

# Define the top-k value
top_k = 50  # Adjust the top-k value as desired

# Define the top-p value
top_p = 0.9  # Adjust the top-p value as desired

# Define the repetition penalty
repetition_penalty = 1.2  # Adjust the penalty value as desired

# Generate text
while len(input_ids[0]) < max_length:
    # Get logits from the model
    output = model(input_ids)

    logits = output.logits

    # Apply softmax with temperature to obtain token probabilities
    token_probs = F.softmax(logits[:, -1, :] / temperature, dim=-1)

    # Apply top-k filtering to the token probabilities
    filtered_token_probs, top_indices = token_probs.topk(top_k, dim=-1)

    # Apply top-p filtering to the filtered token probabilities
    sorted_filtered_probs, sorted_filtered_indices = torch.sort(filtered_token_probs, descending=True, dim=-1)
    cumulative_filtered_probs = torch.cumsum(sorted_filtered_probs, dim=-1)
    exceed_top_p = cumulative_filtered_probs > top_p
    min_filtered_exceed_index = torch.min(sorted_filtered_indices.masked_fill(exceed_top_p, sorted_filtered_indices.size(-1) - 1), dim=-1).values
    final_token_probs = torch.zeros_like(filtered_token_probs)
    final_token_probs.scatter_(dim=-1, index=min_filtered_exceed_index.unsqueeze(-1), src=filtered_token_probs.gather(dim=-1, index=min_filtered_exceed_index.unsqueeze(-1)))

    # Calculate the repetition penalty and apply it to the token probabilities
    last_token = input_ids[:, -1]
    repetition_penalty_tensor = torch.where(last_token.unsqueeze(-1) == top_indices, repetition_penalty, torch.tensor(1.0, device=input_ids.device))
    token_probs = final_token_probs * repetition_penalty_tensor

    # Sample the next token based on the token probabilities
    next_token = torch.multinomial(token_probs, 1)

    # Map the sampled token back to the original token space
    next_token = top_indices.gather(dim=-1, index=next_token)

    # Append the sampled token to input_ids
    input_ids = torch.cat([input_ids, next_token], dim=-1)

# Decode the generated text
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)

# Print the generated text
print(generated_text)


Once upon a time, in a land far, far away, the world was a land forever changed.
Then came we.
The world was a land forever forgotten. After all here was a land forever changed.
And now, here is


### Putting it all together in a more organized way

Since we have now most of the bells and whistles of the GENERATOR, we can organize it as a class.

In [61]:
import torch
import torch.nn.functional as F

class Config:
    def __init__(self, max_length=50, temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.2, repetition_window=5):
        self.max_length = max_length
        self.temperature = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.repetition_penalty = repetition_penalty
        self.repetition_window = repetition_window


def generate(model, input_ids, config):
    while len(input_ids[0]) < max_length:
        # Get logits from the model
        output = model(input_ids)

        logits = output.logits

        # Apply softmax with temperature to obtain token probabilities
        token_probs = F.softmax(logits[:, -1, :] / temperature, dim=-1)

        # Apply top-k filtering to the token probabilities
        filtered_token_probs, top_indices = token_probs.topk(top_k, dim=-1)

        # Apply top-p filtering to the filtered token probabilities
        sorted_filtered_probs, sorted_filtered_indices = torch.sort(filtered_token_probs, descending=True, dim=-1)
        cumulative_filtered_probs = torch.cumsum(sorted_filtered_probs, dim=-1)
        exceed_top_p = cumulative_filtered_probs > top_p
        min_filtered_exceed_index = torch.min(sorted_filtered_indices.masked_fill(exceed_top_p, sorted_filtered_indices.size(-1) - 1), dim=-1).values
        final_token_probs = torch.zeros_like(filtered_token_probs)
        final_token_probs.scatter_(dim=-1, index=min_filtered_exceed_index.unsqueeze(-1), src=filtered_token_probs.gather(dim=-1, index=min_filtered_exceed_index.unsqueeze(-1)))

        # Calculate the repetition penalty and apply it to the token probabilities
        last_token = input_ids[:, -1]
        repetition_penalty_tensor = torch.where(last_token.unsqueeze(-1) == top_indices, repetition_penalty, torch.tensor(1.0, device=input_ids.device))
        token_probs = final_token_probs * repetition_penalty_tensor

        # Sample the next token based on the token probabilities
        next_token = torch.multinomial(token_probs, 1)

        # Map the sampled token back to the original token space
        next_token = top_indices.gather(dim=-1, index=next_token)

        # Append the sampled token to input_ids
        input_ids = torch.cat([input_ids, next_token], dim=-1)

    # Decode the generated text
    generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)

    return generated_text 

config = Config(
    max_length=50, 
    temperature=0.7, 
    top_k=50, 
    top_p=0.9, 
    repetition_penalty=1.2, 
    repetition_window=5
    )

# Initialize the input_ids with your starting text
input_text = "Once upon a time, in a land far, far away,"
input_ids = tokenizer.encode(input_text, return_tensors='pt')

generated_text = generate(model, input_ids, config)
print(generated_text)


Once upon a time, in a land far, far away, the world was a land forever changed.
Then came we.
The world was a land forever forgotten. After all here was a land forever changed.
And now, here is


## Create a class to emulate Huggingface

Lets take the code from above and replicate the "Huggingface Pipeline" class, with generator, config and all!

In [74]:
import torch
import torch.nn.functional as F

class MyPipeline:
    """A class for generating text using a language model.

    Args:
        model (torch.nn.Module): The language model to use for text generation.
        tokenizer: The tokenizer associated with the language model.
        config (MyPipeline.Config, optional): Configuration for text generation. Defaults to None.

        Configuration parameters for text generation.

        MyPipeline.Config:

        Args:
            max_length (int, optional): The maximum length of generated text. Defaults to 50.
            temperature (float, optional): The temperature for controlling randomness. Defaults to 0.7.
            top_k (int, optional): The top-k value for filtering tokens. Defaults to 50.
            top_p (float, optional): The top-p value for filtering tokens. Defaults to 0.9.
            repetition_penalty (float, optional): The repetition penalty value. Defaults to 1.2.
            repetition_window (int, optional): The repetition window size. Defaults to 5.

    """
    class Config:

        def __init__(self, max_length=50, temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.2, repetition_window=5):
            self.max_length = max_length
            self.temperature = temperature
            self.top_k = top_k
            self.top_p = top_p
            self.repetition_penalty = repetition_penalty
            self.repetition_window = repetition_window

    def __init__(self, model, tokenizer, config=None):
        self.model = model
        self.tokenizer = tokenizer
        self.config = config if config is not None else MyPipeline.Config()

    def __call__(self, input_text):  # Define the __call__ method
        return self.generate(input_text)

    def generate(self, input_text):
        max_length = self.config.max_length
        temperature = self.config.temperature
        top_k = self.config.top_k
        top_p = self.config.top_p
        repetition_penalty = self.config.repetition_penalty

        input_ids = self.tokenizer.encode(input_text, return_tensors='pt')
        
        while len(input_ids[0]) < max_length:
            # Get logits from the model
            output = self.model(input_ids)

            logits = output.logits

            # Apply softmax with temperature to obtain token probabilities
            token_probs = F.softmax(logits[:, -1, :] / temperature, dim=-1)

            # Apply top-k filtering to the token probabilities
            filtered_token_probs, top_indices = token_probs.topk(top_k, dim=-1)

            # Apply top-p filtering to the filtered token probabilities
            sorted_filtered_probs, sorted_filtered_indices = torch.sort(filtered_token_probs, descending=True, dim=-1)
            cumulative_filtered_probs = torch.cumsum(sorted_filtered_probs, dim=-1)
            exceed_top_p = cumulative_filtered_probs > top_p
            min_filtered_exceed_index = torch.min(sorted_filtered_indices.masked_fill(exceed_top_p, sorted_filtered_indices.size(-1) - 1), dim=-1).values
            final_token_probs = torch.zeros_like(filtered_token_probs)
            final_token_probs.scatter_(dim=-1, index=min_filtered_exceed_index.unsqueeze(-1), src=filtered_token_probs.gather(dim=-1, index=min_filtered_exceed_index.unsqueeze(-1)))

            # Calculate the repetition penalty and apply it to the token probabilities
            last_token = input_ids[:, -1]
            repetition_penalty_tensor = torch.where(last_token.unsqueeze(-1) == top_indices, repetition_penalty, torch.tensor(1.0, device=input_ids.device))
            token_probs = final_token_probs * repetition_penalty_tensor

            # Sample the next token based on the token probabilities
            next_token = torch.multinomial(token_probs, 1)

            # Map the sampled token back to the original token space
            next_token = top_indices.gather(dim=-1, index=next_token)

            # Append the sampled token to input_ids
            input_ids = torch.cat([input_ids, next_token], dim=-1)

        # Decode the generated text
        generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)

        return generated_text 


### We go full circle to high level of abstraction with our own code

In [75]:
my_generator = MyPipeline(model, tokenizer)
my_generator("Once upon a time, in a land far, far away,")

'Once upon a time, in a land far, far away, the world was a land forever changed.\nThen came we.\nThe world was a land forever forgotten. After all here was a land forever changed.\nAnd now, here is'