# Generation Test

Tests text generation with `model.generate()` through NDIF.

**Environment Variables:**
- `MODEL_NAME`: Model to test
- `NDIF_API`: NDIF API key
- `HF_TOKEN`: HuggingFace token

In [None]:
import os
import time

MODEL_NAME = os.environ.get("MODEL_NAME", "openai-community/gpt2")
print(f"Testing model: {MODEL_NAME}")

In [None]:
# Configure NDIF
from nnsight import CONFIG

NDIF_API = os.environ.get("NDIF_API")
if NDIF_API:
    CONFIG.set_default_api_key(NDIF_API)
    print("NDIF API key configured")

HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
    os.environ["HF_TOKEN"] = HF_TOKEN
    print("HF_TOKEN configured")

In [None]:
# Load model
from nnsight import LanguageModel

print(f"Loading {MODEL_NAME}...")
start = time.time()
model = LanguageModel(MODEL_NAME, device_map="auto")
load_time = time.time() - start
print(f"Model loaded in {load_time:.1f}s")

In [None]:
# Run generation
prompt = "Once upon a time"
max_new_tokens = 20

print(f"Generating from: '{prompt}'")
print(f"Max new tokens: {max_new_tokens}")

start = time.time()
with model.generate(prompt, max_new_tokens=max_new_tokens, remote=True):
    output_ids = model.generator.output.save()

gen_time = time.time() - start
print(f"Generation completed in {gen_time:.1f}s")

In [None]:
# Decode and validate output
output_text = model.tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f"\nGenerated text:\n{output_text}")

# Validate
assert len(output_text) > len(prompt), "Output should be longer than prompt"
assert output_text.startswith(prompt) or prompt.lower() in output_text.lower(), \
    "Output should contain the prompt"

# Check we got new tokens
prompt_tokens = len(model.tokenizer.encode(prompt))
output_tokens = len(output_ids[0])
new_tokens = output_tokens - prompt_tokens
print(f"\nPrompt tokens: {prompt_tokens}")
print(f"Output tokens: {output_tokens}")
print(f"New tokens: {new_tokens}")

assert new_tokens > 0, "Should have generated at least 1 new token"

print("\n" + "=" * 40)
print("GENERATION " + "TEST PASSED")
print("=" * 40)