### Imports

In [1]:
# Uncomment to run installation
# !pip install datasets pandas openai typing_extensions==4.9.0 google-cloud-aiplatform --upgrade transformers==4.37.2 tables pytables

from PIL import Image
import PIL
import pandas as pd
from datasets import load_dataset
import json
import traceback
import numpy as np
from tqdm import tqdm
tqdm.pandas()

PROJECT_PATH = r'PATH_TO_PROJECT_FOLDER'

# Utils

## General

In [4]:
import io
import os
import base64
import matplotlib.pyplot as plt
from IPython.display import display

def convert(o):
    if isinstance(o, np.int64) or isinstance(o, np.int32):
        return int(o)
    raise TypeError

def dump_json(file_path, data):
    with open(file_path, 'w') as f:
        string_json = json.dumps(data, cls=NumpyEncoder)
        f.write(string_json)

def create_json(path):
  if not os.path.isfile(path):
    print('Creating json ', path)
    with io.open(path, 'w+') as json_file:
      json_file.write(json.dumps({}))
  return path

def get_text_nouns(text):
    target_obj = text.replace('-', ' ').replace('_', ' ')
    target_obj = "".join(filter(lambda x: str.isalnum(x) or x == " ", target_obj))
    target_tokens = list(map(lambda token: token.lower() if not is_capital_letters(token) else token , word_tokenize(target_obj)))
    nouns = list(filter(lambda x: is_noun(x) and x != 'None', target_tokens))
    nouns = list(map(lambda x: lemmatizer.lemmatize(x, pos='n') ,nouns))
    return nouns

def convert_PIL_to_CV(image, image_format='RGB'):
    pil_image = image.convert(image_format)
    image = np.array(pil_image)
    return image[:, :, ::-1].copy()

def convert_CV_to_PIL(image, BGR_format=True, image_format='PNG'):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA if BGR_format else cv2.COLOR_RGB2RGBA)
    image = Image.fromarray(image)
    image.format = image_format
    return image

def get_json(file_path, DataFrame=False):
    with open(os.path.join(file_path),  encoding="utf8") as f:
        response = json.load(f)
        if DataFrame:
            return pd.DataFrame(response)
        return response

class NumpyEncoder(json.JSONEncoder):
    """ Custom encoder for numpy data types and specific handling for 'pred_masks'."""
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return {'type': 'ndarray', 'data': obj.tolist(), 'shape': obj.shape}  # Include shape information
        elif isinstance(obj, (np.float32, np.float64, np.int32, np.int64)):
            return float(obj)  # Convert NumPy floats to Python float
        return json.JSONEncoder.default(self, obj)

def use_cuda_and_freeze(model):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Model Device set to {device}.\nFreezing model\'s weights.' )
    model.to(device)
    for param in model.parameters():
        param.requires_grad = False
    return device

def get_highly_compressed_image_string(image, output_size=(20, 20), quality=5):
    original_img_size = image.size
    img_resized = image.resize(output_size).convert('RGB')
    # Save the resized image to a bytes buffer with high compression
    img_bytes = io.BytesIO()
    img_resized.save(img_bytes, format='JPEG', quality=quality)
    img_bytes.seek(0)
    # Encode the compressed image to a base64 string
    img_base64 = base64.b64encode(img_bytes.read()).decode('utf-8')
    return img_base64+str(original_img_size)

def get_instance_id(example, original=False):
    return f"{example['img_id']}_{example['original_instruction'] if original else example['instruction']}_{example['turn_index']}"

GEMINI_CACHE_PATH = create_json(PROJECT_PATH + '/gemini_prediction_cache.json')
GPT_CACHE_PATH = create_json(PROJECT_PATH + '/gpt_response_cache.json')
GEMINI_CACHE = get_json(GEMINI_CACHE_PATH)
GPT_CACHE = get_json(GPT_CACHE_PATH)
LOCAL_RUN = False # use cache only no apis

## GPT

In [None]:
from openai import OpenAI
from io import BytesIO

API_KEY = "OPEN_AI_API_KEY"
client = OpenAI(api_key=API_KEY)

# Params
MAX_TOKENS = 500
TEMPERATURE = 0.000001  # Default 1

def get_gpt_cache_id(prompt, model, temperature = TEMPERATURE, max_tokens=MAX_TOKENS):
    return f"{prompt}___{model}___{temperature}_____{max_tokens}".replace('\r', '')

def get_gpt_response_from_cache(prompt, model, temperature = TEMPERATURE, max_tokens=MAX_TOKENS):
    id = get_gpt_cache_id(prompt, model, temperature, max_tokens)
    return GPT_CACHE.get(id)

def save_gpt_response_in_cache(prompt, model, response, temprature=TEMPERATURE, max_tokens=MAX_TOKENS):
    GPT_CACHE.update({get_gpt_cache_id(prompt, model, temprature, max_tokens): response})
    dump_json(GPT_CACHE_PATH, GPT_CACHE)

def get_chatgpt_4_prediction(prompt, overide_cache=False, model="gpt-4", temprature=TEMPERATURE, max_tokens=MAX_TOKENS):
    try:
        # response = get_chat_gpt_prediction(prompt, "gpt-4-0125-preview", overide_cache) 
        response = get_chat_gpt_prediction(prompt, model, overide_cache, temprature=temprature, max_tokens=max_tokens)
    except:
        try:
          # response = get_chat_gpt_prediction(prompt, "gpt-4-0125-preview", overide_cache) 
          response = get_chat_gpt_prediction(prompt, model, overide_cache, temprature=temprature, max_tokens=max_tokens)
        except Exception as e:
            print('GPT request failed twice' , e)
            raise Exception('GPT request failed twice')
    return response

def get_chatgpt_3_prediction(prompt, overide_cache=False, model="gpt-3.5-turbo-0125", temprature=TEMPERATURE, max_tokens=MAX_TOKENS):
    try:
        # response = get_chat_gpt_prediction(prompt, "gpt-4-0125-preview", overide_cache) 
        response = get_chat_gpt_prediction(prompt, model, overide_cache, temprature=temprature, max_tokens=max_tokens)
    except:
        try:
          # response = get_chat_gpt_prediction(prompt, "gpt-4-0125-preview", overide_cache) 
          response = get_chat_gpt_prediction(prompt, model, overide_cache, temprature=temprature, max_tokens=max_tokens)
        except Exception as e:
            print('GPT request failed twice' , e)
            raise Exception('GPT request failed twice')
    return response

def get_chat_gpt_prediction(prompt, model, overide_cache=False, temprature=TEMPERATURE, max_tokens=MAX_TOKENS):
    cached_response = get_gpt_response_from_cache(prompt, model, temprature, max_tokens) if not overide_cache else None
    if cached_response is not None:
        return cached_response
    if LOCAL_RUN:
        raise Exception()
    response = client.chat.completions.create(
      model=model,
      max_tokens=1000,
      temperature=temprature,
      messages=[
        {"role": "user", "content": prompt},
      ]
    )
    save_gpt_response_in_cache(prompt, model, response.choices[0].message.content, temprature, max_tokens)
    return response.choices[0].message.content

## OCR Utils

In [None]:
def get_segmentation_as_bbox(segmentation):
    contours = pred_mask_to_contours(segmentation)
    boxes = []
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        boxes.append([x, y, x+w, y+h])
    return boxes

### Owlv2

In [None]:
import torch
from transformers import Owlv2Processor, Owlv2ForObjectDetection
import cv2

processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
# Check if CUDA is available, and set the device accordingly

device = use_cuda_and_freeze(model)

def display_boxes(image, boxes, bbox_color=(0, 0,255), text_color=(0, 0, 255), bbox_thickness=3, font_scale=0.5, font_thickness=1, resize=(400, 400)):
    display_image = convert_PIL_to_CV(image)
    # Draw each bounding box and text
    for obj_box in boxes:
        print(obj_box)
        (x, y, w, h) = obj_box['bbox']
        id = obj_box['id'] if 'id' in obj_box else ''
        cv2.rectangle(display_image, (x, y), (w, h), bbox_color, bbox_thickness)
        # Calculate text size to position it inside the bounding box
        text_size = cv2.getTextSize(id, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness)[0]
        text_x = x + 5  # 5 pixels from the left edge of the bbox
        text_y = y + text_size[1] + 5  # 5 pixels above the bottom edge of the bbox
        # Ensure the text stays within the bounding box
        if text_y > y + h:
            text_y = y + h - 5  # Adjust to be within the box if it goes beyond
        cv2.putText(display_image, id, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, font_thickness)
    
    display_image = convert_CV_to_PIL(display_image)
    # Assuming you have a function to display PIL images
    display(display_image.resize(resize))
    return boxes, display_image

def any_bboxes_intersect(bboxes1, bboxes2):
    def do_bboxes_intersect(bbox1, bbox2):
        return not (bbox1[2] < bbox2[0] or  # bbox1 is left of bbox2
                    bbox1[0] > bbox2[2] or  # bbox1 is right of bbox2
                    bbox1[3] < bbox2[1] or  # bbox1 is above bbox2
                    bbox1[1] > bbox2[3])    # bbox1 is below bbox2
    for bbox1 in bboxes1:
        for bbox2 in bboxes2:
            if do_bboxes_intersect(bbox1, bbox2):
                return True
    return False

# Except to recive PIL
def get_top_boxes(image, texts, bbox_score_threshold=0.3, show=False, print_details=False, bbox_color=(255, 0, 0), bbox_thickness=2, masks=[]):
    assert isinstance(texts, list) or isinstance(texts, set)
    image = image.convert('RGB')
    texts = list(map(lambda text: text.replace('\"','')[:60], texts))
    inputs = processor(text=texts, images=[image], return_tensors="pt")
    inputs.to(device)
    outputs = model(**inputs)

    # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
    image_size = image.size[::-1]
    target_sizes = torch.Tensor([(max(image_size), max(image_size))])
    # Convert outputs (bounding boxes and class logits) to Pascal VOC Format (xmin, ymin, xmax, ymax)
    results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.1)
    i = 0  # Retrieve predictions for the first image for the corresponding text queries
    text = texts[i]
    boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
    boxes_to_draw = []
    all_boxes = []
    for index, (box, score, label) in enumerate(zip(boxes, scores, labels)):
        class_label = texts[int(label)]
        box = box.to(int).tolist()
        
        if len(masks) > 0 and not any_bboxes_intersect([box], masks):
            continue
        
        detailed_box = {
            "bbox": box,
            "class": class_label,
            'id': f'Owlv2_{class_label}_{index}',
            "score": round(score.item(), 3),
        }
        all_boxes.append(detailed_box)
        if score > bbox_score_threshold:
            boxes_to_draw.append(detailed_box)
        if print_details:
            print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {[round(i, 2) for i in box.tolist()]}")

    # Show the image if requested
    if show:
        boxes_to_draw, image = display_boxes(image, boxes_to_draw, bbox_color, bbox_thickness)
        return all_boxes, boxes_to_draw, image, scores
        
    return all_boxes, boxes_to_draw, image, scores

### Detic - Segmentation

In [None]:
# clone and install Detic
# !git clone https://github.com/facebookresearch/Detic.git --recurse-submodules
# %cd Detic
# !pip install fasttext-wheel
# !pip install pickleshare
# remove fasttext from requirements.txt before pip install
# !pip install -r requirements.txt 
# %cd ../

In [None]:
def resize_boolean_array(array, new_shape):
    # Making sure the array is of numpy type
    array = np.array(array)
    
    # Convert boolean array to integer (0 for False, 1 for True)
    array = array.astype(int)
    
    # Calculate scale factors for each dimension
    scale_factor_row = array.shape[0] / new_shape[0]
    scale_factor_col = array.shape[1] / new_shape[1]
    
    # Calculate new indices for resizing
    new_indices_row = np.floor(np.arange(new_shape[0]) * scale_factor_row).astype(int)
    new_indices_col = np.floor(np.arange(new_shape[1]) * scale_factor_col).astype(int)
    
    # Use advanced indexing to resize the array
    resized_array = array[new_indices_row][:, new_indices_col]
    
    return resized_array.astype(bool)

In [None]:
# Install detectron2
import torch
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
# Install detectron2 that matches the above pytorch version
# See https://detectron2.readthedocs.io/tutorials/install.html for instructions
# !pip install detectron2 -f "https://dl.fbaipublicfiles.com/detectron2/wheels/118/torch2.1/index.html"
# Use the below line to install detectron2 if the above one has an error
# !python -m pip install 'git+https://github.com/facebookresearch/detectron2.git@v0.6'

In [None]:
import sys
Detic_repo_path = rf'{PROJECT_PATH}\Detic'
module_paths = [rf'{PROJECT_PATH}\Detic\third_party\CenterNet2', Detic_repo_path] # put here the CenterNet and Detic path 
for module_path in module_paths:
    if module_path not in sys.path:
        print('Added module: ' + module_path)
        sys.path.append(module_path)

In [None]:
import random
all_classes = ['aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock', 'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet', 'antenna', 'apple', 'applesauce', 'apricot', 'apron', 'aquarium', 'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor', 'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer', 'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy', 'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel', 'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon', 'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo', 'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow', 'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap', 'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)', 'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)', 'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie', 'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper', 'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt', 'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor', 'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath', 'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card', 'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket', 'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry', 'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg', 'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase', 'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle', 'bottle_opener', 'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)', 'bow-tie', 'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'box', 'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere', 'bread-bin', 'bread', 'breechcloth', 'bridal_gown', 'briefcase', 'broccoli', 'broach', 'broom', 'brownie', 'brussels_sprouts', 'bubble_gum', 'bucket', 'horse_buggy', 'horned_cow', 'bulldog', 'bulldozer', 'bullet_train', 'bulletin_board', 'bulletproof_vest', 'bullhorn', 'bun', 'bunk_bed', 'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car', 'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf', 'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)', 'can', 'can_opener', 'candle', 'candle_holder', 'candy_bar', 'candy_cane', 'walking_cane', 'canister', 'canoe', 'cantaloup', 'canteen', 'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino', 'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car', 'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship', 'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton', 'cash_register', 'casserole', 'cassette', 'cast', 'cat', 'cauliflower', 'cayenne_(spice)', 'CD_player', 'celery', 'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue', 'chalice', 'chandelier', 'chap', 'checkbook', 'checkerboard', 'cherry', 'chessboard', 'chicken_(animal)', 'chickpea', 'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker', 'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent', 'cleat_(for_securing_rope)', 'clementine', 'clip', 'clipboard', 'clippers_(for_plants)', 'cloak', 'clock', 'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat', 'coat_hanger', 'coatrack', 'cock', 'cockroach', 'cocoa_(beverage)', 'coconut', 'coffee_maker', 'coffee_table', 'coffeepot', 'coil', 'coin', 'colander', 'coleslaw', 'coloring_material', 'combination_lock', 'pacifier', 'comic_book', 'compass', 'computer_keyboard', 'condiment', 'cone', 'control', 'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie', 'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)', 'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet', 'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall', 'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker', 'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib', 'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown', 'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch', 'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup', 'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain', 'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard', 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk', 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', 'tux', 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher', 'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup', 'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin', 'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly', 'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit', 'dresser', 'drill', 'drone', 'dropper', 'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling', 'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan', 'eagle', 'earphone', 'earplug', 'earring', 'easel', 'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater', 'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk', 'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan', 'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)', 'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', 'fire_alarm', 'fire_engine', 'fire_extinguisher', 'fire_hose', 'fireplace', 'fireplug', 'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl', 'fishing_rod', 'flag', 'flagpole', 'flamingo', 'flannel', 'flap', 'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal', 'folding_chair', 'food_processor', 'football_(American)', 'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car', 'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice', 'frying_pan', 'fudge', 'funnel', 'futon', 'gag', 'garbage', 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', 'garlic', 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'generator', 'giant_panda', 'gift_wrap', 'ginger', 'giraffe', 'cincture', 'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles', 'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose', 'gorilla', 'gourd', 'grape', 'grater', 'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle', 'grill', 'grits', 'grizzly', 'grocery_bag', 'guitar', 'gull', 'gun', 'hairbrush', 'hairnet', 'hairpin', 'halter_top', 'ham', 'hamburger', 'hammer', 'hammock', 'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel', 'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw', 'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil', 'headband', 'headboard', 'headlight', 'headscarf', 'headset', 'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet', 'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog', 'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah', 'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce', 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear', 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate', 'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board', 'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey', 'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak', 'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono', 'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit', 'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)', 'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)', 'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard', 'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather', 'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade', 'lettuce', 'license_plate', 'life_buoy', 'life_jacket', 'lightbulb', 'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor', 'lizard', 'log', 'lollipop', 'speaker_(stero_equipment)', 'loveseat', 'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)', 'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange', 'manger', 'manhole', 'map', 'marker', 'martini', 'mascot', 'mashed_potato', 'masher', 'mask', 'mast', 'mat_(gym_equipment)', 'matchbox', 'mattress', 'measuring_cup', 'measuring_stick', 'meatball', 'medicine', 'melon', 'microphone', 'microscope', 'microwave_oven', 'milestone', 'milk', 'milk_can', 'milkshake', 'minivan', 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', 'money', 'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor', 'motor_scooter', 'motor_vehicle', 'motorcycle', 'mound_(baseball)', 'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom', 'music_stool', 'musical_instrument', 'nailfile', 'napkin', 'neckerchief', 'necklace', 'necktie', 'needle', 'nest', 'newspaper', 'newsstand', 'nightshirt', 'nosebag_(for_animals)', 'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'ostrich', 'ottoman', 'oven', 'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle', 'padlock', 'paintbrush', 'painting', 'pajamas', 'palette', 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose', 'papaya', 'paper_plate', 'paper_towel', 'paperback_book', 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)', 'parasol', 'parchment', 'parka', 'parking_meter', 'parrot', 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport', 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter', 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg', 'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box', 'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)', 'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet', 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano', 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow', 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball', 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)', 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat', 'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)', 'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)', 'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)', 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', 'potato', 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', 'pretzel', 'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune', 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher', 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit', 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish', 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat', 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt', 'recliner', 'record_player', 'reflector', 'remote_control', 'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map', 'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade', 'rolling_pin', 'root_beer', 'router_(computer_equipment)', 'rubber_band', 'runner_(carpet)', 'plastic_bag', 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin', 'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)', 'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)', 'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse', 'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf', 'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver', 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane', 'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark', 'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl', 'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt', 'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass', 'shoulder_bag', 'shovel', 'shower_head', 'shower_cap', 'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink', 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole', 'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)', 'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman', 'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball', 'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon', 'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)', 'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish', 'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)', 'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish', 'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel', 'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', 'stirrer', 'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove', 'strainer', 'strap', 'straw_(for_drinking)', 'strawberry', 'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer', 'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower', 'sunglasses', 'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants', 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit', 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table', 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight', 'tambourine', 'army_tank', 'tank_(storage_vessel)', 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure', 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup', 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth', 'telephone_pole', 'telephoto_lens', 'television_camera', 'television_set', 'tennis_ball', 'tennis_racket', 'tequila', 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread', 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', 'tinfoil', 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', 'toaster_oven', 'toilet', 'toilet_tissue', 'tomato', 'tongs', 'toolbox', 'toothbrush', 'toothpaste', 'toothpick', 'cover', 'tortilla', 'tow_truck', 'towel', 'towel_rack', 'toy', 'tractor_(farm_equipment)', 'traffic_light', 'dirt_bike', 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline', 'tray', 'trench_coat', 'triangle_(musical_instrument)', 'tricycle', 'tripod', 'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat', 'turban', 'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)', 'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn', 'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest', 'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture', 'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick', 'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe', 'washbasin', 'automatic_washer', 'watch', 'water_bottle', 'water_cooler', 'water_faucet', 'water_heater', 'water_jug', 'water_gun', 'water_scooter', 'water_ski', 'water_tower', 'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake', 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream', 'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)', 'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket', 'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon', 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt', 'yoke_(animal_equipment)', 'zebra', 'zucchini']

def generate_rgb_colors(classes):
    color_map = {}
    
    # Define specific colors for "truck" and "car"
    color_map['truck'] = (211, 211, 211)  # Grey
    color_map['car_(automobile)'] = (144, 238, 144)  # Green
    
    # Assign colors to each class
    colors = []
    for cls in classes:
        if cls in color_map:
            colors.append(color_map[cls])
        else:
            # Assign random color for other classes
            colors.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))
    
    return colors

# Generate colors for all classes
all_classes_colors = generate_rgb_colors(classes=all_classes)

In [None]:
# Some basic setup:
# Setup detectron2 logger
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import sys
import numpy as np
import os, json, cv2, random

# import some common detectron2 utilities
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer, ColorMode, _create_text_labels, GenericMask
from detectron2.data import MetadataCatalog

# Detic libraries
from centernet.config import add_centernet_config
from detic.config import add_detic_config
from detic.modeling.utils import reset_cls_test

class DETIC_Utils:
    def __init__(self, ONE_CLASS_PER_PROPOSAL=True, SCORE_THRESH_TEST=0.5, vocabulary='lvis'):
        # Build the detector and download our pretrained weights
        self.repo_path = Detic_repo_path + '//'
        self.ONE_CLASS_PER_PROPOSAL = ONE_CLASS_PER_PROPOSAL
        self.SCORE_THRESH_TEST = SCORE_THRESH_TEST
        self.init_configuration()
        # Setup the model's vocabulary using build-in datasets
        self.BUILDIN_CLASSIFIER = {
            'lvis': self.repo_path + 'datasets/metadata/lvis_v1_clip_a+cname.npy',
            'objects365': self.repo_path + 'datasets/metadata/o365_clip_a+cnamefix.npy',
            'openimages': self.repo_path + 'datasets/metadata/oid_clip_a+cname.npy',
            'coco': self.repo_path + 'datasets/metadata/coco_clip_a+cname.npy',
        }
        
        self.BUILDIN_METADATA_PATH = {
            'lvis': 'lvis_v1_val',
            'objects365': 'objects365_v2_val',
            'openimages': 'oid_val_expanded',
            'coco': 'coco_2017_val',
        }

        self.predictor = DefaultPredictor(self.cfg)
        self.vocabulary = vocabulary # change to 'lvis', 'objects365', 'openimages', or 'coco'
        self.metadata = MetadataCatalog.get(self.BUILDIN_METADATA_PATH[self.vocabulary])
        self.metadata.set(thing_colors=all_classes_colors)
        self.classifier = self.BUILDIN_CLASSIFIER[self.vocabulary]
        self.num_classes = len(self.metadata.thing_classes)
        reset_cls_test(self.predictor.model, self.classifier, self.num_classes)
        

    def init_configuration(self):
        self.cfg = get_cfg()
        add_centernet_config(self.cfg)
        add_detic_config(self.cfg)
        self.cfg.DATALOADER.TARFILE_PATH = self.repo_path + self.cfg.DATALOADER.TARFILE_PATH
        self.cfg.DATALOADER.TAR_INDEX_DIR = self.repo_path + self.cfg.DATALOADER.TAR_INDEX_DIR
        self.cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH = self.repo_path + self.cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH
        self.cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = self.repo_path + self.cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH 
        self.cfg.merge_from_file(self.repo_path + "configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml")
        self.cfg.MODEL.WEIGHTS = 'https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth'
        self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = self.SCORE_THRESH_TEST  # set threshold for this model
        self.cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = 'rand'
        self.cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = self.ONE_CLASS_PER_PROPOSAL # For better visualization purpose. Set to False for all classes.
        # cfg.MODEL.DEVICE='cpu' # uncomment this to use cpu-only mode.

    def draw_instance_predictions(self, v, predictions, alpha=0.8):
        """
        Draw instance-level prediction results on an image.

        Args:
            predictions (Instances): the output of an instance detection/segmentation
                model. Following fields will be used to draw:
                "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle").

        Returns:
            output (VisImage): image object with visualizations.
        """
        boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None
        scores = predictions.scores if predictions.has("scores") else None
        classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None
        labels = _create_text_labels(classes, scores, v.metadata.get("thing_classes", None))
        keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None

        if predictions.has("pred_masks"):
            masks = np.asarray(predictions.pred_masks)
            masks = [GenericMask(x, v.output.height, v.output.width) for x in masks]
        else:
            masks = None
        if v._instance_mode == ColorMode.SEGMENTATION and v.metadata.get("thing_colors"):
            colors = [
                [x / 255 for x in v.metadata.thing_colors[c]] for c in classes
            ]
        else:
            colors = None
            alpha = 0.5

        if v._instance_mode == ColorMode.IMAGE_BW:
            v.output.reset_image(
                v._create_grayscale_image(
                    (predictions.pred_masks.any(dim=0) > 0).numpy()
                    if predictions.has("pred_masks")
                    else None
                )
            )
            alpha = 0.3
        v.overlay_instances(
            masks=masks,
            boxes=boxes,
            labels=labels,
            keypoints=keypoints,
            assigned_colors=colors,
            alpha=alpha,
        )
        return v.output
    
    def draw_bounding_boxes(self, image, instances, should_display, class_whitelist=None, alpha=0.8):
        # Convert to CPU for processing
        pred_classes = instances.pred_classes.cpu().numpy()
        filtered_indices = []
        # Filter instances by class whitelist
        for idx, pred_class in enumerate(pred_classes):
            class_label = self.metadata.get("thing_classes", [])[pred_class]
            if class_whitelist is not None:
                if 'car' in class_label and random.uniform(0, 1) > 0.5:
                    continue
                elif class_label in class_whitelist:
                    filtered_indices.append(idx)
            elif class_whitelist is None:
                filtered_indices.append(idx)
        # Only keep the instances that are in the whitelist
        if filtered_indices:
            instances = instances[filtered_indices]  # Keep only filtered instances
        # Draw bounding boxes
        v = Visualizer(image[:, :, ::-1], self.metadata, instance_mode=ColorMode.SEGMENTATION)
        out = self.draw_instance_predictions(v, instances.to("cpu"), alpha=alpha)
        image = convert_CV_to_PIL(out.get_image(), BGR_format=False)
        # Display the image if required
        if should_display:
            display(image)
        return image
        
    def get_bounding_boxes(self, image, target_size=None, should_display=False, add_bboxs=False, bbox_size_threshold=0, class_whitelist=None, alpha=0.8):
        target_size = target_size if target_size is not None else image.size
        objects, model = ['segmentation_details'], f'detic_segmentation_{self.vocabulary}' # for cache

         # when we cached the results and do not need to display the boxes (becuse then we need the actual output)
        cv_image = convert_PIL_to_CV(image)
        global output
        output = self.predictor(cv_image)

        # Image scaling factors
        current_size = cv_image.shape[:2][::-1]  # OpenCV image size is in height, width
        scale_x = target_size[0] / current_size[0]
        scale_y = target_size[1] / current_size[1]
        
        # collect bboxes
        pred_boxes = output['instances'].pred_boxes.tensor.cpu().numpy()  # Convert tensor to numpy array and move to CPU
        pred_classes = output['instances'].pred_classes.cpu().numpy()  # Convert tensor to numpy array and move to CPU
        pred_scores = output['instances'].scores.cpu().numpy()
        pred_masks =  output['instances'].pred_masks.cpu().numpy()
        results = []

        for index, (bbox, pred_class, scores, pred_mask) in enumerate(zip(pred_boxes, pred_classes, pred_scores, pred_masks)):
            # Format the bbox coordinates and get the class label
            if is_bbox_too_small([int(coord) for coord in bbox], current_size, bbox_size_threshold):
                continue
            if class_whitelist is not None and pred_class not in class_whitelist:
                continue
            
            if current_size != target_size:
                # Scale the original coordinates (x, y) and dimensions (w, h) accordingly
                x_scaled = bbox[0] * scale_x
                y_scaled = bbox[1] * scale_y
                w_scaled = (bbox[2] - bbox[0]) * scale_x
                h_scaled = (bbox[3] - bbox[1]) * scale_y
                
                # Adjust the bbox to the format (x, y, x+w, y+h) after scaling
                bbox = [x_scaled, y_scaled, x_scaled + w_scaled, y_scaled + h_scaled]

                       # Scale the pred_mask
                scaled_mask_width = int(current_size[0] * scale_x)
                scaled_mask_height = int(current_size[1] * scale_y)
                pred_mask = resize_boolean_array(pred_mask, (scaled_mask_width, scaled_mask_height))
            
            bbox_formatted = [int(coord) for coord in bbox] 
            class_label = self.metadata.get("thing_classes", [])[pred_class]  # Get class label from metadata
            # Append the result
            results.append({
                "bbox": bbox_formatted,
                "class": class_label,
                'id': f'DETIC_{class_label}_{index}',
                "score": scores,
                "pred_masks": pred_mask
            })

        # display
        if add_bboxs or should_display:
            image = self.draw_bounding_boxes(cv_image, output["instances"], should_display, class_whitelist=class_whitelist, alpha=alpha)
        results = filter_dall_e_watermark(results, target_size)
        return image, results

detic_utils = DETIC_Utils()
# detic_utils_openimages = DETIC_Utils(vocabulary='openimages')
# detic_utils_sensitive_openimages = DETIC_Utils(ONE_CLASS_PER_PROPOSAL=False, SCORE_THRESH_TEST=0.4, vocabulary='openimages')
detic_utils_sensitive_lvis = DETIC_Utils(ONE_CLASS_PER_PROPOSAL=True, SCORE_THRESH_TEST=0.4, vocabulary='lvis')

def filter_dall_e_watermark(segmentation_results, image_size=(500,500)):
    def is_not_dall_e_watermark(semgentation_instance):
        width, height = image_size[0], image_size[1]
        w_scaled = 0.98 * width
        h_scaled = 0.90 * height
        if semgentation_instance['class'] == 'flag':
            x, y, h, w = semgentation_instance['bbox']
            assert w <= (width + 5) and y <= height
            return not (w_scaled <= w <= (width + 5) and x >= h_scaled and y >= h_scaled) 
        return True
    return list(filter(is_not_dall_e_watermark, segmentation_results))

def is_bbox_too_small(bbox, image_size, threshold=0.001):
    image_size = image_size[0] * image_size[1] # w*h
    bbox_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
    return (bbox_size / image_size) <= threshold

### Mask Utils

In [None]:
import cv2
import numpy as np

def calculate_shared_area_ratio(image, input_contours, input_bboxes=None, additional_contours=None, percentage_threshold=70, debug=False):
    """
    Calculate the shared area ratio between specified contours and either bounding boxes or additional contours,
    including the reverse coverage percentage.

    :param image: Input image
    :param input_contours: List of primary contours to evaluate
    :param input_bboxes: Optional; List of bounding boxes, each defined as (x, y, w, h)
    :param additional_contours: Optional; List of additional contours to evaluate against the primary contours
    :return: None
    """
    # Convert the image to grayscale for processing
    image = convert_PIL_to_CV(image)
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    artifcats_contours = []
    
    # Process each primary contour
    for contour in input_contours:
        # Create an empty mask for the current contour
        mask_contour = np.zeros_like(gray_image)
        cv2.drawContours(mask_contour, [contour], -1, (255,), thickness=cv2.FILLED)

        # Calculate the area of the current contour
        contour_area = cv2.contourArea(contour)

        # Initialize a list to store all comparison contours (from bounding boxes or additional contours)
        comparison_contours = []

        # If bounding boxes are provided, convert them to contours and add to the comparison list
        if input_bboxes:
            for bbox in input_bboxes:
                x, y, w, h = bbox
                bbox_contour = np.array([[x, y], [x+w, y], [x+w, y+h], [x, y+h]])
                comparison_contours.append(bbox_contour)

        # If additional contours are provided, add them directly to the comparison list
        if additional_contours:
            comparison_contours.extend(additional_contours)

        # Process each comparison contour (either from a bbox or an additional contour)
        for comp_contour in comparison_contours:
            # Create a mask for the comparison contour
            mask_comparison = np.zeros_like(gray_image)
            cv2.drawContours(mask_comparison, [comp_contour], -1, (255,), thickness=cv2.FILLED)

            # Calculate intersection by bitwise AND operation between contour mask and comparison mask
            intersection = cv2.bitwise_and(mask_contour, mask_comparison)
            intersection_area = cv2.countNonZero(intersection)

            # Calculate the area of the comparison contour
            comp_contour_area = cv2.contourArea(comp_contour)
            
            #Init values
            percentage_overlap = 0
            reverse_percentage_overlap = 0
            
            # Calculate the percentage of the contour area that intersects with the comparison contour
            if contour_area > 0:  # Prevent division by zero
                percentage_overlap = (intersection_area / contour_area) * 100
                debug and print(f"Percentage of primary contour covered by comparison: {percentage_overlap:.2f}%")
            
            # Calculate the reverse coverage percentage
            if comp_contour_area > 0:  # Prevent division by zero
                reverse_percentage_overlap = (intersection_area / comp_contour_area) * 100
                debug and print(f"Percentage of comparison contour covered by primary: {reverse_percentage_overlap:.2f}%")
                if reverse_percentage_overlap > percentage_threshold:
                    artifcats_contours.append({"contour": comp_contour, "covered_percentage": reverse_percentage_overlap, "covering_percentage": percentage_overlap})
        
    return len(artifcats_contours) > 0, artifcats_contours

In [None]:
import cv2
import numpy as np
from matplotlib import pyplot as plt

# Calculates the percentage of the image area that is covered by contours (usually from a mask image)
def get_contour_image_percentage(mask_img):
    contours, image_with_polygons = get_masking_contours(mask_img)
    image = convert_PIL_to_CV(mask_img)
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    contour_mask = np.zeros_like(gray_image)
    for contour in contours:
        cv2.drawContours(contour_mask, [contour], -1, (255), -1)
    contour_size = cv2.countNonZero(contour_mask)
    return int((contour_size/(mask_img.size[0]*mask_img.size[1])) * 100)

# Draws and optionally displays contours on a PIL image
def show_contours(image, contours, show=False):
    image_rgb = image.convert('RGB')
    image_with_polygons_cv = np.array(image_rgb)
    cv2.drawContours(image_with_polygons_cv, contours, -1, (0, 255, 0), 2)

    # Convert back to PIL Image for consistency and optional display
    image_with_polygons = Image.fromarray(image_with_polygons_cv)
    if show:
        display(image_with_polygons)
    return contours, image_with_polygons

# Extracts contours from a PIL image using the alpha channel
# This method also captures bounding boxes around white space
def get_masking_contours(image, show=False, add_polygons=False, minimum_size=30):
    image = convert_PIL_to_CV(image, image_format='RGBA')

    alpha_channel = image[:, :, 0]

    # Adjusted threshold to capture white regions. This will likely need fine-tuning.
    # Threshold value is set slightly less than 255 to capture all white regions
    _, binary_mask = cv2.threshold(alpha_channel, 254, 255, cv2.THRESH_BINARY_INV)

    # Find contours on the binary mask
    contours, hierarchy = cv2.findContours(binary_mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)

    contours_original = list(contours)
    
    # Filter the contours by minimum_size
    contours = list(filter(lambda contour: len(contour) > minimum_size, contours))
    
    if len(contours) == 0:
        # If no contours pass the threshold, find the largest contour
        if len(contours_original) > 0:
            largest_contour = max(contours_original, key=len)
            contours = [largest_contour]
        else:
            contours = []

    # Draw polygons on the original image using the contours
    image_with_polygons = image

    if add_polygons or show:
        image_with_polygons = image
        show_contours(convert_CV_to_PIL(image_with_polygons), contours, show=True)
        
    return contours, image_with_polygons

# Computes bounding boxes around contours in a mask image and optionally shows them
def get_masking_bbox(masked_image, show=False):
    objects, model = ['masked_image_masking_bbox'], 'open_cv_classic' # for cache
        
    contours = get_masking_contours(masked_image)
    bboxes = []
    for index, contour in enumerate(contours[0]):
        # Compute the bounding box for each contour
        x, y, w, h = cv2.boundingRect(contour)
        bboxes.append({'bbox': [x, y, x+w, y+h], 'class': 'masking_box', 'id': 'masking_box_'+str(index), 'score': 1})
    if show:
        display_boxes(masked_image, bboxes, resize=(400, 400))
        # Display the image with bounding boxes
    return bboxes

### Bboxes intersection with masked polygon

In [None]:
import cv2
import numpy as np

def calculate_intersection_area_bbox(rect1, rect2):
    # Unpack the rectangles: rect = (x, y, x+w, y+h)
    x1, y1, x1_plus_w1, y1_plus_h1 = rect1
    x2, y2, x2_plus_w2, y2_plus_h2 = rect2

    # Calculate the width and height from the given format
    w1 = x1_plus_w1 - x1
    h1 = y1_plus_h1 - y1
    w2 = x2_plus_w2 - x2
    h2 = y2_plus_h2 - y2

    # Calculate the area of both input rectangles
    area1 = w1 * h1
    area2 = w2 * h2

    # Determine the (x, y) coordinates of the intersection rectangle's top-left and bottom-right corners
    x_left = max(x1, x2)
    y_top = max(y1, y2)
    x_right = min(x1_plus_w1, x2_plus_w2)
    y_bottom = min(y1_plus_h1, y2_plus_h2)

    # Check if there is an intersection
    if x_right < x_left or y_bottom < y_top:
        return 0, 0.0, 0.0  # No intersection

    # Calculate the area of the intersection rectangle
    intersection_area = (x_right - x_left) * (y_bottom - y_top)

    # Calculate the percentage of intersection relative to each rectangle's area
    percent_cover_rect1 = (intersection_area / area1) * 100
    percent_cover_rect2 = (intersection_area / area2) * 100

    return intersection_area, percent_cover_rect1, percent_cover_rect2

def show_image(image, title):
    plt.imshow(image)
    plt.title(title)
    plt.axis('off')
    plt.show()

def show_image(image, title='Image'):
    """Display image."""
    cv2.imshow(title, image)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

def get_intersected_bboxes(image, contours, bboxes, contour_covered_threshold=0, bbox_covered_threshold=0, show=False, debug=False, strict_intersection=False):
    image = convert_PIL_to_CV(image)
    # Convert the image to grayscale
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Create an empty mask for drawing
    mask_contour = np.zeros_like(gray_image)

    # Draw all contours on the mask
    cv2.drawContours(mask_contour, contours, -1, (255), -1)
    intersected_bboxes = []

    # Iterate through each contour
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        contour_mask = np.zeros_like(gray_image)
        cv2.drawContours(contour_mask, [contour], -1, (255), -1)
        contour_size = cv2.countNonZero(contour_mask)
        
        # Check intersection with each bbox
        for original_bbox in bboxes:
            bbox = original_bbox['bbox'] if isinstance(original_bbox, dict) else original_bbox
            top_left_x, top_left_y, bottom_right_x, bottom_right_y = bbox # this should be in format of x, y, x+w, y+h

            bbox_mask = np.zeros_like(gray_image)
            bbox_mask[top_left_y:bottom_right_y, top_left_x:bottom_right_x] = 255

            intersection_area = cv2.countNonZero(np.bitwise_and(contour_mask, bbox_mask))
            percentage_of_contour_covered = (intersection_area / contour_size) * 100 if contour_size != 0 else 0
            percentage_of_bbox_covered = (intersection_area / cv2.countNonZero(bbox_mask)) * 100 if cv2.countNonZero(bbox_mask) != 0 else 0

            # Local variables for readability
            contour_criteria_met = percentage_of_contour_covered > contour_covered_threshold
            bbox_criteria_met = percentage_of_bbox_covered > bbox_covered_threshold
            
            # Conditional validation based on strict_check
            is_criterias_met = (contour_criteria_met and bbox_criteria_met if strict_intersection else contour_criteria_met or bbox_criteria_met)
            if is_criterias_met and original_bbox not in intersected_bboxes:
                # show_image(np.bitwise_and(contour_mask, bbox_mask), 'Intersection')
                show and show_image(bbox_mask, 'mask')
                show and show_image(contour_mask, 'mask')
                intersected_bboxes.append(original_bbox)
            debug and print(f"Class: {original_bbox['class']}, Intersection: {percentage_of_contour_covered}%, and bbox: {percentage_of_bbox_covered}")

    return intersected_bboxes

### Grouding Dino

##### Installation

In [None]:
# os.chdir(PROJECT_PATH)

# import wget

# !git clone https://github.com/IDEA-Research/GroundingDINO.git
# %cd GroundingDINO/

# !mkdir weights
# %cd weights
# url = 'https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth'
# destination_path = 'groundingdino_swint_ogc.pth'

# # Download the file
# wget.download(url, destination_path)

# print("File downloaded successfully.")
# %cd ..   

# !pip install -q -e .

##### Loading The Model

In [None]:
import os
os.chdir(PROJECT_PATH + r'\GroundingDINO')
from groundingdino.util.inference import load_model, predict, annotate,box_convert
import groundingdino.datasets.transforms as T
import cv2
import numpy as np
import torch

def convert_PIL_to_dino_format(image):
    image_source = image.convert('RGB')
    transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    image = np.asarray(image_source)
    image_transformed, _ = transform(image_source, None)
    return image, image_transformed

class GroundDino:
    def __init__(self, BOX_TRESHOLD = 0.35, TEXT_TRESHOLD = 0.25):
        self.model = load_model("groundingdino/config/GroundingDINO_SwinT_OGC.py", "weights/groundingdino_swint_ogc.pth")
        self.BOX_TRESHOLD = BOX_TRESHOLD
        self.TEXT_TRESHOLD = TEXT_TRESHOLD

    def transfer_to_original_size(self, bounding_boxes, image_source):
        h, w, _ = image_source.shape
        boxess = bounding_boxes
        boxess = boxess * torch.Tensor([w, h, w, h])
        xyxy = box_convert(boxes=boxess, in_fmt="cxcywh", out_fmt="xyxy").numpy().astype(int)
        return xyxy

    def any_bboxes_intersect(self, bboxes1, bboxes2):
        def do_bboxes_intersect(bbox1, bbox2):
            return not (bbox1[2] < bbox2[0] or  # bbox1 is left of bbox2
                        bbox1[0] > bbox2[2] or  # bbox1 is right of bbox2
                        bbox1[3] < bbox2[1] or  # bbox1 is above bbox2
                        bbox1[1] > bbox2[3])    # bbox1 is below bbox2
        for bbox1 in bboxes1:
            for bbox2 in bboxes2:
                if do_bboxes_intersect(bbox1, bbox2):
                    return True
        return False

 
    def get_top_boxes(self, original_image, objects, bbox_score_threshold=0.3, show=False, masks=[]):
        image_source, image = convert_PIL_to_dino_format(original_image)
        boxes, logits, phrases = predict(model = self.model, image = image, caption = objects, box_threshold = self.BOX_TRESHOLD, text_threshold = self.TEXT_TRESHOLD)
        if show:
            annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
            display(convert_CV_to_PIL(annotated_frame))
        IDC_bboxes = []
        best_IDC_bboxes = []
        scores = []
        for index, details in enumerate(zip(boxes, logits, phrases)):
            box, score, phrase = details
            box = list(self.transfer_to_original_size(box, image_source))
            if len(masks) > 0 and not self.any_bboxes_intersect([box], masks):
                continue
            scores.append(score)
            IDC_box = {'bbox': box, 'score': round(score.item(), 2), 'id': f'GroundDino_{phrase}_{index}', 'class': phrase}
            IDC_bboxes.append(IDC_box)
            if IDC_box['score'] > 0.3:
                best_IDC_bboxes.append(IDC_box)
        return IDC_bboxes, best_IDC_bboxes, original_image, scores

dino = GroundDino()
os.chdir(PROJECT_PATH)

## Gimini

### Gemini Utils

In [None]:
from vertexai.preview.generative_models import (
    GenerationConfig,
    GenerativeModel,
    Image as PartImage
)

GEMINI_MAX_TOKENS = 1000 
GEMINI_TEMPRATURE = 0
# GEMINI_MODEL = "gemini-pro-vision"
GEMINI_MODEL = "gemini-2.0-flash"

objects_list_generation_config = GenerationConfig(
    temperature=GEMINI_TEMPRATURE,
    top_p=1.0,
    top_k=32,
    candidate_count=1,
    max_output_tokens=GEMINI_MAX_TOKENS,
)

# Define project information
PROJECT_ID = "gen-lang-client-0642013381"  # @param {type:"string"}
LOCATION = "us-central1"  # @param {type:"string"}

# Initialize Vertex AI
import vertexai
vertexai.init(project=PROJECT_ID, location=LOCATION)
multimodal_model = GenerativeModel(GEMINI_MODEL)

import http.client
import IPython.display

def get_gemini_cache_id(prompt):
    return f"{prompt}___{GEMINI_MODEL}___{GEMINI_MAX_TOKENS}___{GEMINI_TEMPRATURE}"

def get_gemini_response_from_cache(prompt):
    id = get_gemini_cache_id(prompt).replace('\r', '')
    return GEMINI_CACHE.get(id)

def save_gemini_response_in_cache(prompt, response):
    GEMINI_CACHE.update({get_gemini_cache_id(prompt): response})
    dump_json(GEMINI_CACHE_PATH, GEMINI_CACHE)

# This is the image type required by Gemini API
def get_part_image(image):
    img_byte_arr = io.BytesIO()
    image.save(img_byte_arr, format=image.format if image.format is not None else 'PNG')
    return PartImage.from_bytes(img_byte_arr.getvalue())

def get_original_unparsed_gemini_response(responses, contents, debug=False, generation_config=None):
    if generation_config is None:
        responses_generator = multimodal_model.generate_content(contents, stream=True)
    else:
        responses_generator = multimodal_model.generate_content(contents, stream=True, generation_config=generation_config)
    for response in responses_generator:
        try:
            responses.append(response.text) # there is more metadata here, for now we ingnore this
        except Exception as e:
            print('Error reading Gemini response, Error:', e)
            print('Contents', contents)
    debug and print(responses)
    return responses

def get_geimini_response(contents, debug=False, generation_config=objects_list_generation_config):
    debug and print(contents)
    
    instance_id = ''
    for index, obj in enumerate(contents):
        if isinstance(obj, PIL.Image.Image):
            contents[index] = get_part_image(obj)
            instance_id = instance_id + get_highly_compressed_image_string(obj)
        else:
            instance_id = instance_id +obj # should be string

    cached_response = get_gemini_response_from_cache(instance_id)
    if cached_response is not None:
        return ''.join(cached_response).strip()

    if LOCAL_RUN:
        raise Exception()
    
    debug and print(contents)
    try:
        responses = get_original_unparsed_gemini_response([], contents, debug, generation_config=generation_config) # first try
    except:
        try:
            responses = get_original_unparsed_gemini_response([], contents, debug, generation_config=generation_config) # second try
        except Exception as e:
            print('Gemini Error tryed two time!\n', e)
            responses = []
    
    save_gemini_response_in_cache(instance_id, responses)
    return ''.join(responses).strip()

## NLTK utils

In [None]:
import nltk
import inflect
from nltk import pos_tag
from nltk.corpus import wordnet as wn
from nltk.corpus import wordnet
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from pattern.text.en import singularize, pluralize

# !python -m spacy download en_core_web_sm
# Load the language model
import spacy
nlp = spacy.load("en_core_web_sm")

# Ensure necessary NLTK resources are downloaded
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
nltk.download('wordnet')

def shortest_path_distance(word1, word2):
    # Get synsets for both words
    synsets1 = wn.synsets(word1)
    synsets2 = wn.synsets(word2)
    
    # Initialize the shortest path distance to None
    shortest_distance = None
    
    # Compare each synset of word1 against each synset of word2
    for synset1 in synsets1:
        for synset2 in synsets2:
            # Compute the shortest path distance between synset1 and synset2
            distance = synset1.shortest_path_distance(synset2)
            if distance is not None:
                # If it's the first distance found or if it's shorter than the previous shortest, update
                if shortest_distance is None or distance < shortest_distance:
                    shortest_distance = distance
    
    return shortest_distance

def get_related_synonyms(word, level=3):
    synonyms = set()
    synset = wn.synsets(word, pos=wn.NOUN)
    if len(synset) >= 1:
        synset = synset[0]
        # Direct synonyms
        synonyms.update(lemma.name() for lemma in synset.lemmas())
        
        # Explore one level of hypernyms (more general terms)
        if level >= 2:
            for hypernym in synset.hypernyms():
                synonyms.update(lemma.name() for lemma in hypernym.lemmas())
            
        # Explore one level of hyponyms (more specific terms)
        if level >= 3:
            for hyponym in synset.hyponyms():
                synonyms.update(lemma.name() for lemma in hyponym.lemmas())
            
        return list(synonyms)
    else:
        return []

def convert_none(distance):
    return distance if distance is not None else 10

def get_word_synonyms(word):
    related_synonyms = get_related_synonyms(word, level=3)
    close_related_synonyms = set(filter(lambda synonym: convert_none(shortest_path_distance(word, synonym)) <= 3, related_synonyms))
    return set(map(lambda synonym: synonym.replace('_', ' '), close_related_synonyms))

def is_noun(word):
    # Tokenize the word
    tokens = word_tokenize(word)
    # POS tag the tokenized word
    tagged = pos_tag(tokens)
    # Check if the POS tag of the word is one of the noun tags
    return tagged[0][1] in ["NN", "NNS", "NNP", "NNPS"]

def filter_strings(arr):
    return [item for item in arr if isinstance(item, str)]
    
# Initialize the WordNet lemmatizer
lemmatizer = WordNetLemmatizer()

def get_wordnet_pos(treebank_tag):
    """Converts Penn Treebank tag to WordNet tag"""
    if treebank_tag.startswith('J'):
        return wordnet.ADJ
    elif treebank_tag.startswith('V'):
        return wordnet.VERB
    elif treebank_tag.startswith('N'):
        return wordnet.NOUN
    elif treebank_tag.startswith('R'):
        return wordnet.ADV
    else:
        return wordnet.NOUN  # default to noun if no specific mapping

def extract_main_objects(sentence):
    # Parse the sentence
    doc = nlp(sentence)
    # Extract objects by checking dependency labels
    objects = [token.text for token in doc if 'obj' in token.dep_]
    return objects

# Expands a sentence into a set of singular/plural variations using both inflect and NLP-based lemmatization techniques
def get_expanded_singular_and_plural(sentence):
    # Uses the inflect library to get singular and plural forms of a noun
    def get_singular_and_plural_inflect(noun):
        p = inflect.engine()
        singular_form = p.singular_noun(noun)
        plural_form = p.plural(noun)
        return singular_form, plural_form
    
    # Uses NLTK and pattern-based tools to lemmatize a sentence and generate singular/plural versions of each word
    def get_singular_and_plural_nltk_and_pattern(sentence):
        """Lemmatizes a sentence"""
        # Tokenize the sentence and get the POS tags
        tokens = nltk.word_tokenize(sentence)
        pos_tags = nltk.pos_tag(tokens)
        
        # Lemmatize each word with its POS tag
        lemmatized_tokens = []
        singular_tokens = []
        plural_tokens = []
        for token, pos_tag in pos_tags:
            wordnet_pos = get_wordnet_pos(pos_tag)
            lemmatized_token = lemmatizer.lemmatize(token, pos=wordnet_pos)
            lemmatized_tokens.append(lemmatized_token)
            singular_tokens.append(singularize(token))
            plural_tokens.append(pluralize(token))
        
        return ' '.join(lemmatized_tokens), ' '.join(singular_tokens), ' '.join(plural_tokens)

    singular_inflect, plural_inflect = get_singular_and_plural_inflect(sentence)
    lemmatized_sentence, singular_sentence, plural_sentence = get_singular_and_plural_nltk_and_pattern(sentence)
    return set(filter_strings([singular_inflect, plural_inflect, lemmatized_sentence, singular_sentence]))

# return array of words with new elments - concatenated i & i+1 pairs
def concatenate_with_previous(words):
    if not words:  
        return []
    # Initialize the result list with the first element of the input list
    expanded_words = [words[0]]
    # Iterate through the list starting from the second element
    for i in range(1, len(words)):
        # Concatenate the current and previous elements and add to the result list
        concatenated = words[i - 1] + words[i]
        expanded_words.append(concatenated)
    
    return expanded_words

# Checks whether each non-empty string in the list starts with a capital letter
def is_capital_letters(text):
    return False not in [(bool(s) and s[0].isupper()) for s in text]

# Extracts and lemmatizes all noun-like tokens from the input text
def get_text_nouns(text):
    target_obj = text.replace('-', ' ').replace('_', ' ')
    target_obj = "".join(filter(lambda x: str.isalnum(x) or x == " ", target_obj))
    target_tokens = list(map(lambda token: token.lower() if not is_capital_letters(token) else token , word_tokenize(target_obj)))
    nouns = list(filter(lambda x: is_noun(x) and x != 'None', target_tokens))
    nouns = list(map(lambda x: lemmatizer.lemmatize(x, pos='n') ,nouns))
    return nouns

# Returns a set of nouns extracted from the input text
def get_text_nouns_synonyms(text):
    text_nouns = get_text_nouns(text)
    return set(text_nouns)

# Example usage
word = "backpack"
related_synonyms = get_related_synonyms(word)
print(f"Related synonyms for '{word}':", related_synonyms)

# Magic Brush - Pipeline


In [None]:
dataset_name = "osunlp/MagicBrush"
dataset = load_dataset(dataset_name)
MB_train = dataset["train"]
MB_dev = dataset["dev"]

## Enrich Difference Caption

#### Pipeline - Grounding Instruction Utils

In [None]:
def find_intersected_set(set1_details, set2_details, set3_details):
    set1, set2, set3 = set1_details[1], set2_details[1], set3_details[1]
    # Check intersections between each pair of sets
    intersection12 = set1.intersection(set2)
    intersection23 = set2.intersection(set3)
    intersection31 = set3.intersection(set1)

    # Determine which sets have intersections with the other two
    has_intersection = {
        set1_details[0]: bool(intersection12) and not bool(intersection31) and not bool(intersection23),
        set2_details[0]: bool(intersection23) and not bool(intersection31) and not bool(intersection12),
        set3_details[0]: bool(intersection31) and not bool(intersection12) and not bool(intersection23),
    }

    intersecting_sets = [key for key, value in has_intersection.items() if value]
    
    if len(intersecting_sets) > 0:
        return intersecting_sets[0]
    else:
        return None

def get_number_of_intersection_by_synonyms(first_words, second_words):
    words_synonyms_map = dict()
    intersected_words = set()
    for word in first_words.union(second_words):
        words_synonyms_map.update({word: get_word_synonyms(word).union({word})})
    for first_word in first_words:
        for second_word in second_words:
            if (first_word in words_synonyms_map.get(second_word)) or (second_word in words_synonyms_map.get(first_word)):
                intersected_words.add(first_word)
    return len(intersected_words)

def is_edit_object_not_detected(example, object):
        return example[f'{object}_object'] == None or example[f'{object}_object'].lower() == 'none'

def is_instruction_intersect_with_image_description(example, object, edit_instruction_nouns):
    text_nouns = get_text_nouns_synonyms(example[f'{object}_object'])
    selected_description_nouns = get_text_nouns_synonyms(example[f'{object}_selected_description'])
    all_nouns_detected = text_nouns.union(selected_description_nouns)
    instruction_concatenated_nouns = set(concatenate_with_previous(list(edit_instruction_nouns))) # many time instructions seperate one words to two
    return len(edit_instruction_nouns.intersection(all_nouns_detected)) > 0 or len(instruction_concatenated_nouns.intersection(all_nouns_detected)) > 0

def select_grounded_captions(example):
    for object in ['source', 'target']:
        example[f'is_instruction_{object}_object_found'] = False
        example[f'is_two_captions_{object}_object_found'] = False
        example[f'is_{object}_instruction_different'] = False
        
        if is_edit_object_not_detected(example, object):
            continue

        # Validate if the defualt image description interesect with instruction 
        edit_instruction_nouns = set(get_text_nouns(example['instruction']))
        if is_instruction_intersect_with_image_description(example, object, edit_instruction_nouns):
            continue

        # Search for the image description with most intersections
        instruction_nouns_synonyms = set()
        for word in edit_instruction_nouns:
            instruction_nouns_synonyms = instruction_nouns_synonyms.union(get_word_synonyms(word).union(word))
            
        example[f'is_{object}_instruction_different'] = True
        if example[f'is_{object}_instruction_different']:
            bbox_intersected_nouns = get_number_of_intersection_by_synonyms(edit_instruction_nouns, set(get_text_nouns(example[f'{object}_mask_bbox_description'])))
            padding_bbox_intersected_nouns = get_number_of_intersection_by_synonyms(edit_instruction_nouns, set(get_text_nouns(example[f'{object}_mask_bbox_padding_description'])))
            target_description_intersected_nouns = get_number_of_intersection_by_synonyms(edit_instruction_nouns, set(get_text_nouns(example[f'{object}_description'])))
            maximum_intersctions_with_instruction = max([bbox_intersected_nouns, padding_bbox_intersected_nouns, target_description_intersected_nouns])

            if maximum_intersctions_with_instruction > 0:
                if bbox_intersected_nouns == maximum_intersctions_with_instruction:
                    example[f'{object}_selected_description'] = example[f'{object}_mask_bbox_description']
                elif padding_bbox_intersected_nouns == maximum_intersctions_with_instruction:
                    example[f'{object}_selected_description'] = example[f'{object}_mask_bbox_padding_description']
                elif target_description_intersected_nouns == maximum_intersctions_with_instruction:
                    example[f'{object}_selected_description'] = example[f'{object}_description']
                example[f'is_instruction_{object}_object_found'] = True
            else:
                try:
                    # In case no instruction nouns were found - we check if 2 of 3 captions share the same nouns and chose from them.
                    intersected_set_caption = find_intersected_set([f'{object}_mask_bbox_description', set(extract_main_objects(example[f'{object}_mask_bbox_description']))], [f'{object}_mask_bbox_padding_description', set(extract_main_objects(example[f'{object}_mask_bbox_padding_description']))], [f'{object}_description', set(extract_main_objects(example[f'{object}_description']))])
                    if intersected_set_caption is not None:
                        example[f'{object}_selected_description'] = example[intersected_set_caption]
                        example[f'is_two_captions_{object}_object_found'] = True
                        print(example['instruction'], f'is_two_captions_{object}_object_found')
                except Exception as e:
                    print(intersected_set_caption)
                    print(example['instruction'], 'failed verify')
        
    return example


#### Pipeline Methods

In [None]:
generate_difference_details_prompt = open(PROJECT_PATH + '//prompts/generate_difference_details.txt', "r").read()

def initalize_valid(example):
    example['valid'] = True
    return example['valid']

def set_valid(example, is_valid):
    example['valid'] = example['valid'] and is_valid
    return example['valid']

def get_valid(example):
    return example['valid']

def is_action_invalid(action):
    return True if action == None else (action.lower == 'none')

def is_response_action_invalid(response):
    action = response.split('\n')[0].split('Edit Action Type:')[1].strip()
    return is_action_invalid(action)
        
def get_source_and_target_descriptions(source_image, target_image, size, debug=False):
    source_image_description = get_geimini_response([source_image.resize(size), '\nPlease describe the image.\n'], debug)
    target_image_description = get_geimini_response([target_image.resize(size), '\nPlease describe the image.\n'], debug)
    return source_image_description, target_image_description

def get_source_and_target_mask_bbox_description(edit_instruction, source_mask_bbox, source_mask_bbox_padding, target_mask_bbox, target_mask_bbox_padding, source_image, targe_image, debug=False):
    source_mask_bbox_description, target_mask_bbox_description, source_image_description, target_image_description, source_mask_bbox_padding_description, target_mask_bbox_padding_description = '', '', '', '', '', ''
    
    source_mask_bbox_description, target_mask_bbox_description = get_source_and_target_descriptions(source_mask_bbox, target_mask_bbox, source_mask_bbox.size, debug)
    if '' in [source_mask_bbox_description, target_mask_bbox_description]:
        print('Faild to get bboxes descriptions, changing size. Edit instruction:', edit_instruction)
        source_mask_bbox_description, target_mask_bbox_description = get_source_and_target_descriptions(source_mask_bbox, target_mask_bbox, target_mask_bbox.size, debug)

    source_mask_bbox_padding_description, target_mask_bbox_padding_description = get_source_and_target_descriptions(source_mask_bbox_padding, target_mask_bbox_padding, source_mask_bbox_padding.size, debug)
    if '' in [source_mask_bbox_padding_description, target_mask_bbox_padding_description]:
        print('Faild to get bboxes descriptions, changing size. Edit instruction:', edit_instruction)
        source_mask_bbox_padding_description, target_mask_bbox_padding_description = get_source_and_target_descriptions(source_mask_bbox_padding, target_mask_bbox_padding, target_mask_bbox_padding.size, debug)
    
    source_image_description, target_image_description = get_source_and_target_descriptions(source_image, targe_image, source_image.size, debug)
    if '' in [source_image_description, target_image_description]:
        print('Faild to get full image descriptions, changing size. Edit instruction:', edit_instruction)
        source_image_description, target_image_description = get_source_and_target_descriptions(source_image, targe_image, targe_image.size, debug)
        
    return [source_mask_bbox_description, target_mask_bbox_description, source_image_description, target_image_description, source_mask_bbox_padding_description, target_mask_bbox_padding_description]

def parse_generate_difference_detail_prompt_response(response):
    return {
        'action': response.split('\n')[0].split('Edit Action Type:')[1].strip(), # *change also above if changing here*
        'caption': response.split('\n')[1].split('Short Difference Caption:')[1].strip(),
        'revised_instruction': response.split('\n')[2].split('Short Edit Description:')[1].strip(),
        'source_object': response.split('\n')[3].split('Source Object:')[1].strip(),
        'target_object': response.split('\n')[4].split('Target Object:')[1].strip(),
        'edit_explenation': response.split('\n')[5].split('Edit Explanation:')[1].strip(),
        'extensive_revised_instruction': response.split('\n')[6].split('Extensive Edit Description:')[1].strip(),
        'extensive_caption': response.split('\n')[7].split('Extensive Difference Caption:')[1].strip()
    }

def construct_description_dict(descriptions, selected_source_description, selected_target_description, use_bboxes_crops = False):
    """Constructs the description dictionary for the example."""
    return {
        'source_mask_bbox_description': descriptions[0],
        'target_mask_bbox_description': descriptions[1],
        'source_description': descriptions[2],
        'target_description': descriptions[3],
        'source_selected_description': selected_source_description,
        'target_selected_description': selected_target_description,
        'source_mask_bbox_padding_description': descriptions[4],
        'target_mask_bbox_padding_description': descriptions[5],
        'use_bbox_crop': use_bboxes_crops
    }

def generate_edit_metadata(edit_instruction, source_description, target_description):
    prompt = generate_difference_details_prompt.format(edit_instruction, source_description, target_description)
    return get_chatgpt_4_prediction(prompt)

def get_difference_details(example, debug=False):
    img_id, edit_instruction, source_mask_bbox, source_mask_bbox_padding, target_mask_bbox, target_mask_bbox_padding, source_image, targe_image, multiple_bboxes = example['img_id'], example['instruction'], example['source_masked_bbox'], example['source_masked_bbox_padding'], example['target_masked_bbox'], example['target_masked_bbox_padding'], example['source_img'], example['target_img'], example['multiple_bboxes']
    is_valid = True
    # extract desciptions and make prompt
    descriptions = get_source_and_target_mask_bbox_description(edit_instruction, source_mask_bbox, source_mask_bbox_padding, target_mask_bbox, target_mask_bbox_padding, source_image, targe_image, debug)
    use_bboxes_crops = '' not in [descriptions[0], descriptions[1]] and not multiple_bboxes
    if not use_bboxes_crops:
        example['multiple_bboxes'] and print('Multiple masking bboxes, will use full images. Edit instruction:', edit_instruction)
        not example['multiple_bboxes'] and print('Faild to get bboxes descriptions from masks, will use full images. Edit instruction:', edit_instruction)
    selected_source_description, selected_target_description = descriptions[0 if use_bboxes_crops else 2], descriptions[1 if use_bboxes_crops else 3]
    is_valid = use_bboxes_crops or not ('' in [descriptions[2], descriptions[3]])
    set_valid(example, is_valid) # in case '' is one of the chosen captions
    if not is_valid:
        debug and print('Gemini could not extract a source and target description. Skipping. Instance: ', img_id, edit_instruction) 
        return construct_description_dict(descriptions, selected_source_description, selected_target_description)
    response = generate_edit_metadata(edit_instruction, selected_source_description, selected_target_description) 
    action = response.split('\n')[0].split('Edit Action Type:')[1].strip()
    if is_response_action_invalid(response):
        # get details - full image 
        selected_source_description, selected_target_description = descriptions[2], descriptions[3]
        response = generate_edit_metadata(edit_instruction, selected_source_description, selected_target_description) 

        # get details - padding image if full fails
        if is_response_action_invalid(response) and use_bboxes_crops:
            selected_source_description, selected_target_description = descriptions[4], descriptions[5]
            response = generate_edit_metadata(edit_instruction, selected_source_description, selected_target_description) 
            
    debug and print(response)
    try:
        descriptions_dict = construct_description_dict(descriptions, selected_source_description, selected_target_description, use_bboxes_crops=use_bboxes_crops)
        return {**parse_generate_difference_detail_prompt_response(response), **descriptions_dict}
    except Exception as e:
        print('Parsing Exception:', e)
        return dict()
    
def set_difference_details(example, details):
    example['original_instruction'] = example['instruction']
    example['action'] = details.get('action') or "None"
    example['simple_caption'] = details.get('caption') or ""
    example['extensive_caption'] = details.get('extensive_caption') or details.get('caption') or ""
    fields = ['source_object', 'target_object', 'source_selected_description', 'target_selected_description', 'edit_explenation', 'revised_instruction', 'extensive_revised_instruction', 'source_mask_bbox_description', 'target_mask_bbox_description', 'source_mask_bbox_padding_description', 'target_mask_bbox_padding_description', 'source_description', 'target_description', 'use_bbox_crop']
    for field in fields:
        example[field] = details.get(field)
    return example
    
def add_defualt_edit_metadata(example):
    details = get_difference_details(example)
    set_difference_details(example, details)
    return example

def expand_bboxes(bboxes, image_width, image_height, apply_padding=True):
    adjusted_bboxes = []
    min_width = image_width * 0.15
    min_height = image_height * 0.15
    for bbox_obj in bboxes:
        bbox = bbox_obj['bbox']
        
        # Calculate center and size
        center_x, center_y  = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
        width, height = bbox[2] - bbox[0], bbox[3] - bbox[1]

        # Double the size or apply min size based on the flag
        new_width = max(width * 2, min_width) if apply_padding else width
        new_height = max(height * 2, min_height) if apply_padding else height

        # Calculate new bbox coordinates and clamp coordinates to ensure they are within image boundaries
        new_x1 = max(0, min(int(center_x - new_width / 2), image_width))
        new_y1 = max(0, min(int(center_y - new_height / 2), image_height))
        new_x2 = max(0, min(int(center_x + new_width / 2), image_width))
        new_y2 = max(0, min(int(center_y + new_height / 2), image_height))

        # Ensure the bbox does not exceed image dimensions and has positive area
        new_x1, new_x2 = sorted([new_x1, new_x2])[:2]
        new_y1, new_y2 = sorted([new_y1, new_y2])[:2]

        # Update the bbox in the object
        adjusted_bbox_obj = bbox_obj.copy()
        adjusted_bbox_obj['bbox'] = [new_x1, new_y1, new_x2, new_y2]
        adjusted_bboxes.append(adjusted_bbox_obj)

    return adjusted_bboxes

def extract_bbox_details(bbox):
    x, y, w, h = bbox['bbox']
    t = bbox['id']
    return [x, y, w, h, t]

def crop_images_and_add_masking_bbox(example, debug=False):
    is_valid = True
    example['multiple_bboxes'] = False
    for object in ['source', 'target']:
        try:
            size = example[f'{object}_img'].size
            original_bboxes = get_masking_bbox(example['mask_img'].resize(size))
            bboxes_original = expand_bboxes(original_bboxes, size[0], size[1], apply_padding=False)
            bboxes_with_padding = expand_bboxes(original_bboxes, size[0], size[1], apply_padding=True)
            is_valid = is_valid and len(bboxes_original) > 0
            if len(bboxes_original) > 1:
                example[f'{object}_masked_bbox'] = example[f'{object}_masked_bbox_padding'] = example[f'{object}_img']
                example[f'{object}_masked_bbox_str'] = list(map(lambda box: json.dumps(extract_bbox_details(box)), original_bboxes))
                example['multiple_bboxes'] = True
            else:
                for bboxes_details in [['', bboxes_original], ['_padding', bboxes_with_padding]]:
                    name, bbox = bboxes_details
                    x, y, w, h, t = extract_bbox_details(bbox[0])
                    cropped_image = convert_PIL_to_CV(example[f'{object}_img'].resize(size))[y:h, x:w]
                    example[f'{object}_masked_bbox{name}'] = convert_CV_to_PIL(cropped_image)
                    example[f'{object}_masked_bbox{name}_str'] = [json.dumps([x, y, w, h])]
                    debug and display(example[f'{object}_masked_bbox'])
        except Exception as e:
            print(f'Error when getting {object} bboxes, using full image, marking instance as not valid. Error:', e)
            is_valid = False
            example[f'{object}_masked_bbox'] = example[f'{object}_img']
            example[f'{object}_masked_bbox_padding'] = example[f'{object}_img']
    
    set_valid(example, is_valid)
    return example
   
def run_pipeline_with_grounded_instructions(example):
    select_grounded_captions(example)
    if example['is_source_instruction_different'] or example['is_target_instruction_different']:
        response = generate_edit_metadata(example['instruction'], example['source_selected_description'], example['target_selected_description'])
        parsed_response = parse_generate_difference_detail_prompt_response(response)
        if not is_action_invalid(parsed_response['action']):
            set_difference_details(example, {**example, **parsed_response})
        else: 
            print('None', example['instruction'])
    return example


#### Pipeline

In [None]:
def enrich_dataset(example, debug=False):
    initalize_valid(example)
    example = crop_images_and_add_masking_bbox(example)  # fill none values
    example = add_defualt_edit_metadata(example) if get_valid(example) else set_difference_details(example, {})
    example = run_pipeline_with_grounded_instructions(example)
    return example

## Enrich Is Artifcats

### Artifacts intersected with masking

In [None]:
import numpy as np
import cv2

def do_bounding_boxes_intersect(rect1s, rect2s):
    for rect1 in rect1s:
        for rect2 in rect2s:
            x1_1, y1_1, x2_1, y2_1 = rect1  # x1_1, y1_1 are the top-left and x2_1, y2_1 are the bottom-right
            x1_2, y1_2, x2_2, y2_2 = rect2  # Same here for rect2
        
            # Check if one rectangle is to the left of the other
            no_intersection = x2_1 < x1_2 or x2_2 < x1_1 or y2_1 < y1_2 or y2_2 < y1_1
            if not(no_intersection):
                return True
    return False

def pred_mask_to_contours(pred_mask):
    # Convert boolean mask to uint8 binary image
    bin_image = np.uint8(pred_mask) * 255
    contours, _ = cv2.findContours(bin_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    return contours

def get_contour_area(contour, orignal_shape, target_shape):
    mask1 = np.zeros(orignal_shape, np.uint8)
    cv2.drawContours(mask1, [contour], -1, (1), thickness=cv2.FILLED)
    # Resize mask1 to match the segmentation shape if different
    if orignal_shape != target_shape:
        mask1 = cv2.resize(mask1, (target_shape[1], target_shape[0]), interpolation=cv2.INTER_NEAREST)
    return np.sum(mask1), mask1

def calculate_intersection_area(masking_contour, masking_shape, segmentation_contours, segmentation_shape):
    # Draw the user-marked contour and gets area
    masked_area, mask1 = get_contour_area(masking_contour, masking_shape, segmentation_shape)

    # Initialize mask2 for prediction contours
    mask2 = np.zeros(segmentation_shape, np.uint8)
    intersection_area = 0
    total_pred_contours_area = 0

    for contour in segmentation_contours:
        # Reset mask2 for each contour
        mask2.fill(0)
        # Draw the prediction contour on mask2
        cv2.drawContours(mask2, [contour], -1, (1), thickness=cv2.FILLED)
        # Calculate intersection and add to the total
        intersection = np.logical_and(mask1, mask2)
        intersection_area += np.sum(intersection)
        # Also add to the total predicted contours area
        total_pred_contours_area += np.sum(mask2)

    masked_area_intersection_percentage = (intersection_area / masked_area) * 100 if masked_area else 0
    segmented_area_intersection_percentage = (intersection_area / total_pred_contours_area) * 100 if total_pred_contours_area else 0
    return masked_area_intersection_percentage, segmented_area_intersection_percentage

def get_intersecting_detections_by_class(bbox_details, masking_contour, masking_shape):
    intersecting_detections_by_class = {}

    for detail in bbox_details:
        # Extract the shape of the predicted mask to calculate intersections accurately.
        shape = detail['pred_masks'].shape
        # Convert the predicted mask to contours for intersection calculation.
        segmentation_contours = pred_mask_to_contours(detail['pred_masks'])
        
        # Calculate the intersection area between the predicted contours and the user contour.
        masked_area_intersection_percentage, segmented_area_intersection_percentage = calculate_intersection_area(masking_contour, masking_shape, segmentation_contours, shape)
        # Check if there's a significant intersection area.
        if 0 < masked_area_intersection_percentage < 40:  # Intersection exists within a specific range.
            cls = detail['class']
            score = detail['score']
            detail.update({'masked_area_intersection_percentage': masked_area_intersection_percentage, 'segmented_area_intersection_percentage': segmented_area_intersection_percentage})
            # Update the class's highest-scoring detail if this object's score is higher.
            if cls not in intersecting_detections_by_class:
                intersecting_detections_by_class[cls] = []
            intersecting_detections_by_class[cls].append()
    return intersecting_detections_by_class

def get_pred_mask_size(pred_masks):
    sizes = 0
    for pred_mask in pred_masks:
        sizes += np.sum(pred_mask)
    return sizes

def filter_significant_differences(source_intersecting_detections_by_class, target_intersecting_detections_by_class, difference_threshold):
    significant_differences = []

    # Combine class keys from both source and target to ensure no class is overlooked.
    all_classes = set(source_intersecting_detections_by_class.keys()) | set(target_intersecting_detections_by_class.keys())
    for cls in all_classes:
        
        detail1_bboxes = source_intersecting_detections_by_class.get(cls)
        detail2_bboxes = target_intersecting_detections_by_class.get(cls)
        if detail1_bboxes is None or detail2_bboxes is None:
            continue
        for detail1 in detail1_bboxes:
            detail1_bbox = get_segmentation_as_bbox(detail1['pred_masks']) # here we return array of boxes (bbox on each contour)
            for detail2 in detail2_bboxes:
                detail2_bbox = get_segmentation_as_bbox(detail2['pred_masks'])

                # Only proceed if both images have detections for the class.
                if not do_bounding_boxes_intersect(detail1_bbox, detail2_bbox):
                    continue
  
                if detail1 and detail2:
                    # Calculate the absolute difference in scores.
                    score_diff = abs(detail1['score'] - detail2['score'])

                    # Small area do not take artifacts (low probability + can't see them)
                    small_effected_area = (detail1['masked_area_intersection_percentage'] <= 2.4 and detail2['masked_area_intersection_percentage'] <= 2.4) and (detail1['segmented_area_intersection_percentage'] < 2.4 or detail2['segmented_area_intersection_percentage'] < 2.4)
                    
                    details1_segmented_area_interscetion_precentage = round(detail1['segmented_area_intersection_percentage'], 2)
                    details2_segmented_area_interscetion_precentage = round(detail2['segmented_area_intersection_percentage'], 2)
                    is_object_completley_inside_masking_area = (97 <= details1_segmented_area_interscetion_precentage) and (97 <= details2_segmented_area_interscetion_precentage)
                    # Not stable - is_object_segmention_difference_small_change = abs(details1_segmented_area_interscetion_precentage - details2_segmented_area_interscetion_precentage) <= 1 # it didnt change the size there for not artifacts
                    # print(score_diff, small_effected_area, is_object_completley_inside_masking_area)
                    # if all below 2.4 in the inter/mask and any below 2.4 in seg/mask 
                    # Consider the difference significant if it meets or exceeds the threshold.
                    if score_diff >= difference_threshold and not small_effected_area and not is_object_completley_inside_masking_area:
                        significant_differences.append({
                            'class': cls,
                            'score_difference': round(score_diff * 100, 3),
                            'detail1_bbox': detail1_bbox,
                            'details1': round(detail1['score'] * 100, 3),
                            'details1_masked_area_interscetion_precentage': round(detail1['masked_area_intersection_percentage'], 2),
                            'details1_segmented_area_interscetion_precentage': details1_segmented_area_interscetion_precentage,
                            'details1_size':  round(get_pred_mask_size(detail1['pred_masks']), 0),
                            'detail2_bbox': detail2_bbox,
                            'details2': round(detail2['score'] * 100, 3),
                            'details2_masked_area_interscetion_precentage': round(detail2['masked_area_intersection_percentage'], 2) ,
                            'details2_segmented_area_interscetion_precentage': details2_segmented_area_interscetion_precentage,
                            'details2_size': round(get_pred_mask_size(detail2['pred_masks']), 0)
                        })

    return significant_differences

def add_artifacts_details(example, show=False):
    # Detect objects in resized source and target images.
    example['artifacts_details'] = []
    for detctor in [detic_utils]:
        selected_sizes = [example['mask_img'].size, example['source_img'].size]
        for selected_size in selected_sizes:
            image, source_results = detctor.get_bounding_boxes(example['source_img'], should_display=show, target_size=selected_size)
            image, target_results = detctor.get_bounding_boxes(example['target_img'], should_display=show, target_size=selected_size)
            
            # Get the contour from the mask image and optionally display it.
            masking_contours, image_with_polygons = get_masking_contours(example['mask_img'])
            for masking_contour in masking_contours:
                if show:
                    display(example['mask_img'].resize((500, 500)))
                try:
                    # Map detected objects to the highest score with intersection with the mask.
                    source_intersected_dections_by_class = get_intersecting_detections_by_class(source_results, masking_contour, selected_size)
                    target_intersected_dections_by_class = get_intersecting_detections_by_class(target_results, masking_contour, selected_size)
                except:
                    print(example['instruction'],'failed to extract contour')
                    set_valid(example, False)
                    continue;
                    # return example
                
                # Filter and identify significant differences between the source and target.
                artifacts_details = filter_significant_differences(source_intersected_dections_by_class, target_intersected_dections_by_class, difference_threshold=0.04)
                # Serialize significant differences details and add to the example.
                example['artifacts_details'] = example['artifacts_details'] + list(map(lambda x: json.dumps(x, cls=NumpyEncoder), artifacts_details))
            
    return example

### Artifacts in masking

In [None]:
def get_classes_list(bboxes, replace_underscore=True):
    return set(map(lambda x: x['class'].replace('_', ' ').replace('-', ' ') if replace_underscore else x['class'], bboxes))

def get_all_bboxes(image, object_classes):
    if len(object_classes) == 0:
        return [], []
    source_all_boxes, source_boxes, source_image, scores = get_top_boxes(image, list(object_classes), bbox_score_threshold=0.2)
    all_bboxes, bboxes = source_all_boxes, source_boxes
    return filter_dall_e_watermark(all_bboxes), filter_dall_e_watermark(bboxes)

def get_bbox_classes(bboxes, replace_underscore=False):
    return set(list(map(lambda x: x['class'] if not replace_underscore else x['class'].replace('_', ' '), bboxes)))

def filter_out_source_object(source_object, objects_bboxes, debug=False):
    non_source_object_classes = set()
    objects_bboxes_classes = get_bbox_classes(objects_bboxes, replace_underscore=True)
    for bbox_class in objects_bboxes_classes:
        similar_object = False
        bbox_noun_set = set(get_text_nouns(bbox_class))
        source_object_noun_set = set(get_text_nouns(source_object))
        if len(bbox_noun_set.intersection(source_object_noun_set)) > 0:
            similar_object = True
            print('No need to use synonyms', bbox_noun_set, 'noun: ' + bbox_class, 'source object noun:', bbox_noun_set)
        if not similar_object:
            for word in source_object_noun_set:
                synonyms = get_word_synonyms(word)
                if len(set(synonyms).intersection(bbox_noun_set)) > 0:
                    similar_object = True
                    debug and print('Needed to use synonyms - source!' 'source', source_object, 'source noun', word, 'source noun synonyms', synonyms, 'bbox', bbox_class)
                    break
        if not similar_object:
            for word in bbox_noun_set:
                synonyms = get_word_synonyms(word)
                if len(set(synonyms).intersection(source_object_noun_set.union({source_object}))) > 0:
                    similar_object = True
                    debug and print('Needed to use synonyms - bbox!', 'bbox', bbox_class, 'bbox noun', word, 'bbox nouns syno', synonyms, 'source object', source_object, '')
                    break;             
        if not similar_object:
            non_source_object_classes.add(bbox_class)
            
    filtered_bboxes = list(filter(lambda bbox: bbox['class'].replace('_', ' ') in non_source_object_classes, objects_bboxes))
    assert len(get_classes_list(filtered_bboxes)) == len(non_source_object_classes)
    return filtered_bboxes

def filter_bboxes_interescted_with_main_object(main_source_object_boxes, secondary_objects_boxes, debug=False):
    valid_secondary_objects = []
    for secondary_object_bbox in secondary_objects_boxes:
        secondary_object_not_intersect_with_source = True
        for source_object in main_source_object_boxes:
            intersection_area, percent_cover_rect1, percent_cover_rect2 = calculate_intersection_area_bbox(secondary_object_bbox['bbox'], source_object['bbox'])
            debug and print(secondary_object_bbox['class'], source_object['class'], intersection_area, percent_cover_rect1, percent_cover_rect2)
            if intersection_area != 0:
                secondary_object_not_intersect_with_source = False
    
        if secondary_object_not_intersect_with_source:
            valid_secondary_objects.append(secondary_object_bbox)
    
    return valid_secondary_objects

def get_source_object_variations(source_object):
    source_object_variations = set(get_text_nouns(source_object)).union({source_object})
    lemmatized_source_object_variations = set(map(lambda word: lemmatizer.lemmatize(word), source_object_variations))
    return source_object_variations.union(lemmatized_source_object_variations)

def get_all_OD_classes(source_img, target_img, target_size=(500,500)):
    source_image, source_results = detic_utils_sensitive_lvis.get_bounding_boxes(source_img, target_size=target_size, should_display=False, bbox_size_threshold=0.001)
    target_image, target_results = detic_utils_sensitive_lvis.get_bounding_boxes(target_img, target_size=target_size, should_display=False, bbox_size_threshold=0.001)
    return get_classes_list(source_results + target_results), source_results, target_results

def get_main_source_object_boxes(image, main_object, contours, target_size=(500,500), debug=False):
    debug and print(main_object)
    assert main_object is not None
    if main_object != 'None':
        main_object_variations = get_source_object_variations(main_object)
        source_all_boxes, source_boxes = get_all_bboxes(image.resize(target_size), list(main_object_variations))
    else:
        source_all_boxes, source_boxes = [], []
    debug and print(source_boxes)
    
    main_source_object_boxes = get_intersected_bboxes(image, contours, source_boxes, contour_covered_threshold=0, bbox_covered_threshold=0)
    return main_source_object_boxes

In [None]:
# source_all_boxes_all, source_boxes_all, source_image = [], [], None
def get_unintersected_bboxes(source_bboxes, filter_bboxes, debug=False):
    unintersected_bboxes = []
    for source_bbox in source_bboxes:
        is_no_intersection = True
        for bbox in filter_bboxes:
            intersection_area, percent_cover_rect1, percent_cover_rect2 = calculate_intersection_area_bbox(source_bbox['bbox'], bbox['bbox'])
            is_no_intersection = intersection_area == 0
            debug and print(source_bbox['id'], bbox['id'], intersection_area, percent_cover_rect1, percent_cover_rect2, source_bbox['bbox'], bbox['bbox'])
            if not is_no_intersection:
                break;
        if is_no_intersection:
            debug and print('Survived', source_bbox['id'], source_bbox['bbox'])
            unintersected_bboxes.append(source_bbox)
    return unintersected_bboxes
        

def validate_classes_boxes(image, source_secondary_bbox, contours, target_secondary_bbox, debug=False):
    source_secondary_bbox_classes = get_bbox_classes(source_secondary_bbox, replace_underscore=True)
    if len(source_secondary_bbox_classes) == 0:
        print('Empty source bboxex to verify')
        return source_secondary_bbox_classes
    debug and print('---')
    unintersected_masked_intersected_bbox = get_unintersected_bboxes(source_secondary_bbox, target_secondary_bbox)
    intersected_bbox_classes = get_classes_list(unintersected_masked_intersected_bbox, replace_underscore=False)
    debug and print('OD classes ', intersected_bbox_classes)
    debug and print('Real classes ', source_secondary_bbox_classes)
    debug and print('Target classes ', get_classes_list(target_secondary_bbox, replace_underscore=True))
    return unintersected_masked_intersected_bbox

def enrich_example_with_object_bboxes(example, target_size=(500, 500)):
    contours, image_with_polygons = get_masking_contours(example['mask_img'].resize(target_size), add_polygons=False)
    
    # get all_classes and source and target segmentation bboxes
    all_object_classes, source_segmentation_bboxes, target_segmentation_bboxes = get_all_OD_classes(example['source_img'], example['target_img'], target_size=target_size)
    all_source_bboxes, source_bboxes = get_all_bboxes(example['source_img'].resize(target_size), all_object_classes)
    all_source_bboxes = source_bboxes
    
    # display_boxes(example['source_img'], all_source_bboxes)
    all_target_bboxes, target_bboxes = get_all_bboxes(example['target_img'].resize(target_size), all_object_classes)
    all_target_bboxes = target_bboxes

    secondary_objects_source = get_intersected_bboxes(example['source_img'], contours, all_source_bboxes, contour_covered_threshold=0, bbox_covered_threshold=0)
    example['source_image_bboxes'] = all_source_bboxes
    example['source_image_intersected_bboxes'] = secondary_objects_source

    secondary_objects_target = get_intersected_bboxes(example['target_img'], contours, all_target_bboxes, contour_covered_threshold=0, bbox_covered_threshold=0)
    example['target_image_bboxes'] = all_target_bboxes
    example['target_image_intersected_bboxes'] = secondary_objects_target

    return example

def enrich_with_artifacts_secondary(example, target_size=(500, 500)):
    contours, image_with_polygons = get_masking_contours(example['mask_img'].resize(target_size), add_polygons=False)
    secondary_objects_source, secondary_objects_target, all_source_bboxes, all_target_bboxes = example['source_image_intersected_bboxes'], example['target_image_intersected_bboxes'], example['source_image_bboxes'], example['target_image_bboxes']
    
    secondary_actions = ['Add', 'Remove']
    if example['action'] not in secondary_actions:
        example['secondary_objects_source_classes'] = []
        example['secondary_objects_target_classes'] = []
        example['is_secondary_artifact'] = False
        return example

    target_size = (500, 500)
    contours, image_with_polygons = get_masking_contours(example['mask_img'].resize(target_size), add_polygons=False)

    # main_source_object_boxes, secondary_objects_boxes
    print(example['source_img'], example['source_object'])
    main_source_object_boxes = get_main_source_object_boxes(example['source_img'], example['source_object'], contours, target_size=target_size)
    secondary_objects_source = filter_out_source_object(example['source_object'], secondary_objects_source)
    secondary_objects_source_classes = get_bbox_classes(secondary_objects_source)

    main_target_object_boxes = get_main_source_object_boxes(example['target_img'], example['target_object'], contours, target_size=target_size)
    secondary_objects_target = filter_out_source_object(example['target_object'], secondary_objects_target)
    secondary_objects_target_classes = get_bbox_classes(secondary_objects_target)
   
    # new/old items, in the masking area, that do not intersect with an old/new items (same class) bboxes.
    before_items = secondary_objects_source # list(filter(lambda item: item['class'] in unique_source_target_classes, secondary_objects_source))
    before_items = filter_bboxes_interescted_with_main_object(main_source_object_boxes, before_items)
    after_items = secondary_objects_target # list(filter(lambda item: item['class'] in unique_source_target_classes, secondary_objects_target))
    after_items = filter_bboxes_interescted_with_main_object(main_target_object_boxes, after_items)

    # Validating the bboxes
    example['secondary_objects_source'] = validate_classes_boxes(example['source_img'].resize(target_size), secondary_objects_source, contours, all_target_bboxes, debug=False)
    example['secondary_objects_source_classes'] = get_bbox_classes(example['secondary_objects_source'])
    example['secondary_objects_target'] = validate_classes_boxes(example['target_img'].resize(target_size), secondary_objects_target, contours, all_source_bboxes, debug=False)
    example['secondary_objects_target_classes'] = get_bbox_classes(example['secondary_objects_target'])

    
    # secondary_artifact_items = (example['secondary_objects_source_classes'] - example['secondary_objects_target_classes']) if example['action'] == 'Add' else (example['secondary_objects_target_classes'] - example['secondary_objects_source_classes'])
    secondary_artifact_items = (example['secondary_objects_source_classes']) if example['action'] == 'Add' else (example['secondary_objects_target_classes'])
    example['is_secondary_artifact'] = len(secondary_artifact_items) > 0
    return example

#### Generate Negative examples

In [None]:
import inflect
from collections import Counter
p = inflect.engine()

def count_classes(array):
    class_counter = Counter()
    for item in array:
        class_name = item['class']
        class_counter[class_name] += 1
    return dict(class_counter)

def get_bbox_masking_difference(mask_size, bbox):
    bbox_size, _ , _ =  calculate_intersection_area_bbox(bbox['bbox'], bbox['bbox'])
    return abs(mask_size - bbox_size)

def get_negative_remove_object(example):
    bbox_classes = set(map(lambda bbox: bbox['class'] , example['source_image_bboxes'])) - set(map(lambda bbox: bbox['class'] , example['source_image_intersected_bboxes']))
    bbox_classes = list(filter(lambda bbox: bbox['class'] in bbox_classes and bbox['score'] >= 0.2, example['source_image_bboxes']))
    
    if len(bbox_classes) == 0:
        bbox_classes = list(filter(lambda bbox: bbox not in example['source_image_intersected_bboxes'], example['source_image_bboxes']))
        if len(bbox_classes) == 0:
            return None
    
    bbox_classes_counter = count_classes(example['source_image_bboxes'])
    masking_contour, _ = get_masking_contours(example['mask_img'].resize((500,500)))
    try:
        mask_size, mask = get_contour_area(masking_contour[0], (500,500), (500,500))
    except:
        raise Exception()
        
    sorted_bboxes = sorted(bbox_classes, key= lambda bbox: get_bbox_masking_difference(mask_size, bbox))
    negative_class = sorted_bboxes[0]['class']
    return negative_class if bbox_classes_counter.get(negative_class) == 1 else p.plural(negative_class)

In [None]:
examples_second = []
extract_negative_source_target_object_prompt = open(PROJECT_PATH + '//prompts/extract_negative_target_object_prompt.txt', "r").read()
generate_simple_negative_replace_instruction_prompt = open(PROJECT_PATH + '//prompts/generate_negative_replace_instruction_prompt.txt', "r").read()
generate_simple_negative_change_attribute_instruction_prompt = open(PROJECT_PATH + '//prompts/generate_negative_change_attribute_instruction_prompt.txt', "r").read()
extend_remove_difference_caption_prompt = open(PROJECT_PATH + '//prompts/extend_remove_difference_caption_prompt.txt', "r").read()

def extract_source_example(example, debug=False):
    negative_target, negative_source = '', ''
    if example['revised_instruction']: 
        prompt = extract_negative_source_target_object_prompt.format(example['revised_instruction'])
        debug and print(example['revised_instruction'])
        debug and print(prompt)
        response = get_chatgpt_4_prediction(prompt, model='gpt-4o-mini')
        debug and print(response)
        negative_target = response.split('Source:')[0].split('Target:')
        negative_target = (negative_target[0] if len(negative_target) == 1 else negative_target[1]).replace('\n', '').strip()
        
        negative_source = response.split('Source:')[1].replace('\n', '').strip()
        debug and print('target', negative_target, 'source', negative_source)
    return {'target': negative_target if negative_target.lower() != 'none' else example['target_object'] ,
            'source': negative_source if negative_source.lower() != 'none' else example['source_object']}

generate_similar_object_add_prompt = open(PROJECT_PATH + '//prompts/generate_similar_object_add_prompt.txt', "r").read()
generate_add_instruction_metadata_prompt = open(PROJECT_PATH + '//prompts/generate_new_instruction_object_metadata_prompt.txt', "r").read()
generate_difference_caption_metadata_prompt = open(PROJECT_PATH + '//prompts/generate_new_difference_caption_metada_prompt.txt', "r").read()

def extract_instructions(r):
    return r.split('Modified Instruction:')[1].split('Modified Extensive Instruction:')[0].strip(), r.split('Modified Extensive Instruction:')[1].strip()

def extract_difference_captions(r):
    return r.split('Simple Difference Caption:')[1].split('Extensive Difference Caption:')[0].strip(), r.split('Extensive Difference Caption:')[1].strip()

def get_negative_add_edit(example, debug=False):
    prompt = generate_similar_object_add_prompt.format(example['revised_instruction'], example['extensive_revised_instruction'], example['simple_caption'], example['extensive_caption'], example['negative_old_target'])
    response = get_chatgpt_4_prediction(prompt, model='gpt-4o-mini')
    debug and print('Prompt:')
    debug and print(prompt)
    debug and print('Response:')
    debug and print(response)
    old_target = response.split('Old Target Object:')[1].split('New Target Object:')[0].replace('\n', '').strip()
    negative_target = response.split('New Target Object:')[1].split('Explanation:')[0].replace('\n', '').strip()
    negative_explenation = response.split('Explanation:')[1].split('Modified Instruction:')[0].replace('\n', '').strip()
    new_instruction_metadata_prompt = generate_add_instruction_metadata_prompt.format(old_target, negative_target, example['revised_instruction'], example['extensive_revised_instruction'])
    new_caption_metadata_prompt = generate_difference_caption_metadata_prompt.format(old_target, negative_target, example['simple_caption'], example['extensive_caption'])
    
    new_instruction, extensive_new_instruction = extract_instructions(get_chatgpt_4_prediction(new_instruction_metadata_prompt, model='gpt-4o-mini'))
    new_difference_caption, extensive_new_difference_caption = extract_difference_captions(get_chatgpt_4_prediction(new_caption_metadata_prompt, model='gpt-4o-mini'))

    return {
        'old_target': old_target,
        'target': negative_target,
        'explenation': negative_explenation,
        'modified_instruction': new_instruction,
        'extensive_modified_instruction': extensive_new_instruction,
        'simple_difference_caption': new_difference_caption,
        'extensive_difference_caption': extensive_new_difference_caption
    }

def get_negative_replace_edit(example):
    prompt = generate_simple_negative_replace_instruction_prompt.format(example['revised_instruction'], example['extensive_revised_instruction'], example['simple_caption'], example['extensive_caption'], example['negative_old_source'], example['negative_old_target'])
    response = get_chatgpt_3_prediction(prompt)
    # response = get_chatgpt_4_prediction(prompt, model='gpt-4-0125-preview')
    old_target = response.split('Old Target Object:')[1].split('New Target Object:')[0].replace('\n', '').strip()
    negative_target = response.split('New Target Object:')[1].split('Explanation:')[0].replace('\n', '').strip()
    negative_explenation = response.split('Explanation:')[1].split('Modified Instruction:')[0].replace('\n', '').strip()
    negative_modified_instruction = response.split('Modified Instruction:')[1].split('Modified Extensive Instruction:')[0].replace('\n', '').strip()
    negative_extensive_modified_instruction = response.split('Modified Extensive Instruction:')[1].split('Simple Difference Caption:')[0].replace('\n', '').strip()
    # examples_first.append([negative_extensive_modified_instruction, example['extensive_revised_instruction']])
    global examples_second
    examples_second.append([negative_extensive_modified_instruction, example['extensive_revised_instruction']])
    simple_difference_caption = response.split('Simple Difference Caption:')[1].split('Extensive Difference Caption:')[0].replace('\n', '').strip()
    extensive_difference_caption = response.split('Extensive Difference Caption:')[1].replace('\n', '').strip()
    return {
        'source': example['source_object'],
        'old_target': old_target,
        'target': negative_target,
        'explenation': negative_explenation,
        'modified_instruction': negative_modified_instruction,
        'extensive_modified_instruction': negative_extensive_modified_instruction,
        'simple_difference_caption': simple_difference_caption,
        'extensive_difference_caption': extensive_difference_caption,
    }

# Modified Instruction: Remove both of the ostrich legs.
# Modified Extensive Instruction: Remove both of the ostrich legs, leaving the ostrich without any legs in the image.
# New Target Object: both ostrich legs
# Simple Difference caption: Both ostrich legs were removed.
# Extensive Difference caption: The pair of ostrich legs was removed leaving the ostrich without any legs in the image.
def get_negative_remove_edit(example, debug=False):
    new_souce = get_negative_remove_object(example)
    if new_souce is None:
        return dict()
    modified_instruction = f'remove the {new_souce}'
    extensive_modified_instruction = f'remove the {new_souce}'
    simple_difference_caption = f'The {new_souce} {"was" if p.singular_noun(new_souce) == new_souce else "were"} removed'
    extensive_difference_caption = f'The {new_souce} {"was" if p.singular_noun(new_souce) == new_souce else "were"} removed'
    
    new_diffrence_caption_prompt = extend_remove_difference_caption_prompt.format(example['source_description'], example['extensive_caption'], extensive_difference_caption)
    extensive_difference_caption = get_chatgpt_4_prediction(new_diffrence_caption_prompt, model='gpt-4o-mini', temprature=0.3)
    debug and print(new_diffrence_caption_prompt, extensive_difference_caption)
    return {
        'source': new_souce,
        'modified_instruction': modified_instruction,
        'extensive_modified_instruction': extensive_modified_instruction,
        'simple_difference_caption': simple_difference_caption,
        'extensive_difference_caption': extensive_difference_caption
    }

# Modified Instruction:  Change the man's action to blowing on the pizza as if it's too hot. 
# New Target Object: Man's action
# Short Modified Instruction: Change the man's action to blowing on the pizza
def get_negative_change_attribute_edit(example, debug=False):
    prompt = generate_simple_negative_change_attribute_instruction_prompt.format( example['revised_instruction'], example['extensive_revised_instruction'], example['simple_caption'], example['extensive_caption'], example['source_object'], example['target_object'])
    debug and print(prompt)
    response = get_chatgpt_4_prediction(prompt, model='gpt-4o-mini')  # get_chatgpt_3_prediction(prompt)
    debug and print(response)
    # Parsing the response based on the new output format
    modified_instruction = response.split('Modified Instruction:')[1].split('Source Object:')[0].replace('\n', '').strip()
    source_object = response.split('Source Object:')[1].split('Modified Target Object:')[0].replace('\n', '').strip()
    target_object = response.split('Modified Target Object:')[1].split('Modified Extensive Instruction:')[0].replace('\n', '').strip()
    extensive_modified_instruction = response.split('Modified Extensive Instruction:')[1].split('Modified Simple Difference Caption:')[0].replace('\n', '').strip()
    simple_difference_caption = response.split('Modified Simple Difference Caption:')[1].split('Modified Extensive Difference Caption:')[0].replace('\n', '').strip()
    extensive_difference_caption = response.split('Modified Extensive Difference Caption:')[1].replace('\n', '').strip()
    return {
        'source': source_object,
        'target': target_object,
        'modified_instruction': modified_instruction,
        'extensive_modified_instruction': extensive_modified_instruction,
        'simple_difference_caption': simple_difference_caption,
        'extensive_difference_caption': extensive_difference_caption
    }
    
# Modified Instruction: Change the cat's mouth to mimic a yawn, showing a wide-open mouth without visible teeth.
# New Target Object: Cat's mouth
def get_negative_edit(example, debug=False):
    action_type = example['action'].lower() if example['action'] is not None else None
    if action_type is None:
        return example
    if action_type not in ['add', 'remove', 'change attribute', 'replace']:
        raise Exception('Functions works only with add, remove, change att, and replace. Recive invalid action: ' +  example['action'])
        
    negative_source_target = extract_source_example(example)
    example['negative_old_source'] = negative_source_target['source']
    example['negative_old_target'] = negative_source_target['target']
    
    neagtive_instruction_details = dict()

    if action_type == 'add':
        neagtive_instruction_details = get_negative_add_edit(example, debug)
    if action_type == 'replace':
        neagtive_instruction_details = get_negative_replace_edit(example)
    if action_type == 'remove':
        neagtive_instruction_details = get_negative_remove_edit(example, debug)
    if action_type == 'change attribute':
        neagtive_instruction_details = get_negative_change_attribute_edit(example)
    example['is_negative_valid'] = False
    example['negative_instruction'] = neagtive_instruction_details.get('modified_instruction') or 'None'
    example['negative_extensive_instruction'] = neagtive_instruction_details.get('extensive_modified_instruction') or 'None'
    example['negative_simple_caption'] = neagtive_instruction_details.get('simple_difference_caption')
    example['negative_extensive_caption'] = neagtive_instruction_details.get('extensive_difference_caption')
    example['negative_source'] = neagtive_instruction_details.get('source') or 'None'
    example['negative_target'] = neagtive_instruction_details.get('target') or 'None'
    example['negative_explenation'] = neagtive_instruction_details.get('explenation') or 'None'

    if example['negative_source'] in ['None', None] and example['negative_target'] in ['None', None]:
        example['is_negative_valid'] = False
    else:
        example['is_negative_valid'] = True
        
    return example

In [None]:
# Edit If we want to edit also Source Object
def turn_example_to_negative(example):
    # example['original_instruction'] = example['instruction']
    negative_instruction = example['negative_instruction'] if example['original_instruction'].lower().replace('.', '') != example['negative_instruction'].lower().replace('.', '') else example['negative_extensive_instruction']
    example = example.copy()
    example['instruction'] = negative_instruction
    example['revised_instruction'] = negative_instruction
    example['explenation'] = example['negative_explenation']
    if example['target_object'] != 'None':
        example['target_object'] = example['negative_target']
    if example['source_object'] != 'None':
        example['source_object'] = example['negative_source']
    example['extensive_revised_instruction'] = example['negative_extensive_instruction']
    example['simple_caption'] = example['negative_simple_caption']
    example['extensive_caption'] = example['negative_extensive_caption']
    example['is_negative'] = True
    return example

def get_negative_dataframe(data_df):
    return data_df.copy()[data_df['is_negative_valid'] == True].progress_apply(lambda x: turn_example_to_negative(x), axis=1) 

def move_column_to_front(df, column_name):
    if column_name not in df.columns:
        print(f"Column '{column_name}' not found in DataFrame.")
        return df
    new_columns_order = [column_name] + [col for col in df.columns if col != column_name]
    return df[new_columns_order]

# We seperate the rows negative & positive rows - one from the other
def get_balanced_dataframe(data_df):
    positive_df = data_df.copy()
    negative_df = get_negative_dataframe(data_df)
    balanaced_df = pd.concat([positive_df, negative_df])
    balanaced_df['is_negative'] = balanaced_df.get('is_negative').fillna(False)
    balanaced_df = move_column_to_front(balanaced_df, 'is_negative')
    return balanaced_df.sort_index()

#### Reverse examples

In [None]:
# Load the instruction and caption prompts separately
change_attribute_instruction_prompt = open(PROJECT_PATH + '//prompts/reverse_edit_direction_instructions_pormpt - Change Attribute.txt', "r").read()
replace_instruction_prompt = open(PROJECT_PATH + '//prompts/reverse_edit_direction_instructions_pormpt - Replace.txt', "r").read()
remove_instruction_prompt = open(PROJECT_PATH + '//prompts/reverse_edit_direction_instructions_pormpt - Remove.txt', "r").read()
add_instruction_prompt = open(PROJECT_PATH + '//prompts/reverse_edit_direction_instructions_pormpt - Add.txt', "r").read()
replace_caption_prompt = open(PROJECT_PATH + '//prompts/reverse_edit_direction_captions_pormpt - Replace.txt', "r").read()
add_caption_prompt = open(PROJECT_PATH + '//prompts/reverse_edit_direction_captions_pormpt - Add.txt', "r").read()
remove_caption_prompt = open(PROJECT_PATH + '//prompts/reverse_edit_direction_captions_pormpt - Remove.txt', "r").read()
change_attribute_caption_prompt = open(PROJECT_PATH + '//prompts/reverse_edit_direction_captions_pormpt - Change Attribute.txt', "r").read()
self_contained_prompt = open(PROJECT_PATH + '//prompts/self_contain_caption_prompt.txt', "r").read()

def get_relevant_prompt(example, is_instruction=False):
    action = example.get('action', '').lower()
    if action == 'replace':
        return replace_instruction_prompt if is_instruction else replace_caption_prompt
    elif action == 'add':
        return add_instruction_prompt if is_instruction else add_caption_prompt
    elif action == 'remove':
        return remove_instruction_prompt if is_instruction else remove_caption_prompt
    elif action == 'change attribute':
        return change_attribute_instruction_prompt if is_instruction else change_attribute_caption_prompt
    else:
        raise ValueError(f"Unknown action: {action}")

def parse_instruction_response(example, response):
    example['revised_instruction'] = response.split('Edit Instruction:')[1].split('Extensive Edit Instruction:')[0].strip()
    example['extensive_revised_instruction'] = example['revised_instruction']
    return example['revised_instruction']

def parse_caption_response(example, response):
    example['extensive_caption'] = response.split('New Extensive Caption:')[1].split('New Source:')[0].strip()
    return example['extensive_caption']

def parse_reverse_response(example, response, columns):
    example_copy = example.copy()
    example['original_instruction'] = example['instruction']
    example['instruction'] = example['revised_instruction']
    example['extensive_instruction'] = example['extensive_revised_instruction']
    
    example['is_reverse'] = True
    if example['action'] in ['Add', 'Remove']:
        example['action'] = 'Add' if example['action'] == 'Remove' else 'Remove'

    source_object_columns = list(filter(lambda x: 'source' in x, columns)) # including 'source_object'
    target_object_columns = list(filter(lambda x: 'target' in x, columns)) # including 'target_object'
    for source_column in source_object_columns:
        example[source_column.replace('source', 'target')] = example_copy[source_column]
    for target_column in target_object_columns:
        example[target_column.replace('target', 'source')] = example_copy[target_column]
    return example
    
def get_reverse_example(example, columns=None, debug=False):
    example = example.copy()
    
    if columns is None:
        columns = example.index.tolist()
        
    instruction_prompt_to_send = get_relevant_prompt(example, is_instruction=True).format(example['revised_instruction'])
    response = get_chatgpt_4_prediction(instruction_prompt_to_send, model='gpt-4o-mini')
    parse_instruction_response(example, response)
    debug and print(instruction_prompt_to_send)
    debug and print(response,'\n')
    
    caption_prompt_to_send = get_relevant_prompt(example).format(example['extensive_caption'])
    response = get_chatgpt_4_prediction(caption_prompt_to_send, model='gpt-4o-mini')

    parsed_response = parse_caption_response(example, response)
    standalone_response = get_chatgpt_4_prediction(self_contained_prompt.format(parsed_response), model='gpt-4o-mini')
    example['extensive_caption'] = standalone_response
    try:
        reverse_example = parse_reverse_response(example, response, columns)
        return reverse_example
    except:
        raise Exception()

def get_augmentaion_with_reverse_rows(data_df):
    original_df = data_df.copy()
    reverse_df = data_df.progress_apply(lambda example: get_reverse_example(example, data_df.columns), axis=1)
    reverse_and_original_df = pd.concat([original_df, reverse_df])
    reverse_and_original_df['is_reverse'] = reverse_and_original_df['is_reverse'].fillna(False)
    return reverse_and_original_df.sort_index()



## Filter Instance with no object

In [None]:
# Configure global warning filters to ignore specific UserWarnings
import warnings
warnings.simplefilter("ignore", category=UserWarning)
warnings.simplefilter("ignore", category=FutureWarning)

BBOX_COVERED_THRESHOLD = 10 # The precentage of object bbox the user mask needs to cover (I.E. when pipeline consuues with objects near the edit masking)

# Function to extract main words related to specific objects in text by generating anchor words from 
# a combination of tokens, concatenated terms, and shortest path distances
def concatenate_with_next_word(text):
    words = text.split()
    concatenated_words = []
    for i in range(len(words) - 1):
        concatenated_words.append(words[i] + "_" + words[i+1])
    return concatenated_words

# A variant of get_main_words that identifies a single main object and generates related anchor words based on word tokenization
def get_main_words(object):
    prompt = "What is the main object here \"{}\"? Output only the main object, and nothing else.\nThe main object is:".format(object)
    main_object = get_chatgpt_3_prediction(prompt)
    
    words = word_tokenize(object)
    wn_objects = concatenate_with_next_word(object)
    anchor_words = set([main_object])
    
    for wn_object in wn_objects:
        for word in words:
            shortest_path = shortest_path_distance(word, wn_object)
            if shortest_path is not None and shortest_path < 4:
                anchor_words.add(word)
    return list(anchor_words)

# Detect bounding boxes of specified objects in an image using the "dino" model, 
# filtering results to only include boxes within masked regions
def get_dino_bboxes(image, mask_img, objects, bbox_score_threshold=0.3, masks=[], debug=False):
    source_all_boxes_all, source_boxes_all, source_image = [], [], None
    contours, image_with_polygons = get_masking_contours(mask_img.resize(image.size))
    for object in objects:
        singulars_and_plurals = get_expanded_singular_and_plural(object)
        singulars_and_plurals.add(object)
        for object_form in singulars_and_plurals:
            try:
                source_all_boxes, source_boxes, source_image, scores = dino.get_top_boxes(image, object_form +'.', bbox_score_threshold, masks=masks)
                source_all_boxes = get_intersected_bboxes(image, contours, source_all_boxes, debug=debug, bbox_covered_threshold=BBOX_COVERED_THRESHOLD, strict_intersection=True)
                source_boxes = get_intersected_bboxes(image, contours, source_boxes, debug=debug, bbox_covered_threshold=BBOX_COVERED_THRESHOLD, strict_intersection=True)
                for box in (source_all_boxes + source_boxes):
                    box.update({'class': object}) # dino returns different class
                source_all_boxes_all.extend(source_all_boxes)
                source_boxes_all.extend(source_boxes)
            except Exception as e:
                print('Dino failed object: ' + object_form + '. Error: ' + str(e))
    return source_all_boxes_all, source_boxes_all, source_image

# Function to enrich example data by adding object detection details (OD) for the specified object type
# with bounding box data from both "owl" and "dino" models, updating scores and detection status
def get_owl_boxes(image, mask_img, objects, bbox_score_threshold=0.3, masks=[], debug=False):
    source_all_boxes_all, source_boxes_all, source_image = [], [], None
    contours, image_with_polygons = get_masking_contours(mask_img.resize(image.size))
    for object in objects:
        singulars_and_plurals = get_expanded_singular_and_plural(object)
        singulars_and_plurals.add(object)
        for object_form in singulars_and_plurals:
            try:
                source_all_boxes, source_boxes, source_image, scores = get_top_boxes(image, [object_form], bbox_score_threshold, masks=masks)
                source_all_boxes = get_intersected_bboxes(image, contours, source_all_boxes, debug=debug, bbox_covered_threshold=BBOX_COVERED_THRESHOLD, strict_intersection=True)
                source_boxes = get_intersected_bboxes(image, contours, source_boxes, debug=debug, bbox_covered_threshold=BBOX_COVERED_THRESHOLD, strict_intersection=True)
                for box in (source_all_boxes + source_boxes):
                    box.update({'class': object}) # dino returns different class
                source_all_boxes_all.extend(source_all_boxes)
                source_boxes_all.extend(source_boxes)
            except Exception as e:
                print('Owl failed object: ' + object_form + '. Error: ' + str(e))
    return source_all_boxes_all, source_boxes_all, source_image

# Make it that if we find the entire source skip
# than if not try with seperate words
def enrich_with_object_OD_details(example, object_type, debug=False):
    source_and_target_objects = extract_source_example(example)
    edited_object = source_and_target_objects[object_type] # target or source
    if (edited_object is not None and edited_object not in ['none', 'nones']):
        source_main_words = get_main_words(edited_object)
        objects_to_detect = list(set(source_main_words))
        for box_func, objects, backbone in [(get_owl_boxes, objects_to_detect, 'owlv2'), (get_dino_bboxes, objects_to_detect, 'dino')]:
            masks_bboxes = list(map(lambda x: json.loads(x), example[f'{object_type}_masked_bbox_str']))
            source_all_boxes, source_boxes, source_image = box_func(example[f'{object_type}_img'], example[f'mask_img'], objects , bbox_score_threshold=0.3, masks=masks_bboxes, debug=debug)
            OD_source_classes_found = set(map(lambda x: x['class'].lower(), source_all_boxes))
            example[f'{object_type}_not_found_threshold_{backbone}'] = len(set(map(lambda x: x['class'].lower(), source_boxes))) == 0
            example[f'{object_type}_not_found_{backbone}'] = len(OD_source_classes_found) == 0
            example[f'precent_of_{object_type}_found_{backbone}'] = 1 if example[f'{object_type}_object'].lower() in OD_source_classes_found else len(OD_source_classes_found) / len(source_main_words)
            example[f'{object_type}_max_score_{backbone}'] = 0 if example[f'{object_type}_not_found_{backbone}'] else max(list(map(lambda x: x['score'], source_all_boxes)))
            example[f'{object_type}_mean_score_{backbone}'] = 0 if example[f'{object_type}_not_found_{backbone}'] else np.mean(list(map(lambda x: x['score'], source_all_boxes)))
            example[f'{object_type}_min_{backbone}'] = 0 if example[f'{object_type}_not_found_{backbone}'] else min(list(map(lambda x: x['score'], source_all_boxes)))

# Checks if the source object was detected in the example data, based on detection outcomes from all backbones
def is_source_found(example):
    if example['action'] == 'Add':
        return True
    for backbone in ['owlv2', 'dino']:
        if not example[f'source_not_found_{backbone}']: # if it was found return false
            return True
    return False

# Checks if the target object was detected in the example data, based on detection outcomes from all backbones
def is_target_found(example):
    if example['action'] == 'Remove':
        return True
    for backbone in ['owlv2', 'dino']:
        if not example[f'target_not_found_{backbone}']: # if it was found return false
            return True
    return False

# Function to assess if object detection has failed, based on conditions for source and target not being found, 
# and very low detection scores
DINO_THRESHOLD = 0.2
def is_edit_failure(example, debug=False):
    is_source_not_found = not example['pipeline_source_found'] # Cases where the pipeline missed Source (Remove, Replace, Change Attribute)
    is_target_not_found = not example['pipeline_target_found'] # Cases where the pipeline missed Target (Add, Replace, Change Attribute)
    is_very_low_source = (example['source_max_score_owlv2'] < 0.2 and (example['source_max_score_dino'] < DINO_THRESHOLD or example['source_min_dino'] < 0.37)) if example['action'] != 'Add' else False
    is_very_low_target = (example['target_max_score_owlv2'] < 0.2 and (example['target_max_score_dino'] < DINO_THRESHOLD or example['target_min_dino'] < 0.36)) if example['action'] != 'Remove' else False
    debug and print('Source not found', is_source_not_found)
    debug and print('Target not found', is_target_not_found)
    debug and print('Very low source', is_very_low_source)
    debug and print('Very low target', is_very_low_target)
    return is_source_not_found or is_target_not_found or is_very_low_source or is_very_low_target

# Function to enrich example data by adding object detection (OD) details for both source and target 
# objects, updating fields for object detection success/failure
def enrich_with_example_OD_details(example, debug=False):
    try:
        enrich_with_object_OD_details(example, 'source', debug=debug)
        enrich_with_object_OD_details(example, 'target', debug=debug)
        example['pipeline_source_found'] = is_source_found(example)
        example['pipeline_target_found'] = is_target_found(example)
    except Exception as e:
        print('Failed to enrich with OD details. Error: ', e)
        print(example['instruction'])
    return example

## Run Pipeline

#### Generate CSV Utils

In [None]:
def is_edit_accurate(example):
    # In case of annotation the 'action' can be 'None' (There is no "is_negative" columns...)
    assert not((example['action'] == 'None') and example.get('is_negative'))
    # is_negative - augmentation
    return (example['action'] != 'None') and (not example.get('is_negative'))

def get_artifacts_bboxes_data(example):
    artifacts_details = example['artifacts_details']
    artifacts_data = []
    for details in artifacts_details:
        parsed_details = json.loads(details)
        artifacts_data.append({
            'class': parsed_details['class'],
            'detail1_bbox': parsed_details['detail1_bbox'],
            'detail2_bbox': parsed_details['detail2_bbox']
        })
    return artifacts_data

def get_mask_bboxes(bboxes, original_img):
    if type(bboxes) == list:
        width, height = original_img.size
        return list(map(lambda x: json.dumps(json.loads(x)[:4]), bboxes))
    return []

def transform_data_to_csv_format(example):
    pipeline_failure = example.get('original_action') == 'None'
    pipeline_success = not pipeline_failure
    is_vald_masked_bbox =  type(example['target_masked_bbox_str']) == list
    try:
        return {
         "instruction": example['instruction'],
         "rich_instruction": example.get('extensive_revised_instruction') if pipeline_success else '',
         "is_edit_accurate": is_edit_accurate(example),
         "is_edit_contains_artifacts": len(example['artifacts_details']) > 0 or (example['is_secondary_artifact'] == True),
         "is_edit_reverse": example.get('is_reverse'),
         "source_image_description": example['source_description'],
         'target_image_description': example['target_description'],   
         "selected_source_image_description": example['source_selected_description'],
         'selected_target_image_description': example['target_selected_description'],   
         "source_mask_bbox": get_mask_bboxes(example['source_masked_bbox_str'], example['source_img']) if is_vald_masked_bbox else None,
         "target_mask_bbox": get_mask_bboxes(example['target_masked_bbox_str'], example['target_img']) if is_vald_masked_bbox else None,
         "artifacts_bboxes": get_artifacts_bboxes_data(example),
         "extensive_caption": example['extensive_caption'] if pipeline_success else '',
         "artifcats_metadata": example['artifacts_details'],
         "original_instruction": example.get('original_instruction') or example['instruction'],
         "caption": example['simple_caption'] if pipeline_success else '',
         "source_object": example['source_object'],
         "target_object": example['target_object'],
         "is_secondary_artifact": example['is_secondary_artifact'],
         "action": example['action'],
         "original_cation": example.get('original_cation') or example['action'],
         "id": get_instance_id(example, example.get('original_cation')),
         "original_id": get_instance_id(example, example.get('original_cation')),
         "mask_img": example['mask_img'],
         "source_img": example['source_img'], 
         "target_img": example['target_img'],
         "source_image_description": example['source_description'],
         "source_image_bbox_description": example['source_mask_bbox_description'],
         "target_image_description": example['target_description'],
         "target_bbox_image_description": example['target_mask_bbox_description']
        }
    except Exception as e:
        print(e)
        raise Exception()

#### Run Pipeline

In [None]:
from tqdm import tqdm


def enrich_invalid_example(example):
    no_changes_edit = 'changes' in example['extensive_caption'] 
    if no_changes_edit or (example['action'] in [None, 'None']) or (',' in example['action']) or ('and' in example['action']):
        example['original_action'] = example['action'] if example.get('original_action') is None else example['original_action']
        example['action'] = 'None'
    return example
    
# is_negative_valid
splits = {'train': MB_train, 'dev': MB_dev} # ,  'test': MB_test
def run_pipeline(split, save_csv, annotation_csv, batch_size, batch_start, batch_end):
    assert ((batch_end - batch_start) % batch_size) == 0
    for batch_index, batch_range in tqdm(enumerate([[batch_start, batch_end]]), desc="Batch Progress"):
        # Starting index of the batch range
        # Ending index of the batch range
        start_index = batch_range[0]
        end_index = batch_range[1]
        MB_set = splits.get(split)
        # Iterate over the batch range in steps of 100
        
        for sub_batch_start in range(start_index, end_index, batch_size):
            # Calculate the end index of the sub-batch, ensuring it does not exceed the batch range
            sub_batch_end = min(sub_batch_start + batch_size, end_index)
            # Now, you have a sub-batch range from sub_batch_start to sub_batch_end
            # You can process this sub-batch range as needed
            print(f"Starting Batch {batch_index}, Sub-batch range: {sub_batch_start}-{sub_batch_end}")
        
            try:
                # Selected Range to Execute (For 32GB ram worked well for batches of 1000 instances.)
                selected_range = list(range(sub_batch_start, sub_batch_end))
                
                dataset_df = pd.DataFrame(MB_set.select(selected_range) if split!='test' else MB_set[sub_batch_start:sub_batch_end])
                # Enriching the dataset with the edit details: Source,Target, Edit Action Type, Simple Caption (difference caption), Extensive Caption (extensive difference caption), etc.
                dataset_df = dataset_df.progress_apply(lambda x: enrich_dataset(x), axis=1)
                
                # Sets valid False instances where the pipeline failed to extract the edit details
                dataset_df = dataset_df.progress_apply(lambda x: enrich_invalid_example(x), axis=1)

                # Filter Edits when edited objects not found or very low quality
                dataset_df_OD_details = dataset_df.progress_apply(lambda x: enrich_with_example_OD_details(x), axis=1)
                dataset_df_OD_details = dataset_df_OD_details[dataset_df_OD_details.progress_apply(lambda x: not is_edit_failure(x), axis=1)]

                # Adding artifacts details
                dataset_df_with_artifacts = dataset_df_OD_details.progress_apply(lambda x: add_artifacts_details(x), axis=1)
                dataset_df_with_artifacts = dataset_df_with_artifacts.progress_apply(lambda x: enrich_example_with_object_bboxes(x), axis=1)
                dataset_df_with_artifacts = dataset_df_with_artifacts.progress_apply(lambda x: enrich_with_artifacts_secondary(x), axis=1)

                dataset_df_with_artifacts.progress_apply(lambda x: enrich_invalid_example(x), axis=1)
                dataset_df_with_artifacts = dataset_df_with_artifacts[dataset_df_with_artifacts.progress_apply(lambda x: x['action'] != 'None', axis=1)]
                if not annotation_csv:
                    # Reverse the edit - table change to dog -> dog change to table (2x rows) - gpt 3
                    balanced_dataset_df = get_augmentaion_with_reverse_rows(dataset_df_with_artifacts).reset_index(drop=True)
                    
                    ### Augmentation ###
                    # Negative exampled - gpt3 + gpt4 (for extracting source and target)
                    balanced_dataset_df = balanced_dataset_df.progress_apply(lambda x: get_negative_edit(x), axis=1)
                    balanced_dataset_df = get_balanced_dataframe(balanced_dataset_df).reset_index(drop=True) # seperate the rows to negative and positive (2x Number of rows)
                else:
                    balanced_dataset_df = dataset_df_with_artifacts

                ### Transform to CSV format and save ###
                balanced_dataset_df = pd.DataFrame(list(balanced_dataset_df.progress_apply(lambda x: transform_data_to_csv_format(x), axis=1)))
                balanced_dataset_df['split'] = split
                if save_csv:
                    csv_name = f'MaNisthana_batch_{split}_{sub_batch_start}_{sub_batch_end}.json'
                    if annotation_csv:
                        csv_name = 'Annotation_' + csv_name
                    balanced_dataset_df.to_json(r'batches/'+csv_name)
               
            except Exception as e:
                print(e)
                traceback.print_exc()
                print(f"Failed in Batch {batch_index}, Sub-batch range: {sub_batch_start}-{sub_batch_end}")
        return balanced_dataset_df
LOCAL_RUN = False
balanced_dataset_df = run_pipeline(split='dev', save_csv=True, annotation_csv=False, batch_size=1, batch_start=0, batch_end=1)
