In [None]:
import json
import pandas as pd
import anndata
from pathlib import Path
import logging

In [None]:
# Load GSVA
gsva_gene_sets = pd.read_parquet(snakemake.input.gsva).set_index("Unnamed: 0").drop(columns=["library"])

In [None]:
# Load annotations
with open(snakemake.input.processed_annotations) as f:
    annotations = json.load(f)

In [None]:
# Take n-th replicate
annotations = {k: v[snakemake.params.annotation_replicate] for k,v in annotations.items()}

In [None]:
# Load gene ranks
top_genes = pd.read_parquet(snakemake.input.top_genes)
top_genes.columns[top_genes.iloc[0] == "CD19"]  # the first sample is B cell cancer cell line and should exhibit CD19 overexpression

In [None]:
# Ensure sample overlap (GSVA is a subset)
assert len(set(top_genes.index) ^ set(annotations.keys())) == 0, "annotations and gene_ranks should be the same"
assert len(set(gsva_gene_sets.columns) - set(annotations.keys())) == 0, "All samples in GSVA should be present in annotations"

In [None]:
request_template = Path(snakemake.input.request_template).read_text()
def prep_request(annotation, top_gene_sets, top_genes, **kwargs):
    return {"role": "user", 
            "content": request_template.format(
                annotation=annotation, 
                top_gene_sets=", ".join(top_gene_sets[:snakemake.params.top_n_gene_sets]),
                top_genes=", ".join(top_genes[:snakemake.params.top_n_genes])
            )
           }

In [None]:
# Prepare the few shot prompts
few_shot_messages = []
few_shot_sample_ids = []
for prompt_file, response_file in zip(
        snakemake.input.few_shot_prompts, snakemake.input.few_shot_responses
):
    data = json.loads(Path(prompt_file).read_text())
    
    few_shot_messages.append(prep_request(**data))
    
    few_shot_messages.append({"role": "assistant", 
                              "content": json.dumps(json.loads(Path(response_file).read_text()))})  # passing the JSON content as string here, but without the newlines and indentation
    few_shot_sample_ids.append(data["sample_id"])

with open(snakemake.output.few_shot_block, "w") as f:
    json.dump(few_shot_messages, f)

In [None]:
def extract_sample_data(sample_id):
    return {
        "annotation": annotations[sample_id],
        "top_genes": top_genes.loc[sample_id].dropna().to_list(),
        "top_gene_sets": gsva_gene_sets[sample_id].sort_values(ascending=False).index.to_list()
    }

In [None]:
target_sample_ids = [s for s in gsva_gene_sets.columns if s not in few_shot_sample_ids][snakemake.params.start_from_num:]  # preserve order
target_sample_ids = list(set(gsva_gene_sets.columns) - set(few_shot_sample_ids))

for split_fn in snakemake.output.request_splits:
    split_i, split_n = map(int, Path(split_fn).stem.split('-of-'))
    split_i -= 1  # 0-indexing
    # take the i-th split from annotations:
    split_requests = {sample_id: prep_request(**extract_sample_data(sample_id)) for i, sample_id in enumerate(target_sample_ids) if i % split_n == split_i}

    # write the split to a file
    with open(split_fn, "w") as f:
        json.dump(split_requests, f)