# Load Required Libraries and Model
Import necessary libraries and load the pretrained model from the saved checkpoint directory. Include error handling for model loading.

In [None]:
# Import necessary libraries
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import Accelerator

# Initialize the accelerator for distributed inference if needed
accelerator = Accelerator()

# Define the path to the saved model checkpoint
checkpoint_dir = "./output/layer_looping_qwen/best_model"

# Load the tokenizer
try:
    tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
except Exception as e:
    raise RuntimeError(f"Failed to load tokenizer from {checkpoint_dir}: {e}")

# Load the model
try:
    model = AutoModelForCausalLM.from_pretrained(checkpoint_dir)
    model = accelerator.prepare(model)  # Prepare the model for distributed inference
except Exception as e:
    raise RuntimeError(f"Failed to load model from {checkpoint_dir}: {e}")

# Confirm successful loading
print("Model and tokenizer loaded successfully.")

# Setup Tokenizer and Model Configuration
Initialize the tokenizer from the same checkpoint, configure generation parameters like max length, temperature, and top_p. Set up the model's layer looping parameters (n, m, max_loop_count).

In [1]:
# Set up generation parameters
generation_config = {
    "max_length": 128,  # Maximum length of the generated sequence
    "temperature": 0.7,  # Sampling temperature
    "top_p": 0.9,  # Nucleus sampling probability
    "do_sample": True,  # Enable sampling
    "eos_token_id": tokenizer.eos_token_id,  # End-of-sequence token
}

# Configure layer looping parameters
layer_looping_config = {
    "n": 8,  # Start layer index for looping
    "m": 12,  # End layer index for looping
    "max_loop_count": 5,  # Maximum number of times to loop
}

# Add layer looping parameters to the model's configuration
if hasattr(model, "config"):
    model.config.update(layer_looping_config)

# Confirm configuration setup
print("Generation and layer looping configurations set up successfully.")

NameError: name 'tokenizer' is not defined

# Create Text Generation Function
Implement a helper function that handles input preprocessing, model inference, and output post-processing. Include parameters for controlling the loop count and generation settings.

In [2]:
# Define a function for text generation
def generate_text(prompt, loop_count=1, max_length=None, temperature=None, top_p=None):
    """
    Generate text using the pretrained layer looping transformer model.

    Args:
        prompt (str): The input text prompt for generation.
        loop_count (int): Number of times to loop through the specified layers.
        max_length (int, optional): Maximum length of the generated sequence. Defaults to the value in generation_config.
        temperature (float, optional): Sampling temperature. Defaults to the value in generation_config.
        top_p (float, optional): Nucleus sampling probability. Defaults to the value in generation_config.

    Returns:
        str: The generated text.
    """
    # Update generation parameters with provided values or defaults
    gen_params = generation_config.copy()
    if max_length is not None:
        gen_params["max_length"] = max_length
    if temperature is not None:
        gen_params["temperature"] = temperature
    if top_p is not None:
        gen_params["top_p"] = top_p

    # Tokenize the input prompt
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
    input_ids = inputs["input_ids"].to(accelerator.device)
    attention_mask = inputs["attention_mask"].to(accelerator.device)

    # Add loop count to the model's configuration
    if hasattr(model, "config"):
        model.config.k = loop_count

    # Generate text
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **gen_params
        )
        print(outputs)

    # Decode the generated tokens
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return generated_text

In [3]:
print(generate_text("The meaning of life is"))

NameError: name 'generation_config' is not defined

# Interactive Text Generation
Create an interactive interface for entering prompts and generating text responses. Include options to adjust generation parameters in real-time.

In [4]:
# Import necessary libraries for interactive widgets
import ipywidgets as widgets
from IPython.display import display

# Define interactive widgets for text generation
prompt_input = widgets.Text(
    value="",
    placeholder="Enter your prompt here...",
    description="Prompt:",
    layout=widgets.Layout(width="100%")
)

loop_count_slider = widgets.IntSlider(
    value=1,
    min=1,
    max=layer_looping_config["max_loop_count"],
    step=1,
    description="Loop Count:",
    continuous_update=False
)

max_length_slider = widgets.IntSlider(
    value=generation_config["max_length"],
    min=16,
    max=512,
    step=16,
    description="Max Length:",
    continuous_update=False
)

temperature_slider = widgets.FloatSlider(
    value=generation_config["temperature"],
    min=0.1,
    max=1.5,
    step=0.1,
    description="Temperature:",
    continuous_update=False
)

top_p_slider = widgets.FloatSlider(
    value=generation_config["top_p"],
    min=0.1,
    max=1.0,
    step=0.1,
    description="Top-p:",
    continuous_update=False
)

generate_button = widgets.Button(
    description="Generate",
    button_style="success",
    tooltip="Click to generate text",
    icon="rocket"
)

output_area = widgets.Output()

# Define the callback function for text generation
def on_generate_button_click(b):
    with output_area:
        output_area.clear_output()
        prompt = prompt_input.value
        loop_count = loop_count_slider.value
        max_length = max_length_slider.value
        temperature = temperature_slider.value
        top_p = top_p_slider.value
        
        if not prompt.strip():
            print("Please enter a valid prompt.")
            return
        
        print("Generating text...")
        generated_text = generate_text(
            prompt=prompt,
            loop_count=loop_count,
            max_length=max_length,
            temperature=temperature,
            top_p=top_p
        )
        print(f"Generated Text:\n{generated_text}")

# Attach the callback to the button
generate_button.on_click(on_generate_button_click)

# Display the interactive interface
display(
    widgets.VBox([
        prompt_input,
        loop_count_slider,
        max_length_slider,
        temperature_slider,
        top_p_slider,
        generate_button,
        output_area
    ])
)

NameError: name 'layer_looping_config' is not defined

# Experiment with Different Loop Counts
Compare model outputs with different loop counts (k values) using the same input prompt. Analyze how the number of loops affects generation quality and speed.

In [None]:
# Experiment with Different Loop Counts

# Define a function to compare outputs with different loop counts
def compare_loop_counts(prompt, loop_counts, max_length=None, temperature=None, top_p=None):
    """
    Compare model outputs with different loop counts using the same input prompt.

    Args:
        prompt (str): The input text prompt for generation.
        loop_counts (list): List of loop counts to compare.
        max_length (int, optional): Maximum length of the generated sequence. Defaults to the value in generation_config.
        temperature (float, optional): Sampling temperature. Defaults to the value in generation_config.
        top_p (float, optional): Nucleus sampling probability. Defaults to the value in generation_config.

    Returns:
        dict: A dictionary mapping loop counts to generated texts.
    """
    results = {}
    for loop_count in loop_counts:
        generated_text = generate_text(
            prompt=prompt,
            loop_count=loop_count,
            max_length=max_length,
            temperature=temperature,
            top_p=top_p
        )
        results[loop_count] = generated_text
    return results

# Define widgets for comparison
compare_prompt_input = widgets.Text(
    value="",
    placeholder="Enter your prompt here...",
    description="Prompt:",
    layout=widgets.Layout(width="100%")
)

loop_counts_input = widgets.Text(
    value="1,2,3",
    placeholder="Enter loop counts (comma-separated)...",
    description="Loop Counts:",
    layout=widgets.Layout(width="100%")
)

compare_button = widgets.Button(
    description="Compare",
    button_style="info",
    tooltip="Click to compare outputs with different loop counts",
    icon="search"
)

compare_output_area = widgets.Output()

# Define the callback function for comparison
def on_compare_button_click(b):
    with compare_output_area:
        compare_output_area.clear_output()
        prompt = compare_prompt_input.value
        loop_counts = loop_counts_input.value
        
        if not prompt.strip():
            print("Please enter a valid prompt.")
            return
        
        try:
            loop_counts = [int(x.strip()) for x in loop_counts.split(",") if x.strip().isdigit()]
        except ValueError:
            print("Invalid loop counts. Please enter a comma-separated list of integers.")
            return
        
        if not loop_counts:
            print("Please enter at least one valid loop count.")
            return
        
        print("Comparing outputs with different loop counts...")
        results = compare_loop_counts(
            prompt=prompt,
            loop_counts=loop_counts,
            max_length=max_length_slider.value,
            temperature=temperature_slider.value,
            top_p=top_p_slider.value
        )
        
        for loop_count, text in results.items():
            print(f"\nLoop Count: {loop_count}\n{'-' * 20}\n{text}")

# Attach the callback to the button
compare_button.on_click(on_compare_button_click)

# Display the comparison interface
display(
    widgets.VBox([
        compare_prompt_input,
        loop_counts_input,
        compare_button,
        compare_output_area
    ])
)

# Batch Processing for Multiple Prompts
Implement batch processing functionality to generate responses for multiple prompts simultaneously. Include performance optimization for batch processing.

In [None]:
# Batch Processing for Multiple Prompts

# Define a function for batch processing
def generate_batch(prompts, loop_count=1, max_length=None, temperature=None, top_p=None):
    """
    Generate responses for multiple prompts in a batch.

    Args:
        prompts (list): A list of input text prompts for generation.
        loop_count (int): Number of times to loop through the specified layers.
        max_length (int, optional): Maximum length of the generated sequence. Defaults to the value in generation_config.
        temperature (float, optional): Sampling temperature. Defaults to the value in generation_config.
        top_p (float, optional): Nucleus sampling probability. Defaults to the value in generation_config.

    Returns:
        list: A list of generated texts corresponding to the input prompts.
    """
    # Update generation parameters with provided values or defaults
    gen_params = generation_config.copy()
    if max_length is not None:
        gen_params["max_length"] = max_length
    if temperature is not None:
        gen_params["temperature"] = temperature
    if top_p is not None:
        gen_params["top_p"] = top_p

    # Tokenize the input prompts
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)
    input_ids = inputs["input_ids"].to(accelerator.device)
    attention_mask = inputs["attention_mask"].to(accelerator.device)

    # Add loop count to the model's configuration
    if hasattr(model, "config"):
        model.config.k = loop_count

    # Generate text in batch
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **gen_params
        )

    # Decode the generated tokens
    generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    return generated_texts

# Define widgets for batch processing
batch_prompts_input = widgets.Textarea(
    value="Prompt 1\nPrompt 2\nPrompt 3",
    placeholder="Enter one prompt per line...",
    description="Prompts:",
    layout=widgets.Layout(width="100%", height="150px")
)

batch_loop_count_slider = widgets.IntSlider(
    value=1,
    min=1,
    max=layer_looping_config["max_loop_count"],
    step=1,
    description="Loop Count:",
    continuous_update=False
)

batch_generate_button = widgets.Button(
    description="Generate Batch",
    button_style="success",
    tooltip="Click to generate text for all prompts",
    icon="rocket"
)

batch_output_area = widgets.Output()

# Define the callback function for batch processing
def on_batch_generate_button_click(b):
    with batch_output_area:
        batch_output_area.clear_output()
        prompts = batch_prompts_input.value.strip().split("\n")
        loop_count = batch_loop_count_slider.value
        max_length = max_length_slider.value
        temperature = temperature_slider.value
        top_p = top_p_slider.value
        
        if not prompts or all(not prompt.strip() for prompt in prompts):
            print("Please enter at least one valid prompt.")
            return
        
        print("Generating text for batch...")
        generated_texts = generate_batch(
            prompts=prompts,
            loop_count=loop_count,
            max_length=max_length,
            temperature=temperature,
            top_p=top_p
        )
        
        for i, (prompt, text) in enumerate(zip(prompts, generated_texts), 1):
            print(f"\nPrompt {i}: {prompt}\n{'-' * 20}\n{text}")

# Attach the callback to the button
batch_generate_button.on_click(on_batch_generate_button_click)

# Display the batch processing interface
display(
    widgets.VBox([
        batch_prompts_input,
        batch_loop_count_slider,
        max_length_slider,
        temperature_slider,
        top_p_slider,
        batch_generate_button,
        batch_output_area
    ])
)