In [None]:
import openai
import tqdm
import json
import shortuuid


if "claude" in snakemake.params.model.lower():
    client = anthropic.Anthropic(api_key=snakemake.params.api_key, max_retries=5)
else:
    client = openai.OpenAI(
        api_key=snakemake.params.api_key,
        base_url=snakemake.params.api_base_url,
    )


def gpt4_response(question, max_tokens: int = 1024, **kwargs):
    messages = [{"role": "user", "content": question}]
    # print(messages)
    for _ in range(3):
        try:
            response = client.chat.completions.create(
                model=snakemake.params.model,
                messages=messages,
                max_tokens=max_tokens,
                temperature=0.0,
            )
            return response.choices[0].message.content
        except Exception as e:
            print(e)
    else:
        raise RuntimeError()

In [None]:
with open(snakemake.input.transcriptome_text_features) as f:
    question_informations = [json.loads(line) for line in f.readlines()]

In [None]:
question_informations[0]

In [None]:
responses = [gpt4_response(info["text"]) for info in tqdm.tqdm(question_informations)]

In [None]:
responses[0]

In [None]:
with open(snakemake.output.response, "w") as f:
    for response, info in zip(responses, question_informations):
        json.dump(
            {
                "question_id": info["question_id"],
                "text": response,
                "answer_id": shortuuid.uuid(),
                "model_id": snakemake.params.model,
                "metadata": {},
            },
            f,
        )
        f.write("\n")

In [None]:
response