In [1]:
from logikon.backends.chat_models_with_grammar import create_logits_model

In [2]:
import dotenv
import os

dotenv.load_dotenv("../.env")
"HUGGINGFACEHUB_API_TOKEN" in os.environ or print("Please set HUGGINGFACEHUB_API_TOKEN in .env")

True

In [10]:
kwargs = {
    #"model_id": "meta-llama/Meta-Llama-3.1-70B-Instruct",
    #"inference_server_url": "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3.1-70B-Instruct",
    "model_id": "HuggingFaceH4/zephyr-7b-beta",
    "inference_server_url": "https://px0zqc1h7zw38b0b.us-east-1.aws.endpoints.huggingface.cloud",
    "llm_backend": "HFChat",
    "api_key": os.environ.get('HUGGINGFACEHUB_API_TOKEN'),
    "temperature": 0.7,
}
hf_chat = create_logits_model(**kwargs)


## Test Logits

In [11]:
from langchain_core.messages import HumanMessage
res = await hf_chat.get_labelprobs(
    [HumanMessage(content="What is the capital of Canada?\n(A) Paris\n(B) Lyon\n(C) Quebec")],
    labels=["A", "B", "C"],
    top_logprobs=5
)

In [12]:
res

{'A': 0.043252036708846336, 'B': 0.11174987599215074, 'C': 0.8449980872990028}

## Test Grammar

In [13]:
# REGEX

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser


prompt = ChatPromptTemplate.from_template(
    "What is the capital of {country}?"
)
regex = r"(Washington|Ontario|London)"
gen_args = {
    "temperature": 0.5,
    "regex": regex,
}
chain = (
    prompt
    | hf_chat.bind(**gen_args).with_retry()
    | StrOutputParser()
)
# fmt: on

inputs = [
    {"country": "Canada"},
    {"country": "France"},
]

results = chain.batch(inputs)

results

['Ontario', 'London']

In [9]:
# JSON

from pydantic import BaseModel

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

class CapitalModel(BaseModel):
    country: str
    capital: str    

prompt = ChatPromptTemplate.from_template(
    "What is the capital of {country}?"
)
guided_json = CapitalModel.model_json_schema()
print(guided_json)
gen_args = {"temperature": 0.4, "json_schema": guided_json}
chain = (
    prompt
    | hf_chat.bind(**gen_args).with_retry()
    | StrOutputParser()
)
# fmt: on

inputs = [
    {"country": "Canada"},
    {"country": "France"},
]

results = chain.batch(inputs)

results


{'properties': {'country': {'title': 'Country', 'type': 'string'}, 'capital': {'title': 'Capital', 'type': 'string'}}, 'required': ['country', 'capital'], 'title': 'CapitalModel', 'type': 'object'}


['{  \n  "capital": "Paris"\n , \n  "country": "France"\n}',
 '{  \n  "capital": "Paris"\n , \n  "country": "France"\n}']