In [None]:
import torch
torch.set_num_threads(snakemake.threads)

import json
import numpy as np
import random
import pandas as pd
import matplotlib.pyplot as plt 
import anndata
import re
import logging

from llava.mm_utils import get_model_name_from_path
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava import conversation as conversation_lib



from llava.train.train import LazySupervisedDataset, DataArguments, DataCollatorForSupervisedDataset
from transformers import Trainer, EvalPrediction
import transformers
from torch.nn import CrossEntropyLoss
import os



In [None]:
model_dir = snakemake.input.llava_model # "/msc/home/mschae83/cellwhisperer/results/llava/finetuned/Mistral-7B-Instruct-v0.2__03jujd8s/"

In [None]:
model_dir

In [None]:
# Load the model

model_name=get_model_name_from_path(model_dir)
assert "mistral" in model_name.lower() and "__" in model_name, "sure that you are not using a mistral model? LLaVA depends on having it in the name (if it is mistral)"

logging.info(f"Loading model {model_name}")
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_dir, model_base=None, model_name=model_name, device="cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Model loaded")


In [None]:
logging.info(f"Loading data")
adata = anndata.read_h5ad(snakemake.input.embedding_adata)
logging.info(f"Data loaded")

In [None]:
logging.info(f"Loading read count data")
read_count_adata = anndata.read_h5ad(snakemake.input.read_count_adata, backed="r")
logging.info(f"Read count data loaded")

In [None]:
assert (adata.obs.orig_ids == read_count_adata.obs.index).all()

In [None]:
adata.obs["index_int"] = list(range(len(adata.obs)))


In [None]:
(adata.obs.leiden == "0").sum()

In [None]:
conversation_lib.default_conversation = conversation_lib.conv_templates["mistral_instruct"]



In [None]:
def cluster_annotation(cluster_values: np.array) -> pd.Series:
    prompt = snakemake.params["request"]
    grouped_embeddings = adata.obs.groupby(cluster_values, observed=True).apply(lambda group: adata.X[group.index_int].mean(axis=0), include_groups=False)
    cluster_labels = {}
    for cluster_id, grouped_embedding in grouped_embeddings.items():
        images = torch.tensor(grouped_embedding, device=model.device, dtype=torch.bfloat16)  # float16 is the way in llava apparently
        
        if "image" not in prompt:
            prompt = prompt + "\n<image>"
        # Should be a noop, but kept for compatibility
        replace_token = DEFAULT_IMAGE_TOKEN
        if getattr(model.config, 'mm_use_im_start_end', False): 
            replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
        prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
         
        image_args = {"images": images, "image_sizes": None}
        logging.info(f"Generating input_ids for cluster {cluster_id}")
        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
        logging.info(f"Generated input_ids cluster {cluster_id}")
        logging.info(f"Generating for cluster {cluster_id}")
        generated_tokens = model.generate(
            inputs=input_ids,
            do_sample=False,
            temperature=0.0,
            num_beams=snakemake.params.num_beams,
            top_p=1.0,
            max_new_tokens=256,
            use_cache=True,
            pad_token_id=tokenizer.eos_token_id,  # explicitly request open-ended generation (suppresses warnings)
            **image_args
        )
        logging.info(f"Generated for cluster {cluster_id}")
        generated_text = tokenizer.decode(generated_tokens[0], skip_prompt=True, skip_special_tokens=True, timeout=15)
        print(generated_text[:80] + "...")
        cluster_labels[cluster_id] = generated_text
    out = pd.Series(cluster_labels)
    return out

In [None]:
dfs = []
cluster_series = {"leiden": adata.obs["leiden"].values}
try:
    for cluster_field in read_count_adata.uns["cluster_fields"]:
        cluster_series[cluster_field] = read_count_adata.obs[cluster_field].values
except KeyError:
    pass

logging.info(f"Starting cluster annotation")
for cluster_field, cluster_values in cluster_series.items():
    logging.info(f"Starting cluster annotation for {cluster_field}")
    out = cluster_annotation(cluster_values)
    logging.info(f"Finished cluster annotation for {cluster_field}")
    out.index.name = "cluster_values"
    out.name = "cluster_annotations"
    out = out.to_frame()
    out["cluster_field"] = cluster_field
    dfs.append(out)

In [None]:
pd.concat(dfs).to_csv(snakemake.output.csv)