<a href="https://colab.research.google.com/github/franlin1860/llm/blob/main/reflection_workflow_v20240820.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Reflection Workflow for Structured Outputs

This notebook walks through setting up a `Workflow` to provide reliable structured outputs through retries and reflection on mistakes.



# Prevent disconnection

In [None]:
#@markdown <h3>← 输入了代码后运行以防止断开</h>
import IPython
from google.colab import output

display(IPython.display.Javascript('''
 function ClickConnect(){
   btn = document.querySelector("colab-connect-button")
   if (btn != null){
     console.log("Click colab-connect-button");
     btn.click()
     }

   btn = document.getElementById('ok')
   if (btn != null){
     console.log("Click reconnect");
     btn.click()
     }
  }

setInterval(ClickConnect,60000)
'''))

print("Done.")

<IPython.core.display.Javascript object>

Done.


In [None]:
function ConnectButton(){
    console.log("Connect pushed");
    document.querySelector("#connect").click()
}
setInterval(ConnectButton,60000);

In [None]:
!pip install -U llama-index

Collecting llama-index
  Downloading llama_index-0.10.67.post1-py3-none-any.whl.metadata (11 kB)
Collecting llama-index-agent-openai<0.3.0,>=0.1.4 (from llama-index)
  Downloading llama_index_agent_openai-0.2.9-py3-none-any.whl.metadata (729 bytes)
Collecting llama-index-cli<0.2.0,>=0.1.2 (from llama-index)
  Downloading llama_index_cli-0.1.13-py3-none-any.whl.metadata (1.5 kB)
Collecting llama-index-core<0.11.0,>=0.10.67 (from llama-index)
  Downloading llama_index_core-0.10.67-py3-none-any.whl.metadata (2.4 kB)
Collecting llama-index-embeddings-openai<0.2.0,>=0.1.5 (from llama-index)
  Downloading llama_index_embeddings_openai-0.1.11-py3-none-any.whl.metadata (655 bytes)
Collecting llama-index-indices-managed-llama-cloud>=0.2.0 (from llama-index)
  Downloading llama_index_indices_managed_llama_cloud-0.2.7-py3-none-any.whl.metadata (3.8 kB)
Collecting llama-index-legacy<0.10.0,>=0.9.48 (from llama-index)
  Downloading llama_index_legacy-0.9.48.post3-py3-none-any.whl.metadata (8.5 kB)


Since workflows are async first, this all runs fine in a notebook. If you were running in your own code, you would want to use `asyncio.run()` to start an async event loop if one isn't already running.

```python
async def main():
    <async code>

if __name__ == "__main__":
    import asyncio
    asyncio.run(main())
```

# Setup LLM env

In [None]:
import os

os.environ["DEEPSEEK_API_KEY"] = "sk-"

In [None]:
!pip install llama_index-llms-openai_like
!pip install llama_index-embeddings-huggingface

Collecting llama_index-llms-openai_like
  Downloading llama_index_llms_openai_like-0.1.3-py3-none-any.whl.metadata (753 bytes)
Downloading llama_index_llms_openai_like-0.1.3-py3-none-any.whl (3.0 kB)
Installing collected packages: llama_index-llms-openai_like
Successfully installed llama_index-llms-openai_like-0.1.3
Collecting llama_index-embeddings-huggingface
  Downloading llama_index_embeddings_huggingface-0.2.3-py3-none-any.whl.metadata (769 bytes)
Collecting sentence-transformers>=2.6.1 (from llama_index-embeddings-huggingface)
  Downloading sentence_transformers-3.0.1-py3-none-any.whl.metadata (10 kB)
Collecting minijinja>=1.0 (from huggingface-hub[inference]>=0.19.0->llama_index-embeddings-huggingface)
  Downloading minijinja-2.0.1-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.11.0->sentence-transformers>=2.6.1->llama_index-embeddings-huggingface)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.10

In [None]:
import os
import logging
import sys
from llama_index.llms.openai_like import OpenAILike
from llama_index.core import Settings, ServiceContext
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

# 配置日志
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

# 定义DeepSpeed model
llm = OpenAILike(model="deepseek-chat",
                 api_base="https://api.deepseek.com/v1",
                 api_key=os.environ["DEEPSEEK_API_KEY"],
                 temperature=0.6,
                 is_chat_model=True)

# 配置环境
Settings.llm = llm

# 设置嵌入模型
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-small-zh-v1.5")
Settings.embed_model = embed_model
Settings.chunk_size = 256
service_context = ServiceContext.from_defaults(
    llm=llm, embed_model=embed_model
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/27.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/95.8M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/367 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/110k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/439k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

  service_context = ServiceContext.from_defaults(


## Designing the Workflow

To validate the structured output of an LLM, we need only two steps:
1. Generate the structured output
2. Validate that the output is proper JSON

The key thing here is that, if the output is invalid, we **loop** until it is, giving error feedback to the next generation.

### The Workflow Events

To handle these steps, we need to define a few events:
1. An event to pass on the generated extraction
2. An event to give feedback when the extraction is invalid

The other steps will use the built-in `StartEvent` and `StopEvent` events.

In [18]:
from llama_index.core.workflow import Event


class ExtractionDone(Event):
    output: str
    passage: str


class ValidationErrorEvent(Event):
    error: str
    wrong_output: str
    passage: str

### Item to Extract

To prompt our model, lets define a pydantic model we want to extract.

Pydantic 是 Python 中一个用于数据解析和验证的库。它通过类型注解来定义数据模型，并提供自动的数据验证和序列化功能。

In [19]:
from pydantic import BaseModel


class Car(BaseModel):
    brand: str
    model: str
    power: int


class CarCollection(BaseModel):
    cars: list[Car]

### The Workflow Itself

With our events defined, we can construct our workflow and steps.

Note that the workflow automatically validates itself using type annotations, so the type annotations on our steps are very helpful!

In [20]:
import json


from llama_index.core.workflow import (
    Workflow,
    StartEvent,
    StopEvent,
    Context,
    step,
)

EXTRACTION_PROMPT = """
Context information is below:
---------------------
{passage}
---------------------

Given the context information and not prior knowledge, create a JSON object from the information in the context.
The JSON object must follow the JSON schema:
{schema}

"""

REFLECTION_PROMPT = """
You already created this output previously:
---------------------
{wrong_answer}
---------------------

This caused the JSON decode error: {error}

Try again, the response must contain only valid JSON code. Do not add any sentence before or after the JSON object.
Do not repeat the schema.
"""


class ReflectionWorkflow(Workflow):
    max_retries: int = 3

    @step(pass_context=True)
    async def extract(
        self, ctx: Context, ev: StartEvent | ValidationErrorEvent
    ) -> StopEvent | ExtractionDone:
        current_retries = ctx.data.get("retries", 0)
        if current_retries >= self.max_retries:
            return StopEvent(result="Max retries reached")
        else:
            ctx.data["retries"] = current_retries + 1

        if isinstance(ev, StartEvent):
            passage = ev.get("passage")
            if not passage:
                return StopEvent(result="Please provide some text in input")
            reflection_prompt = ""
        elif isinstance(ev, ValidationErrorEvent):
            passage = ev.passage
            reflection_prompt = REFLECTION_PROMPT.format(
                wrong_answer=ev.wrong_output, error=ev.error
            )

        llm = Settings.llm
        prompt = EXTRACTION_PROMPT.format(
            passage=passage, schema=CarCollection.schema_json()
        )
        if reflection_prompt:
            prompt += reflection_prompt

        output = await llm.acomplete(prompt)

        print("LLM 输出：", output)  # 打印 LLM 输出

        return ExtractionDone(output=str(output), passage=passage)

    @step()
    async def validate(
        self, ev: ExtractionDone
    ) -> StopEvent | ValidationErrorEvent:
        try:
            json.loads(ev.output)
        except Exception as e:
            print("Validation failed, retrying...")
            return ValidationErrorEvent(
                error=str(e), wrong_output=ev.output, passage=ev.passage
            )

        return StopEvent(result=ev.output)

And thats it! Let's explore the workflow we wrote a bit.

- We have one entry point, `extract` (the steps that accept `StartEvent`)
- When `extract` finishes, it emits a `ExtractionDone` event
- `validate` runs and confirms the extraction:
  - If its ok, it emits `StopEvent` and halts the workflow
  - If nots not, it returns a `ValidationErrorEvent` with information about the error
- Any `ValidationErrorEvent` emitted will trigger the loop, and `extract` runs again!
- This continues until the structured output is validated

In [21]:
import json
import re


from llama_index.core.workflow import (
    Workflow,
    StartEvent,
    StopEvent,
    Context,
    step,
)


EXTRACTION_PROMPT = """
Context information is below:
---------------------
{passage}
---------------------

Given the context information and not prior knowledge, extract the relevant information and create a JSON object.
The JSON object must follow the JSON schema below, and should include a list of items with their corresponding attributes.
Do not include the schema in your answer, only the JSON object.

JSON schema:
{schema}
"""


REFLECTION_PROMPT = """
You already created this output previously:
---------------------
{wrong_answer}
---------------------

This caused the JSON decode error: {error}

Try again, the response must contain only valid JSON code. Do not add any sentence before or after the JSON object.
Do not repeat the schema.
"""


class ReflectionWorkflow(Workflow):
    max_retries: int = 3

    @step(pass_context=True)
    async def extract(
        self, ctx: Context, ev: StartEvent | ValidationErrorEvent
    ) -> StopEvent | ExtractionDone:
        current_retries = ctx.data.get("retries", 0)
        if current_retries >= self.max_retries:
            return StopEvent(result="Max retries reached")
        else:
            ctx.data["retries"] = current_retries + 1

        if isinstance(ev, StartEvent):
            passage = ev.get("passage")
            if not passage:
                return StopEvent(result="Please provide some text in input")
            reflection_prompt = ""
        elif isinstance(ev, ValidationErrorEvent):
            passage = ev.passage
            reflection_prompt = REFLECTION_PROMPT.format(
                wrong_answer=ev.wrong_output, error=ev.error
            )

        llm = Settings.llm
        prompt = EXTRACTION_PROMPT.format(
            passage=passage, schema=CarCollection.schema_json()
        )
        if reflection_prompt:
            prompt += reflection_prompt

        output = await llm.acomplete(prompt)

        print("LLM 输出：", output)  # 打印 LLM 输出

        return ExtractionDone(output=str(output), passage=passage)

    @step()
    async def validate(
        self, ev: ExtractionDone
    ) -> StopEvent | ValidationErrorEvent:
        # 使用正则表达式提取 JSON 对象
        match = re.search(r'\{.*\}', ev.output, flags=re.DOTALL)
        if match:
            json_string = match.group(0)
            try:
                json.loads(json_string)
                return StopEvent(result=json_string)
            except Exception as e:
                print("Validation failed, retrying...")
                return ValidationErrorEvent(
                    error=str(e), wrong_output=ev.output, passage=ev.passage
            )
        else:
            print("Validation failed, no JSON object found, retrying...")
            return ValidationErrorEvent(
                error="No JSON object found", wrong_output=ev.output, passage=ev.passage
          )


## Run the Workflow!

**NOTE:** With loops, we need to be mindful of runtime. Here, we set a timeout of 120s.

In [22]:
w = ReflectionWorkflow(timeout=300, verbose=True)

# Run the workflow
ret = await w.run(
    passage="I own two cars: a Fiat Panda with 45Hp and a Honda Civic with 330Hp."
)

Running step extract
LLM 输出： ```json
{
  "cars": [
    {
      "brand": "Fiat",
      "model": "Panda",
      "power": 45
    },
    {
      "brand": "Honda",
      "model": "Civic",
      "power": 330
    }
  ]
}
```
Step extract produced event ExtractionDone
Running step validate
Step validate produced event StopEvent


In [23]:
print(ret)

{
  "cars": [
    {
      "brand": "Fiat",
      "model": "Panda",
      "power": 45
    },
    {
      "brand": "Honda",
      "model": "Civic",
      "power": 330
    }
  ]
}
