# Requirments

In [1]:
# # Install required versions
# !pip install transformers==4.49.0 \
#             scikit-learn==1.5.2 \
#             numpy==2.1.2 \
#             openai==1.54.4 \
#             google-cloud-aiplatform==1.72.0 \
#             pandas==2.2.3 \
#             Pillow==11.0.0 \
#             opencv-python==4.10.0 \
#             requests==2.32.3 \
#             qwen-vl-utils==0.0.8

# Imports

In [3]:
from PIL import Image
import PIL
import pandas as pd
import json
import requests
import transformers
import numpy as np
import os
from tqdm import tqdm
tqdm.pandas()


PROJECT_PATH = './expirments' # Path to expirments folder

# Utils

### General

In [None]:
import io
import os
import base64
import json
from PIL import Image
import nltk
import zlib
import math
import hashlib
from IPython.display import display, HTML

def display_title(title, font_size=24):
    display(HTML(f"<h1 style='font-size: {font_size}px;'>{title}</h1>"))

def convert(o):
    if isinstance(o, np.int64) or isinstance(o, np.int32):
        return int(o)
    raise TypeError

def dump_json(file_path, data):
    with open(file_path, 'w') as f:
        string_json = json.dumps(data, cls=NumpyEncoder)
        f.write(string_json)
        f.close()

def create_json(path):
    if not os.path.isfile(path):
        print('Creating json ', path)
        with io.open(path, 'w+') as json_file:
            json_file.write(json.dumps({}))
    return path

def flatten(lst):
    flat_list = []
    for item in lst:
        if isinstance(item, list) or isinstance(item, set):
            flat_list.extend(flatten(item))
        else:
            flat_list.append(item)
    return flat_list

def pil_to_html(image, size=100):
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return f'<img src="data:image/png;base64,{img_str}" alt="Product Image" style="max-width:{size}px;">'

def display_images_in_dataframe(data_df, image_colums=['source_img', 'target_img', 'mask_img', 'source_mask_bbox_padding', 'target_mask_bbox_padding'], index_start=0, index_end=5, size=100, drop_columns=None):
    data_df = data_df[index_start:index_end]
    data_df = pd.DataFrame(data_df)

    if drop_columns:
        data_df.drop(columns=drop_columns, inplace=True, errors='ignore')

    data_frame_colums = list(data_df.columns)
    for column in image_colums:
        if column in data_frame_colums:
            data_df[column] = data_df[column].map(lambda x: pil_to_html(x, size))

    # Convert the DataFrame to HTML
    df_html = data_df.to_html(escape=False)
    
    # Wrap the DataFrame HTML in a div that has a horizontal scrollbar
    html = f'''
    <div style="overflow-x: auto; border: 1px solid #ccc; margin-bottom: 10px;">
        {df_html}
    </div>
    '''
    display(HTML(html))
    
def get_json(file_path, DataFrame=False):
    with open(os.path.join(file_path),  encoding="utf8") as f:
        response = json.load(f)
        f.close()
        if DataFrame:
            return pd.DataFrame(response)
        return response
    
class NumpyEncoder(json.JSONEncoder):
    """ Custom encoder for numpy data types and specific handling for 'pred_masks'."""
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return {'type': 'ndarray', 'data': obj.tolist(), 'shape': obj.shape}  # Include shape information
        elif isinstance(obj, (np.float32, np.float64, np.int32, np.int64)):
            return float(obj)  # Convert NumPy floats to Python float
        return json.JSONEncoder.default(self, obj)

def decode_element(element):
    """Recursively decode JSON element, converting any special encoded numpy arrays back."""
    if isinstance(element, dict):
        if 'type' in element and element['type'] == 'ndarray':
            # This is an encoded numpy array, convert it back
            return np.array(element['data']).reshape(element['shape'])
        else:
            # Otherwise, recursively decode dictionary elements
            return {key: decode_element(value) for key, value in element.items()}
    elif isinstance(element, list):
        # Recursively decode list elements
        return [decode_element(item) for item in element]
    else:
        # Return the element unchanged if it's not a list or dictionary
        return element

def compress_segmentation(segmentation_array):
    array_int = segmentation_array.astype(int)
    bit_string = ''.join(map(str, array_int.flatten()))
    byte_array = bytearray(int(bit_string[i : i + 8], 2) for i in range(0, len(bit_string), 8))
    # Compress the byte array
    compressed_data = zlib.compress(byte_array)
    compressed_data_base64 = base64.b64encode(compressed_data).decode('utf-8')
    return compressed_data_base64

def decompresse_segmentation(compressed_segmentation_array):
    decompressed_data_bytes = base64.b64decode(compressed_segmentation_array)
    # Decompress
    decompressed_data = zlib.decompress(decompressed_data_bytes)
    # Convert back to an array (you need to know the original shape and content type)
    original_bit_string = ''.join(format(byte, '08b') for byte in decompressed_data)
    width = height = int(math.sqrt(len(original_bit_string)))
    original_array = np.array([int(bit) for bit in original_bit_string]).reshape((width, height))
    return original_array

def get_image_id(image: Image.Image) -> str:
    """
    Generate a unique string ID for a given PIL image using an MD5 hash of the pixel data.
    """
    # Convert the image to raw bytes
    img_bytes = image.tobytes()

    # Create an MD5 hash object and update it with the image bytes
    hash_md5 = hashlib.md5()
    hash_md5.update(img_bytes)

    # Return the hexadecimal digest of the hash as the unique ID
    return hash_md5.hexdigest()

def is_pil_image(image):
    return isinstance(image, Image.Image)

def get_instance_id(example, original=False):
    return f"{example['img_id']}_{example['original_instruction'] if original else example['instruction']}_{example['turn_index']}"



###  Load Cache Files

In [None]:
GEMINI_CACHE_PATH = create_json(PROJECT_PATH + '/merged_gemini_cache.json')
GPT_CACHE_PATH = create_json(PROJECT_PATH + '/merged_gpt_cache.json')
QWEN_CACHE_PATH = create_json(PROJECT_PATH + '/qwen_cache.json')
INTERN_VL3_CACHE_PATH  = create_json(PROJECT_PATH + '/intern_vl3_cache.json')
GEMINI_CACHE = get_json(GEMINI_CACHE_PATH)
GPT_CACHE = get_json(GPT_CACHE_PATH)

LOCAL_RUN = False # use cache only no apis

### Gemini

In [None]:
import os

# Set the environment variable to point to the credentials file
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = './application_default_credentials.json'

from vertexai.preview.generative_models import (
    GenerationConfig,
    GenerativeModel,
    Image as PartImage
)

GEMINI_MAX_TOKENS = 1000 
GEMINI_TEMPRATURE = 0
GEMINI_MODEL = "gemini-pro-vision"
GEMINI_MODE_one_half = "gemini-1.5-pro-preview-0409"

objects_list_generation_config = GenerationConfig(
    temperature=GEMINI_TEMPRATURE,
    top_p=1.0,
    top_k=32,
    candidate_count=1,
    max_output_tokens=GEMINI_MAX_TOKENS,
)

# Define project information
PROJECT_ID = "gen-lang-client-0642013381"  # @param {type:"string"}
LOCATION = "us-central1"  # @param {type:"string"}

# Initialize Vertex AI
import vertexai
vertexai.init(project=PROJECT_ID, location=LOCATION)
multimodal_model = GenerativeModel(GEMINI_MODEL)
multimodal_model_one_half = GenerativeModel(GEMINI_MODE_one_half)

def get_gemini_cache_id(prompt):
    return f"{prompt}___{GEMINI_MODEL}___{GEMINI_MAX_TOKENS}___{GEMINI_TEMPRATURE}"

def get_gemini_response_from_cache(prompt):
    id = get_gemini_cache_id(prompt).replace('\r', '')
    return GEMINI_CACHE.get(id)

def save_gemini_response_in_cache(prompt, response):
    GEMINI_CACHE.update({get_gemini_cache_id(prompt): response})
    dump_json(GEMINI_CACHE_PATH, GEMINI_CACHE)

def get_multimodal_model(mode_name):
    if mode_name == GEMINI_MODEL:
        return multimodal_model
    elif mode_name == GEMINI_MODE_one_half:
        return multimodal_model_one_half
    raise Exception()

# This is the image type required by Gemini API
def get_part_image(image):
    img_byte_arr = io.BytesIO()
    image.save(img_byte_arr, format=image.format if image.format is not None else 'PNG')
    return PartImage.from_bytes(img_byte_arr.getvalue())

def get_original_unparsed_gemini_response(responses, contents, model, debug=False, generation_config=None):
    model = get_multimodal_model(model)
    if generation_config is None:
        responses_generator = model.generate_content(contents, stream=True)
    else:
        responses_generator = model.generate_content(contents, stream=True, generation_config=generation_config)
    for response in responses_generator:
        try:
            responses.append(response.text) # there is more metadata here, for now we ingnore this
        except Exception as e:
            debug and print('Error reading Gemini response, Error:', e)
            debug and print('Contents', contents)
    debug and print(responses)
    return responses

def get_geimini_response(contents, model=GEMINI_MODEL, debug=False, generation_config=objects_list_generation_config, overide_cache=False):
    debug and print(contents)
    
    instance_id = '' if model == GEMINI_MODEL else f'{model}-'
    for index, obj in enumerate(contents):
        instance_id = instance_id + get_image_id(obj) if isinstance(obj, PIL.Image.Image) else instance_id +obj

    cached_response = get_gemini_response_from_cache(instance_id)

    if cached_response is not None and not overide_cache and (cached_response != [] and cached_response != ['']):
        return ''.join(cached_response).strip()

    for index, obj in enumerate(contents):
        if isinstance(obj, PIL.Image.Image):
            contents[index] = get_part_image(obj)

    if LOCAL_RUN:
        raise Exception()
    
    debug and print(contents)
    try:
        responses = get_original_unparsed_gemini_response([], contents, model, debug, generation_config=generation_config) # first try
        save_gemini_response_in_cache(instance_id, responses)
    except:
        try:
            responses = get_original_unparsed_gemini_response([], contents, model, debug, generation_config=generation_config) # second try
            save_gemini_response_in_cache(instance_id, responses)
        except Exception as e:
            print('Gemini Error tryed two time!\n', e)
            responses = []
   
    return ''.join(responses).strip()

### NLTK utils

In [None]:
import nltk
from nltk import pos_tag
from nltk.corpus import wordnet as wn
from nltk.tokenize import word_tokenize
from nltk.translate.meteor_score import meteor_score
from sentence_transformers import SentenceTransformer, util
from rouge_score import rouge_scorer

# Ensure necessary NLTK resources are downloaded
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
nltk.download('wordnet')

model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
def evaluate_captions(reference_caption, generated_caption):
    reference_tokens = nltk.word_tokenize(reference_caption)
    generated_tokens = nltk.word_tokenize(generated_caption)
    meteor = meteor_score([reference_tokens], generated_tokens)
    rouge_scorer_inst = rouge_scorer.RougeScorer(['rouge1'], use_stemmer=True)
    rouge1_recall = rouge_scorer_inst.score(reference_caption, generated_caption)['rouge1'].recall
    reference_embedding = model.encode(reference_caption, convert_to_tensor=True)
    generated_embedding = model.encode(generated_caption, convert_to_tensor=True)
    similarity = util.pytorch_cos_sim(reference_embedding, generated_embedding).item()
    return meteor, rouge1_recall, similarity

def shortest_path_distance(word1, word2):
    # Get synsets for both words
    synsets1 = wn.synsets(word1)
    synsets2 = wn.synsets(word2)
    
    # Initialize the shortest path distance to None
    shortest_distance = None
    
    # Compare each synset of word1 against each synset of word2
    for synset1 in synsets1:
        for synset2 in synsets2:
            # Compute the shortest path distance between synset1 and synset2
            distance = synset1.shortest_path_distance(synset2)
            if distance is not None:
                # If it's the first distance found or if it's shorter than the previous shortest, update
                if shortest_distance is None or distance < shortest_distance:
                    shortest_distance = distance
    
    return shortest_distance

def get_related_synonyms(word, level=3):
    synonyms = set()
    synset = wn.synsets(word, pos=wn.NOUN)
    if len(synset) >= 1:
        synset = synset[0]
        # Direct synonyms
        synonyms.update(lemma.name() for lemma in synset.lemmas())
        
        # Explore one level of hypernyms (more general terms)
        if level >= 2:
            for hypernym in synset.hypernyms():
                synonyms.update(lemma.name() for lemma in hypernym.lemmas())
            
        # Explore one level of hyponyms (more specific terms)
        if level >= 3:
            for hyponym in synset.hyponyms():
                synonyms.update(lemma.name() for lemma in hyponym.lemmas())
            
        return list(synonyms)
    else:
        return []

def is_noun(word):
    # Tokenize the word
    tokens = word_tokenize(word)
    # POS tag the tokenized word
    tagged = pos_tag(tokens)
    # Check if the POS tag of the word is one of the noun tags
    return tagged[0][1] in ["NN", "NNS", "NNP", "NNPS"]

# Example usage
word = "backpack"
related_synonyms = get_related_synonyms(word)
print(f"Related synonyms for '{word}':", related_synonyms)

### GPT

In [None]:
from openai import OpenAI
from io import BytesIO

API_KEY = ""
client = OpenAI(api_key=API_KEY)
gpt_counter = 0
# Params
MAX_TOKENS = 500
TEMPERATURE = 0.000001  # Default 1

def get_gpt_cache_id(prompt, model, temperature = TEMPERATURE, max_tokens=MAX_TOKENS):
    return f"{prompt}___{model}___{temperature}_____{max_tokens}".replace('\r', '')

def get_gpt_response_from_cache(prompt, model, temperature = TEMPERATURE, max_tokens=MAX_TOKENS):
    id = get_gpt_cache_id(prompt, model, temperature, max_tokens)
    return GPT_CACHE.get(id)

def save_gpt_response_in_cache(prompt, model, response, temprature=TEMPERATURE, max_tokens=MAX_TOKENS):
    global gpt_counter
    GPT_CACHE.update({get_gpt_cache_id(prompt, model, temprature, max_tokens): response})
    gpt_counter += 1
    if gpt_counter % 100 == 0:
        print(gpt_counter, 'saving cache')
        dump_json(GPT_CACHE_PATH, GPT_CACHE)

def encode_image(image):
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')
    
def get_gpt_4_vision_response(contents, MODEL="gpt-4-vision-preview", debug=False):
    instance_id = ''
    debug and print(contents)

    for index, obj in enumerate(contents):
        instance_id = instance_id + get_image_id(obj) if isinstance(obj, PIL.Image.Image) else instance_id + obj

    cached_response = get_gpt_response_from_cache(instance_id, MODEL)
    
    if cached_response is not None:
        return cached_response
    
    if LOCAL_RUN:
        raise Exception()
    
    for index, obj in enumerate(contents):
        if isinstance(obj, PIL.Image.Image) or isinstance(obj, str):
            contents[index] = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_image(obj)}",  "detail": "low"}} if isinstance(obj, PIL.Image.Image) else {"type": "text", "text": contents[index]}
    
    try:
        response = client.chat.completions.create(model=MODEL,messages=[{"role": "user","content": contents}], temperature=TEMPERATURE,max_tokens=1000)
    except:
        try:
            response = client.chat.completions.create(model=MODEL,messages=[{"role": "user","content": contents}], temperature=TEMPERATURE,max_tokens=1000)
        except Exception as e:
            print('GPT4 Vision request failed twice' , e)
            raise Exception('GPT4 Vision request failed twice' )

    debug and print(response)
    save_gpt_response_in_cache(instance_id, MODEL, response.choices[0].message.content)
    return response.choices[0].message.content

def get_chatgpt_4_prediction(prompt, overide_cache=False, temprature=TEMPERATURE, max_tokens=MAX_TOKENS):
    model="gpt-4"
    try:
        response = get_chat_gpt_prediction(prompt, model, overide_cache, temprature=temprature, max_tokens=max_tokens)
    except:
        try:
            response = get_chat_gpt_prediction(prompt, model, overide_cache, temprature=temprature, max_tokens=max_tokens)
        except Exception as e:
            print('GPT request failed twice' , e)
            raise Exception('GPT request failed twice')
    return response

def get_chat_gpt_prediction(prompt, model, overide_cache=False, temprature=TEMPERATURE, max_tokens=MAX_TOKENS):
    cached_response = get_gpt_response_from_cache(prompt, model, temprature, max_tokens) if not overide_cache else None

    if cached_response is not None:
        return cached_response
    if LOCAL_RUN:
        raise Exception()
    
    response = client.chat.completions.create(
      model=model,
      max_tokens=1000,
      temperature=temprature,
      messages=[
        {"role": "user", "content": prompt},
      ]
    )
    save_gpt_response_in_cache(prompt, model, response.choices[0].message.content, temprature, max_tokens)
    return response.choices[0].message.content

## Utils

In [None]:
import warnings
import cv2
import numpy as np
from PIL import Image
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*ANTIALIAS is deprecated.*")

def get_annotated_data_with_majority():
    df = pd.read_csv(r'./editinspector_benchmark.csv', index_col=0)
    for col in df.columns:
        if col.startswith('metadata_'):
            df[col] = df[col].apply(eval)
    return df

In [None]:
def get_MB_brush_test_set():
    test_folder = 'PATH TO MAGICBRUSH TEST SET'
    images_folder = test_folder + '/images'
    test = get_json(test_folder + '/edit_sessions.json')
    
    # ['img_id', 'turn_index', 'source_img', 'mask_img', 'instruction', 'target_img']
    def get_turn_index(example):
        try:
            return int(example['input'].split('.')[0][-1])+1
        except:
            return 1
    
    formated_edits = []
    for key in test.keys():
        edits = test.get(key)
        for edit in edits:
            formated_edit = dict()
            formated_edit['img_id'] = key
            formated_edit['instruction'] = edit['instruction']
            formated_edit['source_img'] = Image.open(images_folder + fr'/{key}/{edit["input"]}')
            formated_edit['mask_img'] = Image.open(images_folder + fr'/{key}/{edit["mask"]}')
            formated_edit['target_img'] = Image.open(images_folder + fr'/{key}/{edit["output"]}')
            formated_edit['turn_index'] = get_turn_index(edit)
            formated_edits.append(formated_edit)
    return pd.DataFrame(formated_edits)


MB_test = get_MB_brush_test_set()

In [None]:
import pandas as pd
from PIL import Image
from tqdm import tqdm
import math

tqdm.pandas()

def is_pil_image(image):
    return isinstance(image, Image.Image)

def is_nan_or_none(value):
    return value is None or (isinstance(value, float) and math.isnan(value))

def get_MB_example_idx(example, MB_set):
    example_img_id = example['id'].split('_')[0]
    turn_index = int(example['id'].split('_')[-1])
    example_instruction = example.get('original_instruction') if not is_nan_or_none(example.get('original_instruction')) else example['instruction']
    indices_instruction = set([i for i, x in enumerate(MB_set['instruction']) if x == example_instruction])
    indices_img_id = set([i for i, x in enumerate(MB_set['img_id']) if x == example_img_id])
    indices_turn = set([i for i, x in enumerate(MB_set['turn_index']) if x == turn_index])
    example_index = indices_instruction.intersection(indices_img_id).intersection(indices_turn)
    assert len(example_index) == 1
    return example_index.pop()

def smooth_dall_e_pixels(image):
    """
    Smooths the DALLÂ·E watermark region by copying pixels from nearby areas and using inpainting.
    The region is defined as the bottom-right corner (x: 92.1% to 100% of width, y: 98.3% to 100% of height).
    
    :param image: PIL Image object to which the smoothing will be applied.
    :return: Modified image with inpainting applied.
    """
    width, height = image.size

    # Calculate the region coordinates for the image
    x_start = int(width * 0.921)
    x_end = width
    y_start = int(height * 0.983)
    y_end = height

    # Convert the PIL image to a NumPy array
    image_np = np.array(image)

    # Create a mask for inpainting
    mask = np.zeros(image_np.shape[:2], np.uint8)

    # Mark the region in the mask
    mask[y_start:y_end, x_start:x_end] = 255

    # Apply inpainting to smooth the transition
    inpainted_image = cv2.inpaint(image_np, mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)

    # Convert the inpainted result back to a PIL image
    inpainted_image_pil = Image.fromarray(inpainted_image)

    return inpainted_image_pil

images_cache = dict()
def load_example_images(example, smooth_images=True):
    global images_cache
    for MB_set in [MB_test]:
        try:
            example_index = get_MB_example_idx(example, MB_set)
            example_images = images_cache.get(example_index)
            if example_images is None:
                example_images = [MB_set.iloc[example_index]['source_img'], MB_set.iloc[example_index]['target_img']] 
                images_cache.update({example_index: example_images})
            example['source_img'] = example_images[0].copy()
            example['target_img'] = example_images[1].copy()
            if smooth_images:
                source_img_size = example['source_img'].size
                example['target_img'] = smooth_dall_e_pixels(example['target_img'].resize(source_img_size) if example['target_img'].size != source_img_size else example['target_img'])
                example['source_img'] = smooth_dall_e_pixels(example['source_img'])
        except Exception:
            continue
            
    assert is_pil_image(example['source_img']) and is_pil_image(example['target_img'])
    return example 


In [None]:
import re
import pandas as pd

def smart_convert_yes_no_list(values):
    """
    Convert variations of "Yes"/"No" strings (e.g., "Yes", "No", " yes ", etc.) to True/False.
    Handles cases where the response starts with "Yes" or "No" but may have additional text.
    Raises an exception if a string is not recognized as "Yes" or "No".
    Accepts a list or pandas Series.
    """
    # Regular expression patterns for matching Yes/No variants
    yes_pattern = re.compile(r'^\s*(yes|no contradiction)\b.*$', re.IGNORECASE)
    no_pattern = re.compile(r'^\s*(no|contradiction found)\b.*$', re.IGNORECASE)

    def map_to_bool(value):
        """Map Yes/No variations to True/False using regex matching or raise an error."""
        if isinstance(value, str):
            if yes_pattern.match(value):
                return True
            elif no_pattern.match(value):
                return False
            else:
                print(f'Received invalid answer: "{value}", treating this as "No".')
                return False
                # raise Exception(f"Invalid value for Yes/No conversion: '{value}'")
        return value  # Return original non-string values without conversion

    # Apply the mapping function to each value
    if isinstance(values, pd.Series):
        return values.apply(map_to_bool)
    else:
        return [map_to_bool(v) for v in values]

## Qwen2.5

In [None]:
import os
import json
import hashlib
from PIL import Image
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-7B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"
)
qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
qwen_model.eval()


if os.path.exists(QWEN_CACHE_PATH):
    try:
        with open(QWEN_CACHE_PATH, "r", encoding="utf-8") as f:
            QWEN_CACHE = json.load(f)
    except json.JSONDecodeError:
        print("Warning: Cache file was corrupt or empty. Starting with empty cache.")
        QWEN_CACHE = {}
else:
    QWEN_CACHE = {}

def manual_save_qwen_one():
    with open(QWEN_CACHE_PATH, "w", encoding="utf-8") as f:
        json.dump(QWEN_CACHE, f, indent=2)

qwen_counter = 0
def save_qwen_cache():
    global qwen_counter
    qwen_counter+=1
    if qwen_counter % 10 == 0:
        with open(QWEN_CACHE_PATH, "w", encoding="utf-8") as f:
            json.dump(QWEN_CACHE, f, indent=2)

def make_cache_key(before_image: Image.Image, after_image: Image.Image, prompt: str) -> str:
    before_id = get_image_id(before_image)
    after_id = get_image_id(after_image)
    return f"{before_id}_{after_id}_{hashlib.md5(prompt.encode('utf-8')).hexdigest()}"

# --- Main Function ---
def get_qwen_response(inputs):
    with torch.no_grad():
        before_image, after_image, prompt = inputs
        cache_key = make_cache_key(before_image, after_image, prompt)

        if cache_key in QWEN_CACHE:
            print("[CACHE HIT]")
            return QWEN_CACHE[cache_key]

        print("[CACHE MISS] Running model...")

        messages = [
            {"role": "user", "content": [
                {"type": "image", "image": before_image},
                {"type": "image", "image": after_image},
                {"type": "text", "text": prompt},
            ]}
        ]

        # Preparation for inference
        text = qwen_processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = qwen_processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        inputs = inputs.to("cuda")

        # Inference: Generation of the output
        generated_ids = qwen_model.generate(**inputs, max_new_tokens=265)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        response = qwen_processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )

        final_response = response[0]
        QWEN_CACHE[cache_key] = final_response
        save_qwen_cache()
        return final_response


## InternVL3


In [None]:
import os
from PIL import Image
import requests
from io import BytesIO
import math
import numpy as np
import torch
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoConfig, AutoTokenizer
generation_config = dict(max_new_tokens=1024, do_sample=True)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

generation_config = dict(max_new_tokens=1024, do_sample=True)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images



def load_image(image, input_size=448, max_num=12):
    # Load image from URL or local path
    if isinstance(image, str):
        if image.startswith('http://') or image.startswith('https://'):
            response = requests.get(image)
            response.raise_for_status()
            image = Image.open(BytesIO(response.content)).convert('RGB')
        else:
            image = Image.open(image).convert('RGB')

    # Apply preprocessing
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)

    # Transform and stack valid images
    pixel_values = []
    for img in images:
        if isinstance(img, Image.Image):
            pixel_values.append(transform(img))
    return torch.stack(pixel_values)

def split_model(model_name):
    device_map = {}
    world_size = torch.cuda.device_count()
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    num_layers = config.llm_config.num_hidden_layers
    # Since the first GPU will be used for ViT, treat it as half a GPU.
    num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
    num_layers_per_gpu = [num_layers_per_gpu] * world_size
    num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for j in range(num_layer):
            device_map[f'language_model.model.layers.{layer_cnt}'] = i
            layer_cnt += 1
    device_map['vision_model'] = 0
    device_map['mlp1'] = 0
    device_map['language_model.model.tok_embeddings'] = 0
    device_map['language_model.model.embed_tokens'] = 0
    device_map['language_model.output'] = 0
    device_map['language_model.model.norm'] = 0
    device_map['language_model.model.rotary_emb'] = 0
    device_map['language_model.lm_head'] = 0
    device_map[f'language_model.model.layers.{num_layers - 1}'] = 0

    return device_map

# If you set `load_in_8bit=True`, you will need two 80GB GPUs.
# If you set `load_in_8bit=False`, you will need at least three 80GB GPUs.
model_path = 'OpenGVLab/InternVL3-8B'
device_map = split_model('InternVL3-8B')
model = AutoModel.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    load_in_8bit=False,
    low_cpu_mem_usage=True,
    use_flash_attn=False,
    trust_remote_code=True,
    device_map=device_map).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)


INTERN_VL3_CACHE_PATH = "intern_vl3_cache.json"
INTERN_VL3_CACHE = {}

# Manual save function
def manual_save_intern_vl3():
    with open(INTERN_VL3_CACHE_PATH, "w", encoding="utf-8") as f:
        json.dump(INTERN_VL3_CACHE, f, indent=2)

# Load cache if it exists
if os.path.exists(INTERN_VL3_CACHE_PATH):
    with open(INTERN_VL3_CACHE_PATH, "r", encoding="utf-8") as f:
        INTERN_VL3_CACHE = json.load(f)

# Auto-save every 10 requests
intern_vl3_counter = 0
def save_intern_vl3_cache():
    global intern_vl3_counter
    intern_vl3_counter += 1
    if intern_vl3_counter % 10 == 0:
        with open(INTERN_VL3_CACHE_PATH, "w", encoding="utf-8") as f:
            json.dump(INTERN_VL3_CACHE, f, indent=2)

# Cache key generation
def make_cache_key(before_image: Image.Image, after_image: Image.Image, prompt: str) -> str:
    before_id = get_image_id(before_image)
    after_id = get_image_id(after_image)
    return f"{before_id}_{after_id}_{hashlib.md5(prompt.encode('utf-8')).hexdigest()}"

# Prediction with caching
@torch.no_grad()
def get_intern_vl3_response(inputs):
    before_image = inputs[0]
    after_image = inputs[1]
    prompt = inputs[2]
    cache_key = make_cache_key(before_image, after_image, prompt)
    if cache_key in INTERN_VL3_CACHE:
        return INTERN_VL3_CACHE[cache_key]

    print("[CACHE MISS] Running InternVL3...")

    pixel_values1 = load_image(before_image, max_num=12).to(torch.bfloat16).cuda()
    pixel_values2 = load_image(after_image, max_num=12).to(torch.bfloat16).cuda()
    pixel_values = torch.cat((pixel_values1, pixel_values2), dim=0)

    response = model.chat(tokenizer, pixel_values, prompt, generation_config)
    INTERN_VL3_CACHE[cache_key] = response
    save_intern_vl3_cache()
    return response

# Evaluation

In [None]:
MODELS = ['gemini', 'gpt-4', 'gpt-4o', 'gpt-4-turbo', 'gemini-1.5', 'pipeline']
# MODELS = ['gpt-4', 'pipeline']

def get_task_relevant_pipeline_column(task_name):
    if task_name in ['is_extensive_difference_caption_accurate', 'is_difference_captipn_accurate', 'is_accurate']:
        return 'is_edit_accurate'
    elif task_name == 'is_artifacts':
        return 'is_edit_contains_artifacts'
    elif task_name == 'is_good_quality':
        return 'is_bad_quality_edit'

def get_human_prediction(example, task_name):
    if task_name in ['is_difference_captions_same_complex', 'single difference caption']:
        return get_selected_annotated_difference_caption(example, second_prioretiy=True)
    raise Exception('Human evlauation not defined for task')

    
def get_pipeline_prediction(example, task_name):
    if task_name in ['is_extensive_difference_caption_accurate', 'is_difference_captipn_accurate', 'is_accurate', 'is_artifacts']:
        return 'Yes' if example[get_task_relevant_pipeline_column(task_name)] else 'No'
    elif task_name == 'is_good_quality':
        return 'No' if example[get_task_relevant_pipeline_column(task_name)] else 'Yes'
    elif task_name in ['is_difference_captions_same_complex', 'single difference caption']:
        return example['extensive_caption']
    else:
        raise Exception(task_name)

def get_model_response(inputs, model, example, task_name):
    if model == 'pipeline':
        return get_pipeline_prediction(example, task_name)
    elif model == 'gemini':
        return get_geimini_response(inputs)
    elif model == 'gemini-1.5':
        return get_geimini_response(inputs, model=GEMINI_MODE_one_half)
    elif model == 'intern-vl3':
        return get_intern_vl3_response(inputs)
    elif model == 'qwen':
        return get_qwen_response(inputs)
    elif model == 'gpt-4':
        return get_gpt_4_vision_response(inputs)
    elif model == 'gpt-4o':
        return get_gpt_4_vision_response(inputs, MODEL='gpt-4o')
    elif model == 'gpt-4-turbo':
        return get_gpt_4_vision_response(inputs, MODEL='gpt-4-turbo')
    elif model == 'human':
        return get_human_prediction(example, task_name)
    raise Exception('Model evaluation method not implemented')

### Difference Caption Evaluation Metrics & Utils

In [None]:
generate_general_edit_actions_from_discription = open(PROJECT_PATH + "//prompts/generate_general_edit_actions_from_discription.txt").read()
general_detect_the_difference_prompt = open(PROJECT_PATH + '//prompts/detect_the_difference_prompt.txt', "r").read()
get_objects_similarity_label_prompt = open(PROJECT_PATH + '//prompts/get_objects_similarity_label_prompt.txt', "r").read()
caption_replacer = '$REPLACE_HERE$'

identical_columns = {"exact match {}"}
almost_identical_columns = {
    "exact match different order {}", 
    "partial match both source and target different action {}"
}
completely_wrong = {"same action different objects {}", "completely different {}"}
under_specify = completely_wrong.union({
    "partial match either source or target same action {}", 
    "partial match either source or target different action {}", 
    "partial match either source or target different order same action {}", 
    "partial match either source or target different order different action {}"
})
all_metrics_columns = list(identical_columns.union(almost_identical_columns).union(under_specify).union(completely_wrong))

def select_elements_by_indexes(my_list, indexes):
    selected_elements = []
    for i in indexes:
        selected_elements.append(my_list[i])
    return selected_elements

def get_selected_annotated_difference_caption(example, second_prioretiy=False):
    matching_indexes = list(range(0, len(example['metadata_annotated_difference_caption_explanation'])))
    try:
        gorund_truth_artifacts = example['annotated_is_artifacts']
        gorund_truth_caption = example['annotated_is_extensive_difference_caption_accurate']
        
        artifacts_annotations = example['metadata_annotated_is_artifacts']
        caption_annotations = example['metadata_annotated_is_extensive_difference_caption_accurate']
        
        if len(artifacts_annotations) != len(caption_annotations):
            print(artifacts_annotations, type(artifacts_annotations))
            print(caption_annotations, type(caption_annotations))
            raise ValueError("Both arrays must have the same length")
    
        # Find indexes where both arrays have the ground_truth value
        matching_indexes = [i for i in range(len(artifacts_annotations)) if artifacts_annotations[i] == gorund_truth_artifacts and caption_annotations[i] == gorund_truth_caption]
    except Exception as e:
        print('get_selected_annotated_difference_caption - ' +  str(e))
            
    matching_descriptions = select_elements_by_indexes(example['metadata_annotated_difference_caption_explanation'], matching_indexes)
    matching_descriptions = sorted(matching_descriptions, key=len, reverse=True)
    if len(matching_descriptions) == 0 or len(matching_descriptions[0]) < 5: # In case of accurate edit
        return example['extensive_caption']
    if not second_prioretiy and 'unexpected' not in matching_descriptions[0]:
        return matching_descriptions[0]
    else:
        try:
            return matching_descriptions[1]
        except:
            captions = example['metadata_annotated_difference_caption_explanation'].copy()
            captions.remove(matching_descriptions[0])
            return captions[0] if len(captions[0]) > len(captions[1]) else captions[1]

def get_general_difference_prompt(before_image, after_image, model, example, task):
    return get_model_response([before_image, after_image, general_detect_the_difference_prompt], model=model, example=example, task_name=task)

def get_action_id(edit_action, id_prefix, debug=False):
    debug and print('Id fuction recived:,', edit_action)
    return f'{id_prefix}{edit_action["action"]}_{edit_action["source_object"]}_{edit_action["target_object"]}'

def get_similarity_label(first_object, second_object, debug=False):
    prompt = get_objects_similarity_label_prompt.format(first_object, second_object)
    debug and print(prompt)
    response = get_gpt_4_vision_response([prompt], MODEL='gpt-4o')
    cleaned_response = response.replace('```json', '').replace('```', '').strip()
    try:
        parsed_json = json.loads(cleaned_response)
        return parsed_json['similarityLevel']
    except:
        print('Recived invalid response', cleaned_response, 'original response', response)
        return 0

def get_edit_actions_from_description(description, id_prefix, debug=False):
    response = None
    if not is_nan_or_none(description):
        prompt = generate_general_edit_actions_from_discription.replace(caption_replacer, description)
        try:
            debug and print('Edit action from description prompt: ', prompt)
            response = get_chatgpt_4_prediction(prompt)
            if response.replace('\n', '').strip() == 'None':
                parsed_response = []
            else:
                parsed_response = json.loads(response)
            debug and print('Edit action from description response: ', parsed_response)
            for action in parsed_response:
                action.update({'id': get_action_id(action, id_prefix, debug=debug)})
            return parsed_response
        except:
            debug and print(prompt, response)
            print('Failed to get actions from description - debug here')
            return []
    else:
        print('Recived NaN Difference caption (This is probably caption pipleine fail).')
        print(description)
        print(id_prefix)
        return []
        

gpt_nouns = set()
def get_action_similarity_level(first_object, second_object, debug=False):
    global gpt_nouns
    if first_object == second_object:
        return '2'
    first_object_words = list(filter(lambda word: is_noun(word), word_tokenize(first_object)))
    second_object_words = list(filter(lambda word: is_noun(word), word_tokenize(second_object)))
    
    second_object_synonyms_nouns = []
    for word in second_object_words:
        second_object_synonyms_nouns.append(word)
        second_object_synonyms_nouns = second_object_synonyms_nouns + get_related_synonyms(word)
    
    first_object_synonyms_nouns = []
    for word in first_object_words:
        first_object_synonyms_nouns.append(word)
        first_object_synonyms_nouns = first_object_synonyms_nouns + get_related_synonyms(word)
    
    is_first_and_second_objects_share_nouns = set(first_object_synonyms_nouns).intersection(set(second_object_synonyms_nouns))
    if is_first_and_second_objects_share_nouns:
        similarity = get_similarity_label(first_object, second_object, debug)
        if not is_first_and_second_objects_share_nouns and is_first_and_second_objects_share_subtype_nouns:
            gpt_nouns.add(first_object + ' + ' + second_object + ' + ' + str(similarity))
            
        return similarity
    else:
        return '0'
    
def get_action_similarity_details(ground_truth_edit_action, predicted_edit_action):
    is_actions_replace_and_change_attribute = (ground_truth_edit_action['action'] in ['Change Attribute', 'Replace'] and predicted_edit_action['action'] in ['Change Attribute', 'Replace'])
    is_same_action = (ground_truth_edit_action['action'] == predicted_edit_action['action']) or is_actions_replace_and_change_attribute
    source_similarity_level = get_action_similarity_level(ground_truth_edit_action['source_object'], predicted_edit_action['source_object'])
    target_similarity_level = get_action_similarity_level(ground_truth_edit_action['target_object'], predicted_edit_action['target_object'])
    source_ground_target_predicted_similarity_level = get_action_similarity_level(ground_truth_edit_action['source_object'], predicted_edit_action['target_object'])
    target_ground_source_predicted_similarity_level = get_action_similarity_level(ground_truth_edit_action['target_object'], predicted_edit_action['source_object'])
    return is_same_action, source_similarity_level, target_similarity_level, source_ground_target_predicted_similarity_level, target_ground_source_predicted_similarity_level

def remove_action_by_id(actions_list, id):
    for action in actions_list:
        if action.get('predicted_id') == id:
            actions_list.remove(action)
    return actions_list

def is_exact_match(is_same_action, is_source_similar, is_target_similar, ground_action):
    return (is_source_similar or ground_action['action'] == 'Add') and (is_target_similar or ground_action['action'] == 'Remove') and is_same_action

def is_exact_match_different_order(is_target_and_source_similar_different_order, is_same_action):
    return is_target_and_source_similar_different_order and is_same_action

def is_partial_match_both_source_and_target_different_action(is_source_similar, is_target_similar, is_same_action, ground_action):
    return (is_source_similar or ground_action['action'] == 'Add') and (is_target_similar or ground_action['action'] == 'Remove') and not is_same_action

def is_partial_match_either_source_or_target_same_action(is_source_similar, is_target_similar, is_same_action):
    return (is_source_similar or is_target_similar) and is_same_action

def is_partial_match_either_source_or_target_different_action(is_source_similar, is_target_similar, is_same_action):
    return (is_source_similar or is_target_similar) and not is_same_action

def is_partial_match_either_source_or_target_different_order_same_action(is_target_or_source_similar_different_order, is_same_action):
    return is_target_or_source_similar_different_order and is_same_action

def is_partial_match_either_source_or_target_different_order_different_action(is_target_or_source_similar_different_order, is_same_action):
    return is_target_or_source_similar_different_order and not is_same_action

same_source_and_target_action = []
def evaluate_actions_similarity(ground_truth_actions, predicted_actions, threshold):
    """Evaluate the similarity between ground truth actions and predicted actions."""
    similarity_categories = {
        'exact match': [],
        'exact match different order': [],
        'partial match either source or target same action': [],
        'partial match both source and target different action': [],
        'partial match either source or target different action': [],
        'partial match either source or target different order same action': [],
        'partial match either source or target different order different action': [],
        'same action different objects': [],
        'completely different': [
            {'predicted_id': action['id'], 'predicted_source': action['source_object'], 'predicted_target': action['target_object']}
            for action in predicted_actions
        ]
    }

    for ground_action in ground_truth_actions:
        for predicted_action in predicted_actions:
            ids = [{
                'ground_id': ground_action['id'], 'predicted_id': predicted_action['id'],
                'ground_source': ground_action['source_object'], 'ground_target': ground_action['target_object'],
                'predicted_source': predicted_action['source_object'], 'predicted_target': predicted_action['target_object']
            }]
            
            is_same_action, source_similarity_level, target_similarity_level, source_ground_target_predicted_similarity_level, target_ground_source_predicted_similarity_level = get_action_similarity_details(ground_action, predicted_action)
            
            is_source_similar = int(source_similarity_level) >= threshold and ground_action['action'] != 'Add'
            is_target_similar = int(target_similarity_level) >= threshold and ground_action['action'] != 'Remove'
            is_target_or_source_similar_different_order = (
                (int(source_ground_target_predicted_similarity_level) >= threshold and ground_action['action'] != 'Add') or 
                (int(target_ground_source_predicted_similarity_level) >= threshold and ground_action['action'] != 'Remove')
            )
            is_target_and_source_similar_different_order = (
                int(source_ground_target_predicted_similarity_level) >= threshold and 
                int(target_ground_source_predicted_similarity_level) >= threshold
            )

            if any([is_same_action, is_source_similar, is_target_similar, is_target_or_source_similar_different_order]):
                remove_action_by_id(similarity_categories['completely different'], predicted_action['id'])
            # metadata_partial match either source or target same action gemini 
            if is_exact_match(is_same_action, is_source_similar, is_target_similar, ground_action):
                similarity_categories['exact match'].extend(ids)
            elif is_exact_match_different_order(is_target_and_source_similar_different_order, is_same_action):
                similarity_categories['exact match different order'].extend(ids)
            elif is_partial_match_both_source_and_target_different_action(is_source_similar, is_target_similar, is_same_action, ground_action):
                similarity_categories['partial match both source and target different action'].extend(ids)
            elif is_partial_match_either_source_or_target_same_action(is_source_similar, is_target_similar, is_same_action):
                similarity_categories['partial match either source or target same action'].extend(ids)
            elif is_partial_match_either_source_or_target_different_action(is_source_similar, is_target_similar, is_same_action):
                similarity_categories['partial match either source or target different action'].extend(ids)
            elif is_partial_match_either_source_or_target_different_order_same_action(is_target_or_source_similar_different_order, is_same_action):
                similarity_categories['partial match either source or target different order same action'].extend(ids)
            elif is_partial_match_either_source_or_target_different_order_different_action(is_target_or_source_similar_different_order, is_same_action):
                similarity_categories['partial match either source or target different order different action'].extend(ids)
            elif is_same_action:
                similarity_categories['same action different objects'].extend(ids)

    return tuple(similarity_categories[key] for key in similarity_categories)

def enrich_with_evaluation_data(example, ground_truth_difference_caption, predicted_difference_caption, model, threshold=1):
    global same_source_and_target_action
    """Enrich the example with evaluation data based on predicted and ground truth captions."""
    if example.get('ground truth edit actions') is None:
        example['ground truth edit actions'] = get_edit_actions_from_description(ground_truth_difference_caption, 'ground truth_')

    ground_truth_edit_actions = example['ground truth edit actions']
    predicted_edit_actions = get_edit_actions_from_description(predicted_difference_caption, f'{model}_')
    example[f'predicted edit actions {model}'] = predicted_edit_actions
    
    (exact_match, exact_match_different_order, partial_match_either_source_or_target_same_action, 
     partial_match_both_source_and_target_different_action, partial_match_either_source_or_target_different_action, 
     partial_match_either_source_or_target_different_order_same_action, partial_match_either_source_or_target_different_order_different_action, 
     same_action_different_objects, completely_different) = evaluate_actions_similarity(ground_truth_edit_actions, predicted_edit_actions, threshold)
  
    same_source_and_target_action.extend(partial_match_both_source_and_target_different_action)
    example[f'exact match {model}'] = exact_match
    example[f'exact match different order {model}'] = exact_match_different_order
    example[f'partial match either source or target same action {model}'] = partial_match_either_source_or_target_same_action
    example[f'partial match both source and target different action {model}'] = partial_match_both_source_and_target_different_action
    example[f'partial match either source or target different action {model}'] = partial_match_either_source_or_target_different_action
    example[f'partial match either source or target different order same action {model}'] = partial_match_either_source_or_target_different_order_same_action
    example[f'partial match either source or target different order different action {model}'] = partial_match_either_source_or_target_different_order_different_action
    example[f'same action different objects {model}'] = same_action_different_objects
    example[f'completely different {model}'] = completely_different
    
    return example

    
def enrich_free_text_difference_caption_evaluation_details(example, threshold=1, boolean_evaluation_mode=False):
    example = load_example_images(example)
    example['eval_selected_difference_caption'] = get_selected_annotated_difference_caption(example)
    for model in MODELS:
        predicted_difference_prompt = get_general_difference_prompt(example['source_img'], example['target_img'], model, example, 'is_difference_captions_same_complex')
        example[f'predicted_difference_caption_{model}'] = predicted_difference_prompt
        if not boolean_evaluation_mode:
            enrich_with_evaluation_data(example, example['eval_selected_difference_caption'], predicted_difference_prompt, model, threshold=threshold)
    return example

#### Analysis

##### Generation Question

In [None]:
import warnings
from tqdm import tqdm
from pandas.errors import PerformanceWarning
is_diffrence_captions_main_diffrence_same = 'Determine if the primary difference is accurately present in the caption. Answer "Yes" or "No" only.\n\nPrimary difference:\n{}\n\nCaption:\n{}\n\nQuestion:\nIs the primary difference present in the caption? (Answer Yes/No only)'
is_diffrence_captions_same_prompt = open(PROJECT_PATH + '//prompts/is_diffrence_captions_same.txt', "r").read()
get_main_diffrence_prompt = open(PROJECT_PATH + '//prompts/get_main_difference_prompt.txt', "r").read()
generate_single_change_difference_caption_prompt = 'Please describe the main difference between the two images.'
get_visual_consistency_prompt = open(PROJECT_PATH + '//prompts/get_visual_consistency_prompt.txt', 'r').read()
get_visual_quality_prompt = open(PROJECT_PATH + '//prompts/get_visual_quality_prompt.txt', 'r').read()
is_feedback_same_prompt = open(PROJECT_PATH + '//prompts/is_feedback_same_prompt.txt', 'r').read()
warnings.filterwarnings("ignore", category=FutureWarning, message="Setting an item of incompatible dtype is deprecated")
warnings.simplefilter(action='ignore', category=PerformanceWarning)

def get_columns_actions_statistics(example, model, cluster_columns, debug=False):
    # Format the selected columns for the model
    selected_columns = [column.format(model) for column in cluster_columns]
    if debug:
        print('Selected columns -', selected_columns)
    
    # Get ground truth IDs
    ground_truth_ids = {action['id'] for action in example['ground truth edit actions']}
    
    # Collect ground truth IDs from the selected columns
    selected_columns_ground_truth_ids = set()
    selected_columns_predicted_ids = set()
    for column in selected_columns:
        selected_columns_ground_truth_ids.update(action['ground_id'] for action in example[f'metadata_{column}'])
        selected_columns_predicted_ids.update(action['predicted_id'] for action in example[f'metadata_{column}'])
    
    # Ensure selected column IDs do not exceed ground truth IDs
    assert len(selected_columns_ground_truth_ids) <= len(ground_truth_ids)
    if debug:
        print('Ground Truth Ids:', ground_truth_ids)
        print('Ids in predictions:', selected_columns_ground_truth_ids)
        print('# of misses:', len(ground_truth_ids - selected_columns_ground_truth_ids))
    
    # Calculate identical IDs and percentages
    identical_ids = ground_truth_ids - selected_columns_ground_truth_ids
    percentage_of_ground_truth_identical_actions = len(selected_columns_ground_truth_ids) / len(ground_truth_ids)
    predicted_edit_actions_count = len(example[f'predicted edit actions {model}'])
    percentage_of_predicted_identical_actions = (len(selected_columns_predicted_ids) / predicted_edit_actions_count) if predicted_edit_actions_count != 0 else 0

    # Return results
    if not is_nan_or_none(percentage_of_ground_truth_identical_actions):
        return len(identical_ids) == 0, percentage_of_ground_truth_identical_actions, percentage_of_predicted_identical_actions
    else:
        return len(identical_ids) == 0, 0, 0

def is_over(row, model):
    if get_columns_actions_statistics(row, model, almost_identical_columns)[0] or get_columns_actions_statistics(row, model, identical_columns)[0]:
        under = False
        for col in all_metrics_columns:
            if col in under_specify:
                if row[col.format(model)] > 0:
                    under = True
        return under
    return False    

def is_under(row, model):
    if not get_columns_actions_statistics(row, model, almost_identical_columns)[0] and not get_columns_actions_statistics(row, model, identical_columns)[0]:
        under = False
        for col in all_metrics_columns:
            if col in under_specify:
                if row[col.format(model)] > 0:
                    under = True
        return under
    return False

def get_example_main_difference(example, debug=False):
    example_explenation = example['metadata_annotated_difference_caption_explanation']
    prompt = get_main_diffrence_prompt.format(example_explenation[0], example_explenation[1], example_explenation[2])
    response = get_model_response([prompt], model='gpt-4o', example=example, task_name='main difference')
    debug and print(prompt)
    debug and print(response)
    return response

def generate_single_change_difference_caption(row, model, debug=False):
    return get_model_response([row['source_img'], row['target_img'], generate_single_change_difference_caption_prompt], model=model, example=row, task_name='single difference caption')
    
def find_longest_string(arr):
    if not arr:  # Check if the array is empty
        return None
    return max(arr, key=len)

def get_is_visual_consistency_feedback_same(row, model, debug=False):
    row['eval_visualy_consistence_explanation'] = find_longest_string(row['metadata_annotated_visualy_consistence_explanation'])
    debug and print('Ground truth visual consistency response: \n', row['eval_visualy_consistence_explanation'] )
    row['predicted_visualy_consistence_explanation_{}'.format(model)] = get_model_response([row['source_img'], row['target_img'], get_visual_consistency_prompt.format(row['Instruction'])], model=model, example=row, task_name='feedback')
    debug and print('Model visual consistency response: \n', row['predicted_visualy_consistence_explanation_{}'.format(model)])
    prompt = is_feedback_same_prompt.format(row['eval_visualy_consistence_explanation'], row['predicted_visualy_consistence_explanation_{}'.format(model)])
    response = get_model_response([prompt], model='gpt-4o', example=row, task_name='is_feedback_same_boolean')
    debug and print(response, '\n')
    return smart_convert_yes_no_list([response])[0], response

def get_is_technical_precision_feedback_same(row, model, debug=False):
    row['eval_good_quality_explanation'] = find_longest_string(row['metadata_annotated_good_quality_explanation'])
    debug and print('Ground truth good visual response: \n', row['eval_good_quality_explanation'] )
    row['predicted_good_quality_explanation_{}'.format(model)] = get_model_response([row['source_img'], row['target_img'], get_visual_quality_prompt.format(row['Instruction'])], model=model, example=row, task_name='feedback')
    debug and print('Model good visual response: \n', row['predicted_good_quality_explanation_{}'.format(model)])
    prompt = is_feedback_same_prompt.format(row['eval_good_quality_explanation'], row['predicted_good_quality_explanation_{}'.format(model)])
    response = get_model_response([prompt], model='gpt-4o', example=row, task_name='is_feedback_same_boolean')
    debug and print(response, '\n')
    return smart_convert_yes_no_list([response])[0], response

def get_is_difference_captions_same(row, model_difference_caption, single_change=False, debug=False):
    if single_change:
        prompt = is_diffrence_captions_main_diffrence_same.format(row['caption_primary_difference'], model_difference_caption)
    else:
        ground_truth_difference_caption  = row['eval_selected_difference_caption']
        prompt = is_diffrence_captions_same_prompt.format(ground_truth_difference_caption, model_difference_caption)
    response = get_model_response([prompt], model='gpt-4o', example=row, task_name='is_difference_captions_same_boolean')
    debug and print(prompt)
    debug and print(response)
    return smart_convert_yes_no_list([response])[0], response

def initalize_model_stats(stats, model, lingustic_metrics, boolean_evaluation_mode):
    stats.update({f'is_difference_caption_same_{model}_': 0, f'is_ALL_difference_caption_main_difference_same_{model}_': 0})
    stats.update({f'is_single_change_difference_caption_same_{model}_': 0})
    stats.update({f'is_visualy_consistence_feedback_same_{model}_': 0, f'is_technical_precision_feedback_same_{model}_': 0})
    if lingustic_metrics:
        stats.update({f'meteor_{model}_': 0, f'rouge1_recall_{model}_': 0, f'sentece_similarity_{model}_': 0 })
    if not boolean_evaluation_mode:
        stats.update({f'almost_identical_{model}_': 0, f'under_{model}_': 0, f'over_{model}_': 0, f'identical_{model}_': 0, f'almost_identical_ground_truth_actions_{model}_': 0, f'identical_ground_truth_actions_{model}_': 0, f'almost_identical_predicted_actions_{model}_': 0, f'identical_predicted_actions_{model}_': 0,  })
    return stats

def update_stats(stats, column, model, index, value, evaluation_df_parsed):
    stats.update({f'{column}_{model}_': stats.get(f'{column}_{model}_') + value})
    evaluation_df_parsed.at[index, f'{column}_{model}_'] = value

def evaluate_models(evaluation_df, boolean_evaluation_mode=False, lingustic_metrics=False, evaluate_consistency_and_precision=False, evaluate_all_difference_caption=False):
    stats = dict()
    metadata = []

    evaluation_df_parsed = evaluation_df.copy()
    if not boolean_evaluation_mode:
        for model in MODELS:
            model_base_metrics_columns = list(map(lambda x: x.format(model), all_metrics_columns))
            metadata_model_base_metrics_columns = list(map(lambda x: 'metadata_' + x.format(model), all_metrics_columns))
            evaluation_df_parsed[metadata_model_base_metrics_columns] = evaluation_df[model_base_metrics_columns]
            evaluation_df_parsed[model_base_metrics_columns] = evaluation_df[model_base_metrics_columns].map(len)
    
    for model in MODELS:
        stats = initalize_model_stats(stats, model, lingustic_metrics, boolean_evaluation_mode)   
        technical_precision_counter = 0   
        contextual_consistency_counter = 0    
        for index, row in tqdm(evaluation_df_parsed.iterrows(), total=len(evaluation_df_parsed), desc=f"{model} predictions"):
            row = load_example_images(row)
            evaluation_df_parsed.at[index, 'caption_primary_difference'] = row['caption_primary_difference'] = get_example_main_difference(row)
            row_metadata = dict()
            
            if evaluate_all_difference_caption:
                # Difference caption - main difference
                is_difference_caption_same, response = get_is_difference_captions_same(row, row['predicted_difference_caption_{}'.format(model)])
                is_difference_caption_main_difference_same, response_main_diffrence = get_is_difference_captions_same(row, row['predicted_difference_caption_{}'.format(model)], single_change=True)
                row_metadata.update({'model': model, 'id': row['id'],'is_difference_caption_same': is_difference_caption_same, 'is_difference_caption_same_response': response,  'is_difference_caption_main_difference_same': is_difference_caption_main_difference_same, 'is_ALL_difference_caption_main_difference_same_response': response_main_diffrence})
            
                            
                # This is the main difference of the big all difference captions of models similarity
                if is_difference_caption_main_difference_same:
                    stats.update({f'is_ALL_difference_caption_main_difference_same_{model}_': stats.get(f'is_ALL_difference_caption_main_difference_same_{model}_') + 1})
                    evaluation_df_parsed.at[index, f'is_ALL_difference_caption_main_difference_same_{model}_'] = True
                    
                if is_difference_caption_same:
                    stats.update({f'is_difference_caption_same_{model}_': stats.get(f'is_difference_caption_same_{model}_') + 1})
                    evaluation_df_parsed.at[index, f'is_difference_caption_same_{model}_'] = True
                    
            # Single Difference caption - main difference - is_single_change_difference_caption_main_difference_same_
            single_change_difference_caption = generate_single_change_difference_caption(row, model) # genera a single difference caption for a speicific model
            evaluation_df_parsed.at[index, 'predicted_single_change_difference_caption_{}'.format(model)] = single_change_difference_caption
            is_single_change_difference_caption_same, response_is_single_change_difference_caption = get_is_difference_captions_same(row, single_change_difference_caption, single_change=True)
            row_metadata.update({'single_change_difference_caption': single_change_difference_caption, 'is_single_change_difference_caption_same': is_single_change_difference_caption_same, 'response_is_single_change_difference_caption': response_is_single_change_difference_caption})
            
            if evaluate_consistency_and_precision:
                if not row['annotated_is_good_quality'] and model not in ['pipeline', 'human'] and 'finetune' not in model:
                    technical_precision_counter += 1
                    is_technical_precision_feedback_same = get_is_technical_precision_feedback_same(row, model)[0]
                    row_metadata.update({'is_technical_precision_feedback_same_': is_technical_precision_feedback_same})
                    if is_technical_precision_feedback_same:
                        stats.update({f'is_technical_precision_feedback_same_{model}_': stats.get(f'is_technical_precision_feedback_same_{model}_') + 1})
                if not row['annotated_is_visualy_consistence'] and model not in ['pipeline', 'human'] and 'finetune' not in model:
                    contextual_consistency_counter += 1
                    is_visual_consistency_feedback_same = get_is_visual_consistency_feedback_same(row, model)[0]
                    row_metadata.update({'is_visualy_consistence_feedback_same': is_visual_consistency_feedback_same})
                    if is_visual_consistency_feedback_same:
                        stats.update({f'is_visualy_consistence_feedback_same_{model}_': stats.get(f'is_visualy_consistence_feedback_same_{model}_') + 1})
                
                meteor, rouge1_recall, sentece_similarity = None, None, None
                if lingustic_metrics:
                    model_predicted_difference_caption = row['predicted_difference_caption_{}'.format(model)]
                    if isinstance(model_predicted_difference_caption, str):
                        meteor, rouge1_recall, sentece_similarity = evaluate_captions(model_predicted_difference_caption, model_predicted_difference_caption)

                if lingustic_metrics and meteor and isinstance(model_predicted_difference_caption, str):
                    row_metadata.update({'meteor': meteor, 'rouge1_recall': rouge1_recall, 'sentece_similarity': sentece_similarity})
                    stats.update({f'sentece_similarity_{model}_': stats.get(f'sentece_similarity_{model}_') + sentece_similarity})
                    stats.update({f'rouge1_recall_{model}_': stats.get(f'rouge1_recall_{model}_') + rouge1_recall})
                    stats.update({f'meteor_{model}_': stats.get(f'meteor_{model}_') + meteor})
            
            # This is predicting the a short difference caption is it similar to the main difference extracted from human anntations
            if is_single_change_difference_caption_same:
                stats.update({f'is_single_change_difference_caption_same_{model}_': stats.get(f'is_single_change_difference_caption_same_{model}_') + 1})
                evaluation_df_parsed.at[index, f'is_single_change_difference_caption_same_{model}_'] = True
                
            if boolean_evaluation_mode:
                continue

            identical, identical_ground_truth_precentage, identical_predicted_precentage = get_columns_actions_statistics(row, model, identical_columns)
            almost_identical, almost_identical_ground_truth_precentage, almost_identical_predicted_precentage = get_columns_actions_statistics(row, model, almost_identical_columns)
            under, over = is_under(row, model), is_over(row, model)
            
            row_metadata.update({'index': index, 'model': model, 'identical': identical, 'almost_identical': almost_identical, 'under': under, 'over': over})
            metadata.append(row_metadata)
            number_of_categories = sum([almost_identical, under, over, identical])
            # assert(number_of_categories <= 1)
            
            if identical and not over:
                stats.update({f'identical_{model}_': stats.get(f'identical_{model}_') + 1})
                evaluation_df_parsed.at[index, f'identical_{model}_'] = True
            
            if almost_identical:
                stats.update({f'almost_identical_{model}_': stats.get(f'almost_identical_{model}_') + 1})
                evaluation_df_parsed.at[index, f'almost_identical_{model}_'] = True
            
            if under or number_of_categories == 0:
                stats.update({f'under_{model}_': stats.get(f'under_{model}_') + 1})
                evaluation_df_parsed.at[index, f'under_{model}_'] = True
            
            if over:
                stats.update({f'over_{model}_': stats.get(f'over_{model}_') + 1})
                evaluation_df_parsed.at[index, f'over_{model}_'] = True
            
            update_stats(stats, 'identical_ground_truth_actions', model, index, identical_ground_truth_precentage, evaluation_df_parsed)
            update_stats(stats, 'almost_identical_ground_truth_actions', model, index, almost_identical_ground_truth_precentage, evaluation_df_parsed)
            update_stats(stats, 'identical_predicted_actions', model, index, identical_predicted_precentage, evaluation_df_parsed)
            update_stats(stats, 'almost_identical_predicted_actions', model, index, almost_identical_predicted_precentage, evaluation_df_parsed)
    
        original_stats = stats.copy()
        updated_stats = {}  # New dictionary to hold the updated values

        for key in original_stats:
            if f'_{model}_' not in key:
                continue
            if 'visualy_consistence_feedback' in key:
                updated_stats[key + '_precentage'] = round(original_stats.get(key)/(contextual_consistency_counter or 1), 2) * 100
            elif 'technical_precision_feedback' in key:
                updated_stats[key + '_precentage'] = round(original_stats.get(key)/(technical_precision_counter or 1), 2) * 100
            elif 'precentage' not in key:
                updated_stats[key + '_precentage'] = round(original_stats.get(key)/len(evaluation_df), 2) * 100

        # Merge the new percentages back into the original stats
        stats.update(updated_stats)

    # Create a DataFrame for each model and concatenate them
    dfs = []
    for model in MODELS:
        model_data = {k.replace(f'_{model}_', ''): v for k, v in stats.items() if f'_{model}_' in k}
        for column_to_pop in ['meteor', 'rouge1_recall', 'sentece_similarity', f'almost_identical_actions', f'identical_actions']:
            model_data.pop(column_to_pop, None)
        df_model = pd.DataFrame([model_data], index=[model])
        dfs.append(df_model)
    
    return pd.concat(dfs), evaluation_df_parsed, metadata, stats

#### Run Difference Caption Tasks

In [5]:
LOCAL_RUN = False # Uses cached responsed only and will through error when trying to do an API call
BOOLEAN_EVALUATION_MODE = False # Evaluate all tasks other then ALL difference caption task
LINGUISTIC_METRICS_MODE = False # Rouge, Meteor, etc.
EVALUATE_CONSISTENCY_AND_PRECISION = False # evaluate Visual Consistey and Techincal Precision
EVALUATE_ALL_DIFFERENCE_CAPTION = True # Evaluate ALL difference caption task

MODELS = ['intern-vl3', 'gpt-4o', 'gpt-4o', 'gemini-1.5', 'gpt-4','gpt-4-turbo', 'qwen']
annotated_data = get_annotated_data_with_majority()

enriched_difference_captions = annotated_data
if EVALUATE_ALL_DIFFERENCE_CAPTION:
    enriched_difference_captions = annotated_data.progress_apply(lambda x: enrich_free_text_difference_caption_evaluation_details(x, boolean_evaluation_mode=BOOLEAN_EVALUATION_MODE, threshold=1), axis=1)
results_df_diff, evaluation_df_parsed_diff, metadata_diff, stats_diff = evaluate_models(enriched_difference_captions, boolean_evaluation_mode=BOOLEAN_EVALUATION_MODE, lingustic_metrics=LINGUISTIC_METRICS_MODE, evaluate_consistency_and_precision=EVALUATE_CONSISTENCY_AND_PRECISION, evaluate_all_difference_caption=EVALUATE_ALL_DIFFERENCE_CAPTION)
results_df_diff

### Evaluation - Booleans


In [None]:
def get_introduction_prompt(example):
    return f'You are provided with before and after images of an image edit for the edit instruction "{example["instruction"]}".'

def get_extensive_difference_caption_task_prompt(example, finetune_prompt=False):
    if finetune_prompt:
        return f'Does the difference caption "{example["extensive_caption"]}" describes the difference between the two images (Answer only Yes/No)?'
    else:
        introduction = get_introduction_prompt(example)
        return f'{introduction} Does the difference caption "{example["extensive_caption"]}" describe the main difference between the two images reflected from the edit instruction (Answer only Yes/No)?'

def get_difference_caption_task_prompt(example, finetune_prompt=False):
    if finetune_prompt:
        return f'Does the difference caption "{example["extensive_caption"]}" describes the difference between the two images (Answer only Yes/No)?'
    else:
        introduction = get_introduction_prompt(example)
        return f'{introduction} Does the difference caption "{example["caption"]}" describe the main difference between the two images reflected from the edit instruction (Answer only Yes/No)?'

def get_is_visualy_consistence_task_prompt(example, finetune_prompt=False):
    introduction = get_introduction_prompt(example)
    return f'{introduction} Is the edited object or the area affected by the instruction consistent with the edit instruction and the image scene in terms of shape, size, brightness, shadows, texture, color, etc. (Answer only Yes/No)?'

def get_is_accurate_task_prompt(example, finetune_prompt=False):
    if finetune_prompt:
        return f'Did the edit instruction "{example["Instruction"]}" was accurately executed and reflect the intended change (Answer only Yes/No)?'
    else: 
        introduction = get_introduction_prompt(example)
        return f'{introduction} Was the instruction accurately executed, and does it reflect the intended change? Ignore any other changes that do not relate to the instruction. Disregard visual quality issues like low resolution, blur, or unexpected properties such as shape, size, and color. (Answer Yes/No only)'

def get_is_artifacts_caption_task_prompt(example, finetune_prompt=False):
    if finetune_prompt:
        return f'Are there any artifacts or alterations in the image not intended to be affected by the edit "{example["instruction"]}" (Answer only Yes/No)?'
    else:
        introduction = get_introduction_prompt(example)
        return f'{introduction} Are there any artifacts or alterations in the image not intended to be affected by the edit instruction (Answer only Yes/No)?'

def get_is_good_quality_task_prompt(example, finetune_prompt=False):
    introduction = get_introduction_prompt(example)
    return f'Does the edited object or the area affected by the instruction "{example["instruction"]}" maintain the image resolution, exhibit blur, show any smoothness, etc. (Answer only Yes/No)?'


task_map = {
    'is_extensive_difference_caption_accurate': get_extensive_difference_caption_task_prompt,
    'is_difference_captipn_accurate': get_difference_caption_task_prompt,
    'is_accurate': get_is_accurate_task_prompt,
    'is_artifacts': get_is_artifacts_caption_task_prompt,
    'is_visualy_consistence': get_is_visualy_consistence_task_prompt,
    'is_good_quality': get_is_good_quality_task_prompt
}

# assert len(set(task_map.keys()).intersection(set(evaluation_columns))) == 5

def enrich_with_task_details(data_df, predictions_output, task_map=task_map):
    def enrich_models_predictions(example):
        for task_name in task_map.keys():
            task_prompt_function = task_map[task_name]
            for model in MODELS:
                task_prompt = task_prompt_function(example, finetune_prompt='finetune' in model)
                example[model + f'_{task_name}_prediction'] = get_model_response([example['source_img'], example['target_img'], task_prompt], model=model, example=example, task_name=task_name)
                
        return example
    
    data_df = data_df.progress_apply(lambda x: enrich_models_predictions(x), axis=1)
    if predictions_output is not None:
        data_df.to_csv(predictions_output)
        print('Saved predictions to ' + predictions_output)
    return data_df

In [None]:
import pandas as pd
import os
from sklearn.metrics import precision_score, f1_score, recall_score, balanced_accuracy_score
from tqdm import tqdm

# Initialize tqdm for progress bar
tqdm.pandas()

# Define your evaluation function
def evaluate_models(dataframe):
    # Initialize dictionaries to store results for each metric
    precision_results = {}
    recall_results = {}
    f1_results = {}
    weighted_f1_results = {}
    balanced_accuracy_results = {}
    specificity_results = {}

    # Identify annotated columns and model predictions
    annotated_columns = [col for col in dataframe.columns if col.startswith("annotated_")]
    prediction_columns = [col for col in dataframe.columns if "_prediction" in col]
    # Iterate through each annotated column and calculate metrics
    for annotated_col in annotated_columns:
        metric = annotated_col.replace("annotated_", "")
        
        # Initialize nested dictionaries for each metric
        precision_results[metric] = {}
        recall_results[metric] = {}
        f1_results[metric] = {}
        weighted_f1_results[metric] = {}
        balanced_accuracy_results[metric] = {}
        specificity_results[metric] = {}

        # Find all prediction columns related to the current metric
        relevant_prediction_cols = [col for col in prediction_columns if metric in col]
        # Compare each model's predictions against the ground truth
        for pred_col in relevant_prediction_cols:
            model_name = pred_col.replace(f"_{metric}_prediction", "").replace(f"{metric}_", "")
            ground_truth = dataframe[annotated_col]
            predictions = dataframe[pred_col]

            # Ensure ground truth and predictions are consistent data types
            ground_truth = pd.Series(ground_truth).astype(int)
            predictions = pd.Series(smart_convert_yes_no_list(predictions)).astype(int)

            # Calculate evaluation scores for binary classification
            f1 = f1_score(ground_truth, predictions)
            f1_weighted = f1_score(ground_truth, predictions, average='weighted')
            balanced_accuracy = balanced_accuracy_score(ground_truth, predictions)
            precision = precision_score(ground_truth, predictions)
            recall = recall_score(ground_truth, predictions)
            
            false_negatives = ((predictions == 0) & (ground_truth == 1)).sum()
            true_negatives = ((predictions == 0) & (ground_truth == 0)).sum()

            specificity = true_negatives / (true_negatives + false_negatives) if (false_negatives + true_negatives) > 0 else 0

            # Store the results in the respective dictionaries
            f1_results[metric][model_name] = str(round(f1 * 100, 1)) + '%'
            weighted_f1_results[metric][model_name] = str(round(f1_weighted * 100, 1)) + '%'
            balanced_accuracy_results[metric][model_name] = str(round(balanced_accuracy * 100, 1)) + '%'
            precision_results[metric][model_name] = str(round(precision * 100, 1)) + '%'
            recall_results[metric][model_name] = str(round(recall * 100, 1)) + '%'
            specificity_results[metric][model_name] = str(round(specificity * 100, 1)) + '%'

    # Convert the results nested dictionaries to DataFrames
    precision_df = pd.DataFrame.from_dict(precision_results, orient='index')
    recall_df = pd.DataFrame.from_dict(recall_results, orient='index')
    f1_df = pd.DataFrame.from_dict(f1_results, orient='index')
    weighted_f1_df = pd.DataFrame.from_dict(weighted_f1_results, orient='index')
    balanced_accuracy_df = pd.DataFrame.from_dict(balanced_accuracy_results, orient='index')
    specificity_df = pd.DataFrame.from_dict(specificity_results, orient='index')
    return [['Weighted F1', weighted_f1_df],['Balanced Accuracy', balanced_accuracy_df], ['F1', f1_df], ['Precision', precision_df], ['Recall', recall_df], ['Specificity (True Negative Rate)', specificity_df]]#, ['False Discovery Rate', fdr_df], ['False Omission Rate', for_df]]

def evaluate_batch(annotated_data, predictions_output=None, predict_only=False, end_index=1000, task_map=task_map):
    annotated_data = annotated_data.progress_apply(lambda example: load_example_images(example), axis=1)
    evaluation_data = enrich_with_task_details(annotated_data[:end_index], predictions_output, task_map=task_map)
    evaluation_data['annotated_is_difference_captipn_accurate'] = evaluation_data['annotated_is_extensive_difference_caption_accurate'].copy()
    if not predict_only:
        return evaluate_models(evaluation_data), evaluation_data
    return evaluation_data

In [None]:
def calculate_metrics(ground_truth, predictions):
    """Calculates and returns a dictionary of evaluation metrics."""
    f1 = f1_score(ground_truth, predictions)
    f1_weighted = f1_score(ground_truth, predictions, average='weighted')
    balanced_accuracy = balanced_accuracy_score(ground_truth, predictions)
    precision = precision_score(ground_truth, predictions)
    recall = recall_score(ground_truth, predictions)

    false_negatives = ((predictions == 0) & (ground_truth == 1)).sum()
    true_negatives = ((predictions == 0) & (ground_truth == 0)).sum()

    specificity = true_negatives / (true_negatives + false_negatives) if (false_negatives + true_negatives) > 0 else 0

    return {
        'f1': f1,
        'f1_weighted': f1_weighted,
        'balanced_accuracy': balanced_accuracy,
        'precision': precision,
        'recall': recall,
        'specificity': specificity
    }

def store_metrics(results_dict, metrics, metric, model_name):
    """Stores calculated metrics into the results dictionaries."""
    results_dict['f1_results'][metric][model_name] = f"{round(metrics['f1'] * 100, 1)}%"
    results_dict['weighted_f1_results'][metric][model_name] = f"{round(metrics['f1_weighted'] * 100, 1)}%"
    results_dict['balanced_accuracy_results'][metric][model_name] = f"{round(metrics['balanced_accuracy'] * 100, 1)}%"
    results_dict['precision_results'][metric][model_name] = f"{round(metrics['precision'] * 100, 1)}%"
    results_dict['recall_results'][metric][model_name] = f"{round(metrics['recall'] * 100, 1)}%"
    results_dict['specificity_results'][metric][model_name] = f"{round(metrics['specificity'] * 100, 1)}%"

def evaluate_models(dataframe):
    # Initialize dictionaries to store results for each metric
    results_dict = {
        'precision_results': {},
        'recall_results': {},
        'f1_results': {},
        'weighted_f1_results': {},
        'balanced_accuracy_results': {},
        'specificity_results': {}
    }

    # Identify annotated columns and model predictions
    annotated_columns = [col for col in dataframe.columns if col.startswith("annotated_")]
    prediction_columns = [col for col in dataframe.columns if "_prediction" in col]

    # Iterate through each annotated column and calculate metrics
    for annotated_col in annotated_columns:
        metric = annotated_col.replace("annotated_", "")

        # Initialize nested dictionaries for each metric
        for key in results_dict:
            results_dict[key][metric] = {}

        # Find all prediction columns related to the current metric
        relevant_prediction_cols = [col for col in prediction_columns if metric in col]

        # Oracle logic when both 'gpt-4o' and 'pipeline' exist and metric contains 'artifact'
        if 'artifact' in metric and any('gpt-4o' in col for col in relevant_prediction_cols) and any('pipeline' in col for col in relevant_prediction_cols):
            gpt4o_col = next(col for col in relevant_prediction_cols if 'gpt-4o' in col)
            pipeline_col = next(col for col in relevant_prediction_cols if 'pipeline' in col)

            # Get predictions for both models
            gpt4o_predictions = pd.Series(smart_convert_yes_no_list(dataframe[gpt4o_col])).astype(int)
            pipeline_predictions = pd.Series(smart_convert_yes_no_list(dataframe[pipeline_col])).astype(int)
            ground_truth = pd.Series(dataframe[annotated_col]).astype(int)

            # Create oracle prediction based on ground truth and model performance
            oracle_predictions = pd.Series(0, index=ground_truth.index)
            oracle_predictions[(ground_truth == 1) & ((gpt4o_predictions == 1) | (pipeline_predictions == 1))] = 1
            oracle_predictions[(ground_truth == 0) & ((gpt4o_predictions == 0) | (pipeline_predictions == 0))] = 0

            # Calculate metrics for the oracle model
            oracle_metrics = calculate_metrics(ground_truth, oracle_predictions)
            store_metrics(results_dict, oracle_metrics, metric, 'oracle - gpt-4o & pipeline')

        # Compare each model's predictions against the ground truth
        for pred_col in relevant_prediction_cols:
            model_name = pred_col.replace(f"_{metric}_prediction", "").replace(f"{metric}_", "")
            predictions = pd.Series(smart_convert_yes_no_list(dataframe[pred_col])).astype(int)
            ground_truth = pd.Series(dataframe[annotated_col]).astype(int)

            # Calculate metrics for the current model
            model_metrics = calculate_metrics(ground_truth, predictions)
            store_metrics(results_dict, model_metrics, metric, model_name)

    # Convert results dictionaries to DataFrames
    precision_df = pd.DataFrame.from_dict(results_dict['precision_results'], orient='index')
    recall_df = pd.DataFrame.from_dict(results_dict['recall_results'], orient='index')
    f1_df = pd.DataFrame.from_dict(results_dict['f1_results'], orient='index')
    weighted_f1_df = pd.DataFrame.from_dict(results_dict['weighted_f1_results'], orient='index')
    balanced_accuracy_df = pd.DataFrame.from_dict(results_dict['balanced_accuracy_results'], orient='index')
    specificity_df = pd.DataFrame.from_dict(results_dict['specificity_results'], orient='index')

    return [
        ['Weighted F1', weighted_f1_df],
        ['Balanced Accuracy', balanced_accuracy_df],
        ['F1', f1_df],
        ['Precision', precision_df],
        ['Recall', recall_df],
        ['Specificity (True Negative Rate)', specificity_df]
    ]


#### Run Boolean Test

In [None]:
def display_evaluation_results(results_df_details, concat=False):
    if concat:
        combined_df = pd.DataFrame()
        for metric_name, evaluation_df in results_df_details:
            evaluation_df = evaluation_df.copy()  # Ensure we don't modify the original DataFrame
            evaluation_df.index = [f"{metric_name}_{idx.replace('_', ' ').title()}" for idx in evaluation_df.index]
            combined_df = pd.concat([combined_df, evaluation_df])
        display(combined_df)
    else:
        for metric_name, evaluation_df in results_df_details:
            display_title(metric_name)
            evaluation_df.index = evaluation_df.index.str.replace('_', ' ').str.title()
            display(evaluation_df)

In [None]:
def merge_new_columns_by_index(df_a, df_b):
    new_columns_in_b = df_b.columns.difference(df_a.columns)
    merged_df = pd.concat([df_a, df_b[new_columns_in_b]], axis=1)
    return merged_df

In [None]:
LOCAL_RUN=False

MODELS = ['gpt-4o', 'gpt-4o', 'gemini-1.5', 'gpt-4','gpt-4-turbo']#, 'qwen', 'intern-vl3'] # add gpt-4o-mini

tasks_configuration = [
      'is_artifacts',
      'is_accurate',
      'is_visualy_consistence',
      'is_good_quality',
      'is_extensive_difference_caption_accurate',
      'is_difference_captipn_accurate'
]

annotated_data_majority = get_annotated_data_with_majority()
for task_name in tasks_configuration:
    display_title('--- ' + task_name.replace('_', ' ').title() + ' ---')
    tasks_to_evaluate = {task_name: task_map.get(task_name)}
    results_df, evaluation_data = evaluate_batch(annotated_data_majority, task_map=tasks_to_evaluate)
    annotated_data_majority = merge_new_columns_by_index(annotated_data_majority, evaluation_data)    
    display_evaluation_results([results_df[1]], concat=True)
    display_evaluation_results(results_df, concat=True)