In [None]:
import anndata
import pandas as pd
from pathlib import Path

llava_annotations=pd.read_csv(snakemake.input.cellwhisperer_labels)


In [None]:
# snakemake.input.model = '/msc/home/mschae83/text-generation-webui/models/mixtral-8x7b-instruct-v0.1.Q5_K_M.gguf'

In [None]:
from llama_cpp import (
    Llama,
    LlamaGrammar,
)

# load the model
llm = Llama(
    model_path=snakemake.input.model,
    n_ctx=2048,  # The max sequence length to use - note that longer sequence lengths require more resources
    n_threads=snakemake.threads,  # The number of CPU threads to use
    n_threads_batch=snakemake.threads,
    n_gpu_layers=40,  # 1 layers corresponds to  ~1GB VRAM 
)



In [None]:
llava_annotations

In [None]:
results = {}
for idx, row in llava_annotations.iterrows():
    if not isinstance(row["cluster_annotations"], str):
        response = "No label"
    else:
        output = llm(
            f"[INST] {snakemake.params.request}\n{row['cluster_annotations']} [/INST]",
            max_tokens=1024,  # for training, we only use a max of 128. observe whether this matches..
            stop=["</s>"],  # stop token for Mixtral
            logit_bias={
                llm.tokenizer().encode("\n")[-1]: float("-inf"),
                llm.tokenizer().encode('"')[-1]: float("-inf")
            },  # Prevent newlines and quotes
            echo=False,  # don't echo the prompt as part of the response
            seed=42,
            temperature=0.2,
            top_p=0.9,
            top_k=50,
        )
        response = output["choices"][0]["text"].strip()
    print(response)
    results[idx] = response

In [None]:
llava_annotations["curated_labels"] = results

In [None]:
llava_annotations.to_csv(snakemake.output.curated_labels, index=False)
