In [None]:
import json
import random
import pandas as pd
import logging
import copy


from llava.mm_utils import get_model_name_from_path
from llava.model.builder import load_pretrained_model
from llava import conversation as conversation_lib

from snakemake.io import Namedlist as SnakemakeNamedlist

from llava.train.train import (
    LazySupervisedDataset,
    DataArguments,
    DataCollatorForSupervisedDataset,
)
import torch
from torch import nn
from transformers import Trainer, EvalPrediction
import transformers
from torch.nn import CrossEntropyLoss

# Setup logging with INFO level and file snakemake.log.log
logging.basicConfig(
    level=logging.DEBUG,
    handlers=[
        logging.FileHandler(snakemake.log.log),
        logging.StreamHandler(),
    ],
)
logger = logging.getLogger(__name__)

model_dir = (
    snakemake.input.llava_model
)  # "/msc/home/mschae83/cellwhisperer/results/llava/finetuned/Mistral-7B-Instruct-v0.2__03jujd8s/"
evaluation_dataset_fn = (
    snakemake.input.evaluation_dataset
)  # "/msc/home/mschae83/cellwhisperer/results/llava_evaluation_conversations.json"

In [None]:
with open(evaluation_dataset_fn) as f:
    eval_set = json.load(f)
eval_set[0]

In [None]:
if isinstance(snakemake.input.top_genes, SnakemakeNamedlist):
    top_genes = pd.concat(
        [
            pd.read_parquet(fn).iloc[:, : snakemake.params.top_n_genes]
            for fn in snakemake.input.top_genes
        ]
    )
else:
    top_genes = pd.read_parquet(snakemake.input.top_genes).iloc[:, : snakemake.params.top_n_genes]  # type: ignore [reportUndefinedVariable]
top_genes.head()
# .dropna()  # TODO necessary?

In [None]:
# Load the model

# TODO make sure load_pretrained_model is flexible enough.

model_name = get_model_name_from_path(model_dir)
assert (
    "mistral" in model_name.lower() or "llama" in model_name.lower()
), "sure that you are not using a mistral model? LLaVA depends on having it in the name (if it is mistral)"


if "__" not in model_name:
    logger.warning("'__' not in model_name. Could lead to unforseen consequences.")

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_dir,
    model_base=None,
    model_name=model_name,
    load_8bit=False,
    load_4bit=False,
    device="cuda",
    use_flash_attn=False,
)


logger.info(f"Loaded model {model_name} from {model_dir}")

In [None]:
if (
    tokenizer.pad_token is None
):  # needs to be set explicitly for some non-multimodal models
    if tokenizer.unk_token_id is None:
        tokenizer.add_special_tokens({"unk_token": "<unk>"})
        model.model.config.pad_token_id = tokenizer.unk_token_id

        orig_embed = model.model.embed_tokens
        model.model.embed_tokens = nn.Embedding(
            orig_embed.weight.shape[0] + 1,
            orig_embed.weight.shape[1],
            padding_idx=tokenizer.unk_token_id,
            dtype=model.model.embed_tokens.weight.dtype,
            device=model.model.embed_tokens.weight.device,
        )
        model.model.embed_tokens.weight.data[:-1] = orig_embed.weight.data
        nn.init.zeros_(model.model.embed_tokens.weight.data[-1:, :])

    tokenizer.pad_token_id = tokenizer.unk_token_id

In [None]:
conversation_lib.default_conversation = conversation_lib.conv_templates[
    "mistral_instruct" if "mistral" in model_name.lower() else "llama3_instruct"
]

eval_dataset = LazySupervisedDataset(
    evaluation_dataset_fn,
    tokenizer,
    DataArguments(
        image_data=snakemake.input.image_data,
        mm_vision_select_layer=snakemake.params.model_layer_selector,
    ),
)
assert (
    len(eval_dataset) > 0
), "Something is wrong with the input data: LazySupervisedDataset sees 0 data points."

logger.info(
    f"Loaded evaluation dataset with {len(eval_dataset)} data points. First one: {eval_dataset[0]}"
)

In [None]:
eval_dataset[0]

In [None]:
eval_dataset.list_data_dict[0]

In [None]:
for datapoint in eval_dataset.list_data_dict:
    if not snakemake.params.is_multimodal:
        datapoint["conversations"][0]["value"] = (
            datapoint["conversations"][0]["value"]
            .replace("\n<image>", "")
            .replace("<image>\n", "")
        )
    if snakemake.params.pre_prompt_topgenes:  # type: ignore [reportUndefinedVariable]
        pre_prompt = copy.deepcopy(snakemake.params.pre_prompt_topgenes)
        pre_prompt[0]["value"] = pre_prompt[0]["value"].format(
            ", ".join(top_genes.loc[datapoint["id"]])
        )
        datapoint["conversations"] = pre_prompt + datapoint["conversations"]  # type: ignore [reportUndefinedVariable]

eval_dataset.list_data_dict[0]

In [None]:
eval_dataset.list_data_dict[5]

In [None]:
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

In [None]:
data_collator([eval_dataset[5], eval_dataset[0]])

In [None]:
# Calculate loss with CrossEntropyLoss, which expects raw logits, not probabilities
loss_fct = CrossEntropyLoss(reduction="none")  # ignores -100 token implicitly


def _create_last_response_block_mask(labels):
    mask = []
    for i, row in enumerate(labels):
        row_mask = torch.zeros_like(row)
        indices = torch.nonzero(row == -100, as_tuple=False)
        max100 = indices.max()
        if max100 == len(row) - 1:
            in_block = True
            for j in range(len(row) - 1, 0, -1):
                if in_block:
                    if row[j] == -100:
                        continue
                    else:
                        in_block = False
                        row_mask[j] = 1
                else:
                    if row[j] == -100:
                        break
                    else:
                        row_mask[j] = 1
        else:
            row_mask[max100 + 1 :] = 1

        mask.append(row_mask)

    mask = torch.stack(mask).to(bool)
    return mask


def preprocess_logits_for_metrics(
    logits: torch.Tensor, labels: torch.Tensor
) -> torch.Tensor:

    if snakemake.params.is_multimodal:
        # num_projector_tokens is `1` if is_multimodal is False
        shift_logits = logits[
            ..., snakemake.params.num_projector_tokens - 1 : -1, :
        ].contiguous()
        shift_labels = labels[..., 1:].contiguous()
    else:
        assert -200 not in labels

        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

    attention_mask = _create_last_response_block_mask(shift_labels)

    # Only compute loss where attention_mask is true
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    loss = loss.view(shift_labels.size())

    # Apply the attention mask to exclude ignored indices from the loss calculation
    masked_loss = torch.where(attention_mask, loss, torch.tensor(0.0).to(loss.device))

    # Sum the loss per example and divide by the number of non-ignored tokens to get the loss per example
    example_losses = torch.sum(masked_loss, dim=1)
    example_lengths = torch.sum(attention_mask, dim=1)
    example_perplexities = torch.exp(example_losses / example_lengths)

    return example_perplexities

In [None]:
# Assuming you have already defined `model`, `tokenizer`, `eval_dataset`, and `data_collator`
trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=lambda eval_pred: {
        "perplexity": eval_pred.predictions.mean().item(),
        "all_perplexities": eval_pred.predictions,
    },
    args=transformers.TrainingArguments(
        report_to="none",
        output_dir="/tmp",
        eval_accumulation_steps=8,
        per_device_eval_batch_size=8,
    ),
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    # batch_eval_metrics=True
)
logger.info("Starting evaluation (correct)")
# Evaluate the model
correct_results = trainer.evaluate()

correct_results["response"] = [
    conv["conversations"][-1]["value"] for conv in eval_dataset.list_data_dict
]
correct_results["type"] = "correct"
correct_results["question_id"] = list(
    range(len(correct_results["eval_all_perplexities"]))
)
correct_results["replicate"] = "-1"
logger.info("Finished evaluation (correct)")

In [None]:
[conv["conversations"][-1]["value"] for conv in eval_dataset.list_data_dict][:5]

In [None]:
eval_dataset.list_data_dict[0]

In [None]:
correct_results

In [None]:
# NOTE: would be much more elegant to 'modify the "image" field' rather than the orig_id_to_int_stable. But it's implemented and it works

incorrect_results = []

orig_id_to_int_stable = eval_dataset.orig_id_to_int.copy()

true_responses = [
    conv["conversations"][-1]["value"] for conv in eval_dataset.list_data_dict
]
possible_responses = list(set(true_responses))

if snakemake.params.background_shuffle == "llm-response":
    num_negatives = len(possible_responses) - 1
else:
    num_negatives = snakemake.params.num_negatives


for i in range(num_negatives):
    logger.info(f"Begin evaluation (incorrect) replicate {i}")
    # Calculate background perplexity (i.e. with the wrongly matched transcriptome)
    random.seed(
        i
    )  # something seems to reset the seed on every iteration, so we need to force it differently for random.choice

    if snakemake.params.background_shuffle == "transcriptome":
        # eval_dataset.orig_id_to_int = {k: (v+2)%len(eval_dataset.orig_id_to_int) for k, v in eval_dataset.orig_id_to_int.items()}  # <- deprecated
        for conv in eval_dataset.list_data_dict:
            # swap out the embedding (no effect for non-multimodal)
            true_id = conv["id"]
            true_response = conv["conversations"][-1]["value"]

            mismatch_response_id = random.choice(
                [
                    e["id"]
                    for e in eval_dataset.list_data_dict
                    if e["conversations"][-1]["value"] != true_response
                ]
            )
            eval_dataset.orig_id_to_int[true_id] = orig_id_to_int_stable[
                mismatch_response_id
            ]

            # swap out the top n genes in the same manner
            if snakemake.params.pre_prompt_topgenes:
                pre_prompt = copy.deepcopy(snakemake.params.pre_prompt_topgenes)
                pre_prompt[0]["value"] = pre_prompt[0]["value"].format(
                    ", ".join(top_genes.loc[mismatch_response_id])
                )
                # replace pre-prompt with wrong one
                conv["conversations"] = (
                    pre_prompt + conv["conversations"][len(pre_prompt) :]
                )

            logger.debug(
                f"Produced mismatch conversation with mismatch id {mismatch_response_id}: {conv}"
            )
    elif snakemake.params.background_shuffle == "genesshuffled":
        # Only shuffle the gene order
        assert snakemake.params.pre_prompt_topgenes, "Need top genes for this shuffle"
        for conv in eval_dataset.list_data_dict:
            true_id = conv["id"]

            # shuffle the top n genes:
            pre_prompt = copy.deepcopy(snakemake.params.pre_prompt_topgenes)
            pre_prompt[0]["value"] = pre_prompt[0]["value"].format(
                ", ".join(top_genes.loc[true_id].sample(frac=1))
            )
            # replace pre-prompt with shuffled one
            conv["conversations"] = (
                pre_prompt + conv["conversations"][len(pre_prompt) :]
            )

    elif snakemake.params.background_shuffle == "responsepermuted":
        for true_response, conv in zip(true_responses, eval_dataset.list_data_dict):
            # Generate an incorrect (shuffled) response (retain the source embedding)
            if snakemake.wildcards.dataset.endswith("_top50genes"):
                # Shuffle the order of the response genes
                comma_split = true_response.rsplit(", ", maxsplit=50)

                prefix, first_gene = comma_split[0].rsplit(" ", maxsplit=1)

                genes = [first_gene] + comma_split[1:]
                shuffled = prefix + " " + ", ".join(random.sample(genes, len(genes)))
            else:
                response_words = true_response.rsplit(" ")
                shuffled = " ".join(random.sample(response_words, len(response_words)))
            conv["conversations"][-1]["value"] = shuffled
            logger.debug(f"Shuffled response '{true_response}' to '{shuffled}'")
    elif snakemake.params.background_shuffle == "llm-response":
        for true_response, conv in zip(true_responses, eval_dataset.list_data_dict):
            conv["conversations"][-1]["value"] = [
                v for v in possible_responses if v != true_response
            ][i]
    else:
        raise ValueError(snakemake.params.background_shuffle)

    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        compute_metrics=lambda eval_pred: {
            "perplexity": eval_pred.predictions.mean().item(),
            "all_perplexities": eval_pred.predictions,
        },
        args=transformers.TrainingArguments(
            report_to="none", output_dir="/tmp", eval_accumulation_steps=4
        ),
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )
    incorrect_result = trainer.evaluate()
    incorrect_result["type"] = "incorrect"
    incorrect_result["replicate"] = i
    incorrect_result["response"] = [
        conv["conversations"][-1]["value"] for conv in eval_dataset.list_data_dict
    ]
    incorrect_result["question_id"] = list(
        range(len(incorrect_result["eval_all_perplexities"]))
    )
    incorrect_results.append(incorrect_result)
    torch.cuda.empty_cache()

In [None]:
incorrect_results[0]

In [None]:
df = pd.concat(
    [pd.DataFrame(correct_results)]
    + [pd.DataFrame(incorrect_result) for incorrect_result in incorrect_results]
)

In [None]:
df.to_csv(snakemake.output.all_perplexities)