In [None]:
# start coding here
from zero_shot_validation_scripts.dataset_preparation import load_and_preprocess_dataset
import pandas as pd
import logging
import openai

# Create a logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)  # Set the logging level
file_handler = logging.FileHandler(snakemake.log.log, mode="a")  # Append mode
logger.addHandler(file_handler)

In [None]:
marker_df = pd.read_excel(snakemake.input.cellmarker2_human)
marker_df.cell_name = marker_df.cell_name.apply(lambda v: v + "s")
marker_df.head()

In [None]:
logger.info("\n".join(marker_df.tissue_type.drop_duplicates()))

In [None]:
marker_df.tissue_class.drop_duplicates()

In [None]:
# Filter by organ etc (depending on dataset)
filters = {
    "immgen": (marker_df.tissue_type == "Blood")
    & (marker_df.cell_type == "Normal cell"),
    "pancreas": (marker_df.tissue_type == "Pancreas")
    & (marker_df.cell_type == "Normal cell")
    & (
        ~marker_df.PMID.fillna(-1).astype(int).isin(
            [27345837, 27693023, 27864352, 27667365, 27667665, 27667667]
        )
    ),  # filter out the evaluation datasets to avoid 'overfitting'
    "tabula_sapiens": marker_df.tissue_class.isin(
        [
            "Bladder",
            "Blood",
            "Bone marrow",
            "Eye",
            "Heart",
            "Kidney",
            "Liver",
            "Lung",
            "Lymph node",
            "Mammary gland",
            "Muscle",
            "Ovary",
            "Pancreas",
            "Prostate",
            "Salivary gland",
            "Skin",
            "Spleen",
            "Stomach",
            "Testis",
            "Thymus",
            "Tongue",
            "Trachea",
            "Uterus",
        ]
    )
    & (marker_df.cell_type == "Normal cell")
    & (~marker_df.PMID.fillna(-1).astype(int).isin([35549404])),  # there are actually no morkers in the CellMarker database from this dataset...
    "tabula_sapiens_well_studied_celltypes": marker_df.tissue_class.isin(
        [
            "Blood",
            "Liver",
            "Lung",
        ]
    )
    & (marker_df.cell_type == "Normal cell")
    & (~marker_df.PMID.fillna(-1).astype(int).isin([35549404])),  # there are actually no morkers in the CellMarker database from this dataset...
    "aida": (marker_df.tissue_class == "Blood") & (marker_df.cell_type == "Normal cell"),
}

marker_df = marker_df.loc[filters[snakemake.wildcards.dataset]]
marker_df.cell_name.head()

In [None]:
eval_adata = load_and_preprocess_dataset(
    dataset_name=snakemake.wildcards.dataset,
    read_count_table_path=snakemake.input.eval_data,
)

In [None]:
client = openai.OpenAI(
    api_key=snakemake.params.openai_api_key,
)

prompt = f"Assign the cell type '{{}}' to one of the following candidates: {', '.join(marker_df.cell_name.drop_duplicates().values)}.\n\n Only print the name of a single cell type, nothing else."

In [None]:
predictions = {}

for eval_cell_type in eval_adata.obs.celltype.drop_duplicates().values:
    if eval_cell_type in marker_df.cell_name.drop_duplicates().values:
        logger.info(
            f"Skipping {eval_cell_type} as it is already present in the marker dataset"
        )
        predictions[eval_cell_type] = eval_cell_type
        continue

    chat_completion = client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": prompt.format(eval_cell_type),
            }
        ],
        model=snakemake.params.model,
        temperature=0.0,
    )
    match = chat_completion.choices[0].message.content
    if match not in marker_df.cell_name.drop_duplicates().values:
        logger.info(
            f"Match for {eval_cell_type} was not in the candidates ({match}). Set to 'none'"
        )
        match = "none"
    else:
        logger.info(f"Match for {eval_cell_type} was {match}")

    predictions[eval_cell_type] = match
    file_handler.flush()

# %%%
filtered_marker_df = []
for eval_celltype, marker_celltype in predictions.items():
    sub = marker_df.loc[marker_df.cell_name == marker_celltype]
    sub["eval_cell_type"] = eval_celltype
    filtered_marker_df.append(sub)

filtered_marker_df = pd.concat(filtered_marker_df)

In [None]:
predictions

In [None]:
assigment_matrix = filtered_marker_df.pivot_table(
    index="marker", columns="eval_cell_type", aggfunc="size", fill_value=0
)
assigment_matrix = assigment_matrix.applymap(lambda x: 1 if x > 0 else 0)
assigment_matrix

In [None]:
assigment_matrix.to_csv(snakemake.output.prepared_markers)