In [None]:
import json
import seaborn as sns
import numpy as np
import random
import pandas as pd
import matplotlib.pyplot as plt 
import anndata
import re
import logging

from llava.mm_utils import get_model_name_from_path
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava import conversation as conversation_lib



from llava.train.train import LazySupervisedDataset, DataArguments, DataCollatorForSupervisedDataset
import torch
from transformers import Trainer, EvalPrediction
import transformers
from torch.nn import CrossEntropyLoss

model_dir = snakemake.input.llava_model # "/msc/home/mschae83/cellwhisperer/results/llava/finetuned/Mistral-7B-Instruct-v0.2__03jujd8s/"

In [None]:
model_dir

In [None]:
# Load the model

model_name=get_model_name_from_path(model_dir)
assert "mistral" in model_name.lower() and "__" in model_name, "sure that you are not using a mistral model? LLaVA depends on having it in the name (if it is mistral)"

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_dir, model_base=None, model_name=model_name, load_8bit=False, load_4bit=False, device="cuda", use_flash_attn=False)

In [None]:
adata = anndata.read_h5ad(snakemake.input.adata)

In [None]:
adata.obs["index_int"] = list(range(len(adata.obs)))
grouped_embeddings = adata.obs.groupby("leiden", observed=True).apply(lambda group: adata.X[group.index_int].mean(axis=0), include_groups=False)

In [None]:
(adata.obs.leiden == "0").sum()

In [None]:
grouped_embeddings

In [None]:
conversation_lib.default_conversation = conversation_lib.conv_templates["mistral_instruct"]

prompt = snakemake.params["request"]

In [None]:
cluster_labels = {}
for leiden_id, grouped_embedding in grouped_embeddings.items():
    images = torch.tensor(grouped_embedding, device=model.device, dtype=torch.float16)  # float16 is the way in llava apparently
    
    if "image" not in prompt:
        prompt = prompt + "\n<image>"
    # Should be a noop, but kept for compatibility
    replace_token = DEFAULT_IMAGE_TOKEN
    if getattr(model.config, 'mm_use_im_start_end', False): 
        replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
    prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
     
    num_tokens = int(re.match(r'^mlp(\d+)x_(\d+)t_gelu$', model.config.mm_projector_type).group(2))
    num_image_tokens = prompt.count(replace_token) * num_tokens
    image_args = {"images": images, "image_sizes": None}
    
    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
    
    generated_tokens = model.generate(
        inputs=input_ids,
        do_sample=False,
        temperature=0.0,
        num_beams=snakemake.params.num_beams,
        top_p=1.0,
        max_new_tokens=256,
        use_cache=True,
        pad_token_id=tokenizer.eos_token_id,  # explicitly request open-ended generation (suppresses warnings)
        **image_args
    )
    generated_text = tokenizer.decode(generated_tokens[0], skip_prompt=True, skip_special_tokens=True, timeout=15)
    print(generated_text[:80] + "...")
    cluster_labels[leiden_id] = generated_text

In [None]:
out = pd.Series(cluster_labels)
out.index.name="leiden"
out.name="annotation"

out.to_csv(snakemake.output.csv)

In [None]:
out