In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import asyncio
import csv
from enum import Enum

from loguru import logger
from pydantic import BaseModel

from taskllm.optimizer.data import DataSet, Row
from taskllm.optimizer.methods import BanditTrainer, BayesianTrainer
from taskllm.optimizer.prompt.meta import PromptMode

# logger.remove()  # remove the old handler. Else, the old one will work (and continue printing DEBUG logs) along with the new handler added below'
# logger.add(sys.stdout, level="TRACE")  # add a new handler which has INFO as the default


class Ratings(Enum):
    ONE = "1"
    TWO = "2"
    THREE = "3"
    FOUR = "4"
    FIVE = "5"
    NA = "N/A"


class StarbucksReviewRating(BaseModel):
    rating: Ratings


def sentiment_scoring_function(
    row: Row[StarbucksReviewRating], output: StarbucksReviewRating | None
) -> float:
    if output is None:
        return -10
    if not row.expected_output:
        return 0
    logger.trace(f"Expected: {row.expected_output.rating}, Output: {output.rating}")
    if row.expected_output.rating == output.rating:
        return 1

    return 0


def load_file_as_dataset(path: str) -> DataSet:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            rating = StarbucksReviewRating(rating=Ratings(row["Rating"]))
            rows.append(
                Row.create(
                    input_dictionary={
                        "review": row["Review"],
                        "name": row["name"],
                        "location": row["location"],
                        "date": row["Date"],
                    },
                    output=rating,
                )
            )
    return DataSet(rows=rows[:50], name="starbucks_reviews")



  from .autonotebook import tqdm as notebook_tqdm


In [5]:
csv_path = "./starbucks_reviews.csv"
dataset = load_file_as_dataset(csv_path)
trainer = BayesianTrainer(
    all_rows=dataset,
    task_guidance="determine the rating of this review",
    keys=["review", "name", "location", "date"],
    expected_output_type=StarbucksReviewRating,
    scoring_function=sentiment_scoring_function,
    num_iterations=2,  # Start with fewer iterations for testing
    candidates_per_iteration=2,  # Start with fewer candidates for testing
    prompt_mode=PromptMode.SIMPLE,
    models=[
        "anthropic/claude-3-haiku-20240307",
        "openai/gpt-4.1-nano-2025-04-14",
        "openai/gpt-4.1-mini-2025-04-14",
        "groq/gemma2-9b-it",
        "groq/qwen-qwq-32b"
    ],
)

[32m2025-05-12 20:42:15.960[0m | [1mINFO    [0m | [36mtaskllm.optimizer.methods.bayesian[0m:[36m__init__[0m:[36m170[0m - [1mUsing CPU for Pyro/Torch computations[0m
[32m2025-05-12 20:42:15.960[0m | [1mINFO    [0m | [36mtaskllm.optimizer.methods.base[0m:[36m__init__[0m:[36m236[0m - [1mAll rows: 50[0m
[32m2025-05-12 20:42:15.960[0m | [34m[1mDEBUG   [0m | [36mtaskllm.optimizer.methods.bayesian[0m:[36m__init__[0m:[36m683[0m - [34m[1mBayesianTrainer initialized[0m


In [4]:
await trainer.train()

[32m2025-05-12 20:25:44.708[0m | [1mINFO    [0m | [36mtaskllm.optimizer.methods.bayesian[0m:[36mtrain[0m:[36m687[0m - [1mStarting Bayesian optimization with Pyro: 2 iterations, 2 candidates/iter.[0m
[32m2025-05-12 20:25:44.709[0m | [1mINFO    [0m | [36mtaskllm.optimizer.methods.bayesian[0m:[36mtrain[0m:[36m693[0m - [1mPhase 1: Generating and evaluating initial candidates...[0m
[32m2025-05-12 20:25:44.709[0m | [1mINFO    [0m | [36mtaskllm.optimizer.prompt.meta[0m:[36mgenerate_spec[0m:[36m294[0m - [1mGenerating prompt content for: determine the rating of this review

Use plain language in the prompt you write[0m
[32m2025-05-12 20:25:44.709[0m | [1mINFO    [0m | [36mtaskllm.optimizer.prompt.meta[0m:[36mgenerate_spec[0m:[36m294[0m - [1mGenerating prompt content for: determine the rating of this review

Make the prompt you write as simple as possible[0m
[32m2025-05-12 20:25:59.605[0m | [1mINFO    [0m | [36mtaskllm.optimizer.methods.bayesi


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m



[32m2025-05-12 20:27:55.657[0m | [1mINFO    [0m | [36mtaskllm.optimizer.methods.bayesian[0m:[36mtrain[0m:[36m727[0m - [1mFitting initial surrogate model[0m
[32m2025-05-12 20:27:55.659[0m | [32m[1mSUCCESS [0m | [36mtaskllm.optimizer.methods.bayesian[0m:[36mfit_surrogate_model[0m:[36m229[0m - [32m[1mExtracted features for 2 prompts[0m
[32m2025-05-12 20:27:55.660[0m | [34m[1mDEBUG   [0m | [36mtaskllm.optimizer.methods.base[0m:[36mget_outputs[0m:[36m49[0m - [34m[1mGetting outputs for 40 rows using anthropic/claude-3-haiku-20240307[0m
[32m2025-05-12 20:27:55.661[0m | [34m[1mDEBUG   [0m | [36mtaskllm.optimizer.methods.base[0m:[36mget_outputs[0m:[36m49[0m - [34m[1mGetting outputs for 40 rows using groq/qwen-qwq-32b[0m
[32m2025-05-12 20:27:55.709[0m | [34m[1mDEBUG   [0m | [36mtaskllm.optimizer.methods.base[0m:[36mget_scores[0m:[36m66[0m - [34m[1mCalculated 40 scores[0m
[32m2025-05-12 20:27:55.710[0m | [1mINFO    [0m | [36


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m



[32m2025-05-12 20:28:42.985[0m | [1mINFO    [0m | [36mtaskllm.optimizer.methods.bayesian[0m:[36mtrain[0m:[36m774[0m - [1mPrompt content: 1. Read the review carefully.  
2. Determine the review's tone: look for words indicating positive, ...[0m
[32m2025-05-12 20:28:42.986[0m | [1mINFO    [0m | [36mtaskllm.optimizer.methods.bayesian[0m:[36mtrain[0m:[36m774[0m - [1mPrompt content: 1. Carefully analyze the supplied review text.  
2. Assess the predominant sentiment of the review b...[0m
[32m2025-05-12 20:28:42.987[0m | [1mINFO    [0m | [36mtaskllm.optimizer.methods.bayesian[0m:[36mtrain[0m:[36m779[0m - [1mUpdating surrogate model with new data[0m
[32m2025-05-12 20:28:42.989[0m | [32m[1mSUCCESS [0m | [36mtaskllm.optimizer.methods.bayesian[0m:[36mfit_surrogate_model[0m:[36m229[0m - [32m[1mExtracted features for 4 prompts[0m
[32m2025-05-12 20:28:42.990[0m | [34m[1mDEBUG   [0m | [36mtaskllm.optimizer.methods.base[0m:[36mget_outputs[0m:


[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new[0m
LiteLLM.Info: If you need to debug this error, use `litellm._turn_on_debug()'.


[1;31mProvider List: https://docs.litellm.ai/docs/providers[0m



[32m2025-05-12 20:29:13.024[0m | [34m[1mDEBUG   [0m | [36mtaskllm.optimizer.methods.base[0m:[36mget_scores[0m:[36m66[0m - [34m[1mCalculated 10 scores[0m
[32m2025-05-12 20:29:13.025[0m | [1mINFO    [0m | [36mtaskllm.optimizer.methods.base[0m:[36mcalculate_scores[0m:[36m91[0m - [1mModel ModelsEnum.LLAMA_3_8B with prompt achieved score: 3.0000, Correct: 3, Incorrect: 7, Unlabelled: 0 out of 10[0m
[32m2025-05-12 20:29:13.026[0m | [34m[1mDEBUG   [0m | [36mtaskllm.optimizer.methods.base[0m:[36mget_scores[0m:[36m66[0m - [34m[1mCalculated 10 scores[0m
[32m2025-05-12 20:29:13.027[0m | [1mINFO    [0m | [36mtaskllm.optimizer.methods.base[0m:[36mcalculate_scores[0m:[36m91[0m - [1mModel ModelsEnum.LLAMA_3_8B with prompt achieved score: 5.0000, Correct: 5, Incorrect: 5, Unlabelled: 0 out of 10[0m
[32m2025-05-12 20:29:13.029[0m | [34m[1mDEBUG   [0m | [36mtaskllm.optimizer.methods.base[0m:[36mget_outputs[0m:[36m49[0m - [34m[1mGetting out