
# Hands-On Tutorial: Chat Templates with SmolLM2 **(+ Streaming Datasets)**

This end-to-end notebook teaches you how to:
1. Load **SmolLM2-135M** and enable **chat templates** with `trl.setup_chat_format`.
2. Generate text **with and without** templates to see the difference.
3. Run **multi-turn conversations** using a structured `messages` format.
4. **Stream datasets** from the Hugging Face Hub (`datasets.load_dataset(..., streaming=True)`), convert rows into chat messages, and **tokenize on the fly**.
5. (Optional) Wrap a stream into a **PyTorch `IterableDataset`** for batched inference/eval.

> Tip: You can run this on CPU/GPU. GPU is recommended for speed.


## 1) Setup & Authentication

In [None]:
# If running on a fresh environment (e.g., Colab), uncomment:
# !pip install -q transformers datasets trl huggingface_hub

from huggingface_hub import login

# login()  # Uncomment and run to authenticate with your HF token

## 2) Imports

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import setup_chat_format
from datasets import load_dataset
import torch

## 3) Device, Model & Tokenizer

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else (
        "mps"
        if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available()
        else "cpu"
    )
)
print("Using device:", device)

model_name = "HuggingFaceTB/SmolLM2-135M"

model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Enable chat template support (adds/aligns special tokens and template behavior)
model, tokenizer = setup_chat_format(model=model, tokenizer=tokenizer)

## 4) What Are Chat Templates? Format a Prompt

In [None]:
messages = [
    {
        "role": "system",
        "content": "You are a helpful assistant that explains things clearly.",
    },
    {
        "role": "user",
        "content": "Explain the importance of chat templates in language models.",
    },
]

formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False)
print("Formatted prompt preview:\n")
print(formatted_prompt[:800])

## 5) Generate with a Chat Template

In [None]:
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
model.eval()
with torch.no_grad():
    outputs = model.generate(
        **inputs, max_new_tokens=150, do_sample=True, temperature=0.7, top_p=0.9
    )

print("=== Model Output (with template) ===\n")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

## 6) Compare: Without a Chat Template

In [None]:
raw_prompt = "Explain the importance of chat templates in language models."
raw_inputs = tokenizer(raw_prompt, return_tensors="pt").to(device)
with torch.no_grad():
    raw_outputs = model.generate(
        **raw_inputs, max_new_tokens=150, do_sample=True, temperature=0.7, top_p=0.9
    )

print("=== Model Output (without template) ===\n")
print(tokenizer.decode(raw_outputs[0], skip_special_tokens=True))

## 7) Multi-Turn Conversation

In [None]:
conversation = [
    {"role": "system", "content": "You are a witty assistant that answers concisely."},
    {"role": "user", "content": "What is quantum computing?"},
    {
        "role": "assistant",
        "content": "Quantum computing uses quantum bits (qubits) that can be in multiple states at once.",
    },
    {"role": "user", "content": "Explain it like I'm 5."},
]

mt_prompt = tokenizer.apply_chat_template(conversation, tokenize=False)
mt_inputs = tokenizer(mt_prompt, return_tensors="pt").to(device)

with torch.no_grad():
    mt_outputs = model.generate(
        **mt_inputs, max_new_tokens=120, do_sample=True, temperature=0.7, top_p=0.9
    )

print("=== Multi-Turn Output ===\n")
print(tokenizer.decode(mt_outputs[0], skip_special_tokens=True))


## 8) Why Templates Matter (Quick Notes)

- **Role-aware structure** helps the model separate instructions from answers.
- **Consistency** across training/inference improves output stability.
- **Fair evaluation**: using templates standardizes inputs across different models.


## 9) Stream a Dataset from the Hugging Face Hub (No Full Download)

In [None]:
from itertools import islice

# Stream a dataset (no full download). You can swap for other datasets.
streamed = load_dataset(
    "databricks/databricks-dolly-15k", split="train", streaming=True
)

# Peek a few rows to understand the schema
for i, row in enumerate(islice(streamed, 3)):
    print(f"[{i}] keys:", list(row.keys()))
    print("instruction:", row.get("instruction"))
    print("context:", row.get("context"))
    print("response:", (row.get("response") or "")[:120], "...")
    print("-" * 80)

### Helper: Convert Streamed Rows → Chat Messages

In [None]:
def row_to_messages(row):
    sys_msg = "You are a helpful assistant that follows instructions carefully."
    instr = row.get("instruction") or ""
    ctx = row.get("context") or ""
    user_msg = instr if not ctx else f"{instr}\n\nContext:\n{ctx}"
    asst_msg = row.get("response") or ""

    messages = [
        {"role": "system", "content": sys_msg},
        {"role": "user", "content": user_msg},
    ]

    # Include gold answer for inspection; for inference we usually omit it
    if asst_msg.strip():
        messages.append({"role": "assistant", "content": asst_msg})

    return messages


# Smoke test on a couple of samples
streamed = load_dataset(
    "databricks/databricks-dolly-15k", split="train", streaming=True
)
for i, row in enumerate(islice(streamed, 2)):
    msgs = row_to_messages(row)
    formatted = tokenizer.apply_chat_template(msgs, tokenize=False)
    print(f"---- Sample {i} (formatted preview) ----\n{formatted[:400]} ...\n")

### On-the-Fly Tokenization & Inference Over the Stream

In [None]:
# Re-create the stream (iterators are single-pass)
streamed = load_dataset(
    "databricks/databricks-dolly-15k", split="train", streaming=True
)

for i, row in enumerate(islice(streamed, 3)):  # increase this number as desired
    messages = row_to_messages(row)
    # For inference, stop at the user turn (omit gold assistant response)
    messages = messages[:2]

    prompt = tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs, max_new_tokens=128, do_sample=True, temperature=0.7, top_p=0.9
        )

    print(f"\n===== Streamed Sample #{i} =====")
    print("User prompt:\n", messages[-1]["content"][:500])
    print("\nModel response:\n", tokenizer.decode(outputs[0], skip_special_tokens=True))

### Optional: PyTorch IterableDataset for Batched Streaming

In [None]:
from torch.utils.data import IterableDataset, DataLoader


class ChatStreamDataset(IterableDataset):
    def __init__(
        self, hf_builder, tokenizer, device, max_length=1024, include_answer=False
    ):
        self.hf_builder = hf_builder
        self.tokenizer = tokenizer
        self.device = device
        self.max_length = max_length
        self.include_answer = include_answer

    def __iter__(self):
        # Fresh stream per worker/iterator
        stream = self.hf_builder()
        for row in stream:
            messages = row_to_messages(row)
            if not self.include_answer:
                messages = messages[:2]  # [system, user]

            prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
            batch_inputs = self.tokenizer(
                prompt, return_tensors="pt", truncation=True, max_length=self.max_length
            )
            # Move to device lazily
            batch_inputs = {k: v.to(self.device) for k, v in batch_inputs.items()}
            yield batch_inputs


def make_dolly_stream():
    return load_dataset(
        "databricks/databricks-dolly-15k", split="train", streaming=True
    )


iterable_ds = ChatStreamDataset(
    hf_builder=make_dolly_stream,
    tokenizer=tokenizer,
    device=device,
    max_length=1024,
    include_answer=False,
)


def collate_fn(batch):
    input_ids = [item["input_ids"][0] for item in batch]
    attn_mask = [item["attention_mask"][0] for item in batch]
    padded = tokenizer.pad(
        {"input_ids": input_ids, "attention_mask": attn_mask},
        padding=True,
        return_tensors="pt",
    )
    return {k: v.to(device) for k, v in padded.items()}


loader = DataLoader(iterable_ds, batch_size=4, collate_fn=collate_fn)

model.eval()
for bi, batch in enumerate(loader):
    with torch.no_grad():
        _ = model(**batch)
    print(
        f"Processed batch {bi} with shapes:",
        {k: tuple(v.shape) for k, v in batch.items()},
    )
    if bi >= 2:
        break


---

**Finished:** 2025-08-22 12:56 UTC

**Next steps:**
- Try different datasets (e.g., `OpenAssistant/oasst1`) and adapt `row_to_messages`.
- Experiment with system prompts to shift tone and style.
- Integrate the streaming `IterableDataset` with training loops (TRL or custom).

Happy building! 🚀
