In [2]:
import torch
import requests
import time
import contextlib
import io
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

class LegalLLMAnswerer:
    def __init__(self, model_name: str, mode: str, hf_token: str = None, max_new_tokens: int = 300):
        """
        mode: 'api' (for Zephyr) or 'local' (for local models like Mistral)
        model_name: Hugging Face model name (e.g., HuggingFaceH4/zephyr-7b-beta)
        hf_token: required only for API mode
        """
        self.model_name = model_name
        self.mode = mode
        self.max_new_tokens = max_new_tokens

        if mode == "api":
            assert hf_token is not None, "Hugging Face token is required for API mode"
            self.api_url = f"https://api-inference.huggingface.co/models/{model_name}"
            self.headers = {
                "Authorization": f"Bearer {hf_token}",
                "Content-Type": "application/json"
            }

        elif mode == "local":
            print(f"🔁 Loading local model: {model_name}")
            self.device = "mps" if torch.backends.mps.is_available() else "cpu"
            with contextlib.redirect_stdout(io.StringIO()):
                self.tokenizer = AutoTokenizer.from_pretrained(model_name)
                self.model = AutoModelForCausalLM.from_pretrained(
                    model_name,
                    device_map={"": self.device},
                    torch_dtype=torch.float32
                )
            self.model.eval()
            self.generator = pipeline(
                "text-generation",
                model=self.model,
                tokenizer=self.tokenizer,
                device=-1
            )
            print(f"✅ Model loaded on {self.device}")

    def build_prompt(self, query: str, context_docs: list) -> str:
        context = "\n\n".join([f"Context {i+1}:\n{doc.strip()}" for i, doc in enumerate(context_docs)])
        return (
            f"{context}\n\n"
            f"Question: {query.strip()}\n\n"
            f"Instructions: Based only on the above legal context, first provide a short answer to the question, "
            f"then explain your reasoning in a separate paragraph.\n\n"
            f"Answer:"
        )

    def generate(self, query: str, context_docs: list, debug: bool = False) -> dict:
        prompt = self.build_prompt(query, context_docs)

        if debug:
            print("=" * 80)
            print("PROMPT:\n", prompt)
            print("=" * 80)

        if self.mode == "api":
            payload = {
                "inputs": prompt,
                "parameters": {
                    "max_new_tokens": self.max_new_tokens,
                    "do_sample": False,
                    "temperature": 0.1,
                    "return_full_text": False
                }
            }

            response = requests.post(self.api_url, headers=self.headers, json=payload)
            if response.status_code == 503:
                wait = response.json().get("estimated_time", 10)
                print(f"⏳ Model warming up, retrying in {wait}s...")
                time.sleep(wait)
                return self.generate(query, context_docs, debug)
            if response.status_code != 200:
                raise RuntimeError(f"HF API Error {response.status_code}: {response.text}")

            text = response.json()[0]["generated_text"]

        else:  # local
            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(self.device)
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=self.max_new_tokens,
                    do_sample=False,
                    temperature=0.1,
                    pad_token_id=self.tokenizer.eos_token_id
                )
            text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            text = text[len(prompt):].strip()

        # Parse Answer vs Reasoning
        if "Reasoning:" in text:
            answer, reasoning = text.split("Reasoning:", 1)
        else:
            parts = text.split("\n", 1)
            answer = parts[0]
            reasoning = parts[1] if len(parts) > 1 else ""

        return {
            "answer": answer.strip(),
            "reasoning": reasoning.strip()
        }

In [4]:
hf_token = "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"

llm = LegalLLMAnswerer(
    model_name="HuggingFaceH4/zephyr-7b-beta",
    mode="api",
    hf_token=hf_token
)

query = "Under what conditions can police conduct a warrantless vehicle search?"

context_docs = [
    "The Supreme Court has ruled that if police have probable cause to believe a vehicle contains evidence, they may search it without a warrant.",
    "This principle is known as the automobile exception to the Fourth Amendment.",
    "The scope of such searches is limited to areas where the evidence might reasonably be found.",
    "Searches must still meet the standard of reasonableness under the Constitution."
]

response = llm.generate(query, context_docs, debug=True)

print("\n🧠 Answer:", response["answer"])
print("📚 Reasoning:", response["reasoning"])

PROMPT:
 Context 1:
The Supreme Court has ruled that if police have probable cause to believe a vehicle contains evidence, they may search it without a warrant.

Context 2:
This principle is known as the automobile exception to the Fourth Amendment.

Context 3:
The scope of such searches is limited to areas where the evidence might reasonably be found.

Context 4:
Searches must still meet the standard of reasonableness under the Constitution.

Question: Under what conditions can police conduct a warrantless vehicle search?

Instructions: Based only on the above legal context, first provide a short answer to the question, then explain your reasoning in a separate paragraph.

Answer:

🧠 Answer: Police can conduct a warrantless vehicle search if they have probable cause to believe that the vehicle contains evidence. This principle, known as the automobile exception to the Fourth Amendment, allows for searches of areas where evidence might reasonably be found, as long as the searches meet 