In [None]:
# start coding here
import scanpy as sc
import numpy as np
import pickle
import pandas as pd
from tqdm.notebook import tqdm
import random

import openai
import logging
import pydantic

In [None]:
# Connect logging to file snakemake.log.progress
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.FileHandler(snakemake.log.progress), logging.StreamHandler()],
)

In [None]:
class VariantList(pydantic.BaseModel):
    variants: list[str]


model = snakemake.params.model
client = openai.OpenAI(api_key=snakemake.params.api_key)

In [None]:
# queries are passed as `params`
# queries = pd.read_csv(snakemake.input.queries)
queries = snakemake.params.queries

In [None]:

res = []

for query in queries:
    completion = client.beta.chat.completions.parse(
        model=snakemake.params.model,
        messages=[
            {"role": "system", "content": snakemake.params.system_prompt},
            {"role": "user", "content": query},
        ],
        response_format=VariantList,
        # temperature=snakemake.params.temperature,
        seed=snakemake.params.seed,
    )
    variants = completion.choices[0].message.parsed.variants
    if len(variants) != 5:  # hardcoded in the system_prompt
        logging.error(f"Generated wrong number of variants: {len(variants)}")
    res.append(
        {
            "query": query,
            "variant": variants,
        }
    )
    logging.info(f"Query: {query}, Variants: {variants}")

df = pd.DataFrame(res).explode("variant")
df

In [None]:

df.to_csv(snakemake.output.query_variants)