In [None]:
# start coding here

import openai
import anndata
import pandas as pd

from zero_shot_validation_scripts.dataset_preparation import load_and_preprocess_dataset

In [None]:
train_adata = anndata.read_h5ad(snakemake.input.training_data, backed="r")
training_cell_types = train_adata.obs["cell_type"].cat.categories

eval_adata = load_and_preprocess_dataset(
    dataset_name=snakemake.wildcards.dataset,
    read_count_table_path=snakemake.input.eval_data,
)

In [None]:

prompt = f"Assign the cell type '{{}}' to one of the following candidates: {', '.join(eval_adata.obs.celltype.drop_duplicates().values)}.\n\n If there is no well-matching cell type present, assign none instead. Only print the name of a single cell type (or none if none of the candidates match), nothing else."
prompt

In [None]:
client = openai.OpenAI(
    api_key=snakemake.params.openai_api_key,
)

In [None]:
predictions = []

for training_cell_type in training_cell_types:
    if training_cell_type in eval_adata.obs.celltype.drop_duplicates().values:
        print(
            f"Skipping {training_cell_type} as it is already present in the evaluation dataset"
        )
        predictions.append(training_cell_type)
        continue

    chat_completion = client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": prompt.format(training_cell_type),
            }
        ],
        model=snakemake.params.model,
        temperature=0.0,
    )
    match = chat_completion.choices[0].message.content
    if match not in eval_adata.obs.celltype.drop_duplicates().values:
        print(
            f"Match for {training_cell_type} was not in the candidates ({match}). Set to 'none'"
        )
        match = "none"
    else:
        print(f"Match for {training_cell_type} was {match}")

    predictions.append(match)

In [None]:
df = pd.DataFrame(
    {"training_cell_type": training_cell_types, "evaluation_cell_type": predictions}
)

In [None]:
missing_eval_celltypes = set(eval_adata.obs.celltype) - set(df.evaluation_cell_type)

In [None]:
corrected_df = df.copy()
corrected_df.shape

In [None]:
missing_eval_celltypes = set(eval_adata.obs.celltype) - set(
    corrected_df.evaluation_cell_type
)

In [None]:
for eval_celltype in missing_eval_celltypes:
    # All training celltypes, except the ones that are assigned to exactly one evaluation_cell_type
    single_eval_types = (
        corrected_df["evaluation_cell_type"].value_counts().loc[lambda x: x == 1].index
    )
    possible_celltypes = training_cell_types.difference(
        corrected_df.set_index("evaluation_cell_type")
        .loc[single_eval_types, "training_cell_type"]
        .values
    )

    prompt = f"Assign the query cell type '{{}}' to one of the following candidates: {', '.join(possible_celltypes)}.\n\n Only print the name of a single cell type, nothing else, and don't just repeat the query cell type. Make sure to return one of the candidates"

    for temperature in [0.0, 0.5, 0.8, 0.8, 0.8, 0.8, 1.1, 1.4]:
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": prompt.format(eval_celltype),
                }
            ],
            model=snakemake.params.model,
            temperature=temperature,
        )
        match = chat_completion.choices[0].message.content
        if match not in possible_celltypes:
            print(
                f"Failed to match for {eval_celltype} was not in the candidates ({match}). Set to 'none'"
            )
        else:
            print(f"Match for {eval_celltype} was {match}")
            corrected_df.loc[
                corrected_df["training_cell_type"] == match, "evaluation_cell_type"
            ] = eval_celltype
            break
    else:
        print(f"no hope for {eval_celltype}")

In [None]:
corrected_df.evaluation_cell_type.loc[
    lambda x: x != "none"
]  # what does the loc do??? training_cell_type is never none...

In [None]:
corrected_df.to_csv(snakemake.output.transfered_labels, index=False)