<a href="https://colab.research.google.com/github/foreverYoungGitHub/llm-sorting-hat/blob/main/sortinghat_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!uv pip install vllm

In [None]:
!huggingface-cli login --token $HF_TOKEN

In [None]:
import argparse

from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams

In [None]:
# Initialize LLM
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", max_model_len=1024)

In [None]:
prompt_grammar = """You're a wise and ancient Sorting Hat. Based on the following self-introduction from a young wizard, assign them to the most appropriate Hogwarts house.

Please output only a JSON object in the following format — no explanation:
{{
    "name": "string",
    "age": "int",
    "house": "string"
}}

Here is the input information:
{user_input}
"""

def generate_output(llm: LLM, prompt: str, sampling_params: SamplingParams = None) -> str:
    outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
    return outputs[0].outputs[0].text

In [None]:
user_input = """Hello, my name is Hermione Granger,
and I'm a first-year student at Hogwarts, 11 years old.
I come from a Muggle family—my parents are dentists—but ever since I received my letter, I’ve been absolutely fascinated by the magical world.
Before arriving, I read all the required textbooks, including Hogwarts: A History and Standard Book of Spells.
I'm very excited to start my magical education and hope to make a meaningful contribution to whichever house I’m sorted into!"""

user_prompt = prompt_grammar.format(user_input=user_input)

grammar_output = generate_output(llm, user_prompt) # without grammar
print(grammar_output)

In [None]:
# Guided decoding by Grammar
grammar = r"""root ::= "{" name_entry "," age_entry "," house_entry "}"

name_entry ::= (([\"] "name" [\"])) ":" basic_string
age_entry ::= (([\"] "age" [\"])) ":" age_value
house_entry ::= (([\"] "house" [\"])) ":" house_string

age_value ::= ("0" | [1-9] [0-9]*)
house_string ::= (([\"] house_value [\"]))
house_value ::= "Gryffindor" | "Slytherin" | "Ravenclaw" | "Hufflepuff"

basic_string ::= (([\"] basic_string_1 [\"]))
basic_string_1 ::= "" | [^"\\\x00-\x1F] basic_string_1 | "\\" escape basic_string_1
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
"""

guided_decoding_params_grammar = GuidedDecodingParams(grammar=grammar)
sampling_params_grammar = SamplingParams(guided_decoding=guided_decoding_params_grammar, max_tokens=100, temperature=0.1)

In [None]:
user_input = """Hello, my name is Hermione Granger,
and I'm a first-year student at Hogwarts, 11 years old.
I come from a Muggle family—my parents are dentists—but ever since I received my letter, I’ve been absolutely fascinated by the magical world.
Before arriving, I read all the required textbooks, including Hogwarts: A History and Standard Book of Spells.
I'm very excited to start my magical education and hope to make a meaningful contribution to whichever house I’m sorted into!"""

user_prompt = prompt_grammar.format(user_input=user_input)

grammar_output = generate_output(llm, user_prompt, sampling_params_grammar)
print(grammar_output)