In [None]:
# start coding here

In [None]:
# LLaVA imports 

from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    get_model_name_from_path,
    KeywordsStoppingCriteria,
    tokenizer_image_token
)

# Other imports

import anndata
import matplotlib.pyplot as plt
import torch 


import requests
from io import BytesIO
import re

In [None]:
# load data
adata = anndata.read_h5ad(snakemake.input.adata)
adata.obs["cluster_label"] = None

In [None]:
# Get the unique leiden clusters
leiden_clusters = adata.obs['leiden'].unique()


In [None]:
# Load model
disable_torch_init()  # TODO test loading with and without to see speed benefit. Preferably get rid of it. With considerable speed up, use in my own inference?

model_name = get_model_name_from_path(snakemake.input.llava_model)

# set `base_model` for projector-only loading
tokenizer, model, image_processor, context_len = load_pretrained_model(
    snakemake.input.llava_model, model_base=None, model_name=model_name
)

In [None]:
def _infer_conv_mode(model_name):
    if "llama-2" in model_name.lower():
        return "llava_llama_2"
    elif "v1" in model_name.lower():
        return "llava_v1"
    elif "mpt" in model_name.lower():
        return "mpt"
    else:
        return "llava_v0"


def prepare_conv(qs):
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    if IMAGE_PLACEHOLDER in qs:
        if model.config.mm_use_im_start_end:
            qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
        else:
            qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
    else:
        if model.config.mm_use_im_start_end:
            qs = image_token_se + "\n" + qs
        else:
            qs = DEFAULT_IMAGE_TOKEN + "\n" + qs  # NOTE this one is being executed

    conv_mode = _infer_conv_mode(model_name)
    conv = conv_templates[conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    return conv


conv = prepare_conv("What does the sample describe?")
prompt = conv.get_prompt()
prompt

In [None]:
def features_for_cluster(cluster):
    
    # Clustering and mean over embeddings (alternative: mean over expression)
    mean_cluster_embedding = adata.X[adata.obs['leiden'] == cluster].mean(axis=0)
    
    # TODO is float16 correct?
    transcriptomes_tensor = torch.from_numpy(mean_cluster_embedding).to(
        model.device, dtype=torch.float16
    )
    # Prepare model inputs and inference hyperparameters
    input_ids = (
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .cuda()
    )

    return input_ids, transcriptomes_tensor
input_ids, transcriptomes_tensor = features_for_cluster(leiden_clusters[0])

In [None]:
temperature = 0
num_beams = snakemake.params.num_beams
top_p = 1.0
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]

def predict(input_ids, transcriptomes_tensor):
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=transcriptomes_tensor.unsqueeze(0),
            do_sample=True if temperature > 0 else False,
            temperature=temperature,
            top_p=top_p,
            num_beams=num_beams,
            max_new_tokens=200,
            use_cache=True,
            stopping_criteria=[stopping_criteria],
        )
    
    input_token_len = input_ids.shape[1]
    n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
    if n_diff_input_output > 0:
        print(
            f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
        )
    outputs = tokenizer.batch_decode(
        output_ids[:, input_token_len:], skip_special_tokens=True
    )[0]
    outputs = outputs.strip()
    if outputs.endswith(stop_str):
        outputs = outputs[: -len(stop_str)]
    outputs = outputs.strip()
    return outputs

predict(input_ids, transcriptomes_tensor)

In [None]:
# Iterate over each cluster
for cluster in leiden_clusters:
    input_ids, transcriptomes_tensor = features_for_cluster(cluster)
    outputs = predict(input_ids, transcriptomes_tensor)
    adata.obs.loc[adata.obs.leiden == cluster, "cluster_label"] = outputs

In [None]:
adata.write_h5ad(snakemake.output.adata)