In [2]:
import json
from typing import Any, Dict, List
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import outlines
from outlines.types import JsonSchema

# ---------------------
# Load model once
# ---------------------
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

try:
    model_hf = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map="auto",
        load_in_4bit=True
    )
except Exception:
    model_hf = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map="auto",
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
    )

model = outlines.from_transformers(model_hf, tokenizer)

# ---------------------
# FastAPI setup
# ---------------------
app = FastAPI(title="TinyLlama Constrained API")

class ChatMessage(BaseModel):
    role: str
    content: str

class JsonConstrainedRequest(BaseModel):
    messages: List[ChatMessage]
    json_schema: Dict[str, Any]

@app.post("/constrained/json")
def constrained_json(req: JsonConstrainedRequest):
    try:
        # Simple prompt builder
        prompt = ""
        for m in req.messages:
            prompt += f"<|{m.role}|> {m.content}\n"
        prompt += "<|assistant|> "

        schema = JsonSchema(req.json_schema)
        generator = outlines.generate.json(model, schema)
        text = generator(prompt, temperature=0.2)

        parsed = json.loads(text)
        return {"content": parsed, "raw": text}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))
