# Few-Shot Prompting for Data Extraction

Few-shot prompting refers to a set of techniques to control what an LLM produces by giving examples
of what it should produce. It's typically used to adapt an LLM to a domain, coerce it to follow a certain
style, or to improve the format of what it produces.

In this notebook, we'll use LangSmith to evaluate a couple few-shot prompting techniques.

The basic steps are:
- Create a "training" and evaluation dataset
- Benchmark a baseline instruction model on the eval dataset
- Create a few-shot prompt and benchmark
- Iterate and improve

Tl;dr - we find random sampling to be the best naive method without meta-prompting or example selection conditioned
on the result.

There are some great prompt tuning resources out there. [Lilian Weng's blog](https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/) is always a good place to start.

#### Create Datasets

We've got some samples from the [REBEL](https://github.com/Babelscape/rebel/tree/main) relationship extraction dataset we will use to develop our model. First, upload them to your langsmith organization.

In [2]:
from langsmith import Client

client = Client()

In [3]:
for split in ['train', 'validation', 'test']:
    name = f"Rebel-linearized-{split}"
    client.upload_csv(
        f"data/{name}.csv",
          input_keys=["context"],
          output_keys=["triplets"],
          name=name + "foo"
)

In [4]:
example_format = next(client.list_examples(dataset_name="Rebel-linearized-train"))
print(example_format.inputs)
print(example_format.outputs)

{'context': 'The feature appears in U . S . Navy aerial photographs taken in the 1960s and in imagery obtained by the NASA Earth Resources Technology Satellite ( ERTS-1 ) , 1973–74 . '}
{'triplets': '<triplet> Earth Resources Technology Satellite <subj> NASA <obj> operator'}


## Define an evaluator

Next we'll define a custom evaluator. We want to be permutation invariant and not be super strict on the exact format of each of the tail entity values, since it's hard to get completely open information extraction systems (and humans) to agree on the "ground truth".

We'll write a custom evaluator here.

In [5]:
from typing import Any, Optional

from langchain.evaluation import StringEvaluator
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.output_parsers import openai_functions

import json

eval_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are an impartial grader tasked with measuring the accuracy of extracted entity relations."),
        ("human", "Please evaluate the following data:\n\n"
         "<INPUT>\n{input}</INPUT>\n"
         "<PREDICTED>\n{prediction}</PREDICTED>\n"
         "<GROUND_TRUTH>\n{reference}</GROUND_TRUTH>\n\n"
         "Please save your reasoning and grading by calling the commit_grade function."
         " First, enumerate all factual discrepancies in the predicted triplets relative to the ground truth."
         " Finally, score the prediction on a scale out of 100, taking into account factuality and"
         " correctness according to the ground truth."),
         
    ]
)

commit_grade_schema = {
    "name": "commit_grade",
    "description": "Commits a grade with reasoning.",
    "parameters": {
        "title": "commit_grade_parameters",
        "description": "Parameters for the commit_grade function.",
        "type": "object",
        "properties": {
            "mistakes": {
                "title": "discrepancies",
                "type": "string",
                "description": "Any discrepencies between the predicted and ground truth."
            },
            "reasoning": {
                "title": "reasoning",
                "type": "string",
                "description": "The explanation or logic behind the final grade."
            },
            "grade": {
                "title": "grade",
                "type": "number",
                "description": "The numerical value representing the grade.",
                "minimum": 0,
                "maximum": 100
            }
        },
        "required": ["reasoning", "grade", "mistakes"],
    }
}


def normalize_grade(func_args: str) -> dict:
    args = json.loads(func_args)
    return {
        "reasoning": (args.get("reasoning", "") + "\n\n" + args.get("discrepancies", "")).strip(),
        "score": args.get("grade", 0) / 100,
    }

eval_chain = (
    eval_prompt 
    | ChatOpenAI(model="gpt-4", temperature=0).bind(functions=[commit_grade_schema])
    | openai_functions.OutputFunctionsParser() 
    | normalize_grade                             
)

class EvaluateTriplets(StringEvaluator):
    """Evaluate the triplets of a predicted string."""
    
    @property
    def requires_input(self) -> bool:
        return True
    
    @property
    def requires_reference(self) -> bool:
        return True

    def _evaluate_strings(
        self,
        *,
        prediction: str,
        reference: Optional[str] = None,
        input: Optional[str] = None,
        **kwargs: Any,
    ) -> dict:
        callbacks = kwargs.pop("callbacks", None)
        return eval_chain.invoke(
            {"prediction": prediction, "reference": reference, "input": input}, 
            {"callbacks": callbacks},
        )

#### Define Baseline Chain

We first want to get a baseline measurement of an extractor that only relies on instructions.
We do so below.

In [6]:
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser

In [7]:
# We will focus on an instructional prompt based on the format
# description of the dataset.
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are an autoregressive information extraction agent."),
        ("human", "Extract knowledge triplets from the following text:\n"
        "<TEXT>\n{context}\n<\/TEXT>"),
        ("system", "Output should be in linearized format.<triplet> marks the start of a new triplet with"
        "a new head entity, followed by the surface form"
        "of that entity in the input text. <subj> marks"
        "the end of the head entity and the start of the tail"
        "entity surface form. <obj> marks the end of the"
        "tail entity and the start of the relation between the"
        "head and tail entity, in its surface form"),
    ]
)

chain = prompt | ChatOpenAI() | StrOutputParser()

In [8]:
validation_dataset_name = "Rebel-linearized-validation"

In [9]:
from langchain.smith import RunEvalConfig

config = RunEvalConfig(
    custom_evaluators=[EvaluateTriplets()],
)

In [10]:
_ = await client.arun_on_dataset(validation_dataset_name, chain, evaluation=config)

View the evaluation results for project 'e9369e73b2e94d218aecf06a51af866a-RunnableSequence' at:
https://smith.langchain.com/projects/p/a0d94242-89d1-42af-b851-4fa6e8a3795b?eval=true


Retrying langchain.chat_models.openai.acompletion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised APIError: Bad gateway. {"error":{"code":502,"message":"Bad gateway.","param":null,"type":"cf_bad_gateway"}} 502 {'error': {'code': 502, 'message': 'Bad gateway.', 'param': None, 'type': 'cf_bad_gateway'}} <CIMultiDictProxy('Date': 'Fri, 18 Aug 2023 21:45:08 GMT', 'Content-Type': 'application/json', 'Content-Length': '84', 'Connection': 'keep-alive', 'X-Frame-Options': 'SAMEORIGIN', 'Referrer-Policy': 'same-origin', 'Cache-Control': 'private, max-age=0, no-store, no-cache, must-revalidate, post-check=0, pre-check=0', 'Expires': 'Thu, 01 Jan 1970 00:00:01 GMT', 'Server': 'cloudflare', 'CF-RAY': '7f8d48edbd8dceb5-SJC', 'alt-svc': 'h3=":443"; ma=86400')>.


## Few-Shot Examples

We'll first try with some static examples.


In [71]:
from langchain.prompts import FewShotChatMessagePromptTemplate

def create_few_shot_prompt(examples):
    formatted_examples = [{**ex.inputs, **ex.outputs} for ex in examples]
    # This is a prompt template used to format each individual example.
    example_prompt = ChatPromptTemplate.from_messages(
        [
            ("human", "Extract knowledge triplets from the following text:"
            "\n<TEXT>\n{context}</TEXT>"),
            ("ai", "{triplets}"),
        ]
    )
    return FewShotChatMessagePromptTemplate(
        example_prompt=example_prompt,
        examples=formatted_examples,
    )

def create_chain_from_examples(examples):
    few_shot_prompt = create_few_shot_prompt(examples)
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", "You are an autoregressive information extraction agent."),
            few_shot_prompt,
            ("human", "Extract knowledge triplets from the following text:\n"
            "<TEXT>\n{context}\n<\/TEXT>"),
            ("system", "Output should be in linearized format.<triplet> marks the start of a new triplet with"
        "a new head entity, followed by the surface form"
        "of that entity in the input text. <subj> marks"
        "the end of the head entity and the start of the tail"
        "entity surface form. <obj> marks the end of the"
        "tail entity and the start of the relation between the"
        "head and tail entity, in its surface form"),
        ]
    )
    return prompt | ChatOpenAI() | StrOutputParser()

In [73]:
import random
random.seed(42)

train_dataset_name = "Rebel-linearized-train"

K = 5
examples = list(client.list_examples(dataset_name=train_dataset_name))

# Randomly select K examples
random.shuffle(examples)
examples_head = examples[:K]
chain_2 = create_chain_from_examples(examples_head)

In [14]:
chain_2_results = await client.arun_on_dataset(validation_dataset_name, chain_2, evaluation=config)

View the evaluation results for project 'f091482618c54db4a0726b98067e9661-RunnableSequence' at:
https://smith.langchain.com/projects/p/441695f7-5a8a-4011-b4ec-297a6e6fd13a?eval=true


That increases the score by a bit. It seems to help! Would other static techniques help?

### Select hardest examples

You may hypothesize that selecting the "hardest" examples may be a good idea, since it will help the model "learn more" from them.
Hard examples carry more information, right?

We don't want to unfairly skew our results by selecting examples from the dev set, so we'll first go and score the training set.

In [15]:
# We will run this only on the training set just to score the outputs
training_results = await client.arun_on_dataset(train_dataset_name, chain_2, evaluation=config)

View the evaluation results for project '10a8199763ed4212ad05376cb90a6373-RunnableSequence' at:
https://smith.langchain.com/projects/p/6a28a8bf-733f-43ab-b400-5194a87a9722?eval=true


In [51]:
runs = list(client.list_runs(
    filter='and(eq(feedback_key, "EvaluateTriplets"), lt(feedback_score, 0.1))',
    project_name=training_results["project_name"]
))
example_ids = {r.reference_example_id for r in runs}
# examples = [client.list_examples(example_ids=[r.reference_example_id for r in runs])]
examples = [e for e in client.list_examples(dataset_id=client.read_example(example_id=runs[0].reference_example_id).dataset_id)
 if e.id in example_ids]

In [69]:
chain_3 = create_chain_from_examples(examples[:K])

In [74]:
results = await client.arun_on_dataset(validation_dataset_name, chain_3, evaluation=config)

View the evaluation results for project 'a37aafc4a00a49c0bbf3386a2834c029-RunnableSequence' at:
https://smith.langchain.com/projects/p/5c121572-a15b-48b0-8ae4-9b637fb9bd66?eval=true


This actually **decreased** the score. You can look at the selected examples to see why that may be the case.
In our case, they are all labels that seem to leave out information. It's poorly labeled!

### Select the "easy" examples

You may then think to filter the other way around: select the "easy" examples. 

In [94]:
runs = list(client.list_runs(
    filter='and(eq(feedback_key, "EvaluateTriplets"), gte(feedback_score, 0.7))',
    project_name=training_results["project_name"]
))
example_ids = {r.reference_example_id for r in runs}
# examples = [client.list_examples(example_ids=[r.reference_example_id for r in runs])]
examples = [e for e in client.list_examples(dataset_id=client.read_example(example_id=runs[0].reference_example_id).dataset_id)
 if e.id in example_ids]

In [96]:
chain_4 = create_chain_from_examples(examples[:K])

In [97]:
results = await client.arun_on_dataset(validation_dataset_name, chain_3, evaluation=config)

View the evaluation results for project '567387171ca943e2bce7d6f08889e768-RunnableSequence' at:
https://smith.langchain.com/projects/p/d482c156-2d80-4baf-b37c-fdcdefa98755?eval=true


Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised Timeout: Request timed out: HTTPSConnectionPool(host='api.openai.com', port=443): Read timed out. (read timeout=600).


This performs even worse!