# How to use Structured Generation

Structured generation is a feature that allows you to generate structured data using the vLLM deployment forcing LLM to adhere to a specific JSON schema or regular expression pattern.

Structured generation is supported only for the vLLM deployment at the moment. 

This notebook demonstrates how to use structured generation with vLLM.

In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
from aana.sdk import AanaSDK

aana = AanaSDK().connect(show_logs=True)

For this example, we will be using `Phi-3-mini-4k-instruct` model. So let's deploy it first.

In [None]:
from aana.core.models.sampling import SamplingParams
from aana.core.models.types import Dtype
from aana.deployments.vllm_deployment import VLLMConfig, VLLMDeployment

deployment = VLLMDeployment.options(
    num_replicas=1,
    ray_actor_options={"num_gpus": 0.5},
    user_config=VLLMConfig(
        model_id="microsoft/Phi-3-mini-4k-instruct",
        dtype=Dtype.FLOAT16,
        gpu_memory_reserved=10000,
        enforce_eager=True,
        default_sampling_params=SamplingParams(
            temperature=0.0, top_p=1.0, top_k=-1, max_tokens=1024
        ),
        engine_args={
            "trust_remote_code": True,
        },
    ).model_dump(mode="json"),
)

aana.register_deployment(
    "vllm_deployment",
    deployment,
    deploy=True,
)

We create an AanaDeploymentHandle to remotely interact with the deployment.

In [4]:
from aana.deployments.aana_deployment_handle import AanaDeploymentHandle

handle = await AanaDeploymentHandle.create("vllm_deployment")

Structured Generation requires a JSON schema. We can use Pydantic model to define the schema.

In [17]:
import json

from pydantic import BaseModel


class CityDescription(BaseModel):
    """City description model."""

    city: str
    country: str
    description: str


schema = json.dumps(CityDescription.model_json_schema())
print(schema)

{"properties": {"city": {"title": "City", "type": "string"}, "country": {"title": "Country", "type": "string"}, "description": {"title": "Description", "type": "string"}}, "required": ["city", "country", "description"], "title": "CityDescription", "type": "object"}


First, we ask LLM to tell us about Paris without using structured generation.

In [None]:
from aana.core.models.chat import ChatDialog, ChatMessage

dialog = ChatDialog(
    messages=[
        ChatMessage(role="user", content="Tell me about Paris."),
    ]
)

response = await handle.chat(dialog)

In [7]:
print(response["message"].content)

 Paris, the capital city of France, is renowned globally for its rich history, art, fashion, and cuisine. This iconic city, situated on the River Seine in the north-central part of the country, has enthralled millions of visitors over centuries. With its countless attractions, including the famous Eiffel Tower, Louvre Museum, and Montmartre, Paris embodies elegance and charm in its essence, making it an enduring symbol of love and romance.

Paris also boasts a myriad of world-class cultural landmarks. The Louvre, home to countless treasures such as the Mona Lisa and the Venus de Milo, stands on the right bank of the Seine. The Musée d'Orsay displays its vast collection of Impressionist works of art, while the Centre Pompidou exhibits modern and contemporary pieces. The Orsay Museum, on the other hand, houses an extensive collection of decorative arts and design. Other landmarks include the Notre-Dame Cathedral, the Sainte-Chapelle, the Sacré-Cœur Basilica, and the majestic Versailles P

As you can see it gives us a large text description of Paris without any structure. Now let's try to use the schema we generated to get a structured response.

To enable structured generation, we need to pass `SamplingParams` object to the `chat` method with `json_schema` parameter set to the schema we generated above.

In [None]:
from aana.core.models.chat import ChatDialog, ChatMessage

dialog = ChatDialog(
    messages=[
        ChatMessage(role="user", content="Tell me about Paris."),
    ]
)

response = await handle.chat(dialog, sampling_params=SamplingParams(json_schema=schema))

In [9]:
print(response["message"].content)

{ "city": "Paris", "country": "France", "description": "Paris is the capital city of France, renowned for its art, fashion, gastronomy, and culture. Home to iconic landmarks such as the Eiffel Tower, Louvre Museum, Palace of Versailles, and Notre-Dame Cathedral, Paris is a major global center for art, fashion, gastronomy, and culture. The city is famous for its charming streets, romantic ambiance, and world-class museums and restaurants." }


Now, the response from LLM adheres to the provided schema. We can also use Pydantic model `CityDescription` to parse the response into a Python object.

In [10]:
city = CityDescription.model_validate_json(response["message"].content)
print(f"City: {city.city}")
print(f"Country: {city.country}")
print(f"Description: {city.description}")

City: Paris
Country: France
Description: Paris is the capital city of France, renowned for its art, fashion, gastronomy, and culture. Home to iconic landmarks such as the Eiffel Tower, Louvre Museum, Palace of Versailles, and Notre-Dame Cathedral, Paris is a major global center for art, fashion, gastronomy, and culture. The city is famous for its charming streets, romantic ambiance, and world-class museums and restaurants.


We can also specify another schema to get a list of cities and their descriptions. Notice that here we we added `Return a list of dictionaries.` to the prompt. Usually, it is very helpful to provide some guidance on what you expect from the model. The best is to provide a schema and a few examples of the expected output. That's where prompt engineering comes in.

In [None]:
from pydantic import RootModel

CityDescriptionList = RootModel[list[CityDescription]]
schema = json.dumps(CityDescriptionList.model_json_schema())

dialog = ChatDialog(
    messages=[
        ChatMessage(
            role="user",
            content="Tell me about Vienna, Paris, and New York. Return a list of dictionaries.",
        ),
    ]
)

response = await handle.chat(dialog, SamplingParams(json_schema=schema))

In [12]:
print(response["message"].content)

[{"city":"Vienna","country":"Austria","description":"Vienna is the capital city of Austria and has a rich cultural history dating back to Roman times. It is known for its classical music, imperial palaces, and the annual New Year's Concert performed by the Vienna Philharmonic."},{"city":"Paris","country":"France","description":"Paris is the capital of France and is renowned for its romantic ambiance, iconic landmarks like the Eiffel Tower and Notre-Dame Cathedral, and as a center for art, fashion, gastronomy, and culture."},{"city":"New York","country":"United States","description":"New York City, often simply called New York, is a major cultural, financial, and media hub. It is famous for its diverse cultural scenes, financial markets, historical significance in many domains, and monuments like the Statue of Liberty and Times Square. New York also hosts institutions like the Metropolitan Museum of Art, Broadway, and Central Park."}]


[36m(ServeReplica:vllm_deployment:VLLMDeployment pid=1888371)[0m INFO 2024-09-26 13:28:00,474 vllm_deployment_VLLMDeployment 3yq0wqtc 8c88b063-18f0-4ee4-9f49-a0bcb5224d01 replica.py:376 - CHAT OK 4969.3ms
Compiling FSM index for all state transitions:   0%|          | 0/9 [00:00<?, ?it/s]
Compiling FSM index for all state transitions: 100%|██████████| 9/9 [00:00<00:00, 51.53it/s]


In [13]:
cities = CityDescriptionList.model_validate_json(response["message"].content)
for city in cities.root:
    print(f"City: {city.city}")
    print(f"Country: {city.country}")
    print(f"Description: {city.description}")
    print()

City: Vienna
Country: Austria
Description: Vienna is the capital city of Austria and has a rich cultural history dating back to Roman times. It is known for its classical music, imperial palaces, and the annual New Year's Concert performed by the Vienna Philharmonic.

City: Paris
Country: France
Description: Paris is the capital of France and is renowned for its romantic ambiance, iconic landmarks like the Eiffel Tower and Notre-Dame Cathedral, and as a center for art, fashion, gastronomy, and culture.

City: New York
Country: United States
Description: New York City, often simply called New York, is a major cultural, financial, and media hub. It is famous for its diverse cultural scenes, financial markets, historical significance in many domains, and monuments like the Statue of Liberty and Times Square. New York also hosts institutions like the Metropolitan Museum of Art, Broadway, and Central Park.



And we can also use regular expression patterns to generate structured data. Here we are generating Pi. 
For that we need to set `regex_string` parameter in `SamplingParams` object to the regular expression pattern we want to generate.

In [None]:
dialog = ChatDialog(
    messages=[
        ChatMessage(
            role="user",
            content="What is Pi? Give me the first 15 digits. Only return the number.",
        ),
    ]
)

regex_pattern = "(-)?(0|[1-9][0-9]*)(\\.[0-9]+)?([eE][+-][0-9]+)?"

sampling_params = SamplingParams(regex_string=regex_pattern, max_tokens=32)

response = await handle.chat(dialog, sampling_params=sampling_params)

In [15]:
print(response["message"].content)

3.141592653589793238462643383279


In [16]:
import re

re.fullmatch(regex_pattern, response["message"].content)

[36m(ServeReplica:vllm_deployment:VLLMDeployment pid=1888371)[0m INFO 2024-09-26 13:28:01,461 vllm_deployment_VLLMDeployment 3yq0wqtc 4f035913-7c7e-47fe-8a66-197a00120d26 replica.py:376 - CHAT OK 923.0ms


<re.Match object; span=(0, 32), match='3.141592653589793238462643383279'>