In [None]:
# ==============================================================
# 1.  System packages (run once ‚Äì nothing special for this model)
# ==============================================================

!apt-get -qq update && apt-get -qq install -y git wget   # quiet apt update, install git & wget (useful for debugging)


# ==============================================================
# 2.  Install the *latest* Python dependencies (run once)
#    ‚Äì tokenizer + generation utilities
#    ‚Äì Optimum wrapper + ONNX‚ÄëRuntime integration
# ==============================================================

# NOTE:
#   * We install the newest releases of `transformers` and `optimum`.
#   * `optimum` (v1.16.0+) has been updated to work with the latest
#     `transformers` where `cached_property` was moved to `functools`.
#   * If you ever run into the same import error again, the small
#     monkey‚Äëpatch below will fix it for you.
!pip install -q --upgrade \
    "transformers" \
    "huggingface_hub" \
    "optimum[onnxruntime]" \
    "sentencepiece"   # needed for some tokenizers


# ==============================================================
# 3.  Install the correct ONNX Runtime build (CPU or GPU)
# ==============================================================

import sys, subprocess, os, torch, numpy as np # Import necessary modules

def pip_install(pkgs): # Define a helper function for quiet pip installs
    """Quiet pip install in a subprocess."""
    # Execute pip install command silently in a subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q"] + pkgs)

# Detect whether a CUDA device is available (torch is already installed by Optimum)
cuda_available = torch.cuda.is_available() # Check if a CUDA-enabled GPU is available
print(f"CUDA available: {cuda_available}") # Print the CUDA availability status

# Install the matching ONNX‚ÄëRuntime package
if cuda_available: # If CUDA is available
    pip_install(["onnxruntime-gpu"])   # Install the GPU-enabled ONNX Runtime
elif 'onnxruntime' not in sys.modules: # Else if ONNX Runtime is not already installed
    pip_install(["onnxruntime"])       # Install the CPU-only ONNX Runtime


# ==============================================================
# 4.  Compatibility patch for newer `transformers`
# ==============================================================

# Newer versions of ü§ó‚ÄØTransformers (‚â•‚ÄØ4.36) removed the
# `transformers.utils.cached_property` helper.  Optimum (and some
# older user code) still expects it, so we add a tiny shim that
# points to the standard library implementation.
import transformers # Import the transformers library
if not hasattr(transformers.utils, "cached_property"): # Check if cached_property is missing from transformers.utils
    from functools import cached_property as _cached_property # Import cached_property from functools
    transformers.utils.cached_property = _cached_property # Assign it to transformers.utils
    print("Patched transformers.utils.cached_property ‚Üí functools.cached_property") # Print confirmation of patch


# ==============================================================
# 5.  Download the GPT‚Äë2 ONNX repository from the Hub
# ==============================================================

from huggingface_hub import snapshot_download # Import snapshot_download for downloading Hugging Face repos

repo_id = "onnx-community/gpt2-ONNX"   # HF repo that contains the exported ONNX files (Define the Hugging Face repository ID)

# Keep only the files we actually need ‚Äì this speeds up the download
onnx_dir = snapshot_download( # Download a snapshot of the repository
    repo_id=repo_id, # Specify the repository ID
    allow_patterns=[ # Specify patterns for files to include in the download
        "*.onnx",               # all ONNX model files (model, fp16, q4, q4f16, ‚Ä¶)
        "*_data",               # external weight blobs referenced by the .onnx files
        "config.json", # Configuration file
        "generation_config.json", # Generation configuration file
        "tokenizer_config.json", # Tokenizer configuration file
        "special_tokens_map.json", # Special tokens mapping
        "added_tokens.json", # Added tokens file
        "vocab.json", # Vocabulary file
        "merges.txt", # Merges file for BPE tokenizers
        "tokenizer.json", # Tokenizer definition file
        "tokenizer.model",      # SentencePiece vocab (if present)
    ],
    local_dir="./gpt2-ONNX",   # where the repo will be stored locally (Set the local directory for download)
    cache_dir="./hf_cache",    # shared cache folder (speeds up re‚Äëruns) (Set the cache directory)
    resume_download=True, # Enable resuming interrupted downloads
)
print("Repo downloaded to:", onnx_dir) # Print the local path where the repository is downloaded


# ==============================================================
# 6.  Load tokenizer & ONNX‚Äëruntime‚Äëbacked model (Optimum)
# ==============================================================

from transformers import AutoTokenizer, GenerationConfig # Import GPT2TokenizerFast and GenerationConfig from transformers
from optimum.onnxruntime import ORTModelForCausalLM # Import ORTModelForCausalLM from optimum.onnxruntime

# Tokenizer files live at the repo root, so we can point the tokenizer directly at `onnx_dir`
tokenizer = AutoTokenizer.from_pretrained(onnx_dir) # Load the tokenizer from the downloaded ONNX directory

# Choose the execution provider list
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda_available else ["CPUExecutionProvider"] # Determine ONNX Runtime providers based on CUDA availability

# Load the ONNX model.  `file_name` defaults to `model.onnx` (full‚Äëprecision).
# You can switch to a quantised version (e.g. "model_q4.onnx") by changing the argument.
model = ORTModelForCausalLM.from_pretrained( # Load the ONNX model for causal language modeling
    onnx_dir, # Specify the directory containing the ONNX model files
    file_name="model.onnx",          # change to "model_q4.onnx", "model_fp16.onnx", ‚Ä¶ if you want a smaller model (Specify the ONNX model file name)
    provider=providers[0]            # Optimum expects a *single* provider string (Set the execution provider)
)

# Load generation defaults if the repo ships a generation_config.json
gen_cfg_path = os.path.join(onnx_dir, "generation_config.json") # Construct the path to generation_config.json
if os.path.isfile(gen_cfg_path): # Check if the generation config file exists
    generation_config = GenerationConfig.from_pretrained(onnx_dir) # Load generation config from the directory
else: # If the generation config file does not exist
    generation_config = GenerationConfig()   # empty fallback (Create an empty GenerationConfig)


# ==============================================================
# 7.  Simple generation wrapper
# ==============================================================

import textwrap # Import the textwrap module for text formatting

def generate_text( # Define the text generation function
    prompt: str, # Input prompt string
    max_new_tokens: int = 50, # Maximum number of new tokens to generate
    temperature: float = 0.8, # Sampling temperature for controlling randomness
    top_k: int = 50, # Number of top-k tokens to consider for sampling
    stop_token: str | None = None, # Optional token to stop generation at
    wrap_width: int | None = 80, # Added new parameter for text wrapping width
    **extra_kwargs, # Additional keyword arguments for model.generate
) -> str: # Function returns a string
    """
    Generate text using the ONNX‚Äëruntime‚Äëbacked GPT‚Äë2 model.

    Parameters
    ----------
    prompt : str
        Text to condition on.
    max_new_tokens : int
        Number of tokens to generate (excluding the prompt).
    temperature : float
        Sampling temperature (>0). 1.0 = no scaling.
    top_k : int
        Keep only the top‚Äëk tokens at each step (0 = keep all).
    stop_token : str | None
        If given, generation stops when this token appears in the output.
    wrap_width : int | None
        If given, the output text will be wrapped to this many characters.
    extra_kwargs : dict
        Any additional arguments accepted by `model.generate`
        (e.g. `num_beams=4`, `repetition_penalty=1.2`, ‚Ä¶).

    Returns
    -------
    str
        The full decoded text (prompt + continuation).
    """
    # Tokenise the prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt")   # shape (1, seq_len) (Encode the prompt into input IDs)

    # Build a GenerationConfig that overrides the defaults we care about
    cfg = GenerationConfig(**generation_config.to_dict()) # Create a new GenerationConfig from the existing one's dictionary
    cfg.max_new_tokens = max_new_tokens # Set the maximum number of new tokens
    cfg.temperature = temperature # Set the sampling temperature
    cfg.top_k = top_k # Set the top-k value for sampling
    cfg.do_sample = temperature != 0.0   # greedy when temperature==0 (Enable sampling if temperature is not 0, else use greedy decoding)
    # Apply any extra kwargs the user passed
    for k, v in extra_kwargs.items(): # Iterate through extra keyword arguments
        setattr(cfg, k, v) # Set each extra argument as an attribute of the generation config

    # Run generation
    output_ids = model.generate(input_ids, generation_config=cfg) # Generate output tokens using the model and configuration

    # Decode everything (including the original prompt)
    full_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) # Decode the generated tokens into text

    # Optional stop‚Äëtoken truncation
    if stop_token is not None: # If a stop token is provided
        idx = full_text.find(stop_token) # Find the index of the stop token
        if idx != -1: # If the stop token is found
            full_text = full_text[: idx + len(stop_token)] # Truncate the text at the stop token

    # Apply text wrapping if wrap_width is specified
    if wrap_width is not None: # If a wrap width is specified
        full_text = textwrap.fill(full_text, width=wrap_width) # Wrap the text to the specified width

    return full_text # Return the generated and formatted text


# ==============================================================
# 8.  Quick demo
# ==============================================================

prompt = "Once upon a time in a distant galaxy" # Define the prompt for text generation
print("üñäÔ∏è Prompt:", prompt) # Print the prompt
print("\nü§ñ Generation:\n") # Print a header for the generation output
print( # Print the generated text
    generate_text( # Call the generate_text function
        prompt, # Pass the prompt
        max_new_tokens=500, # Generate up to 80 new tokens
        temperature=0.9, # Set sampling temperature to 0.9
        top_k=50, # Consider top 50 tokens for sampling
        stop_token=None,   # e.g. set to "." to stop at the first period (No stop token for this demo)
        wrap_width=90, # Set a default wrap width for the demo (Wrap output to 80 characters)
    )
)