# Benchmarking Structured JSON Output

Structuring language models' (LM) output into machine readable form is one of the core use cases for guidance. When an LM is part of a data processing pipeline, getting the output structure right 100% of the time is paramount to reducing late night on-calls for the practitioner responsible.

We will be benchmarking structured JSON output utilizing LangChain's chat extraction dataset. LangChain has done good work providing an industry-relevant task. The respective blog post can be found [here](https://blog.langchain.dev/extraction-benchmarking/). For our purposes, we are most interested in the output JSON expected from this.

## The JSON Schema

The expected JSON output is complex yet not unreasonable for a pipeline: nested structures, conditional fields and constraints on some values. To date, LMs vary in accuracy at producing JSON outputs, regardless of schema validity.
 
```json
{
  "title": "GenerateTicket",
  "description": "Generate a ticket containing all the extracted information.",
  "type": "object",
  "properties": {
    "issue_summary": {
      "title": "Issue Summary",
      "description": "short (<10 word) summary of the issue or question",
      "type": "string"
    },
    "question": {
      "title": "Question",
      "description": "Information inferred from the the question.",
      "allOf": [
        {
          "$ref": "#/definitions/QuestionCategorization"
        }
      ]
    },
    "response": {
      "title": "Response",
      "description": "Information inferred from the the response.",
      "allOf": [
        {
          "$ref": "#/definitions/ResponseCategorization"
        }
      ]
    }
  },
  "required": [
    "issue_summary",
    "question",
    "response"
  ],
  "definitions": {
    "QuestionCategory": {
      "title": "QuestionCategory",
      "description": "An enumeration.",
      "enum": [
        "Implementation Issues",
        "Feature Requests",
        "Concept Explanations",
        "Code Optimization",
        "Security and Privacy Concerns",
        "Model Training and Fine-tuning",
        "Data Handling and Manipulation",
        "User Interaction Flow",
        "Technical Integration",
        "Error Handling and Logging",
        "Customization and Configuration",
        "External API and Data Source Integration",
        "Language and Localization",
        "Streaming and Real-time Processing",
        "Tool Development",
        "Function Calling",
        "LLM Integrations",
        "General Agent Question",
        "General Chit Chat",
        "Memory",
        "Debugging Help",
        "Application Design",
        "Prompt Templates",
        "Cost Tracking",
        "Other"
      ],
      "type": "string"
    },
    "Sentiment": {
      "title": "Sentiment",
      "description": "An enumeration.",
      "enum": [
        "Negative",
        "Neutral",
        "Positive"
      ],
      "type": "string"
    },
    "ProgrammingLanguage": {
      "title": "ProgrammingLanguage",
      "description": "An enumeration.",
      "enum": [
        "python",
        "javascript",
        "typescript",
        "unknown",
        "other"
      ],
      "type": "string"
    },
    "QuestionCategorization": {
      "title": "QuestionCategorization",
      "type": "object",
      "properties": {
        "question_category": {
          "$ref": "#/definitions/QuestionCategory"
        },
        "category_if_other": {
          "title": "Category If Other",
          "description": "question category if the category above is 'other'",
          "type": "string"
        },
        "is_off_topic": {
          "title": "Is Off Topic",
          "description": "If the input is general chit chat or does not pertain to technical inqueries about LangChain or building/debugging applications with LLMs/AI, it is off topic. For context, LangChain is a library and framework designed to assist in building applications with LLMs. Questions may also be about similar packages like LangServe, LangSmith, OpenAI, Anthropic, vectorstores, agents, etc.",
          "type": "boolean"
        },
        "toxicity": {
          "title": "Toxicity",
          "description": "Whether or not the input question is toxic",
          "exclusiveMaximum": 6,
          "minimum": 0,
          "type": "integer"
        },
        "sentiment": {
          "$ref": "#/definitions/Sentiment"
        },
        "programming_language": {
          "$ref": "#/definitions/ProgrammingLanguage"
        }
      },
      "required": [
        "question_category",
        "is_off_topic",
        "toxicity",
        "sentiment",
        "programming_language"
      ]
    },
    "ResponseType": {
      "title": "ResponseType",
      "description": "An enumeration.",
      "enum": [
        "resolve issue",
        "provide guidance",
        "request information",
        "give up",
        "none",
        "other"
      ],
      "type": "string"
    },
    "ResponseCategorization": {
      "title": "ResponseCategorization",
      "type": "object",
      "properties": {
        "response_type": {
          "$ref": "#/definitions/ResponseType"
        },
        "response_type_if_other": {
          "title": "Response Type If Other",
          "type": "string"
        },
        "confidence_level": {
          "title": "Confidence Level",
          "description": "The confidence of the assistant in its answer.",
          "exclusiveMaximum": 6,
          "minimum": 0,
          "type": "integer"
        },
        "followup_actions": {
          "title": "Followup Actions",
          "description": "Actions the assistant recommended the user take.",
          "type": "array",
          "items": {
            "type": "string"
          }
        }
      },
      "required": [
        "response_type",
        "confidence_level"
      ]
    }
  }
}
```

## Benchmark

The code following below is what we ran to ensure guidance is both expressive enough to capture structured JSON, and as a gauge on how many tokens we save by enforcing constraints. We focus on LMs that can run on a consumer device as a reflection of diverse pipeline environments.

The key metrics (later found in `agg_df` as `mean_*` and `std_*` columns):
- JSON output accuracy: `json`
- Token reduction: `token_reduction` 

## Requirements

1. Install benchmark dependencies (i.e. `pip install guidance[bench]`).
2. Set environment variable `LANGCHAIN_API_KEY`. You will need an account with [LangChain](https://www.langchain.com/) to obtain a key.
3. Accessible GGUF LM models at the below file paths. They can be found on HuggingFace
   - ["Mistral-7B-Instruct-v0.2-GGUF/mistral-7b-instruct-v0.2.Q8_0.gguf"](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF/tree/main)
   - ["Llama-2-7B-32K-Instruct-GGUF/llama-2-7b-32k-instruct.Q8_0.gguf"](https://huggingface.co/TheBloke/Llama-2-7B-32K-Instruct-GGUF/tree/main)
   - ["Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-fp16.gguf"](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/tree/main)
4. By default we assume CUDA access and high enough VRAM. Feel free to adjust the models to further quantization if needed.

In [1]:
# Initial checks, feel free to disable this once you have the benchmark working.

import os
from pathlib import Path
err_msg = "Requirements not met. Follow above instructions."
try:
    import langchain_benchmarks
    import powerlift
except ImportError:
    raise ValueError(err_msg)
if os.getenv("LANGCHAIN_API_KEY") is None:
    raise ValueError(err_msg)
if not Path.exists(Path("Mistral-7B-Instruct-v0.2-GGUF/mistral-7b-instruct-v0.2.Q8_0.gguf")):
    raise ValueError(err_msg)
if not Path.exists(Path("Llama-2-7B-32K-Instruct-GGUF/llama-2-7b-32k-instruct.Q8_0.gguf")):
    raise ValueError(err_msg)
if not Path.exists(Path("Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-fp16.gguf")):
    raise ValueError(err_msg)

In [2]:
# The JSON output benchmark is defined within the two functions below.

def trial_filter(task):
    """This function works within our benchmarking platform to declaring which methods will be tested against what task.
    The method names here are used later in another function `trial_runner` for conditional execution.
    """
    
    if task.problem == "guidance/struct_decode":
        return [
            "guidance-mistral-7b-instruct",
            "base-mistral-7b-instruct",
            "guidance-phi-3-mini-4k-instruct",
            "base-phi-3-mini-4k-instruct",
            "guidance-llama2-7b-32k-instruct",
            "base-llama2-7b-32k-instruct",
        ]
    return []

def trial_runner(trial):
    """Runs a single trial. The method to be tested will be under `trial.method.name` with the task as `trial.task.name`.

    The imports and all user-defined functions are defined within. This simplifies serialization when the benchmark is run against remote machines.
    """
    import json
    import pandas as pd
    
    from guidance import models, gen, system, user, guidance, select, zero_or_more, capture
    from guidance.models.transformers import Transformers
    from guidance.models.llama_cpp import LlamaCpp
    from guidance import json as gen_json
    from time import time
    import json_stream
    import io
    import os
    
    if trial.task.name == "chat_extract":
        inputs, outputs, meta = trial.task.data(["inputs", "outputs", "meta"])
        merged_df = pd.concat([inputs.reset_index(drop=True), outputs.reset_index(drop=True)], axis=1)

        if trial.method.name.startswith("guidance"):
            QUESTION_CAT = [
                "Implementation Issues",
                "Feature Requests",
                "Concept Explanations",
                "Code Optimization",
                "Security and Privacy Concerns",
                "Model Training and Fine-tuning",
                "Data Handling and Manipulation",
                "User Interaction Flow",
                "Technical Integration",
                "Error Handling and Logging",
                "Customization and Configuration",
                "External API and Data Source Integration",
                "Language and Localization",
                "Streaming and Real-time Processing",
                "Tool Development",
                "Function Calling",
                "LLM Integrations",
                "General Agent Question",
                "General Chit Chat",
                "Memory",
                "Debugging Help",
                "Application Design",
                "Prompt Templates",
                "Cost Tracking",
                "Other"
            ]
            RESPONSE_TYPE = [
                "resolve issue",
                "provide guidance",
                "request information",
                "give up",
                "none",
                "other"
            ]
            
            @guidance(stateless=True, dedent=False)
            def guidance_list(lm):
                return lm + "[" + zero_or_more(gen(regex=r'"[\w ]+", ')) + gen(regex=r'"[\w ]+"') + "]"
            
            WORD_PAT = r'[\w ]+'
            NEW_REC = '\n            '
            DOUBLE_QUOTE = '"'
            NEW_LINE = '\n'
            @guidance(stateless=False, dedent=False)
            def gen_chat_json(lm):
                lm += f"""{{
                    "GenerateTicket": {{
                        "issue_summary": "{gen(regex=WORD_PAT, stop='"')}",
                        "question": {{
                            "question_category": "{select(QUESTION_CAT, name='question_cat')}",
                            """
                if lm['question_cat'] == 'Other':
                    lm += f""""category_if_other": "{gen(regex=WORD_PAT, stop='"')}",
                            """
                    
                lm += f""""is_off_topic": {select(["false", "true"])},
                            "toxicity": {select([0, 1, 2, 3, 4, 5])},
                            "sentiment": "{select(["Negative", "Neutral", "Positive"])}",
                            "programming_language": "{select(["python", "javascript", "typescript", "unknown", "other"])}"
                        }},
                        "response": {{
                            "response_type": "{select(RESPONSE_TYPE, name='response_type')}",
                            """
            
                if lm['response_type'] == "other":
                    lm += f""""response_type_if_other": "{gen(regex=WORD_PAT, stop='"')}",
                            """
            
                lm += f""""confidence_level": {select([0, 1, 2, 3, 4, 5])}"""
                lm += f"""{select(['', ',' + NEW_REC + '"followup_actions":'], name='follow_up')}"""
                if lm.get('follow_up', None) is not None:
                    lm += f""" {guidance_list()}"""
            
                lm += f"""
                    }}
                }}
            }}"""
                return lm

        for i, row in merged_df.iterrows():
            # Initialize LLM
            if i == 0:
                if "mistral" in trial.method.name:
                    base_lm = models.LlamaCpp(
                        "Mistral-7B-Instruct-v0.2-GGUF/mistral-7b-instruct-v0.2.Q8_0.gguf",
                        n_ctx=8192, n_gpu_layers=-1, echo=False, verbose=False
                    )
                elif "llama2-7b" in trial.method.name:
                    base_lm = models.LlamaCpp(
                        "Llama-2-7B-32K-Instruct-GGUF/llama-2-7b-32k-instruct.Q8_0.gguf", 
                        n_ctx=8192, n_gpu_layers=-1, echo=False, verbose=False
                    )
                else:
                    base_lm = models.LlamaCpp(
                        "Phi-3-mini-4k-instruct-gguf/Phi-3-mini-4k-instruct-fp16.gguf", 
                        n_ctx=8192, n_gpu_layers=-1, echo=False, verbose=False
                    )
                    
            # Execute LLM
            print(f"{trial.method.name}[{i}]")
            start_time = time()
            lm = base_lm
            lm.engine.reset_metrics()
            if "mistral" in trial.method.name:
                lm += f"""<s>[INST] {row['system_prompt']}\n{row['user_prompt']} [/INST]"""
            elif "llama" in trial.method.name:
                lm += f"""<s>[INST] <<SYS>>\n{row['system_prompt']}\n<</SYS>>\n\n{row['user_prompt']}[/INST]"""
            elif "phi" in trial.method.name:
                lm += f"""<s><|user|>{row['system_prompt']}\n{row['user_prompt']}<|end|><|assistant|>"""
            else:
                raise ValueError(f"Cannot support {trial.method.name} for system prompts")

            before_idx = len(str(lm))
            if "guidance" in trial.method.name:
                lm += gen_chat_json()
            else:
                lm += gen(max_tokens=1500)
            output_str = str(lm)[before_idx:]
            end_time = time()
            elapsed_time = end_time - start_time

            # Basic measures
            trial.log("output", output_str)
            trial.log("wall_time", elapsed_time)

            # Token statistics
            tm = {
                "input": lm.engine.metrics.engine_input_tokens,
                "output": lm.engine.metrics.engine_output_tokens,
                "token_count": lm.token_count,
            }
            tm["token_reduction"] = 1 - (tm["output"]) / (lm.token_count)
            trial.log("token_input", tm["input"])
            trial.log("token_output", tm["output"])
            trial.log("token_count", tm["token_count"])
            trial.log("token_reduction", tm["token_reduction"])

            # Validate JSON conformance
            json_success = False
            try:
                output_json = json.loads(output_str.strip())
                output_json = output_json['GenerateTicket']
                json_success = True
            except Exception as e:
                trial.log("json_errmsg", str(e))
            trial.log("json", json_success * 1)

            if json_success:
                trial.log("output_json", output_json)
                trial.log("json_dirty", 0)
            else:
                success = False
                candidate = output_str.strip()
                for i, ch in enumerate(candidate):
                    if ch == '{':
                        try:
                            results = json_stream.load(io.StringIO(candidate[i:]))
                            di = json_stream.to_standard_types(results)
                            di = di['GenerateTicket']
                            success = True
                            break
                        except Exception:
                            pass
                if success:
                    output_json = di
                    trial.log("output_json", output_json)
                    trial.log("json_dirty", 1)
                else:
                    trial.log("output_json", {})
                    trial.log("json_dirty", 0)
                

            # Validate JSON schema conformance
            from langchain_benchmarks.extraction.tasks.chat_extraction.schema import GenerateTicket
            try:
                GenerateTicket.parse_obj(output_json)
                trial.log("json_valid", 1)
                if json_success:
                    trial.log("json_valid_strict", 1)
                else:
                    trial.log("json_valid_strict", 0)
            except Exception as e:
                trial.log("json_valid", 0)
                trial.log("json_valid_strict", 0)
                trial.log("json_valid_errmsg", str(e))

            # Toxicity similarity
            expected_json = row['output']['output']
            expected = expected_json['question']['toxicity']
            try:
                pred = output_json["question"]["toxicity"]
                score = 1 - abs(expected - float(pred)) / 5
                trial.log("toxicity", score)
                trial.log("toxicity_strict", score)
            except Exception as e:
                trial.log("toxicity_strict", 0)
                trial.log("toxicity_errmsg", str(e))

            # Sentiment similarity
            expected =  expected_json["question"]["sentiment"]
            ordinal_map = {
                "negative": 0,
                "neutral": 1,
                "positive": 2,
            }
            expected_score = ordinal_map.get(str(expected).lower())
            try:
                pred = output_json["question"]["sentiment"]
                pred_score = ordinal_map.get(str(pred).lower())
                score = 1 - (abs(expected_score - float(pred_score)) / 2)
                trial.log("sentiment", score)
                trial.log("sentiment_strict", score)
            except Exception as e:
                trial.log("sentiment_strict", 0)
                trial.log("sentiment_errmsg", str(e))

            # Question category similarity
            expected = expected_json["question"]["question_category"]
            try:
                pred = output_json["question"]["question_category"]
                score = int(expected == pred)
                trial.log("question_cat", score)
                trial.log("question_cat_strict", score)
            except Exception as e:
                trial.log("question_cat_strict", 0)
                trial.log("question_cat_errmsg", str(e))

            # Off-topic similarity
            expected = expected_json["question"]["is_off_topic"]
            try:
                pred = output_json["question"].get("is_off_topic")
                score = int(expected == pred)
                trial.log("offtopic", score)
                trial.log("offtopic_strict", score)
            except Exception as e:
                trial.log("offtopic_strict", 0)
                trial.log("offtopic_errmsg", str(e))

            # Programming language similarity
            expected = expected_json["question"]["programming_language"]
            try:
                pred = output_json["question"]["programming_language"]
                score = int(expected == pred)
                trial.log("programming", score)
                trial.log("programming_strict", score)
            except Exception as e:
                trial.log("programming_strict", 0)
                trial.log("programming_errmsg", str(e))

In [3]:
# Run the benchmark. This is asynchronous so it should return relatively quickly.
# By default, before each method is run on an example, it will print its name and the example index as specified in `trial_runner`.

from powerlift.bench import Benchmark, Store, populate_with_datasets
from powerlift.executors import LocalMachine
from guidance.bench import retrieve_langchain
from pathlib import Path

conn_str = f"sqlite:///{Path(Path.cwd(), 'guidance-bench.db')}"
store = Store(conn_str, force_recreate=False)
populate_with_datasets(store, retrieve_langchain(cache_dir="~/.guidance-bench/cache"), exist_ok=True)
executor = LocalMachine(store, n_cpus=1, debug_mode=False)

bench = Benchmark(store, name="local_lm_chat_extract")
bench.run(trial_runner, trial_filter, timeout=60*60, executor=executor) 

Dataset Chat Extraction already exists. Skipping.
You can access the dataset at https://smith.langchain.com/o/7953c7d1-5ab3-5e87-9b8d-ca40eef5f42d/datasets/cf39229b-5aa6-4ffa-a7a4-effe91894d12.


<powerlift.executors.localmachine.LocalMachine at 0x7f148499f460>

guidance-mistral-7b-instruct[0]
guidance-mistral-7b-instruct[1]
guidance-mistral-7b-instruct[2]
guidance-mistral-7b-instruct[3]
guidance-mistral-7b-instruct[4]
guidance-mistral-7b-instruct[5]
guidance-mistral-7b-instruct[6]
guidance-mistral-7b-instruct[7]
guidance-mistral-7b-instruct[8]
guidance-mistral-7b-instruct[9]
guidance-mistral-7b-instruct[10]
guidance-mistral-7b-instruct[11]
guidance-mistral-7b-instruct[12]
guidance-mistral-7b-instruct[13]
guidance-mistral-7b-instruct[14]
guidance-mistral-7b-instruct[15]
guidance-mistral-7b-instruct[16]
guidance-mistral-7b-instruct[17]
guidance-mistral-7b-instruct[18]
guidance-mistral-7b-instruct[19]
guidance-mistral-7b-instruct[20]
guidance-mistral-7b-instruct[21]
guidance-mistral-7b-instruct[22]
guidance-mistral-7b-instruct[23]
guidance-mistral-7b-instruct[24]
guidance-mistral-7b-instruct[25]
guidance-mistral-7b-instruct[26]
base-mistral-7b-instruct[0]
base-mistral-7b-instruct[1]
base-mistral-7b-instruct[2]
base-mistral-7b-instruct[3]
base-mi

In [4]:
# Check the status of the trial runs. Status column will be set to 'COMPLETE' for all tasks at the end.
bench.status()

Unnamed: 0,trial_id,replicate_num,meta,method,task,status,errmsg,create_time,start_time,end_time
0,1,0,{},guidance-mistral-7b-instruct,chat_extract,READY,,2024-05-17 20:45:29.721393,,
1,2,0,{},base-mistral-7b-instruct,chat_extract,READY,,2024-05-17 20:45:29.721516,,
2,3,0,{},guidance-phi-3-mini-4k-instruct,chat_extract,READY,,2024-05-17 20:45:29.721562,,
3,4,0,{},base-phi-3-mini-4k-instruct,chat_extract,READY,,2024-05-17 20:45:29.721599,,
4,5,0,{},guidance-llama2-7b-32k-instruct,chat_extract,READY,,2024-05-17 20:45:29.721636,,
5,6,0,{},base-llama2-7b-32k-instruct,chat_extract,READY,,2024-05-17 20:45:29.721672,,


In [6]:
# Generate results, we wait for the benchmark first to complete.
# The most important measures are `mean_json` and `mean_token_reduction`.
# Negative values can happen with `mean_token_reduction`, this is expected due to token healing complicating how we count tokens.

bench.wait_until_complete()
result_df = bench.results()
agg_df = result_df.pivot_table(index='method', columns='name', values='num_val', aggfunc=['mean', 'std'])
agg_df.columns = ["_".join(x) for x in agg_df.columns.to_flat_index()]
agg_df

Unnamed: 0_level_0,mean_json,mean_json_dirty,mean_json_valid,mean_json_valid_strict,mean_offtopic,mean_offtopic_strict,mean_programming,mean_programming_strict,mean_question_cat,mean_question_cat_strict,...,std_question_cat_strict,std_sentiment,std_sentiment_strict,std_token_count,std_token_input,std_token_output,std_token_reduction,std_toxicity,std_toxicity_strict,std_wall_time
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
base-llama2-7b-32k-instruct,0.259259,0.0,0.0,0.0,0.0,0.0,,0.0,,0.0,...,0.0,,0.0,2.363059,440.643091,0.0,0.001582,,0.0,6.046848
base-mistral-7b-instruct,0.0,0.814815,0.0,0.0,0.0,0.0,,0.0,0.26087,0.222222,...,0.423659,0.0,0.320256,387.916261,681.601341,387.679932,0.004073,,0.0,10.971967
base-phi-3-mini-4k-instruct,0.0,0.148148,0.0,0.0,0.0,0.0,,0.0,,0.0,...,0.0,,0.0,58.3146,438.717735,0.0,0.048805,,0.0,2.876374
guidance-llama2-7b-32k-instruct,1.0,0.0,1.0,1.0,0.888889,0.888889,0.296296,0.296296,0.074074,0.074074,...,0.26688,0.0,0.0,191.509285,465.816197,191.31116,0.141899,0.0,0.0,11.480775
guidance-mistral-7b-instruct,1.0,0.0,1.0,1.0,0.888889,0.888889,0.62963,0.62963,0.111111,0.111111,...,0.320256,0.13344,0.13344,5.995725,442.744802,4.798504,0.025946,0.0,0.0,0.577206
guidance-phi-3-mini-4k-instruct,1.0,0.0,1.0,1.0,0.851852,0.851852,0.259259,0.259259,0.074074,0.074074,...,0.26688,0.197924,0.197924,15.927542,438.253725,11.748431,0.052191,0.03849,0.03849,0.356682
