# Gemma Small Model Inference

This notebook demonstrates inference using Google's Gemma small variant (2B parameters).

In [None]:
from huggingface_hub import login
login()

In [None]:
# Install required dependencies
!uv pip install -q transformers accelerate torch

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
device = "cuda"

In [None]:
# Load Gemma 2B model (small variant)

MODEL_ID = "google/gemma-2-2b-it"  # Instruction-tuned version (recommended)
# Alternative: "google/gemma-2-2b" for base model

print(f"Loading tokenizer for {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

print(f"Loading model {MODEL_ID}...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model.eval()

## Inference Function

In [None]:
def generate_response(prompt, max_new_tokens=256, temperature=0.7, top_p=0.9, return_full=False):
    """
    Generate a response from Gemma model.
    
    Args:
        prompt: Input text prompt
        max_new_tokens: Maximum number of tokens to generate
        temperature: Sampling temperature (higher = more creative)
        top_p: Nucleus sampling parameter
        return_full: If True, returns both full response and suffix. If False, returns only suffix.
    
    Returns:
        If return_full=True: (full_response, suffix)
        If return_full=False: suffix (generated text only)
    """
    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_length = inputs['input_ids'].shape[1]
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    # Decode full output
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract only the generated part (suffix)
    generated_ids = outputs[0][input_length:]
    suffix = tokenizer.decode(generated_ids, skip_special_tokens=True)
    
    if return_full:
        return full_response, suffix
    else:
        return suffix

## Example Usage

By default, the function returns only the generated suffix (new text). Use `return_full=True` to get both the full response and the suffix separately.

In [None]:
# Example 1: Get only the suffix (generated text)
prompt = "Explain what machine learning is in simple terms."
suffix = generate_response(prompt, max_new_tokens=150)
print("Prompt:", prompt)
print("\nGenerated suffix:")
print(suffix)

In [None]:
# Example 2: Get both full response and suffix
prompt = "Write a short story about a robot learning to paint."
full_response, suffix = generate_response(prompt, max_new_tokens=200, temperature=0.8, return_full=True)
print("Prompt:", prompt)
print("\nFull response:")
print(full_response)
print("\n" + "="*50)
print("Generated suffix only:")
print(suffix)

In [None]:
# Example 3: Code generation (suffix only)
prompt = "Create a simple Python function to calculate the factorial of a number."
suffix = generate_response(prompt, max_new_tokens=200)
print("Prompt:", prompt)
print("\nGenerated code:")
print(suffix)