In [1]:
from dotenv import load_dotenv

load_dotenv()

True

In [7]:
from typing import List, Union, Literal
import os

from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain_openai import ChatOpenAI
from langchain_core.exceptions import OutputParserException
from langchain_experimental.llms.ollama_functions import OllamaFunctions

In [3]:
from pydantic import (
    BaseModel,
    ValidationError,
    ValidationInfo,
    field_validator,
)

In [4]:
from pydantic.v1 import BaseModel as BaseModelV1

In [5]:
from langchain_anthropic import ChatAnthropic

## OpenAI

In [6]:
import os
from openai import OpenAI

client = OpenAI(
    # This is the default and can be omitted
    api_key=os.environ.get("OPENAI_API_KEY"),
)

chat_completion = client.chat.completions.create(
    messages=[
        {
            "role": "user",
            "content": "Say this is a test",
        }
    ],
    model="gpt-3.5-turbo",
)

## Anthropic

In [7]:
import os
from anthropic import Anthropic

client = Anthropic(
    # This is the default and can be omitted
    api_key=os.environ.get("ANTHROPIC_API_KEY"),
    base_url=os.environ.get("ANTHROPIC_API_BASE")
)

message = client.messages.create(
    max_tokens=1024,
    messages=[
        {
            "role": "user",
            "content": "Hello, Claude",
        }
    ],
    model="claude-3-opus-20240229",
)
print(message.content)

[TextBlock(text="Hello! It's nice to meet you. How can I assist you today?", type='text')]


## Choose model

In [31]:
# model = ChatOpenAI(model ="gpt-4-turbo", temperature=0)

# model = ChatAnthropic(
#     model='claude-3-opus-20240229',
#     anthropic_api_key=os.environ.get("ANTHROPIC_API_KEY"),
#     anthropic_api_url=os.environ.get("ANTHROPIC_BASE_URL")
# )

model = OllamaFunctions(model="phi3", format="json", temperature=0.3)

## Develop retry

In [8]:
MODEL_PROMPT = "Answer the user query. Remember to only respond with JSON.\n{format_instructions}\n{query}\n"
MODEL_REPROMPT = MODEL_PROMPT + "Pay special attention to the following error\n{validation_error}\n"

def get_model(query: str, model, result_model: Union[BaseModel, BaseModelV1], max_retry: int = 3) -> BaseModel:
    parser = PydanticOutputParser(pydantic_object=result_model)

    for i in range(max_retry):
        try:
            prompt = PromptTemplate(
                template=MODEL_PROMPT if i == 0 else MODEL_REPROMPT,
                input_variables=["query"],
                partial_variables={"format_instructions": parser.get_format_instructions()},
            )

            chain = prompt | model | parser

            res = chain.invoke({"query": query}) if i == 0 else chain.invoke(
                {"query": query, "validation_error": str(validation_error)})
            return res
        except OutputParserException as e:
            validation_error = e
    return validation_error

### Examples

In [65]:
class Actor(BaseModel):
    name: str = Field(description="name of an actor")
    country_origin: str = Field(description="Country they were born in")

    @field_validator('country_origin')
    @classmethod
    def country_upper(cls, v: str) -> str:
        if not v.isupper():
            raise ValueError('Must be all caps!')
        return v

In [66]:
class Actors(BaseModel):
    actors: List[Actor]

In [67]:
query = "What 5 actors who played in the movie Inception?"


In [70]:
res = get_model(query, model, result_model=Actors,max_retry=1)



In [71]:
res

langchain_core.exceptions.OutputParserException('Failed to parse Actors from completion {"actors": [{"name": "Leonardo DiCaprio", "country_origin": "United States"}, {"name": "Joseph Gordon-Levitt", "country_origin": "United States"}, {"name": "Ellen Page", "country_origin": "Canada"}, {"name": "Tom Hardy", "country_origin": "United Kingdom"}, {"name": "Ken Watanabe", "country_origin": "Japan"}]}. Got: 5 validation errors for Actors\nactors.0.country_origin\n  Value error, Must be all caps! [type=value_error, input_value=\'United States\', input_type=str]\n    For further information visit https://errors.pydantic.dev/2.7/v/value_error\nactors.1.country_origin\n  Value error, Must be all caps! [type=value_error, input_value=\'United States\', input_type=str]\n    For further information visit https://errors.pydantic.dev/2.7/v/value_error\nactors.2.country_origin\n  Value error, Must be all caps! [type=value_error, input_value=\'Canada\', input_type=str]\n    For further information visi

In [15]:
res

Actors(actors=[Actor(name='Leonardo DiCaprio', country_origin='UNITED STATES'), Actor(name='Joseph Gordon-Levitt', country_origin='UNITED STATES'), Actor(name='Ellen Page', country_origin='CANADA'), Actor(name='Tom Hardy', country_origin='UNITED KINGDOM'), Actor(name='Cillian Murphy', country_origin='IRELAND')])

## Trying actual Vizro models

In [6]:
import vizro.models as vm

### Filter

In [73]:
query = "I need a filter that filters on the columns 'gdpPerCap` and uses a dropdown as selector."

In [74]:
res = get_model(query, model, result_model=vm.Filter)

In [75]:
res

langchain_core.exceptions.OutputParserException('Failed to parse Filter from completion {"id": "1", "type": "filter", "column": "gdpPerCap", "targets": [], "selector": {"id": "2", "type": "dropdown", "options": ["Option 1", "Option 2", "Option 3"], "value": "Option 1", "title": "Select GDP Per Capita"}}. Got: 1 validation error for Filter\nselector -> Dropdown\n  Model with id=2 already exists. Models must have a unique id across the whole dashboard. If you are working from a Jupyter Notebook, please either restart the kernel, or use \'from vizro import Vizro; Vizro._reset()`. (type=value_error.duplicateid)')

### Page

In [51]:
class PageAI(vm.Page):
    components: List[Literal["scatter", "bar", "line", "table", "pie", "map"]]

In [10]:
query = "I need a page with a bar and a scatter that filters on the columns 'gdpPerCap` and uses a dropdown as selector."

In [13]:
res = get_model(query, model, result_model=vm.Page)



In [15]:
print(res)

Failed to parse Page from completion {"components": [{"id": "bar1", "type": "graph", "actions": []}, {"id": "scatter1", "type": "graph", "actions": []}], "title": "Economic Data Visualization", "description": "Visualize GDP per capita data in bar and scatter plot formats.", "layout": null, "controls": [{"id": "filter1", "type": "filter", "column": "gdpPerCap", "targets": ["bar1", "scatter1"], "selector": {"id": "dropdown1", "type": "dropdown", "options": [], "value": null, "multi": true, "title": ""}}], "path": "", "actions": []}. Got: 4 validation errors for Page
components -> 0 -> Graph -> figure
  field required (type=value_error.missing)
components -> 1 -> Graph -> figure
  field required (type=value_error.missing)
controls -> 0 -> Filter -> targets -> 0
  Target bar1 not found in model_manager. (type=value_error)
controls -> 0 -> Filter -> targets -> 1
  Target scatter1 not found in model_manager. (type=value_error)


In [21]:
PageAI

__main__.PageAI

In [57]:
res = get_model("I want a card with some random text that starts with quack", model, result_model=vm.Card)

In [58]:
res

Card(id='8d723104-f773-83c1-3458-a748e9bb17bc', type='card', text='quack quack quack Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.', href='')

In summary, using Vizro models works generally well, but it struggle with more complex models such as Page, likely due to the sheer size of the schema, and thus the large number of validation errors.

## Alter code to try out with Ollama

In [32]:
# Define your desired data structure.
class Joke(BaseModel):
    setup: str = Field(description="question to set up a joke")
    punchline: str = Field(description="answer to resolve the joke")

    # You can add custom validation logic easily with Pydantic.
    # @validator("setup")
    # def question_ends_with_question_mark(cls, field):
    #     if field[-1] != "?":
    #         raise ValueError("Badly formed question!")
    #     return field


# And a query intented to prompt a language model to populate the data structure.
joke_query = "Tell me a joke."

class Actor(BaseModel):
    name: str = Field(description="name of an actor")
    film_names: List[str] = Field(description="list of names of films they starred in")
    
actor_query = "Generate the info for a random actor."

In [33]:
MODEL_PROMPT = "Answer the user query. Remember to only respond with JSON and to fill in all keys.\n{query}\n"
MODEL_REPROMPT = MODEL_PROMPT + "Fix the error in your previous try:\n{validation_error}\n"

def get_model(query: str, model, result_model: Union[BaseModel, BaseModelV1], max_retry: int = 3) -> BaseModel:
    # parser = PydanticOutputParser(pydantic_object=result_model)

    for i in range(max_retry):
        try:
            prompt = PromptTemplate(
                template=MODEL_PROMPT if i == 0 else MODEL_REPROMPT,
                input_variables=["query"],
            )
            print(prompt)

            chain = prompt | model.with_structured_output(result_model)

            res = chain.invoke({"query": query}) if i == 0 else chain.invoke(
                {"query": query, "validation_error": str(validation_error)})
            return res
        except OutputParserException as e:
            validation_error = e
    return validation_error

In [36]:
res = get_model(joke_query, model, result_model=Joke)

input_variables=['query'] template='Answer the user query. Remember to only respond with JSON and to fill in all keys.\n{query}\n'
input_variables=['query', 'validation_error'] template='Answer the user query. Remember to only respond with JSON and to fill in all keys.\n{query}\nFix the error in your previous try:\n{validation_error}\n'
input_variables=['query', 'validation_error'] template='Answer the user query. Remember to only respond with JSON and to fill in all keys.\n{query}\nFix the error in your previous try:\n{validation_error}\n'


In [37]:
res

langchain_core.exceptions.OutputParserException('Failed to parse Joke from completion {"setup": "Why was the math book sad? Because it had too many problems."}. Got: 1 validation error for Joke\npunchline\n  field required (type=value_error.missing)')