In [None]:
from methods.llava_utils import load_llava_state
from methods.blip_utils import get_phrase_embedding, load_blip_state
from tqdm import tqdm
import os
import pickle
from methods.algorithms import generate_mass_edit_pre_hook
import torch
import random

torch.set_grad_enabled(False)

os.chdir(os.environ["VL_ROOT_DIR"])

In [None]:
model_type = "llava7b"
if model_type == "llava7b":
    # Load the LlaVA model
    loaded_state = load_llava_state(model_type, train = True)
  elif model_type.startswith("blip"):
    loaded_state = load_blip_state(model_type, train = True)
  else:
    raise Exception(f"model type {model_type} not supported")

  vocabulary, vocab_embeddings, data, execute_model, register_pre_hook, tokenizer = loaded_state["vocabulary"], loaded_state["vocab_embeddings"], loaded_state["data"], loaded_state["execute_model"], loaded_state["register_pre_hook"], loaded_state["tokenizer"]

  id_to_token = dict()
  for word in vocabulary:
    id_to_token[vocabulary[word]] = word

  output_file_path = f"./vl_results/{output_file}"
  if os.path.exists(output_file_path):
    results = torch.load(output_file_path)
  else:
    results = dict()

  random.seed(1)
  sampled_coco_img_ids = random.sample(list(data.keys()), sample_size)

  img_count = 0
  cache = dict()
  for coco_img in tqdm(sampled_coco_img_ids, desc=f"Experiment 10A: mass editing"):
    img_count += 1
    if img_count % 100 == 0:
      torch.save(results, f"./vl_results/{output_file}")

    if coco_img in results:
      continue

    # Collect the text embeddings that correspond with the hallucinations
    text_embeddings = []
    if remove_hallucinations:
      for caption_word, coco_class in set(data[coco_img]["chair_evals"]["mscoco_hallucinated_words"]):
        text_embeddings.append(get_phrase_embedding(caption_word, vocab_embeddings, tokenizer))

    if remove_gt:
      for caption_word, coco_class in set(data[coco_img]["chair_evals"]["recall_words"]):
        text_embeddings.append(get_phrase_embedding(caption_word, vocab_embeddings, tokenizer))

    if len(text_embeddings) == 0:
      print("Continuing")
      continue

    text_embeddings = torch.stack(text_embeddings, dim = 0)

    # Create a hook that will make the residual embedding orthogonal to these text embeddings
    if model_type.startswith("llava"):
      edit_embeddings_hook = generate_mass_edit_pre_hook(text_embeddings, start_edit_index=35, end_edit_index=-12, layer=0, weight = weight_factor, minimum_size=576)
    else:
      edit_embeddings_hook = generate_mass_edit_pre_hook(text_embeddings, start_edit_index=0, end_edit_index=32, layer=0, weight = weight_factor, minimum_size=32)
    hook = register_pre_hook(edit_embeddings_hook, layer = 0)

    # Rerun the model with the hooks enabled
    new_caption = execute_model(coco_img)

    # Compute the hallucinations
    new_chair_eval = evaluator.compute_hallucinations(coco_img, new_caption)
