# Introduction

We'll be using OpenAI's embeddings for this course so make sure that you've set up a `OPENAI_API_KEY` variable in your shell so that you can run the commands easily out of the box.

Before starting this part, make sure that you have ran the `setup.py` file so that we have a lancedb db that is populated with the first 1000 entries of the ms-marco dataset. Depending on the internet, this might take a while so do make sure that you have completed this step before the workshop

# N Levels of Complexity for Rag

Instead of evaluating your pipelines with a collection of randomly sampled questions, we'll show you today how to quantitatively evaluate your pipeline using a set of easy metrics to collect. 

In [21]:
from lib.query import get_test_queries,get_ms_marco_table,get_k_relevant_chunk_ids
from lib.eval import score
from tqdm import tqdm
import pandas as pd

test_queries = get_test_queries()
table = get_ms_marco_table()

scores = []
for query in tqdm(test_queries):
    retrieved_chunk_ids = get_k_relevant_chunk_ids(table,query['query'],25)
    evaluation_metrics = score(retrieved_chunk_ids,query['selected_chunk_id'])
    scores.append(evaluation_metrics)

df = pd.DataFrame(scores)
df.mean()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [00:04<00:00, 46.05it/s]


mrr@3        0.352857
mrr@5        0.399541
mrr@10       0.426036
mrr@15       0.428352
mrr@25       0.429378
recall@3     0.515306
recall@5     0.719388
recall@10    0.903061
recall@15    0.933673
recall@25    0.954082
dtype: float64

# LanceDB

In this section, we'll be showing you can create a lancedb database, define a new table based off a Pydantic Schema, ingest some data AND perform semantic search and full text search in under 40 lines of code.

In [22]:
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
import numpy as np

# Connect to the database
db = lancedb.connect("./lance-db")

# Configure our Database Schema
func = get_registry().get("openai").create(name="text-embedding-3-small")
class Entry(LanceModel):
    vector: Vector(func.ndims()) = func.VectorField()
    text: str = func.SourceField()

table = db.create_table("sample_table",schema=Entry,mode="overwrite")

# Ingest data into our database
sample_data = [
    "The Capital of France is Paris",
    "How long do you need for sydney and surrounding areas",
    "Twitter is a popular web application"
]
table.add([{"text":item} for item in sample_data])
table.create_fts_index("text",replace=True)

# Vector Search
results = table.search(np.random.random((1536))) \
    .limit(10) \
    .to_list()

for result in results:
    print(f"(Semantic) text: {result['text']}, vector: {result['vector'][:2]}, distance: {round(result['_distance'],3)}\n")

# Vector Search
results = table.search("What's a good place to visit in France?") \
    .limit(10) \
    .to_list()

for result in results:
    print(f"(Full Text Search) text: {result['text']}, vector: {result['vector'][:2]}, distance: {round(result['_distance'],3)}\n")

[2024-05-28T03:11:04Z WARN  lance::dataset] No existing dataset at /Users/ivanleo/Documents/rag-ws/notebooks/lance-db/sample_table.lance, it will be created


(Semantic) text: The Capital of France is Paris, vector: [0.02606566995382309, 0.020613687112927437], distance: 509.274

(Semantic) text: Twitter is a popular web application, vector: [0.004940158687531948, -0.047189969569444656], distance: 511.712

(Semantic) text: How long do you need for sydney and surrounding areas, vector: [0.0089243920519948, 0.022282233461737633], distance: 511.903

(Full Text Search) text: The Capital of France is Paris, vector: [0.02606566995382309, 0.020613687112927437], distance: 0.892

(Full Text Search) text: How long do you need for sydney and surrounding areas, vector: [0.0089243920519948, 0.022282233461737633], distance: 1.605

(Full Text Search) text: Twitter is a popular web application, vector: [0.004940158687531948, -0.047189969569444656], distance: 1.806



## Exercises

Now that we've seen how easy it is to get started with LanceDB, we can start focusing on some of the most common problems that we'd face when ingesting data

1. How do I store metadata?
2. How can I compute some derived fields from my text chunks
3. How does filtering of results work?
4. How can I do deduplication on my data so that I don't have duplicate chunked text

In the next few examples, we'll show you guys how to perform these basic operations

### Setup

Run this command before proceeding with the exercises. The exercises should be done in order

In [23]:
import json

with open("./tools.json", "r") as file:
    data = json.loads(file.read())

data

In [24]:
data

[{'name': 'Hammer',
  'description': 'A tool with a heavy metal head mounted at right angles at the end of a handle, used for jobs such as breaking things and driving in nails.',
  'category': 'Hand Tool'},
 {'name': 'Screwdriver',
  'description': 'A tool with a flattened, cross-shaped, or star-shaped tip that fits into the head of a screw to turn it.',
  'category': 'Hand Tool'},
 {'name': 'Electric Drill',
  'description': 'A power tool fitted with a cutting tool attachment or driving tool attachment, usually a drill bit or driver bit, used for boring holes in various materials or fastening various materials together.',
  'category': 'Power Tool'},
 {'name': 'Wrench',
  'description': 'A tool used for gripping and turning nuts, bolts, pipes, etc.',
  'category': 'Hand Tool'},
 {'name': 'Pliers',
  'description': 'A hand tool used to hold objects firmly, possibly developed from tongs used to handle hot metal in Bronze Age Europe.',
  'category': 'Hand Tool'},
 {'name': 'Circular Saw'

In [25]:
db = lancedb.connect("./lance-db")

# Delete all tables in db
for table in db.table_names():
    db.drop_table(table)

### Adding metadata

Try to create table which has information on the name, description and category of the item. Make sure to also embed the description using the `OpenAI` embedding function text-embedding-3-small model

In [26]:
# Create the Pydantic Schema

func = get_registry().get("openai").create(name="text-embedding-3-small")

class Tool(LanceModel):
    vector: Vector(func.ndims()) = func.VectorField()
    description: str = func.SourceField()
    name:str
    category:str

tool_table = db.create_table("tool_v1",schema=Tool,mode="overwrite")

[2024-05-28T03:11:07Z WARN  lance::dataset] No existing dataset at /Users/ivanleo/Documents/rag-ws/notebooks/lance-db/tool_v1.lance, it will be created


In [27]:
tool_table.add(data)

In [28]:
tool_table.to_pandas()

Unnamed: 0,vector,description,name,category
0,"[0.02942855, 0.047012717, 0.004617972, -0.0125...",A tool with a heavy metal head mounted at righ...,Hammer,Hand Tool
1,"[0.018697312, 0.015099513, -0.047086194, -0.01...","A tool with a flattened, cross-shaped, or star...",Screwdriver,Hand Tool
2,"[-0.03568479, 0.013477032, 0.0101103475, -0.01...",A power tool fitted with a cutting tool attach...,Electric Drill,Power Tool
3,"[0.02803261, 0.024098208, -0.026885075, -0.021...","A tool used for gripping and turning nuts, bol...",Wrench,Hand Tool
4,"[0.02052933, 0.014397344, -0.034926675, -0.013...","A hand tool used to hold objects firmly, possi...",Pliers,Hand Tool
5,"[-0.00037332025, 0.033678483, -0.03620599, -0....",A power-saw using a toothed or abrasive disc o...,Circular Saw,Power Tool
6,"[0.020676186, 0.032748703, 0.0079361405, -0.01...",A flexible ruler used to measure size or dista...,Tape Measure,Measuring Tool
7,"[0.06171664, 0.04006746, -0.0073134657, -0.004...",A tool with a characteristically shaped cuttin...,Chisel,Hand Tool
8,"[-0.024902405, 0.054924037, 0.0053345207, -0.0...",An instrument designed to indicate whether a s...,Level,Measuring Tool
9,"[0.014727382, 0.035678905, -0.026631918, -0.03...",A power tool used for cutting arbitrary curves...,Jigsaw,Power Tool


### Computing a field

Let's now try to compute a chunk_id which identifies a unique description and name using the `hashlib` library in python

In [29]:
# Create the Pydantic Schema

func = get_registry().get("openai").create(name="text-embedding-3-small")

class Tool(LanceModel):
    vector: Vector(func.ndims()) = func.VectorField()
    description: str = func.SourceField()
    name:str
    category:str
    tool_id:str

tool_table = db.create_table("tool_v2",schema=Tool,mode="overwrite")

[2024-05-28T03:11:08Z WARN  lance::dataset] No existing dataset at /Users/ivanleo/Documents/rag-ws/notebooks/lance-db/tool_v2.lance, it will be created


In [30]:
import hashlib

encoded_chunks = []
for row in data:
    name_and_description=f"{row['description']}-{row['name']}"
    tool_id = hashlib.md5(name_and_description.encode()).hexdigest()
    encoded_chunks.append({**row,"tool_id":tool_id})

tool_table.add(encoded_chunks)
tool_table.to_pandas()[:10]

Unnamed: 0,vector,description,name,category,tool_id
0,"[0.029243259, 0.046320148, 0.0031912285, -0.00...",A tool with a heavy metal head mounted at righ...,Hammer,Hand Tool,5127eb88579d6565b794833acf0eff6e
1,"[0.018722245, 0.01513522, -0.047092352, -0.010...","A tool with a flattened, cross-shaped, or star...",Screwdriver,Hand Tool,18ffc6ead466b6e270e8a970544510d3
2,"[-0.03569433, 0.013480635, 0.010113051, -0.019...",A power tool fitted with a cutting tool attach...,Electric Drill,Power Tool,ad1d1aac2c439006635c48d6969d2a8f
3,"[0.028033728, 0.02412259, -0.026932988, -0.021...","A tool used for gripping and turning nuts, bol...",Wrench,Hand Tool,bbf94b44b5866f023b2b8b5491657dd3
4,"[0.02175587, 0.013911997, -0.03692618, -0.0119...","A hand tool used to hold objects firmly, possi...",Pliers,Hand Tool,36d03bc35bdddd89da298675e5c0cd73
5,"[-0.0005456097, 0.033061855, -0.03701373, -0.0...",A power-saw using a toothed or abrasive disc o...,Circular Saw,Power Tool,61051798477a715f7e2e09df8880a8ee
6,"[0.020711862, 0.03277381, 0.007930988, -0.0128...",A flexible ruler used to measure size or dista...,Tape Measure,Measuring Tool,8bd1313982c2bd9313f1ac304ac43f8b
7,"[0.061667863, 0.0399984, -0.0072744344, -0.004...",A tool with a characteristically shaped cuttin...,Chisel,Hand Tool,d65a66bc9595fa8a1247bcdcbe5a0a37
8,"[-0.024902405, 0.054924037, 0.0053345207, -0.0...",An instrument designed to indicate whether a s...,Level,Measuring Tool,356a50889baebb01c83c2b33eba1187c
9,"[0.014716265, 0.03568, -0.026632737, -0.033990...",A power tool used for cutting arbitrary curves...,Jigsaw,Power Tool,dda7672a25896af487fef054e3775f2a


### Filtering

Now that we've indexed our data inside the field, let's try to retrieve all of the tools which have the category Hand Tool (Note here that we have prefilter=True which allows us to ensure we get the number of elements that we want in the end

In [31]:
%%time
tool_table = db.open_table("tool_v2")
tool_table.search().select(["name","description","category"]).where("category='Hand Tool'",prefilter=True).to_pandas()

CPU times: user 4.81 ms, sys: 3.5 ms, total: 8.31 ms
Wall time: 4.81 ms


Unnamed: 0,name,description,category
0,Hammer,A tool with a heavy metal head mounted at righ...,Hand Tool
1,Screwdriver,"A tool with a flattened, cross-shaped, or star...",Hand Tool
2,Wrench,"A tool used for gripping and turning nuts, bol...",Hand Tool
3,Pliers,"A hand tool used to hold objects firmly, possi...",Hand Tool
4,Chisel,A tool with a characteristically shaped cuttin...,Hand Tool
5,Allen Wrench,A tool used to drive bolts and screws with hex...,Hand Tool
6,Socket Wrench,"A wrench with a socket attached at one end, us...",Hand Tool
7,Hacksaw,"A fine-toothed saw, originally and mainly made...",Hand Tool
8,Soldering Iron,"A hand tool used in soldering, which supplies ...",Hand Tool


Now that we've indexed a single column, let's try now returning rows that have the `Measuring Tool` and `Hand Tool` category

In [32]:
%%time
tool_table = db.open_table("tool_v2")
tool_table.search().select(["name","description","category"]).where("category IN ('Hand Tool','Measuring Tool')",prefilter=True).to_pandas()

CPU times: user 5.98 ms, sys: 4.15 ms, total: 10.1 ms
Wall time: 8.41 ms


Unnamed: 0,name,description,category
0,Hammer,A tool with a heavy metal head mounted at righ...,Hand Tool
1,Screwdriver,"A tool with a flattened, cross-shaped, or star...",Hand Tool
2,Wrench,"A tool used for gripping and turning nuts, bol...",Hand Tool
3,Pliers,"A hand tool used to hold objects firmly, possi...",Hand Tool
4,Tape Measure,A flexible ruler used to measure size or dista...,Measuring Tool
5,Chisel,A tool with a characteristically shaped cuttin...,Hand Tool
6,Level,An instrument designed to indicate whether a s...,Measuring Tool
7,Allen Wrench,A tool used to drive bolts and screws with hex...,Hand Tool
8,Socket Wrench,"A wrench with a socket attached at one end, us...",Hand Tool
9,Hacksaw,"A fine-toothed saw, originally and mainly made...",Hand Tool


### Deduplication

Now that we've ingested the data, we need to make sure we don't have any deduplication of data. To do so, we'll use a second set of tools in `tool_2.json` which has some overlapping entries.

In [33]:
with open("./tools_2.json", "r") as file:
    data_2 = json.loads(file.read())

In [34]:
tool_table.count_rows(),len(data_2)

(16, 7)

In [35]:
tool_table = db.open_table("tool_v1")
tool_table.add(data_2)
tool_table.count_rows()

23

In [36]:
tool_table = db.open_table("tool_v1")
tool_table.add(data_2)
tool_table.count_rows()

30

In [37]:
import hashlib

encoded_chunks = []
for row in data_2:
    name_and_description=f"{row['description']}-{row['name']}"
    chunk_id = hashlib.md5(name_and_description.encode()).hexdigest()
    encoded_chunks.append({**row,"chunk_id":chunk_id})

In [38]:
# Create the Pydantic Schema

func = get_registry().get("openai").create(name="text-embedding-3-small")

class Tool(LanceModel):
    vector: Vector(func.ndims()) = func.VectorField()
    description: str = func.SourceField()
    name:str
    category:str
    chunk_id:str

tool_table = db.create_table("tool_v2",schema=Tool,mode="overwrite")

encoded_chunks = []
for row in data:
    name_and_description=f"{row['description']}-{row['name']}"
    chunk_id = hashlib.md5(name_and_description.encode()).hexdigest()
    encoded_chunks.append({**row,"chunk_id":chunk_id})

def get_duplicate_chunk_ids(encoded_chunks):
    tool_table = db.open_table("tool_v2")
    ids = [item['chunk_id'] for item in encoded_chunks]
    
    formatted_filter = ', '.join([f"'{id}'" for id in ids])
    return set(tool_table.to_lance().to_table(filter=f"chunk_id in ({formatted_filter})", columns=['chunk_id']).to_pandas()['chunk_id'])

def filter_duplicate_chunks(encoded_chunks,duplicate_ids):
    return [item for item in encoded_chunks if item['chunk_id'] not in duplicate_ids]

duplicate_ids = get_duplicate_chunk_ids(encoded_chunks)
filtered_chunks = filter_duplicate_chunks(encoded_chunks,duplicate_ids)

if filtered_chunks:
    tool_table.add(filtered_chunks)
tool_table.count_rows()

16

Before we proceed, let's clean up the database that we used for this section and remove it from our computer

In [39]:
import shutil

shutil.rmtree("./lance-db")

**Summary** : LanceDB provides an easy way to have FTS ( as a simple baseline ) and embedding search. It also handles batching and provides other functionality such as integrations with duckdb, arrow table and filtering out of the box.

# Metrics

Now that we've figured out how our vector db works, let's see how we can evaluate our retrieval results.


## Evaluating Our Data

Let's see an example of how we can evaluate our data using the ms-marco dataset with real user queries

In [41]:
from itertools import product
from lib.eval import calculate_reciprocal_rank,calculate_recall
SIZES = [3,5,10,15,25]
metrics = {
    "mrr": calculate_reciprocal_rank,
    "recall":calculate_recall
}

In [43]:
from tqdm import tqdm

db = lancedb.connect("../lance")
table = db.open_table("ms_marco")


def load_jsonl_file(file_path):
    data = []
    with open(file_path, "r") as file:
        for line in file:
            json_obj = json.loads(line.strip())
            data.append(json_obj)
    return data

def get_k_relevant_chunk_ids(table,query,number):
    return [item['chunk_id'] for item in table.search(query,query_type="fts").limit(number).to_list()]

def score(preds,label):
    return {
        f"{fn_name}@{size}": round(metrics[fn_name](preds[:size],[label]),3)
        for fn_name,size in product(metrics.keys(),SIZES)
    }

data = load_jsonl_file("../queries.jsonl")

scores = []
for item in tqdm(data):
    query = item['query']
    selected_chunk = item['selected_passages'][0]['chunk_id']
    retrieved_chunk_ids = get_k_relevant_chunk_ids(table,query,max(SIZES))
    scores.append(score(retrieved_chunk_ids,selected_chunk))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 196/196 [00:04<00:00, 47.40it/s]


In [44]:
import pandas as pd

df = pd.DataFrame(scores)

In [45]:
df.mean()

mrr@3        0.352857
mrr@5        0.399541
mrr@10       0.426036
mrr@15       0.428352
mrr@25       0.429378
recall@3     0.515306
recall@5     0.719388
recall@10    0.903061
recall@15    0.933673
recall@25    0.954082
dtype: float64

## Sample Metrics

In [46]:
db = lancedb.connect("../lance")
table = db.open_table("ms_marco")
table.create_fts_index("text",replace=True)

In [47]:
results = table.search("where are the lungs located?").limit(4).select(['chunk_id','text']).to_pandas()
results

Unnamed: 0,chunk_id,text,_distance
0,5dde98e64a2e9639975cbf032a20922f,your lungs are located on the right and left s...,0.608745
1,b40ab6024251c5c37f26f27adefc41c2,"In humans, the lungs are located on either sid...",0.690447
2,de2eb296f2ddf0f9ca5e86301c13dcc3,"© 2014 WebMD, LLC. All rights reserved. The lu...",0.713795
3,3fd0aabc0e8f2f53be360044b89f3887,1. Get help from a doctor now ›. under ribs: T...,0.763482


Given the results above, what metrics can we use to rank these potential retrieved results?

- [0,1,2,3]
- [3,2,1,0]
- [1,2,3,4]

Obviously, we want `0` to be ranked first because it has the result we want and we want to penalise cases where `0` is ranked lower or not present.

### Reciprocal Rank

Highlights the importance of quickly surfacing at least one relevant document, with an emphasis on the efficiency of relevance delivery. Matters a lot when there are only a few items we can show to the user at any given time.

Formula is $$\frac{1}{\text{First Relevant Item}}$$


In [48]:
def calculate_reciprocal_rank(predictions, labels):
    for index, prediction in enumerate(predictions):
        if prediction in labels:
            return 1 / (index + 1)
    return 0

In [49]:
predictions = [1,2,3,4,5]
labels = [2,4]

calculate_reciprocal_rank(predictions,labels) # 1/2 = 0.5 since earliest relevant item is at index=2

0.5

In [50]:
predictions = [1,2,3,4,5]
labels = [10,20]

calculate_reciprocal_rank(predictions,labels) # No Relevant Items

0

### Recall

Recall measures the system's capability to retrieve all relevant documents within the top K results, emphasizing the breadth of relevant information captured.

Formula is $$\frac{\text{Number of Retrieved Relevant Items}}{\text{Total Number of Relevant Items}}$$

In [51]:
def calculate_recall(predictions, labels):
    correct_predictions = sum(1 for label in labels if label in predictions)
    if labels:
        return correct_predictions / len(labels)
    return 0


In [52]:
predictions = [1,2,3,4,5]
labels = [2,4]

calculate_recall(predictions,labels) # 2/2

1.0

In [53]:
predictions = [1,2,3,4,5]
labels = [200,20]

calculate_recall(predictions,labels) # 0/2 = 0

0.0

In [54]:
predictions = [1,2,3,4,5]
labels = [2,10]

calculate_recall(predictions,labels) # 0.5

0.5

These metrics allow us to be able to see the performance of our system and quantify the performance improvements of incremental improvements over time. There are more metrics that you can track ( see our article [here](https://jxnl.co/writing/2024/02/05/when-to-lgtm-at-k/) )

In short, think of the two metrics as follows

- Recall: How many relevant items did we surface?
- Reciprocal Rank: Where was the first relevant item that we care about?

Always remember to take these metrics at some value of `k` where `k` is the number of elements that we're evaluating this metrics for.

# Cold Starting

What can we do if we have no user queries and we're just starting out? Well, the easiest way is to use synthethic queries to automatically generate the data to do so!

In [58]:
import instructor
import openai
from pydantic import BaseModel,Field
from tqdm.asyncio import tqdm_asyncio as asyncio

client = instructor.from_openai(openai.AsyncOpenAI())

class QuestionAnswerPair(BaseModel):
    """
    This model represents a pair of a question generated from a text chunk, its corresponding answer,
    and the chain of thought leading to the answer. The chain of thought provides insight into how the answer
    was derived from the question.
    """

    chain_of_thought: str = Field(
        ..., description="The reasoning process leading to the answer."
    )
    question: str = Field(
        ..., description="The generated question from the text chunk."
    )
    answer: str = Field(..., description="The answer to the generated question.")

async def generate_question_batch(text_chunk_batch):
    async def generate_question(text: str):
        question = await client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {
                    "role": "system",
                    "content": "You are a world class AI that excels at generating hypothethical search queries. You're about to be given a text snippet and asked to generate a search query which is specific to the specific text chunk that you'll be given. Make sure to use information from the text chunk.",
                },
                {"role": "user", "content": f"Here is the text chunk : {text}"},
            ],
            response_model=QuestionAnswerPair,
            max_retries=3,
        )
        return (question,text)

    coros = [
        generate_question(item) for item in text_chunk_batch
    ]
    res = await asyncio.gather(*coros)
    return [{"input": item.question, "source": text} for item,text in res]

chunks = [
    "Conversion disorder is a type of somatoform disorder where physical symptoms or signs are present that cannot be explained by a medical condition. Very importantly, unlike factitious disorders and malingering, the symptoms of somatoform disorders are not intentional or under conscious control of the patient",
    "A conifer is a tree or shrub which produces distinctive cones as part of its sexual reproduction. These woody plants are classified among the gymnosperms, and they have a wide variety of uses, from trapping carbon in the environment to providing resins which can be used in the production of solvents. Several features beyond the cones set conifers apart from other types of woody plants. A conifer is typically evergreen, although some individuals are deciduous, and almost all conifers have needle or scale-like leaves",
    "Known by multiple common names, such as humbug damselfish, three-striped damselfish and white-tailed damselfish, Dascyllus aruanus is a feisty little fish that adapts well to aquarium life. Three-striped damselfish can be pugnacious and are better introduced at the latter stages of setting up a marine fish community. Remove as many of the three-striped damselfish fry as you want to try and raise to a rearing aquarium, with an absence of adult fish and invertebrates that might look upon the young fish as tasty morsels for the taking. Dascyllus aruanus is a worthy first-time breeding project for up-and-coming marine aquarists"
]

questions = await generate_question_batch(chunks)
questions

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.15it/s]


[{'input': 'What is conversion disorder and how does it differ from factitious disorders and malingering?',
  'source': 'Conversion disorder is a type of somatoform disorder where physical symptoms or signs are present that cannot be explained by a medical condition. Very importantly, unlike factitious disorders and malingering, the symptoms of somatoform disorders are not intentional or under conscious control of the patient'},
 {'input': 'What are the key features that differentiate conifers from other types of woody plants?',
  'source': 'A conifer is a tree or shrub which produces distinctive cones as part of its sexual reproduction. These woody plants are classified among the gymnosperms, and they have a wide variety of uses, from trapping carbon in the environment to providing resins which can be used in the production of solvents. Several features beyond the cones set conifers apart from other types of woody plants. A conifer is typically evergreen, although some individuals are

In [None]:
questions[0]["input"]

## Instructor

Instructor is a library that provides structured output validation

In [None]:
import instructor
from pydantic import BaseModel
from openai import OpenAI


# Define your desired output structure
class UserInfo(BaseModel):
    name: str
    age: int


# Patch the OpenAI client
client = instructor.from_openai(OpenAI())

# Extract structured data from natural language
user_info = client.chat.completions.create(
    model="gpt-3.5-turbo",
    response_model=UserInfo,
    messages=[{"role": "user", "content": "John Doe is 30 years old."}],
)

print(user_info.name)
print(user_info.age)

### Exercises

Now that we've seen what Instructor can do, let's work through a few different exercises to get a better understanding of the library

#### Adding some docstrings

Let's try creating a Pydantic Model that has docstrings and descriptions using the `Field` object.

Modify the original `UserInfo` object to include a docstring and a description of each field

In [59]:
from pydantic import Field

In [60]:
class UserInfo(BaseModel):
    """
    This is a model which represents a single user's information
    """
    name: str = Field(...,description="This is the user's name which we have extracted")
    age: int = Field(...,description="This is the user's age which we have extracted")


In [61]:
UserInfo.model_json_schema()

{'description': "This is a model which represents a single user's information",
 'properties': {'name': {'description': "This is the user's name which we have extracted",
   'title': 'Name',
   'type': 'string'},
  'age': {'description': "This is the user's age which we have extracted",
   'title': 'Age',
   'type': 'integer'}},
 'required': ['name', 'age'],
 'title': 'UserInfo',
 'type': 'object'}

#### Using simple validation

Now that we've seen how to work with simple User Fields, let's start implementing validators. 

Validators are simple functions that run on the returned response from OpenAI. Using Validators, we can ensure that we have valid output. To show how a simple validator might work, let's try to implement a simple function which generates three categories given an article title. 

In [None]:
from pydantic import field_validator

class Metadata(BaseModel):
    """
    This is a model which represents a list of categories that we can classify the given article into
    """
    categories: list[str] = Field(..., description="This is the list of categories that we can classify the given article into")
    keywords: list[str] = Field(...,description="These are some keywords that users might search for when looking for similar articles as the given article.")

    @field_validator('categories')
    def check_categories_length(cls, v):
        if not (3 <= len(v) <= 5):
            raise ValueError('categories must have at least 3 elements and at most 5 elements')
        return v


In [None]:
metadata = client.chat.completions.create(
    model="gpt-3.5-turbo",
    response_model=Metadata,
    messages=[{"role": "system", "content": "You are a World Class classification Algorithm. You are about to be given an article title to classify. Make sure to return your response in the model provided"},
             {"role": "user", "content": "Give me a sample article title for classification: The Future of Artificial Intelligence in Healthcare"}
             ],
)
metadata

#### More Complex Types

We've now seen how to use Pydantic to validate our returned types with instructor. Now let's try a more complex example

Imagine you're trying to do some query parsing and you have a set of given tools

1. Internet Search
2. Database Queries
3. Meeting Scheduler

How might we represent this in a Pydantic Model?

In [71]:
from datetime import datetime
from typing import List,Literal,Union
from pydantic import field_validator
from openai import OpenAI
import instructor

client = instructor.from_openai(OpenAI())

class InternetSearch(BaseModel):
    """
    Model for representing an internet search query.
    
    """
    id: int = Field(..., description="Unique id of the query")
    search_query: str = Field(..., description="This is an internet search query that we will execute to identify relevant information.")
    dependencies: List[int] = Field(
        default_factory=list,
        description="List of sub questions that need to be answered before asking this question",
    )

class CalendarQuery(BaseModel):
    """
    A model that represents
    """
    id: int = Field(..., description="Unique id of the query")
    calendar: Literal['Personal', 'Work'] = Field(..., description="The type of calendar (Personal or Work).")
    start_date: str = Field(..., description="The earliest date for events that we'd like to fetch for this calendar")
    end_date: str = Field(..., description="The latest date for events that we'd like to fetch for this calendar")
    dependencies: List[int] = Field(
        default_factory=list,
        description="List of sub questions that need to be answered before asking this question",
    )

    @field_validator("start_date", "end_date")
    def validate_date_format(cls, value):
        try:
            datetime.strptime(value, "%d-%m")
        except ValueError:
            raise ValueError("Date must be in the format dd-mm")
        return value
    

class QueryModel(BaseModel):
    """
    A list of actions to execute in order to complete the user's request
    """
    actions: List[Union[InternetSearch, CalendarQuery]] = Field(..., description="A list of actions.")


def generate_actions(request: str) -> QueryModel:
    """
    Generate a list of actions to schedule an appointment based on the user's request.
    """
    return client.chat.completions.create(
        model="gpt-4o",
        response_model=QueryModel,
        messages=[
            {"role": "system", "content": "You are a scheduling assistant capable of breaking down complex user queries into actions to be executed. Do not answer the question but instead return a list of plausible steps in order to get enough information to answer the user's query"},
            {"role": "assistant", "content": "The date today is 27 May 2024, Monday. The user lives in Downtown Toronto and generally likes Japanese Food"},
            {"role": "user", "content": request}
        ],
        max_retries=3
    )

request = "I'd like to grab dinner with Daniel sometime next week. Can you help me find some time in my calendar and some potential dinner spots?"
actions = generate_actions(request)
print(actions)

actions=[CalendarQuery(id=1, calendar='Personal', start_date='03-06', end_date='09-06', dependencies=[]), InternetSearch(id=2, search_query='best Japanese restaurants in Downtown Toronto', dependencies=[])]
