In [None]:
!pip -q install -U transformers accelerate datasets evaluate sentencepiece protobuf pydantic rich

import os, json, re, time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

A1) Choose a small instruct model (3B–7B)

Pick one that fits your Colab GPU. Good defaults:

If you often get T4: start with 3B.

If you have L4/A100 (Colab Pro): you can try 7B.

In [None]:
MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"  # safe Colab choice
# MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"  # if you have stronger GPU

A2) Load model + tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
)
model.eval()

print("Device:", next(model.parameters()).device)


A3) Chat template + system prompts

Most modern instruct models expect a “chat format”. We’ll use the tokenizer’s built-in chat template when available.

In [None]:
SYSTEM_PROMPT = (
    "You are a helpful assistant. "
    "Be concise. If unsure, say you don't know."
)

def build_chat_prompt(user_msg: str, system_msg: str = SYSTEM_PROMPT):
    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": user_msg},
    ]
    # apply_chat_template returns a string prompt; add_generation_prompt adds the assistant role start
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return prompt


A4) Basic generation function (greedy or sampling)

In [None]:
@torch.inference_mode()
def generate_text(user_msg, max_new_tokens=256, temperature=0.7, top_p=0.9, do_sample=True):
    prompt = build_chat_prompt(user_msg)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    gen = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature if do_sample else None,
        top_p=top_p if do_sample else None,
        eos_token_id=tokenizer.eos_token_id,
    )
    out = tokenizer.decode(gen[0], skip_special_tokens=True)

    # For chat templates, the decoded text includes prompt + response; extract response crudely
    # (good enough for Week 1; we’ll improve later)
    return out

print(generate_text("Give me 5 birthday party themes for adults."))


In [None]:
from pydantic import BaseModel, ValidationError
from typing import List

class PartyIdeas(BaseModel):
    themes: List[str]

JSON_SYSTEM = (
    "You are a JSON-only assistant. "
    "Return ONLY valid JSON. No markdown. No extra text."
)

def extract_first_json(text: str):
    # simple heuristic: find first {...}
    m = re.search(r"\{.*\}", text, flags=re.DOTALL)
    return m.group(0) if m else None

resp = generate_text(
    "Generate 5 adult birthday party themes. Output JSON with key 'themes' as a list of strings.",
    max_new_tokens=256,
    temperature=0.2,
    top_p=1.0,
    do_sample=True
)

json_str = extract_first_json(resp)
print("RAW:\n", resp)
print("\nJSON CANDIDATE:\n", json_str)

try:
    obj = PartyIdeas.model_validate_json(json_str)
    print("\nVALIDATED:", obj)
except ValidationError as e:
    print("\nVALIDATION ERROR:", e)


In [None]:
from transformers import TextIteratorStreamer
import threading

def stream_generate(user_msg, max_new_tokens=256, temperature=0.7, top_p=0.9):
    prompt = build_chat_prompt(user_msg)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    kwargs = dict(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        streamer=streamer,
    )

    thread = threading.Thread(target=model.generate, kwargs=kwargs)
    thread.start()

    out = ""
    for chunk in streamer:
        out += chunk
        print(chunk, end="", flush=True)
    print()
    return out

_ = stream_generate("Explain RoPE in simple terms with a short example.", max_new_tokens=200)
