In [None]:
import json
import seaborn as sns
import numpy as np

from llava.mm_utils import get_model_name_from_path
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava import conversation as conversation_lib



from llava.train.train import LazySupervisedDataset, DataArguments, DataCollatorForSupervisedDataset
import torch
from transformers import Trainer, EvalPrediction
import transformers
from torch.nn import CrossEntropyLoss

model_dir = snakemake.input.llava_model # "/msc/home/mschae83/cellwhisperer/results/llava/finetuned/Mistral-7B-Instruct-v0.2__03jujd8s/"
evaluation_dataset_fn = snakemake.input.evaluation_dataset  # "/msc/home/mschae83/cellwhisperer/results/llava_evaluation_conversations.json"

In [None]:
with open(evaluation_dataset_fn) as f:
    eval_set = json.load(f)
eval_set[0]

In [None]:
# Load the model

# TODO make sure load_pretrained_model is flexible enough. 

model_name=get_model_name_from_path(model_dir)
assert "mistral" in model_name.lower() and "__" in model_name, "sure that you are not using a mistral model? LLaVA depends on having it in the name (if it is mistral)"

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_dir, model_base=None, model_name=model_name, load_8bit=False, load_4bit=False, device="cuda", use_flash_attn=False)

In [None]:
conversation_lib.default_conversation = conversation_lib.conv_templates["mistral_instruct"]

eval_dataset = LazySupervisedDataset(evaluation_dataset_fn, tokenizer, DataArguments(image_data=snakemake.input.image_data))

data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

In [None]:
def compute_perplexity(eval_pred: EvalPrediction):
    """
    To align each logit with the label of the next token (LLMs predict for each input the NEXT token, we need to shift 1 to the right.
    Because the labels are not adjusted correctly, we additionally adjust them manually be additional 8-1 tokens
    """
    logits, labels = eval_pred.predictions, eval_pred.label_ids

    shift_logits = torch.from_numpy(logits)[..., snakemake.params.num_projector_tokens-1:-1, :].contiguous()  
    shift_labels = torch.from_numpy(labels)[..., 1:].contiguous()

    attention_mask = (shift_labels > -100)  # also cover -200

    # Calculate loss with CrossEntropyLoss, which expects raw logits, not probabilities
    loss_fct = CrossEntropyLoss(reduction='none')  # ignores -100 token implicitly
    # Only compute loss where attention_mask is true
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    loss = loss.view(shift_labels.size())

    # Apply the attention mask to exclude ignored indices from the loss calculation
    masked_loss = torch.where(attention_mask, loss, torch.tensor(0.0).to(loss.device))

    # Sum the loss per example and divide by the number of non-ignored tokens to get the loss per example
    example_losses = torch.sum(masked_loss, dim=1)
    example_lengths = torch.sum(attention_mask, dim=1)
    example_perplexities = torch.exp(example_losses / example_lengths)

    # Calculate the mean perplexity across all examples
    mean_perplexity = torch.mean(example_perplexities)

    return {"perplexity": mean_perplexity.item(), "all_perplexities": example_perplexities}

In [None]:
# Assuming you have already defined `model`, `tokenizer`, `eval_dataset`, and `data_collator`
trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_perplexity,
    args=transformers.TrainingArguments(report_to="none", output_dir="/tmp")
)

# Evaluate the model
correct_results = trainer.evaluate()
correct_results

In [None]:
incorrect_results = []
for i in range(10):
    # Calculate background perplexity (i.e. with the wrongly matched transcriptome)

    # shift by two to induce mismatches that account for the duplicate data points (once conversation, once description)
    eval_dataset.orig_id_to_int = {k: (v+2)%len(eval_dataset.orig_id_to_int) for k, v in eval_dataset.orig_id_to_int.items()}
    
    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        eval_dataset=eval_dataset,
        data_collator=data_collator,
        compute_metrics=compute_perplexity,
        args=transformers.TrainingArguments(report_to="none", output_dir="/tmp")
    )
    incorrect_result = trainer.evaluate()
    incorrect_results.append(incorrect_result["eval_all_perplexities"])

In [None]:
incorrect_result

In [None]:
import pandas as pd
df = pd.DataFrame({"matched": correct_results["eval_all_perplexities"], 
                   "mismatched": torch.stack(incorrect_results).mean(dim=0),
                   "mismatch_std": torch.stack(incorrect_results).std(dim=0),
                        "id": range(len(correct_results["eval_all_perplexities"]))
                       })

plot_df = df.melt(id_vars=["id"], value_vars=["matched", "mismatched"], value_name="perplexity")

In [None]:
fig = sns.catplot(x="variable", y="perplexity", data=plot_df,  kind="point", hue="id")  # , palette=["black"])
fig.legend.set_visible(False)
fig.ax.set_title(f"Matched vs mismatched ppl. Means: {df['matched'].mean()}|{df['mismatched'].mean()}")
fig.savefig(snakemake.output.comparison_plot)


_ = """
Categorical scatterplots:

stripplot() (with kind="strip"; the default)

swarmplot() (with kind="swarm")

Categorical distribution plots:

boxplot() (with kind="box")

violinplot() (with kind="violin")

boxenplot() (with kind="boxen")

Categorical estimate plots:

pointplot() (with kind="point")

barplot() (with kind="bar")

countplot() (with kind="count")
"""

In [None]:
with open(snakemake.output.log_perplexity_ratio, "w") as f:
    f.write(str(np.log(df['matched'].mean()/df['mismatched'].mean())))

In [None]:
df.to_csv(snakemake.output.all_perplexities)