# RAG with Source Highlighting Using Structured Generation

**Structured generation** is a method that forces the LLM output to follow certain constraints, for example, to follow a speific pattern.

Use cases of structured generation:
* output a dictionary with specific keys
* make sure the output will be longer than N characters
* force the output to follow a certain regex pattern for downstream processing
* highlight sources supporting the answer in RAG

In this example, we will build a RAG system that not only provides an answer, but also highlights the supporting snippets that this answer is based on.

We will apply a naive approach to structured generation via prompting and highlights its limits, then demonstrates constrained decoding for more efficient structured generation.

We will use HuugingFace Inference Endpoints, and then applies a local pipeline.

## Setups

In [None]:
!pip install -qU pandas json huggingface_hub pydantic outlines accelerate

In [2]:
import pandas as pd
import json
from huggingface_hub import InferenceClient

pd.set_option('display.max_colwidth', None)

In [3]:
repo_id = 'meta-llama/Meta-LLama-3-8B-Instruct'

llm_client = InferenceClient(model=repo_id, timeout=120)

In [None]:
# test
llm_client.text_generation(prompt="How are you today?", max_new_tokens=20)

## Prompting the model

To get structured outputs from our model, we can simply prompt a powerful enough models with appropriate guidelines. We also want the RAG model to generate not only an answer, but also a confidence score and some source snippets. We want to generate these a JSON dictionary to then easily parse it for downstream processing.

In [None]:
RELEVANT_CONTEXT = """
Document:
The weather is really nice in Paris today.
To define a stop sequence in Transformers, you should pass the stop_sequence argument in your pipeline or model.
"""

In [None]:
RAG_PROMPT_TEMPLATE_JSON = """
Answer the user query based on the source documents.

Here are the source documents: {context}


You should provide your answer as a JSON blob, and also provide all relevant short source snippets from the documents on which you directly based your answer, and a confidence score as a float between 0 and 1.
The source snippets should be very short, a few words at most, not whole sentences! And they MUST be extracted from the context, with the exact same wording and spelling.

Your answer should be built as follows, it must contain the "Answer:" and "End of answer." sequences.

Answer:
{{
  "answer": your_answer,
  "confidence_score": your_confidence_score,
  "source_snippets": ["snippet_1", "snippet_2", ...]
}}
End of answer.

Now begin!
Here is the user question: {user_query}.
Answer:
"""

In [None]:
USER_QUERY = "How can I define a stop sequence in Transformers?"

In [None]:
prompt = RAG_PROMPT_TEMPLATE_JSON.format(context=RELEVANT_CONTEXT, user_query=USER_QUERY)
print(prompt)

In [None]:
answer = llm_client.text_generation(
    prompt,
    max_new_tokens=1000
)

answer = answer.split('End of answer.')[0]
print(answer)

The output of the LLM is a string representation of a dictionary, so we need to load it as a dictionary using `literal_eval`.

In [None]:
from ast import literal_eval

parsed_answer = literal_eval(answer)
parsed_answer

In [None]:
def highlight(s):
    return "\x1b[1;32m" + s + "\x1b[0m"


def print_results(answer, source_text, highlight_snippets):
    print("Answer:", highlight(answer))
    print('\n\n', '='*10 + ' Source documents ' + '='*10)

    for snippet in highlight_snippets:
        source_text = source_text.replace(snippet.strip(), highlight(snippet.strip()))
    print(source_text)

print_results(
    answer=parsed_answer['answer'],
    source_text=RELEVANT_CONTEXT,
    highlight_snippets=parsed_answer['source_snippets']
)

We can also try a less powerful model and increase the temperature.

In [None]:
answer = llm_client.text_generation(
    prompt,
    max_new_tokens=250,
    temperature=1.6,
    return_full_text=False
)
print(answer)

The output now is not valid JSON-format.

## Constrained decoding

To force a JSON output, we will have to use **constrained decoding** where we force the LLM to only output tokens that conform to a set of rules called a **grammar**.

The **grammar** can be defined using Pydantic models, JSON schema, or regular expressions. The model will then generate a response that conforms to the specified grammar.

In [None]:
from pydantic import BaseModel, confloat, StringConstraints
from typing import List, Annotated


class AnswerWithSnippets(BaseModel):
    answer: Annotated[
        str,
        StringConstraints(min_length=10, max_length=100)
    ]

    confidence: Annotated[
        float,
        confloat(ge=0.0, le=1.0)
    ]

    source_snippets: List[Annotated[str, StringConstraints(max_length=30)]]

Check if this schema correctly represents our requirements:

In [None]:
AnswerWithSnippets.schema()

Now we can use either the client's `text_generation` method or use its `post` method.

In [None]:
# Using `text_generation`
answer = llm_client.text_generation(
    prompt,
    grammar={'type': 'json', 'value': AnswerWithSnippets.schema()},
    max_new_tokens=250,
    temperature=1.6,
    return_full_text=False
)
print(answer)

In [None]:
# Using post
data = {
    'inputs': prompt,
    'parameters': {
        'temperature': 1.6,
        'return_full_text': False,
        'grammar': {'type': 'json', 'value': AnswerWithSnippets.schema()},
        'max_new_tokens': 250
    }
}
answer = json.loads(llm_client.post(json=data))[0]['generated_text']
print(answer)

The generated output now has the correct JSON format with the exact keys and types we defined in our grammar.

## Grammar on a local pipeline with Outlines

[`outlines`](https://github.com/dottxt-ai/outlines) is the library that runs under the hood on our Inference API to constrain output generation.

We can use it locally and it works by applying a bias on the logits to force selection of only the ones that conform to our constraint.

In [None]:
import outlines

repo_id = 'mustafaaljadery/gemma-2B-10M'

model = outlines.models.transformers(repo_id)

schema_as_str = json.dumps(AnswerWithSnippets.schema())

In [None]:
generator = outlines.generate.json(model, schema_as_str)
result = generator(prompt)
print(result)