In [None]:
import anndata
import pandas as pd
from pathlib import Path
import json
import shortuuid

In [None]:
# Top genes
top_genes = pd.concat([pd.read_parquet(fn) for fn in snakemake.input.top_genes])

In [None]:
# Annotations
annots_combined = []
for fn in snakemake.input.full_data:  # ["archs4_geo", "cellxgene_census"]
    adata = anndata.read_h5ad(fn, backed="r")
    try:
        annots = adata.obsm["natural_language_annotation_replicates"]["1"]
    except KeyError:
        annots = adata.obs["natural_language_annotation"]
    annots_combined.append(annots)
annots_combined = pd.concat(annots_combined)

In [None]:
reference_information = {
    sample_id: snakemake.params.request_template.format(
        annotation=annots_combined.loc[sample_id],
        top_genes=", ".join(
            top_genes.loc[sample_id].iloc[: snakemake.params.top_n_genes].values
        ),
    )
    for sample_id in top_genes.index
}

In [None]:

with open(snakemake.input.evaluation_dataset) as f:
    data = json.load(f)

In [None]:
structured_questions = []
with open(snakemake.output.formatted_questions, "w") as qf, open(
    snakemake.output.formatted_questions_with_top_genes, "w"
) as qf_topgenes, open(
    snakemake.output.formatted_questions_text_only, "w"
) as qf_textonly, open(
    snakemake.output.reference_responses, "w"
) as rf:
    for i, d in enumerate(data):
        question_id = f"{i+1}_{d['id']}"
        question = (
            d["conversations"][0]["value"].replace("<image>", "").strip("\n")
        )  # stripping to adhere to llava's codebase
        reference_info_i = reference_information[d["image"]]
        sample_annot = annots_combined.loc[d["image"]]
        top_genes_i = ", ".join(
            top_genes.loc[d["image"]].iloc[: snakemake.params.top_n_genes].values
        )

        json.dump(
            {
                "question_id": question_id,
                "reference": reference_info_i,
                "text": question,
                "image": d["image"],
            },
            qf,
        )
        qf.write("\n")

        json.dump(
            {
                "question_id": question_id,
                "reference": reference_info_i,
                "text": [
                    snakemake.params.instruction_prompt,
                    snakemake.params.instruction_response.format(top_genes=top_genes_i),
                    question,
                ],
                "image": d["image"],
            },
            qf_topgenes,
        )
        qf_topgenes.write("\n")

        json.dump(
            {
                "question_id": question_id,
                "reference": reference_info_i,
                "text": snakemake.params.instruction_prompt_text_only.format(
                    top_genes=top_genes_i, question=question
                ),
                "image": d["image"],
            },
            qf_textonly,
        )
        qf_textonly.write("\n")

        json.dump(
            {
                "question_id": question_id,
                "text": d["conversations"][1]["value"],
                "answer_id": shortuuid.uuid(),
                "model_id": "gpt-4_with_input_text_and_curation",
                "metadata": {},
            },
            rf,
        )
        rf.write("\n")