# Hands-On Tutorial: Chat Templates with SmolLM2 **(+ Streaming Datasets)**

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import setup_chat_format
from datasets import load_dataset
import torch

torch.set_float32_matmul_precision("high")

In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else (
        "mps"
        if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
        else "cpu"
    )
)
print("Using device:", device)

model_name = "google/gemma-3-1b-it"

model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Using device: cuda


model.safetensors:   0%|          | 0.00/2.00G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

In [8]:
getattr(tokenizer, "chat_template", None)

'{{ bos_token }}\n{%- if messages[0][\'role\'] == \'system\' -%}\n    {%- if messages[0][\'content\'] is string -%}\n        {%- set first_user_prefix = messages[0][\'content\'] + \'\n\n\' -%}\n    {%- else -%}\n        {%- set first_user_prefix = messages[0][\'content\'][0][\'text\'] + \'\n\n\' -%}\n    {%- endif -%}\n    {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n    {%- set first_user_prefix = "" -%}\n    {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n    {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) -%}\n        {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}\n    {%- endif -%}\n    {%- if (message[\'role\'] == \'assistant\') -%}\n        {%- set role = "model" -%}\n    {%- else -%}\n        {%- set role = message[\'role\'] -%}\n    {%- endif -%}\n    {{ \'<start_of_turn>\' + role + \'\n\' + (first_user_prefix if loop.first else "") }}\n    {%- if message[\'conte

In [9]:
# Template *before* setup
print("Before setup:\n", tokenizer.chat_template)

# Enable chat template support (adds/aligns special tokens and template behavior)
# If this is not configured, tokenizer.apply_chat_template() will break
if not getattr(tokenizer, "chat_template", None):
    model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)

# Template *after* setup
print("After setup:\n", tokenizer.chat_template)

Before setup:
 {{ bos_token }}
{%- if messages[0]['role'] == 'system' -%}
    {%- if messages[0]['content'] is string -%}
        {%- set first_user_prefix = messages[0]['content'] + '

' -%}
    {%- else -%}
        {%- set first_user_prefix = messages[0]['content'][0]['text'] + '

' -%}
    {%- endif -%}
    {%- set loop_messages = messages[1:] -%}
{%- else -%}
    {%- set first_user_prefix = "" -%}
    {%- set loop_messages = messages -%}
{%- endif -%}
{%- for message in loop_messages -%}
    {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
        {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
    {%- endif -%}
    {%- if (message['role'] == 'assistant') -%}
        {%- set role = "model" -%}
    {%- else -%}
        {%- set role = message['role'] -%}
    {%- endif -%}
    {{ '<start_of_turn>' + role + '
' + (first_user_prefix if loop.first else "") }}
    {%- if message['content'] is string -%}
        {{ message['conte

In [10]:
prompt = "Explain what chat templates are in relation to large language models (LLMs)."
messages = [
    {
        "role": "system",
        "content": "You are a helpful assistant that explains things clearly.",
    },
    {"role": "user", "content": prompt},
]

formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
print("Formatted prompt preview:\n")
print(formatted_prompt[:200])

Formatted prompt preview:

<bos><start_of_turn>user
You are a helpful assistant that explains things clearly.

Explain what chat templates are in relation to large language models (LLMs).<end_of_turn>



In [13]:
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
model.eval()
with torch.no_grad():
    outputs = model.generate(
        **inputs, max_new_tokens=120, do_sample=True, temperature=0.7, top_p=0.9
    )

print("=== Model Output (with template) ===\n")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

=== Model Output (with template) ===

user
You are a helpful assistant that explains things clearly.

Explain what chat templates are in relation to large language models (LLMs).
Okay, let's dive into chat templates.  I'm going to explain it in a way that's easy to understand.

**What are Chat Templates?**

Imagine you're talking to a very smart, but sometimes a little literal, assistant. Chat templates are like a pre-written set of phrases and instructions that you give to an LLM (like ChatGPT, Bard, etc.) to guide its responses. 

Think of it like this: you're giving the LLM a recipe. You tell it *what* you want, but you don't


In [14]:
raw_prompt = prompt
raw_inputs = tokenizer(raw_prompt, return_tensors="pt").to(device)
with torch.no_grad():
    raw_outputs = model.generate(
        **raw_inputs, max_new_tokens=120, do_sample=True, temperature=0.7, top_p=0.9
    )

print("=== Model Output (without template) ===\n")
print(tokenizer.decode(raw_outputs[0], skip_special_tokens=True))

=== Model Output (without template) ===

Explain what chat templates are in relation to large language models (LLMs).

**Chat Templates**

Chat templates are a way to structure and streamline conversations with large language models (LLMs) like GPT-3, Bard, and others. They're essentially pre-defined sets of prompts and instructions that guide the LLM towards a specific task or output.  Think of them as a blueprint for the conversation.

Here's a breakdown of key aspects:

* **Structure:** They define the *flow* of a conversation.  They dictate what the LLM should respond to and how it should respond.
* **Specificity:** They're highly targeted


In [15]:
conversation = [
    {"role": "system", "content": "You are a witty assistant that answers concisely."},
    {"role": "user", "content": "What is quantum computing?"},
    {
        "role": "assistant",
        "content": "Quantum computing uses quantum bits (qubits) that can be in multiple states at once.",
    },
    {"role": "user", "content": "Explain it like I'm 5."},
]

mt_prompt = tokenizer.apply_chat_template(conversation, tokenize=False)
mt_inputs = tokenizer(mt_prompt, return_tensors="pt").to(device)

with torch.no_grad():
    mt_outputs = model.generate(
        **mt_inputs, max_new_tokens=100, do_sample=True, temperature=0.7, top_p=0.9
    )

print("=== Multi-Turn Output ===\n")
print(tokenizer.decode(mt_outputs[0], skip_special_tokens=True))

=== Multi-Turn Output ===

user
You are a witty assistant that answers concisely.

What is quantum computing?
model
Quantum computing uses quantum bits (qubits) that can be in multiple states at once.
user
Explain it like I'm 5.
 
Imagine you have a light switch. It can be either on OR off, right? 

Quantum computers are like having a light switch that can be *both* on and off *at the same time*!  It's a bit weird, but it lets them solve really, really complicated problems much faster than regular computers. 




In [16]:
from itertools import islice

# Stream a dataset (no full download). You can swap for other datasets.
streamed = load_dataset(
    "databricks/databricks-dolly-15k", split="train", streaming=True
)

# Peek a few rows to understand the schema
for i, row in enumerate(islice(streamed, 3)):
    print(f"[{i}] keys:", list(row.keys()))
    print("instruction:", row.get("instruction"))
    print("context:", row.get("context"))
    print("response:", (row.get("response") or "")[:120], "...")
    print("-" * 80)

[0] keys: ['instruction', 'context', 'response', 'category']
instruction: When did Virgin Australia start operating?
context: Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.
response: Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. ...
--------------------------------------------------------------------------------
[1] keys: ['instruction', 'context', 'response', 'category']
instruction: Which is a species of fish? Tope or Rope
context: 
response: Tope ...
-----------------

In [17]:
def row_to_messages(row):
    sys_msg = "You are a helpful assistant that follows instructions carefully."
    instr = row.get("instruction") or ""
    ctx = row.get("context") or ""
    user_msg = instr if not ctx else f"{instr}\n\nContext:\n{ctx}"
    asst_msg = row.get("response") or ""

    messages = [
        {"role": "system", "content": sys_msg},
        {"role": "user", "content": user_msg},
    ]

    # Include gold answer for inspection; for inference we usually omit it
    if asst_msg.strip():
        messages.append({"role": "assistant", "content": asst_msg})

    return messages


# Smoke test on a couple of samples
streamed = load_dataset(
    "databricks/databricks-dolly-15k", split="train", streaming=True
)
for i, row in enumerate(islice(streamed, 2)):
    msgs = row_to_messages(row)
    formatted = tokenizer.apply_chat_template(msgs, tokenize=False)
    print(f"---- Sample {i} (formatted preview) ----\n{formatted[:400]} ...\n")

---- Sample 0 (formatted preview) ----
<bos><start_of_turn>user
You are a helpful assistant that follows instructions carefully.

When did Virgin Australia start operating?

Context:
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single  ...

---- Sample 1 (formatted preview) ----
<bos><start_of_turn>user
You are a helpful assistant that follows instructions carefully.

Which is a species of fish? Tope or Rope<end_of_turn>
<start_of_turn>model
Tope<end_of_turn>
 ...



### On-the-Fly Tokenization & Inference Over the Stream

In [18]:
# Re-create the stream (iterators are single-pass)
streamed = load_dataset(
    "databricks/databricks-dolly-15k", split="train", streaming=True
)

for i, row in enumerate(islice(streamed, 3)):  # increase this number as desired
    messages = row_to_messages(row)
    # For inference, stop at the user turn (omit gold assistant response)
    messages = messages[:2]

    prompt = tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs, max_new_tokens=100, do_sample=True, temperature=0.7, top_p=0.9
        )

    print(f"\n===== Streamed Sample #{i} =====")
    print("User prompt:\n", messages[-1]["content"][:500])
    print(
        "\n### Model response ###:\n",
        tokenizer.decode(outputs[0], skip_special_tokens=True),
    )


===== Streamed Sample #0 =====
User prompt:
 When did Virgin Australia start operating?

Context:
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 ci

### Model response ###:
 user
You are a helpful assistant that follows instructions carefully.

When did Virgin Australia start operating?

Context:
Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airl

### Optional: PyTorch IterableDataset for Batched Streaming

In [19]:
from torch.utils.data import IterableDataset, DataLoader


class ChatStreamDataset(IterableDataset):
    def __init__(
        self, hf_builder, tokenizer, device, max_length=1024, include_answer=False
    ):
        self.hf_builder = hf_builder
        self.tokenizer = tokenizer
        self.device = device
        self.max_length = max_length
        self.include_answer = include_answer

    def __iter__(self):
        # Fresh stream per worker/iterator
        stream = self.hf_builder()
        for row in stream:
            messages = row_to_messages(row)
            if not self.include_answer:
                messages = messages[:2]  # [system, user]

            prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
            batch_inputs = self.tokenizer(
                prompt, return_tensors="pt", truncation=True, max_length=self.max_length
            )
            # Move to device lazily
            batch_inputs = {k: v.to(self.device) for k, v in batch_inputs.items()}
            yield batch_inputs


def make_dolly_stream():
    return load_dataset(
        "databricks/databricks-dolly-15k", split="train", streaming=True
    )


iterable_ds = ChatStreamDataset(
    hf_builder=make_dolly_stream,
    tokenizer=tokenizer,
    device=device,
    max_length=256,
    include_answer=False,
)


def collate_fn(batch):
    input_ids = [item["input_ids"][0] for item in batch]
    attn_mask = [item["attention_mask"][0] for item in batch]
    padded = tokenizer.pad(
        {"input_ids": input_ids, "attention_mask": attn_mask},
        padding=True,
        return_tensors="pt",
    )
    return {k: v.to(device) for k, v in padded.items()}


loader = DataLoader(iterable_ds, batch_size=4, collate_fn=collate_fn)

model.eval()
for bi, batch in enumerate(loader):
    with torch.no_grad():
        _ = model(**batch)
    print(
        f"Processed batch {bi} with shapes:",
        {k: tuple(v.shape) for k, v in batch.items()},
    )
    if bi >= 2:
        break

You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Processed batch 0 with shapes: {'input_ids': (4, 141), 'attention_mask': (4, 141)}
Processed batch 1 with shapes: {'input_ids': (4, 256), 'attention_mask': (4, 256)}
Processed batch 2 with shapes: {'input_ids': (4, 256), 'attention_mask': (4, 256)}



---

**Finished:** 2025-08-22 12:56 UTC

**Next steps:**
- Try different datasets (e.g., `OpenAssistant/oasst1`) and adapt `row_to_messages`.
- Experiment with system prompts to shift tone and style.
- Integrate the streaming `IterableDataset` with training loops (TRL or custom).

Happy building! 🚀
