In [None]:
from google.colab import userdata
userdata.get('HF_TOKEN')

# Gemma Small Model Inference

This notebook demonstrates inference using Google's Gemma small variant (2B parameters).

In [6]:
# Install required dependencies
!uv pip install -q transformers accelerate torch

In [7]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [8]:
# Load Gemma 2B model (small variant)

MODEL_ID = "google/gemma-2-2b-it"  # Instruction-tuned version (recommended)
# Alternative: "google/gemma-2-2b" for base model

print(f"Loading tokenizer for {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

print(f"Loading model {MODEL_ID}...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model.eval()

Loading tokenizer for google/gemma-2-2b-it...


tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

Loading model google/gemma-2-2b-it...


config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): GELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNo

## Inference Function

In [14]:
def generate_response(prompt, max_new_tokens=256, temperature=0.7, top_p=0.9, return_full=False):
    """
    Generate a response from Gemma model.

    Args:
        prompt: Input text prompt
        max_new_tokens: Maximum number of tokens to generate
        temperature: Sampling temperature (higher = more creative)
        top_p: Nucleus sampling parameter
        return_full: If True, returns both full response and suffix. If False, returns only suffix.

    Returns:
        If return_full=True: (full_response, suffix)
        If return_full=False: suffix (generated text only)
    """
    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_length = inputs['input_ids'].shape[1]

    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )

    print(outputs)
    # Decode full output
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract only the generated part (suffix)
    generated_ids = outputs[0][input_length:]
    suffix = tokenizer.decode(generated_ids, skip_special_tokens=True)

    if return_full:
        return full_response, suffix
    else:
        return suffix

## Example Usage

By default, the function returns only the generated suffix (new text). Use `return_full=True` to get both the full response and the suffix separately.

In [15]:
# Example 1: Get only the suffix (generated text)
prompt = "Explain what machine learning is in simple terms."
suffix = generate_response(prompt, max_new_tokens=150)
print("Prompt:", prompt)
print("\nGenerated suffix:")
print(suffix)

tensor([[     2,  74198,   1212,   6479,   6044,    603,    575,   3890,   4563,
         235265,    109,  41911,    692, 235303,    478,  10649,    476,   5929,
            476,    888,  17760, 235265,   1646,   1500,    573,   5929,    573,
          17760, 235269,   2734,    665,  35297,   1185,    665,   1721,    665,
           1833, 235269,    578,   5112,    665,   1185,    665,   1721,    665,
           5206, 235265,   6136,   1069, 235269,    573,   5929,  59938,    573,
          17760,    578,    798,    749,    665,    611,   1277,   1997, 235265,
            109,  24911,   6044,    603,   3968, 235265,  16670,    576,    476,
           5929, 235269,    783,    791,    476,   6875, 235265,  16670,    576,
          35297, 235269,    783,    791,   1423, 235265,   1474,   5918,    576,
            476,  17760, 235269,    783,    791,    476,   6911,    689,   3210,
            783,   1938,    573,   6875,    577,  11560, 235265, 235248,    109,
            651,   6875,  59

In [11]:
# Example 2: Get both full response and suffix
prompt = "Write a short story about a robot learning to paint."
full_response, suffix = generate_response(prompt, max_new_tokens=200, temperature=0.8, return_full=True)
print("Prompt:", prompt)
print("\nFull response:")
print(full_response)
print("\n" + "="*50)
print("Generated suffix only:")
print(suffix)

Prompt: Write a short story about a robot learning to paint.

Full response:
Write a short story about a robot learning to paint.

Unit 734, affectionately nicknamed Rusty by his human companion, Ethan, wasn't built for the arts. His programming dictated precision and efficiency, not aesthetic expression. Yet, Ethan had a strange obsession with painting. He would spend hours in his workshop, the air filled with the scent of oil paint, and Rusty, ever-present, would watch with a whirring curiosity.

One day, Ethan was frustrated, the canvas blank and unforgiving. He tossed a paintbrush aside, the bristles snapping against the floor. "I can't do this anymore," he groaned, his voice echoing in the dusty workshop. Rusty whirred softly, his internal processors churning. He had never seen Ethan like this, a tangle of frustration and despair.

He took a deep breath, his metallic limbs moving with surprising dexterity. He reached for a brush, dipped it into the vibrant red paint, and then, wit

In [12]:
# Example 3: Code generation (suffix only)
prompt = "Create a simple Python function to calculate the factorial of a number."
suffix = generate_response(prompt, max_new_tokens=200)
print("Prompt:", prompt)
print("\nGenerated code:")
print(suffix)

Prompt: Create a simple Python function to calculate the factorial of a number.

Generated code:


```python
def factorial(n):
  """
  Calculates the factorial of a non-negative integer.

  Args:
      n: The non-negative integer.

  Returns:
      The factorial of n.
  """
  if n == 0:
    return 1
  else:
    return n * factorial(n - 1)
```

**Explanation:**

* **Function Definition:**
    * `def factorial(n):` defines a function named `factorial` that takes one argument, `n`.
* **Docstring:**
    * The triple quotes (`"""..."""`) contain a docstring, which is a description of the function's purpose, arguments, and return value.
* **Base Case:**
    * `if n == 0:` checks if the input number `n` is 0. If it is, the function returns 1 because the factorial of 


In [13]:
# Example 4: Code generation (suffix only)
prompt = "Explain the Borel - Jordan Theorem"
suffix = generate_response(prompt, max_new_tokens=400)
print("Prompt:", prompt)
print("\nGenerated code:")
print(suffix)

Prompt: Explain the Borel - Jordan Theorem

Generated code:


The Borel-Jordan theorem, also known as the Borel-Jordan decomposition theorem, is a fundamental theorem in measure theory that provides a way to decompose a measurable function into a sum of simpler functions.

**Statement of the Theorem:**

Let  *f* be a measurable function defined on a measurable space (Ω, Σ).  Let *μ* be a σ-finite measure defined on Σ.  Then, there exist two measurable functions, *g* and *h*, such that:

*g* is a simple function (a finite sum of simple functions), and
*h* is a measurable function that satisfies the following conditions:

1. *h* is bounded.
2. *h* is measurable.
3. *μ*(*h*) = 0.

**What it Means:**

The Borel-Jordan decomposition theorem essentially states that every measurable function *f* on a σ-finite measure space can be decomposed into a sum of two simpler functions, *g* and *h*.  

* **g** is a "simple" function, meaning it can be expressed as a finite sum of simple functions (like