# Make color-prior object dataset
This notebook creates a dataset of real-world objects with **known color priors**.
Images are obtained through targeted queries in *query_google_images.ipynb* and segmented to produce binary foreground masks using *segment_outline.ipynb*.
Using these masks, foreground and background regions are colored independently, enabling controlled manipulations of object color while preserving object identity.
Recoloring is implemented via *recolor_images.py*.
Color priors are model-specific and can be queried using *model_priors.py*.

## 1. Pipeline for LLaVA-NeXT

In [None]:
%reload_ext autoreload
%autoreload 2

from transformers import (
    BitsAndBytesConfig,
    LlavaNextProcessor,
    LlavaNextForConditionalGeneration
)
import torch
import pandas as pd
import re
import gc
import numpy as np
from tqdm import tqdm
import os
from pathlib import Path
import matplotlib.pyplot as plt

from test_MLLMs import run_vlm_evaluation
from making_color_images.model_priors import TorchColorPriors, GPTColorPriors
from making_color_images.plot_variants import collect_variants_for, show_variants_grid, plot_vlm_performance, variant_label
from making_color_images.recolor_images import generate_variants, resize_all_images_and_masks


os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

fontsize = 14

In [None]:
# Set a specific seed for reproducibility
SEED = 42
rng = np.random.default_rng(SEED)

# Setting the seed for PyTorch
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)  # If using GPU

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    llm_int8_enable_fp32_cpu_offload=True
)

torch.cuda.empty_cache()
gc.collect()

model_name = "llava-v1.6-mistral-7b-hf"
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
model = LlavaNextForConditionalGeneration.from_pretrained(
    "llava-hf/llava-v1.6-mistral-7b-hf", dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto", quantization_config=bnb_config
).to(device)

## 1. Load segmented images and query world-knowledge priors

In [1]:
# Load dataframe with outline images and masks
df = pd.read_csv(DATA / "segmented_images.csv")
display(df)

NameError: name 'pd' is not defined

In [None]:
# Initalize ModelColorPriors
priors = TorchColorPriors(
    processor=processor,
    model=model,
    data_folder=DATA,
    device=device
)

In [None]:
# Generate new priors
# priors_df = priors.get_model_color_priors(df, save=True)

In [None]:
# Checkpoint: display the generated priors
priors_df = priors.load_model_priors()
display(priors_df[['object', 'correct_answer', 'dummy_priors', 'image_priors']])