In [None]:

import json
from zero_shot_validation_scripts.dataset_preparation import load_and_preprocess_dataset
from pathlib import Path

adata = load_and_preprocess_dataset(
    snakemake.wildcards.dataset, snakemake.input.dataset
)

In [None]:
adata.obs["celltype"].unique()

In [None]:
# generating the 'true' conversations
def row_to_conversation(row):
    return {
        "id": row.name,
        "image": row.name,
        "conversations": [
            {"from": "human", "value": f"{snakemake.params.question}\n<image>"},
            {
                "from": "gpt",
                "value": snakemake.params.response_prefix + row["celltype"],
            },
        ],
    }


conversations = adata.obs.sample(frac=1, random_state=42)

if snakemake.params.num_cells_per_celltype:
    conversations = conversations.groupby("celltype").head(
        snakemake.params.num_cells_per_celltype
    )

conversations = conversations.apply(row_to_conversation, axis=1).values.tolist()
conversations[5]

In [None]:


with open(snakemake.output._default, "w") as f:
    json.dump(conversations, f)

In [None]:

# For _celltype, we set the celltype responses lowercase

conversations = adata.obs.sample(frac=1, random_state=42)


def row_to_conversation_celltype(row):
    return {
        "id": row.name,
        "image": row.name,
        "conversations": [
            {"from": "human", "value": f"{snakemake.params.question}\n<image>"},
            {
                "from": "gpt",
                "value": snakemake.params.response_prefix + row["celltype"].lower(),
            },
        ],
    }


if snakemake.params.num_cells_per_celltype:
    conversations = conversations.groupby(dataset_processor.celltype_obs_colname).head(
        snakemake.params.num_cells_per_celltype
    )

conversations = conversations.apply(
    row_to_conversation_celltype, axis=1
).values.tolist()

with open(snakemake.output._celltype, "w") as f:
    json.dump(conversations, f)

conversations[0]

In [None]:

# For _top50genescelltype, we reuse the conversations, but delete their "image" fields

for i, conversation in enumerate(conversations):
    del conversation["image"]

assert "image" not in conversations[0]

with open(snakemake.output._top50genescelltype, "w") as f:
    json.dump(conversations, f)

conversations[0]