# Transformer Model Text Generation Demo

This notebook demonstrates how to load a trained GPT model checkpoint and generate text from it.

In [1]:
import torch
import tiktoken
from minigpt.model.transformer import GPTModel
from minigpt.config.config import DEFAULT_CONFIG

## 1. Load the Model

First, we'll load the model from a saved checkpoint.

In [2]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Model checkpoint path
checkpoint_path = "checkpoints/model_epoch_1.pth"

# Initialize the model with default configuration
config = DEFAULT_CONFIG.model
model = GPTModel(config)

# Load the weights
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.to(device)
model.eval()

# Initialize the tokenizer
tokenizer = tiktoken.get_encoding("gpt2")

print("Model loaded")

Using device: cpu
Model loaded


## 2. Text Generation Function

Let's define a function to generate text from our model.

In [3]:
def generate_text(prompt, max_new_tokens=100, temperature=0.8, top_k=40):
    """
    Generate text from the model given a prompt.
    
    Args:
        prompt (str): The text prompt to start generation
        max_new_tokens (int): Maximum number of new tokens to generate
        temperature (float): Sampling temperature (higher = more random)
        top_k (int): Top-k sampling parameter (0 to disable)
        
    Returns:
        str: The generated text including the prompt
    """
    # Encode the prompt
    input_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)
    
    # Generate text
    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k
        )
    
    # Decode the output
    generated_text = tokenizer.decode(output_ids[0].tolist())
    
    return generated_text

## 3. Generate Text Samples

Now let's generate some text with different prompts and parameters.

In [4]:
# Example 1: Basic generation with default parameters
prompt = "Once upon a time"
generated_text = generate_text(prompt)

print(f"Prompt: '{prompt}'")
print("Generated text:")
print("=" * 80)
print(generated_text)
print("=" * 80)

Prompt: 'Once upon a time'
Generated text:
Once upon a time, then the best of the most common will be more interested.

When you are trying to see me all to keep that?

And you are in the world of that in the world.

I'm a good of his own person in one of my family.

This is very a really one of you who is so many others.

And you know?

I believe them, don't be talking that?

I believe it!

I


In [5]:
# Example 2: Try with a different prompt
prompt = "The future of artificial intelligence"
generated_text = generate_text(prompt, max_new_tokens=150)

print(f"Prompt: '{prompt}'")
print("Generated text:")
print("=" * 80)
print(generated_text)
print("=" * 80)

Prompt: 'The future of artificial intelligence'
Generated text:
The future of artificial intelligence. In the present hand, the majority as the first section of the world. The results in the case of the
point of the form of the two-temperature are obtained by the two-dalo and a part of the
direction of the current-interverse and the inner.

\begin{remark}[t]
\centering
\begin{tabular}[t]
\includegraphics[width=1.6\columnwidth]{figures/0.01cm}
\caption{An=0.50\tiny){9ures/0.5.}
\end{figure}

\caption{The following of the most-triv


In [6]:
# Example 3: Try with different sampling parameters
prompt = "In a world where"
generated_text = generate_text(
    prompt, 
    max_new_tokens=200, 
    temperature=1.0,  # Higher temperature for more randomness
    top_k=100         # More candidates in top-k sampling
)

print(f"Prompt: '{prompt}'")
print("Generated text (with temperature=1.0, top_k=100):")
print("=" * 80)
print(generated_text)
print("=" * 80)

Prompt: 'In a world where'
Generated text (with temperature=1.0, top_k=100):
In a world where you have all been doing the right page used by you there for their computer for something for and you know of that they take it in you. I thought, so's help. Thank it in the way and put everything. You can make for a business and take you the right option in our face from the process or the first? We just come at the room.
The process for their car: but you have better to learn our family about you. It is fun and it is great and very for the price, no a few-stop company. If it depends on the form this kind of your own, you can be not the only more that we will put a better.
I don't definitely have a very important store that is as good unless you have to take a few of an hour. Like that you have a bit or to do like me see the entire time and get to help it out the amount.
O/4-4/22 -

I can


## 4. Interactive Text Generation

Let's create a simple interactive widget to generate text with custom parameters.

In [7]:
from ipywidgets import widgets
from IPython.display import display, clear_output

# Create widgets
prompt_widget = widgets.Textarea(
    value='Write a story about',
    description='Prompt:',
    layout={'width': '100%', 'height': '80px'}
)

max_tokens_widget = widgets.IntSlider(
    value=100,
    min=10,
    max=500,
    step=10,
    description='Max Tokens:',
    layout={'width': '50%'}
)

temp_widget = widgets.FloatSlider(
    value=0.8,
    min=0.1,
    max=1.5,
    step=0.1,
    description='Temperature:',
    layout={'width': '50%'}
)

top_k_widget = widgets.IntSlider(
    value=40,
    min=0,
    max=100,
    step=5,
    description='Top-K:',
    layout={'width': '50%'}
)

output_widget = widgets.Textarea(
    description='Generated:',
    layout={'width': '100%', 'height': '300px'},
    disabled=True
)

generate_button = widgets.Button(
    description='Generate',
    button_style='primary'
)

# Define button click handler
def on_generate_button_clicked(b):
    prompt = prompt_widget.value
    max_tokens = max_tokens_widget.value
    temperature = temp_widget.value
    top_k = top_k_widget.value
    
    # Disable button while generating
    generate_button.disabled = True
    generate_button.description = 'Generating...'
    
    # Generate text
    result = generate_text(
        prompt,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_k=top_k
    )
    
    # Update output
    output_widget.value = result
    
    # Re-enable button
    generate_button.disabled = False
    generate_button.description = 'Generate'

# Attach handler to button
generate_button.on_click(on_generate_button_clicked)

# Display widgets
display(prompt_widget)
display(widgets.HBox([max_tokens_widget, temp_widget, top_k_widget]))
display(generate_button)
display(output_widget)

Textarea(value='Write a story about', description='Prompt:', layout=Layout(height='80px', width='100%'))

HBox(children=(IntSlider(value=100, description='Max Tokens:', layout=Layout(width='50%'), max=500, min=10, st…

Button(button_style='primary', description='Generate', style=ButtonStyle())

Textarea(value='', description='Generated:', disabled=True, layout=Layout(height='300px', width='100%'))

## 5. Benchmark Generation Speed

Let's measure how fast our model can generate text.

In [8]:
import time
import numpy as np

def benchmark_generation(prompt, num_runs=5, tokens_per_run=100):
    times = []
    
    for i in range(num_runs):
        start_time = time.time()
        _ = generate_text(prompt, max_new_tokens=tokens_per_run)
        end_time = time.time()
        times.append(end_time - start_time)
        
    avg_time = np.mean(times)
    tokens_per_second = tokens_per_run / avg_time
    
    print(f"Average generation time for {tokens_per_run} tokens: {avg_time:.3f} seconds")
    print(f"Tokens per second: {tokens_per_second:.2f}")
    
# Run benchmark
print("Benchmarking generation speed...")
benchmark_generation("The quick brown fox", num_runs=3, tokens_per_run=50)

Benchmarking generation speed...
Average generation time for 50 tokens: 2.623 seconds
Tokens per second: 19.06
