In [1]:
import os
import getpass

os.environ['OPENAI_API_KEY'] = getpass.getpass('Enter your OpenAI API key: ')

In [2]:
os.environ['OPENAI_BASE_URL'] = 'https://api.together.xyz/v1'

In [33]:
from semantix import enhance
from semantix.llms import OpenAI
from semantix.utils import create_enum


llm = OpenAI(model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", verbose=True)

multilabel_classes = ['lists_createoradd', 'calendar_query', 'email_sendemail', 'news_query',
 'play_music', 'play_radio', 'qa_maths', 'email_query','weather_query',
 'calendar_set','iot_hue_lightdim','takeaway_query','social_post'
 'email_querycontact','qa_factoid','calendar_remove','cooking_recipe',
 'lists_query','general_quirky','alarm_query','takeaway_order',
 'iot_hue_lightup','lists_remove','qa_currency','play_game',
 'play_audiobook','qa_definition','music_query','datetime_query',
 'transport_query','iot_hue_lightoff','iot_hue_lightchange',
 'iot_hue_lighton','alarm_set','music_likeness','recommendation_movies',
 'transport_ticket','recommendation_locations','audio_volume_mute',
 'iot_wemo_on','play_podcasts','datetime_convert','audio_volume_other',
 'recommendation_events','alarm_remove','iot_coffee','music_dislikeness',
 'general_joke','social_query']

Label = create_enum("Label", {name: name for name in multilabel_classes}, "Multilabel Classes")

@enhance("", llm)
def classify(text: str) -> list[Label]: ... # type: ignore

In [34]:
classify(text="Play some music and turn on the lights")

[32m2024-10-06 22:49:00.968[0m | [1mINFO    [0m | [36msemantix.llms.base[0m:[36m__call__[0m:[36m254[0m - [1mModel Input
# Goal:  (classify)
## Output Type Definition
- list[Label]
## Type Definitions
- Multilabel Classes (Label) (Enum) -> Label.lists_createoradd, Label.calendar_query, Label.email_sendemail, Label.news_query, Label.play_music, Label.play_radio, Label.qa_maths, Label.email_query, Label.weather_query, Label.calendar_set, Label.iot_hue_lightdim, Label.takeaway_query, Label.social_postemail_querycontact, Label.qa_factoid, Label.calendar_remove, Label.cooking_recipe, Label.lists_query, Label.general_quirky, Label.alarm_query, Label.takeaway_order, Label.iot_hue_lightup, Label.lists_remove, Label.qa_currency, Label.play_game, Label.play_audiobook, Label.qa_definition, Label.music_query, Label.datetime_query, Label.transport_query, Label.iot_hue_lightoff, Label.iot_hue_lightchange, Label.iot_hue_lighton, Label.alarm_set, Label.music_likeness, Label.recommendation_mo

[<Label.play_music: 'play_music'>, <Label.iot_hue_lighton: 'iot_hue_lighton'>]

In [35]:
def validate_output(output, type_hint):
    # recursively validate the output
    if isinstance(type_hint, str):
        type_hint = eval(type_hint)
    if isinstance(type_hint, list):
        assert isinstance(output, list), f"Expected list, got {type(output)}"
        assert len(type_hint) == 1, f"Expected list of length 1, got {len(type_hint)}"
        for item in output:
            validate_output(item, type_hint[0])
    elif hasattr(type_hint, '__args__'):
        if type_hint.__origin__ == list:
            assert isinstance(output, list), f"Expected list, got {type(output)}"
            for item in output:
                validate_output(item, type_hint.__args__[0])
    else:
        assert isinstance(output, type_hint), f"Expected {type_hint}, got {type(output)}"

type_hint = "list[Label]"

validate_output(classify(text="Play some music and turn on the lights"), type_hint)

[32m2024-10-06 22:49:04.737[0m | [1mINFO    [0m | [36msemantix.llms.base[0m:[36m__call__[0m:[36m254[0m - [1mModel Input
# Goal:  (classify)
## Output Type Definition
- list[Label]
## Type Definitions
- Multilabel Classes (Label) (Enum) -> Label.lists_createoradd, Label.calendar_query, Label.email_sendemail, Label.news_query, Label.play_music, Label.play_radio, Label.qa_maths, Label.email_query, Label.weather_query, Label.calendar_set, Label.iot_hue_lightdim, Label.takeaway_query, Label.social_postemail_querycontact, Label.qa_factoid, Label.calendar_remove, Label.cooking_recipe, Label.lists_query, Label.general_quirky, Label.alarm_query, Label.takeaway_order, Label.iot_hue_lightup, Label.lists_remove, Label.qa_currency, Label.play_game, Label.play_audiobook, Label.qa_definition, Label.music_query, Label.datetime_query, Label.transport_query, Label.iot_hue_lightoff, Label.iot_hue_lightchange, Label.iot_hue_lighton, Label.alarm_set, Label.music_likeness, Label.recommendation_mo

In [33]:
from openai import OpenAI
import instructor
from pydantic import BaseModel
from enum import Enum

instructor_client = instructor.patch(OpenAI())

class MultiLabelClassification(BaseModel):
    classes: list[
        Enum("MultilabelClasses", {name: name for name in multilabel_classes})
    ]

text = "Play some music and turn on the lights"
response = instructor_client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
    response_model=MultiLabelClassification,
    messages=[{"role": "user", "content": f"Classify the following text: {text}"}],
)

response

MultiLabelClassification(classes=[<MultilabelClasses.play_music: 'play_music'>, <MultilabelClasses.iot_hue_lighton: 'iot_hue_lighton'>])