In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms as T
import torch.nn.functional as F
from torchvision.io import read_image, ImageReadMode
from PIL import Image, ImageDraw, ImageFilter
import os
import json
import pickle
import numpy as np

class RefCocoG_Dataset(Dataset):
    full_annotations = None

    def __init__(self, root_dir, annotations_f, instances_f, split='train', transform=None, target_transform=None) -> None:
        super().__init__()

        self.root_dir = root_dir
        self.annotations_f = annotations_f
        self.instances_f = instances_f

        self.split = split

        self.transform = transform
        self.target_transform = target_transform

        self.get_annotations()
        self.image_names = list([
            self.annotations[id]['image']['actual_file_name']
            for id in self.annotations
        ])

    def get_annotations(self):
        if RefCocoG_Dataset.full_annotations:
            self.annotations = dict(filter(lambda match: match[1]['image']['split'] == self.split, RefCocoG_Dataset.full_annotations.items()))
            return

        # Load pickle data
        with open(os.path.join(self.root_dir, 'annotations', self.annotations_f), 'rb') as file:
            self.data = pickle.load(file)

        # Load instances
        with open(os.path.join(self.root_dir, 'annotations', self.instances_f), 'rb') as file:
            self.instances = json.load(file)

        # Match data between the two files and build the actual dataset
        self.annotations = {}

        images_actual_file_names = {}
        for image in self.instances['images']:
            images_actual_file_names[image['id']] = image['file_name']

        for image in self.data:
            if image['ann_id'] not in self.annotations:
                self.annotations[image['ann_id']] = {}

            self.annotations[image['ann_id']]['image'] = image
            self.annotations[image['ann_id']]['image']['actual_file_name'] = images_actual_file_names[image['image_id']]

        for annotation in self.instances['annotations']:
            if annotation['id'] not in self.annotations:
                continue

            self.annotations[annotation['id']]['annotation'] = annotation

        # Keep only samples from the given split
        RefCocoG_Dataset.full_annotations = self.annotations
        self.annotations = dict(filter(lambda match: match[1]['image']['split'] == self.split, self.annotations.items()))

    def __len__(self):
        # Return the number of images
        return len(self.image_names)

    def corner_size_to_corners(self, bounding_box):
        """
        Transform (top_left_x, top_left_y, width, height) bounding box representation
        into (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
        """

        return [
            bounding_box[0],
            bounding_box[1],
            bounding_box[0] + bounding_box[2],
            bounding_box[1] + bounding_box[3]
        ]

    def __getitem__(self, idx):
        # Get the image name at the given index
        image_name = self.image_names[idx]

        # Load the image file as a PIL image
        image = Image.open(os.path.join(self.root_dir, 'images', image_name)).convert('RGB')
        # image = read_image(os.path.join(self.root_dir, 'images', image_name), ImageReadMode.RGB)
        
        image_id = list(self.annotations)[idx]

        # print(image_id)

        # Get the caption for the image
        prompts = [
            prompt['sent'] for prompt in self.annotations[image_id]['image']['sentences']
        ]

        # Get the bounding box for the prompts for the image
        bounding_box = self.corner_size_to_corners(self.annotations[image_id]['annotation']['bbox'])

        # Apply the transform if given
        if self.transform:
            image = self.transform(image)

        sample = [
            image,
            bounding_box,
            prompts,
        ]

        # Return the sample as a list
        return sample

In [None]:
dataset_train = RefCocoG_Dataset('refcocog', 'refs(umd).p', 'instances.json', split='train')
dataset_val = RefCocoG_Dataset('refcocog', 'refs(umd).p', 'instances.json', split='val')
dataset_test = RefCocoG_Dataset('refcocog', 'refs(umd).p', 'instances.json', split='test')

dataset_splits = [
    dataset_train,
    dataset_val,
    dataset_test
]

In [None]:
len(RefCocoG_Dataset.full_annotations), len(dataset_train.annotations), len(dataset_val.annotations), len(dataset_test.annotations)

In [None]:
def collate_differently_sized_prompts(batch):
    images = [item[0] for item in batch]
    bboxes = [item[1] for item in batch]
    prompts = [item[2] for item in batch]
    
    return list(images), list(bboxes), list(prompts)

def get_data(dataset_splits, batch_size=64, test_batch_size=256, num_workers=0):
    training_data = dataset_splits[0]
    validation_data = dataset_splits[1]
    test_data = dataset_splits[2]

    # Change shuffle to True for train
    train_loader = torch.utils.data.DataLoader(training_data, batch_size, shuffle=True, drop_last=True, collate_fn=collate_differently_sized_prompts, num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(validation_data, test_batch_size, shuffle=False, collate_fn=collate_differently_sized_prompts, num_workers=num_workers)
    test_loader = torch.utils.data.DataLoader(test_data, test_batch_size, shuffle=False, collate_fn=collate_differently_sized_prompts, num_workers=num_workers)

    return train_loader, val_loader, test_loader

In [None]:
train_loader, val_loader, test_loader = get_data(dataset_splits, batch_size=64, test_batch_size=64, num_workers=0)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0") # First GPU
else:
    device = 'cpu'

In [None]:
if torch.cuda.is_available():
    yolo_models = [torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True).to(f'cuda:{i}') for i in range(torch.cuda.device_count())]
else:
    yolo_models = [torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True).to(device)]

In [None]:
import clip

clip_backbone = 'ViT-B/32'
# clip_backbone = 'ViT-L/14'
clip_backbone = 'ViT-L/14@336px'
# clip_backbone = 'RN50x16'
# clip_backbone = 'RN50x64'
# clip_backbone = 'RN101'

models, preprocesses = [], []

if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        model, preprocess = clip.load(clip_backbone, device=f'cuda:{i}')
        
        models.append(model)
        preprocesses.append(preprocess)
else:
    model, preprocess = clip.load(clip_backbone, device=device)
    models.append(model)
    preprocesses.append(preprocess)

In [None]:
next(models[0].parameters()).device

In [None]:
import torch

def cosine_similarity(a: torch.Tensor, b: torch.Tensor):
    """
    Cosine Similarity

    Normalizes both tensors a and b. Returns <b, a.T> (inner product).
    """

    a_norm = a / a.norm(dim=-1, keepdim=True)
    b_norm = b / b.norm(dim=-1, keepdim=True)

    similarity = (b_norm @ a_norm.T)

    return similarity.cpu()

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import matplotlib.patches as patches

def visualise_scores(scores: torch.Tensor, images, texts: list[str]):
    for t_idx, text in enumerate(texts):
        for i_idx, image in enumerate(images):
            fig, ax = plt.subplots()
            ax.imshow(image)
            ax.set_title(f'Score: {scores[t_idx, i_idx]} / Prompt: {text}')

In [None]:
classes = {id: class_name for id, class_name in yolo_models[0].names.items()}
class_prompts = {id: f'A photo of a {class_name}' for id, class_name in yolo_models[0].names.items()}

with torch.no_grad():
    # Tensor with one row per class
    prompts_tensor = clip.tokenize(class_prompts.values()).to(device)

    # Tensor with one row per class and 512 columns (embeddings), normalized
    class_prompts_embeddings = models[0].encode_text(prompts_tensor)
    class_prompts_embeddings /= class_prompts_embeddings.norm(dim=-1, keepdim=True)
    class_prompts_embeddings = class_prompts_embeddings.to(device)

In [None]:
import torch
import torch.nn as nn
import clip
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.patches as patches

class CirclesModel(nn.Module):
    def __init__(self, device=None, models=None, preprocesses=None, yolo_models=None, classes=None, class_embeddings=None) -> None:
        super().__init__()
        
        if device:
            self.device = device
        else:
            self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
        
        if not models or not preprocesses:
            raise ValueError('Models and preprocesses for CLIP model should be provided')

        self.models = models
        self.preprocesses = preprocesses
        
        if not yolo_models:
            raise ValueError('Models for YOLO should be provided')
        self.yolo_models = yolo_models

        if classes is None or class_embeddings is None:
            raise ValueError('Classes and class embeddings shoulld be provided')
        self.classes = classes
        self.class_embeddings = class_embeddings

        self.transform_to_tensor = T.Compose([
            transforms.ToTensor()
        ])

    def forward(self, indices, images, prompts_list):
        self.device = indices.device
        if indices.is_cuda:
            self.device_index = int(str(self.device)[-1])
        else:
            self.device_index = 0

        # -- Getting the right data and moving it to the correct device --

        # Images remain on the CPU because they are PIL Images, not Tensors
        # Converting to Tensors leads to errors with YOLO
        images = [images[i] for i in indices]

        prompts_list = [prompts_list[i] for i in indices]
        # prompts_list = self.update_prompts(prompts_list)
        # prompts_list = self.update_prompts_with_circles(prompts_list)
        prompts_list = self.update_prompts_with_this_is(prompts_list)
        # print(prompts_list)
        prompts_tensor = [clip.tokenize(prompt_list).to(self.device) for prompt_list in prompts_list]

        # -- Actual processing --

        bounding_boxes = self.get_bounding_boxes(images)

        # It contains the predicted bounding box for each image for each prompt
        # Then, it is a list of length len(images) and for each entry there is a
        # list with len(prompts[i]), where i is the i-th image 
        overall_outputs = []

        with torch.no_grad():
            for idx, prompts_tensor_for_sample in enumerate(prompts_tensor):
                # Image crops
                # image_crops = self.get_cropped_bounding_boxes(images[idx], bounding_boxes.xyxy[idx])
                image_crops = self.get_highlighted_bounding_boxes(images[idx], bounding_boxes.pred[idx])

                preprocessed_image_crops = torch.stack([self.preprocesses[self.device_index](image).to(self.device) for image in image_crops])

                crop_features = self.models[self.device_index].encode_image(preprocessed_image_crops)
                crop_features /= crop_features.norm(dim=-1, keepdim=True)
                # print(crop_features.shape)

                # # Augment the features for each crop by adding the information about the label
                # crop_features_augmented = []
                # for crop_feature, pred in zip(crop_features, bounding_boxes.pred[idx]):
                #     augmented = crop_feature + self.class_embeddings[int(pred[-1])]
                #     crop_features_augmented.append(
                #         augmented / augmented.norm(dim=-1, keepdim=True)
                #     )
                # crop_features_augmented = torch.stack(crop_features_augmented).to(self.device)

                # Scaling is not required as cosine_similarity already scales.
                # This is to avoid redundant computations and speed up runtime
                text_features = self.models[self.device_index].encode_text(prompts_tensor_for_sample)

                text_similarity = cosine_similarity(self.class_embeddings.to(self.device), text_features).float()
                # print(text_similarity.shape)
                # class_max_val, class_max_idx = text_similarity.max(dim=1)
                prompt_categories_p = (100 * text_similarity).softmax(dim=-1)
                # max_sim_values, max_sim_indices = text_similarity.max(dim=1)
                # print(f'max sim: {max_sim_values}, {max_sim_indices}')
                # print(f'prompt categories_p: {prompt_categories_p.shape}')
                
                # fig, ax = plt.subplots()
                # ax.imshow(prompt_categories_p)
                # plt.colorbar()
                # print(class_max_idx, class_max_val)

                weights_for_crops = torch.zeros((prompt_categories_p.shape[0], len(image_crops))).to(self.device)
                # print(f'size: {weights_for_crops.shape}')
                for prompt_idx, t_s in enumerate(text_similarity):
                    for weight_idx, crop in enumerate(bounding_boxes.pred[idx]):
                        weights_for_crops[prompt_idx, weight_idx] = t_s[int(crop[-1])]
                        # print(f"text similarity for category {crop[-1]}: {weights_for_crops[prompt_idx, weight_idx]}")
                # weights_for_crops = weights_for_crops.mean(axis=0)
                # print(f'weights shape: {weights_for_crops.shape}, weights: {weights_for_crops}')
                similarity = cosine_similarity(crop_features, text_features).float().to(self.device)
                # print(similarity)
                similarity *= weights_for_crops

                # weighted_similarity = []
                # print('prompt sim')
                # for prompt_sim in similarity:
                #     print(prompt_sim, prompt_sim.shape)

                texts_p = (100 * similarity).softmax(dim=-1)

                # print(similarity.shape, texts_p.shape, bounding_boxes.pred[idx].shape)

                # # Weight the crops by the similarity of their label to the given prompt
                # # crop_features_augmented = []
                # final_p = torch.empty_like(texts_p)
                # for text_p, f_p, pred in zip(texts_p.T, final_p.T, bounding_boxes.pred[idx]):
                #     print("GIRO")
                #     print(text_p, pred)
                #     f_p = text_p * prompt_categories_p[0, int(pred[-1])]
                #     # augmented = crop_feature + self.class_embeddings[int(pred[-1])]
                #     # crop_features_augmented.append(
                #     #     augmented / augmented.norm(dim=-1, keepdim=True)
                #     # )
                #     # crop_feature *= prompt_categories_p[0, int(pred[-1])]
                # texts_p = final_p

                # crop_features_augmented = torch.stack(crop_features_augmented).to(self.device)

                # To return the cosine similarity between the best crops and the prompts
                # max_cos_sim_values, _ = similarity.max(dim=-1)
                # for max_value in max_cos_sim_values:
                #     overall_outputs.append(max_value.to(self.device))
                # continue

                _, max_indices = texts_p.max(dim=1)
                try:
                    for max_idx in max_indices:
                        overall_outputs.append(
                            torch.tensor(bounding_boxes.xyxy[idx][max_idx, 0:4]).to(self.device)
                        )
                except:
                    for max_idx in max_indices:
                        overall_outputs.append(
                            torch.tensor((0, 0, 0, 0)).to(self.device)
                        )

        return torch.stack(overall_outputs)

    def get_prompts(self, sample):
        return [prompt['sent'] for prompt in sample['image']['sentences']]

    def update_prompts(self, prompts):
        updated_prompts = []

        for sample in prompts:
            sample_prompts = []
            for prompt in sample:
                if 'left' in prompt:
                    prompt += ' with a red overlay'
                elif 'right' in prompt:
                    prompt += ' with a green overlay'
                sample_prompts.append(prompt)
            updated_prompts.append(sample_prompts)

        return updated_prompts
    
    def update_prompts_with_circles(self, prompts):
        updated_prompts = []

        for sample in prompts:
            sample_prompts = []
            for prompt in sample:
                prompt += ' that is highlighted'
                sample_prompts.append(prompt)
            updated_prompts.append(sample_prompts)

        return updated_prompts

    def update_prompts_with_this_is(self, prompts):
        updated_prompts = []

        for sample in prompts:
            sample_prompts = []
            for prompt in sample:
                prompt = 'This is ' + prompt
                sample_prompts.append(prompt)
            updated_prompts.append(sample_prompts)

        return updated_prompts

    def get_bounding_boxes(self, pil_images):
        bounding_boxes = self.yolo_models[self.device_index](pil_images)

        # bounding_boxes.show()

        return bounding_boxes
    
    def draw_circle(self, image_alpha, bounding_box):
        new_img = Image.new('RGBA', image_alpha.size, (0, 0, 0, 0))
        draw = ImageDraw.Draw(new_img)
        draw.ellipse((bounding_box[0].item(), bounding_box[1].item(), bounding_box[2].item(), bounding_box[3].item()),
                     outline='red', width=2)

        new_img = Image.alpha_composite(image_alpha, new_img)

        return new_img

    def draw_circle_and_darken(self, image_alpha, bounding_box):
        circle_mask = Image.new('L', image_alpha.size, 0)
        draw = ImageDraw.Draw(circle_mask)
        draw.ellipse((bounding_box[0].item(), bounding_box[1].item(), bounding_box[2].item(), bounding_box[3].item()), fill=255)

        alpha = 0.15
        # alpha = 0.7
        darkened_image = Image.new('RGB', image_alpha.size, (0, 0, 0))
        darkened_image.paste(image_alpha.convert('RGB'), mask=circle_mask)
        blurred_mask = circle_mask.filter(ImageFilter.GaussianBlur(radius=10))
        darkened_image.putalpha(blurred_mask.point(lambda x: alpha * (255 - x)))
        darkened_image = Image.alpha_composite(image_alpha, darkened_image.convert('RGBA'))

        draw = ImageDraw.Draw(darkened_image)
        draw.ellipse((bounding_box[0].item(), bounding_box[1].item(), bounding_box[2].item(), bounding_box[3].item()),
                     outline='red', width=4)

        return darkened_image
    
    def draw_circle_small_and_darken(self, image_alpha, bounding_box):
        # radius = int(0.06 * min(image_alpha.size))
        # radius = int(min(
        #     bounding_box[2].item() - bounding_box[0].item(),
        #     bounding_box[3].item() - bounding_box[1].item()) / 2)
        # center_x =  (bounding_box[0].item() + bounding_box[2].item()) / 2
        # center_y =  (bounding_box[1].item() + bounding_box[3].item()) / 2

        bbox = (bounding_box[0].item(), bounding_box[1].item(), bounding_box[2].item(), bounding_box[3].item())
        # bbox = (center_x - radius, center_y - radius, center_x + radius, center_y + radius)

        circle_mask = Image.new('L', image_alpha.size, 0)
        draw = ImageDraw.Draw(circle_mask)
        draw.ellipse(bbox, fill=255)

        alpha = 0.15
        alpha = 0.4
        # thickness = int(0.01 * min(image_alpha.size))
        thickness = 3
        darkened_image = Image.new('RGB', image_alpha.size, (0, 0, 0))
        darkened_image.paste(image_alpha.convert('RGB'), mask=circle_mask)
        blurred_mask = circle_mask.filter(ImageFilter.GaussianBlur(radius=10))
        darkened_image.putalpha(blurred_mask.point(lambda x: alpha * (255 - x)))
        darkened_image = Image.alpha_composite(image_alpha, darkened_image.convert('RGBA'))

        draw = ImageDraw.Draw(darkened_image)
        draw.ellipse(bbox,
                     outline='red', width=thickness)

        return darkened_image
    
    def draw_circle_small_and_blur_and_darken(self, image_alpha, bounding_box):
        # radius = int(0.06 * min(image_alpha.size))
        # radius = int(min(
        #     bounding_box[2].item() - bounding_box[0].item(),
        #     bounding_box[3].item() - bounding_box[1].item()) / 2)
        # center_x =  (bounding_box[0].item() + bounding_box[2].item()) / 2
        # center_y =  (bounding_box[1].item() + bounding_box[3].item()) / 2

        bbox = (bounding_box[0].item(), bounding_box[1].item(), bounding_box[2].item(), bounding_box[3].item())
        # bbox = (center_x - radius, center_y - radius, center_x + radius, center_y + radius)

        circle_mask = Image.new('L', image_alpha.size, 0)
        draw = ImageDraw.Draw(circle_mask)
        draw.ellipse(bbox, fill=255)

        alpha = 0.3
        # alpha = 0.2
        # thickness = int(0.01 * min(image_alpha.size))
        thickness = 2
        darkened_image = Image.new('RGB', image_alpha.size, (0, 0, 0))
        darkened_image.paste(image_alpha.convert('RGB'), mask=circle_mask)
        blurred_mask = circle_mask.filter(ImageFilter.GaussianBlur(radius=10))
        darkened_image.putalpha(blurred_mask.point(lambda x: alpha * (255 - x)))
        darkened_image = Image.alpha_composite(image_alpha.filter(ImageFilter.GaussianBlur(radius=1.5)), darkened_image.convert('RGBA'))
        darkened_image.paste(image_alpha.convert('RGB'), mask=circle_mask)

        draw = ImageDraw.Draw(darkened_image)
        draw.ellipse(bbox,
                     outline='red', width=thickness)

        return darkened_image
    
    def draw_rectangle_and_darken(self, image_alpha, bounding_box):
        circle_mask = Image.new('L', image_alpha.size, 0)
        draw = ImageDraw.Draw(circle_mask)
        draw.rectangle((bounding_box[0].item(), bounding_box[1].item(), bounding_box[2].item(), bounding_box[3].item()), fill=255)

        alpha = 0.15
        # alpha = 0.7
        darkened_image = Image.new('RGB', image_alpha.size, (0, 0, 0))
        darkened_image.paste(image_alpha.convert('RGB'), mask=circle_mask)
        blurred_mask = circle_mask.filter(ImageFilter.GaussianBlur(radius=15))
        darkened_image.putalpha(blurred_mask.point(lambda x: alpha * (255 - x)))
        darkened_image = Image.alpha_composite(image_alpha, darkened_image.convert('RGBA'))

        draw = ImageDraw.Draw(darkened_image)
        draw.rectangle((bounding_box[0].item(), bounding_box[1].item(), bounding_box[2].item(), bounding_box[3].item()),
                     outline='red', width=2)

        return darkened_image

    def get_highlighted_bounding_boxes(self, image, bounding_boxes):
        """
        Bounding boxes in the form:
        [top left x, top left y, bottom right x, bottom right y, confidence, categoy]
        """

        highlighted_bounding_boxes = []

        image_width, image_height = image.size
        image_alpha = image.convert('RGBA')
        
        for bbox_idx, bounding_box in enumerate(bounding_boxes):
            # print(f'bbox: {bounding_box}')

            # new_img = Image.new('RGBA', image.size, (0, 0, 0, 0))
            # draw = ImageDraw.Draw(new_img)
            # draw.ellipse((bounding_box[0].item(), bounding_box[1].item(), bounding_box[2].item(), bounding_box[3].item()), outline='red', width=4)

            # new_img = Image.alpha_composite(image_alpha, new_img)

            # new_img = self.draw_circle(image_alpha, bounding_box)
            # new_img = self.draw_circle_and_darken(image_alpha, bounding_box)
            # new_img = self.draw_circle_small_and_darken(image_alpha, bounding_box)
            new_img = self.draw_circle_small_and_blur_and_darken(image_alpha, bounding_box)
            # new_img = self.draw_rectangle_and_darken(image_alpha, bounding_box)

            # fig, ax = plt.subplots()
            # ax.imshow(new_img)
            # ax.set_title(crop_centroid_normalized)


            highlighted_bounding_boxes.append(new_img)

        if len(highlighted_bounding_boxes) == 0:
            highlighted_bounding_boxes.append(image)
                
        return highlighted_bounding_boxes

    def get_cropped_bounding_boxes(self, image, bounding_boxes):
        """
        Bounding boxes in the form:
        [top left x, top left y, bottom right x, bottom right y, confidence, categoy]
        """

        cropped_bounding_boxes = []

        image_width, image_height = image.size
        
        for bbox_idx, bounding_box in enumerate(bounding_boxes):
            # print(f'bbox: {bounding_box}')

            cropped_img = image.crop((bounding_box[0].item(), bounding_box[1].item(), bounding_box[2].item(), bounding_box[3].item()))

            # cropped_img

            # Centroid: (min + (max - min) / 2) / dimension
            crop_centroid_normalized = (
                (bounding_box[0].item() + (bounding_box[2].item() - bounding_box[0].item()) / 2) / image_width,
                (bounding_box[1].item() + (bounding_box[3].item() - bounding_box[1].item()) / 2 ) / image_height
            )

            if crop_centroid_normalized[0] < 0.5:
                overlay = Image.new('RGBA', cropped_img.size, overlay_colors[0])
            elif crop_centroid_normalized[0] > 0.5:
                overlay = Image.new('RGBA', cropped_img.size, overlay_colors[1])
            else:
                overlay = Image.new('RGBA', cropped_img.size, overlay_colors[-1])
            blended = Image.alpha_composite(cropped_img.convert('RGBA'), overlay)
            cropped_bounding_boxes.append(blended)

            # blended.show()
            # fig, ax = plt.subplots()
            # ax.imshow(blended)
            # ax.set_title(crop_centroid_normalized)


            # cropped_bounding_boxes.append(cropped_img)

        if len(cropped_bounding_boxes) == 0:
            cropped_bounding_boxes.append(image)
                
        return cropped_bounding_boxes

circles_model = CirclesModel(models=models, preprocesses=preprocesses, yolo_models=yolo_models, classes=classes, class_embeddings=class_prompts_embeddings)

overlay_colors = [
    # (0, 0, 0, 0),       # None,
    # (0, 0, 0, 0),       # None
    (255, 0, 0, 128),   # Red, alpha = 0.5
    (0, 255, 0, 128),   # Green, alpha = 0.5
    (0, 0, 255, 128),   # Blue, alpha = 0.5
    (0, 0, 0, 0),       # None
]

if torch.cuda.device_count() > 1:
    circles_model = torch.nn.DataParallel(circles_model)

In [None]:
from torchvision.ops import box_iou

def iou_metric(bounding_boxes, ground_truth_bounding_boxes):
    """
    Localization Accuracy Metric

    Intersection over Union (IoU) is a common metric measure for localization accuracy.
    """

    ground_truth_bounding_boxes = torch.tensor(ground_truth_bounding_boxes).unsqueeze(0).to(device)

    return box_iou(bounding_boxes, ground_truth_bounding_boxes)

def cosine_similarity_metric(bounding_boxes, ground_truth_bounding_boxes):
    """
    Cosine Similarity Metric

    Cosine similarity is a common metric measure for semantic similarity.
    """

    ground_truth_bounding_boxes = torch.tensor(ground_truth_bounding_boxes).to(device)
    
    return cosine_similarity(bounding_boxes, ground_truth_bounding_boxes)

### Tests

In [None]:
idx = 17 # man on the beach with frisbee
# idx = 20 # motorbikes
idx = 22 # cows on a beach
# idx = 25 # three oranges and a banana
# idx = 32 # guy with a horse and two busses
# idx = 33 # luggage
# idx = 34 # man on bed in front of a window
# idx = 35 # two zebras
# idx = 36 # two horses
# idx = 38 # chairs around a table with some sweets on top
# idx = 39 # two monitors
# idx = 42 # folks playing wii
# idx = 43 # yellow vehicle and surfboard
# idx = 44 # two women playing tennis
# idx = 45 # woman with a thing of bananas
# idx = 46 # industrial kitchen stove
# idx = 47 # two guys, one has a beard
# idx = 49 # girl eating pizza
# idx = 50 # vertical fork
# idx = 51 # sandwiches
# idx = 54 # computer on the right
# idx = 57 # woman playing tennis

img = next(iter(test_loader))[0][idx]
bbox_gt = next(iter(test_loader))[1][idx]
prompt = next(iter(test_loader))[2][idx]

img

In [None]:
prompt, bbox_gt

In [None]:
# prompt = ['the man on the right with a red overlay', 'the man with a blue shirt']#, 'a photo of a man who is about to throw a frisbee'] # idx == 17

# prompt = ['the red motorcycle with a blue overlay'] # idx == 20
# prompt = ['a red & black color bike in ftont of the three guys'] # idx == 20

prompt = ['the smaller animal'] # idx == 22

# prompt = ['the orange closest to the banana',
#     'orange with a green overlay',
#  'orange between other oranges and a banana',
    # 'A photo of a orange',
    # 'A photo of a dining table',
    # 'A photo of a banana'
    # ] # idx == 25

# prompt = ['the orange closest to the banana with a red overlay']

# prompt = ['near zebra with a red overlay', 'zebra eating grass with a red overlay'] # idx == 35

# prompt = ['the man with glasses'] # idx == 32

# prompt = [
#     'a man with beard wearing blue shirt with his friend',
#     'a man with a beard',
# ] # idx == 47

# prompt = ['the right computer in the right hand picture with a green overlay',
#  'the computer on the right in the right hand picture with a green overlay'] # idx == 54

# prompt = ['the woman on the right']# with a green overlay',
#  'the girl with the racket in the photo on the right with a green overlay'] # idx == 57

In [None]:
# transform = T.Compose([
#     T.Resize(size=224, interpolation=T.InterpolationMode.BICUBIC, max_size=None, antialias='warn'),
#     T.CenterCrop(size=(224, 224)),
#     T.ToTensor(),
# ])

# img_tensor = transform(img)
# print(img_tensor.shape)
# res = yolo_models[0](torch.stack([img_tensor]))

res = yolo_models[0](img)
# res.pred[0].cpu().numpy()[:, -1]
res.pred[0].cpu().numpy() 

In [None]:
# Draw a circle onto the image
from PIL import Image, ImageDraw, ImageFilter

for res_pred in res.pred[0]:
    bbox = res_pred[0:4].cpu().numpy()

    circle_mask = Image.new('L', img.size, 0)
    draw = ImageDraw.Draw(circle_mask)
    draw.ellipse(bbox, fill=255)

    alpha = 0.2
    darkened_image = Image.new('RGB', img.size, (0, 0, 0))
    darkened_image.paste(img, mask=circle_mask)
    blurred_mask = circle_mask.filter(ImageFilter.GaussianBlur(radius=10))
    darkened_image.putalpha(blurred_mask.point(lambda x: alpha * (255 - x)))
    darkened_image = Image.alpha_composite(img.convert('RGBA').filter(ImageFilter.GaussianBlur(radius=2)), darkened_image.convert('RGBA'))

    darkened_image.paste(img, mask=circle_mask)

    draw = ImageDraw.Draw(darkened_image)
    draw.ellipse(bbox, outline='red', width=4)

    # img_new = Image.new('RGBA', img.size, (0, 0, 0, 0))
    # draw = ImageDraw.Draw(img_new)
    # draw.ellipse(bbox, outline='red', width=4)

    # img_new = Image.alpha_composite(img.convert('RGBA'), img_new)

    fig, ax = plt.subplots()
    ax.imshow(darkened_image)

In [None]:
indices = torch.tensor([range(1)]).to(device)
outputs = circles_model(indices, [img], [prompt])

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

output_idx = 0

# Loading the image
# img = img[output_idx]

# Preparing the output
fig, ax = plt.subplots()

# Display the image
ax.imshow(img)

colors = ['r', 'b', 'g', 'yellow', 'orange']

# print(outputs.shape, bbox.unsqueeze(0).shape)

gt_and_outputs = torch.cat([torch.tensor(bbox_gt).unsqueeze(0).to(device), outputs])
# print(gt_and_outputs, bbox_gt)

# Create a Rectangle patch
for bbox_idx, (bbox, color) in enumerate(zip(gt_and_outputs, colors)):
    bounding_box_coordinates = bbox.cpu()
    top_left_x, top_left_y = bounding_box_coordinates[0], bounding_box_coordinates[1]
    width, height = bounding_box_coordinates[2]- top_left_x, bounding_box_coordinates[3] - top_left_y

    # Parameters: (x, y), width, height
    rect = patches.Rectangle((top_left_x, top_left_y), width, height, linewidth=2 if bbox_idx == 0 else 1, edgecolor=color, facecolor='none')

    # Add the patch to the Axes
    ax.add_patch(rect)

ax.set_title(prompt)

In [None]:
class_prompts

In [None]:
prompts_tensor = clip.tokenize(prompt).to(device)

with torch.no_grad():
    prompts_features = models[0].encode_text(prompts_tensor)

In [None]:
import matplotlib.pyplot as plt

cos_sim_prompt_categories = cosine_similarity(class_prompts_embeddings, prompts_features).float()
prompt_categories_p = (100 * cos_sim_prompt_categories).softmax(dim=-1)
print(prompt_categories_p.shape)

max_value, max_index = cos_sim_prompt_categories.max(dim=-1)
print(max_value, max_index)

plt.imshow(prompt_categories_p)
plt.colorbar()

In [None]:
prompts_tensor = clip.tokenize(prompt).to(device)
text_features = models[0].encode_text(prompts_tensor)

Prendiamo tutte le classi di Yolo, le trasformiamo in "A photo of {category}" e calcoliamo l'embedding. Poi, quando dobbiamo valutarte un bounding box, la pesiamo per la cosine similarity tra il prompt dato in input e quello "A photo of {category}" in cui "category" è la label che Yolo assegna al bounding box.

In [None]:
cosine_similarity(text_features, text_features)#[0][1:]

### To compute average cosine similarity between embeddings

In [None]:
overall_outputs = []

for batch_idx, (images, gt_bounding_boxes, prompts) in enumerate(test_loader):
    print(f'-- Batch index: {batch_idx} --')

    prompts_tensor = [clip.tokenize(prompt_list) for prompt_list in prompts]
    
    indices = torch.tensor(list(range(len(images)))).to(device)
    outputs = circles_model(indices, images, prompts)

    overall_outputs.append(outputs)

In [None]:
cos_sim_cpu = []
for out in overall_outputs:
    for cos_sim_val in out:
        cos_sim_cpu.append(cos_sim_val.item())
cos_sim_cpu = np.array(cos_sim_cpu)
np.nanmean(cos_sim_cpu)

### To compute standard metrics

In [None]:
from torchvision.ops import boxes as box_ops

IoUs = []
cosine_similarities = []
  
for batch_idx, (images, gt_bounding_boxes, prompts) in enumerate(test_loader):
    print(f'-- Batch index: {batch_idx} --')

    prompts_tensor = [clip.tokenize(prompt_list) for prompt_list in prompts]
    
    indices = torch.tensor(list(range(len(images)))).to(device)
    outputs = circles_model(indices, images, prompts)

    outputs_grouped_by_sample = []
    outputs_idx = 0
    prompts_idx = 0
    while True:
        if not prompts_idx < len(images):
            break

        outputs_grouped_by_sample.append(
            outputs[outputs_idx : outputs_idx + len(prompts[prompts_idx])]
        )

        outputs_idx += len(prompts[prompts_idx])
        prompts_idx += 1

    for output_bboxes, gt_bboxes in zip(outputs_grouped_by_sample, gt_bounding_boxes):
        """
        There is one output bounding box for each prompt given in input.
        Note that each prompt for a given input is actually a list of prompts,
        therefore it can contain an arbitrary number of promps. Hence, there is
        a bounding box for each one of them.
        """

        result_ious = iou_metric(output_bboxes, gt_bboxes)
        result_cosine_similarity = cosine_similarity_metric(output_bboxes, gt_bboxes)

        for iou in result_ious:
            IoUs.append(iou)

        for cs in result_cosine_similarity:
            cosine_similarities.append(cs)

In [None]:
IoUs_to_cpu = np.array([tensor.item() if torch.is_tensor(tensor) else 0 for tensor in IoUs])
mIoU = np.nanmean(IoUs_to_cpu)

cosine_similarities_to_cpu = np.array([tensor.item() if torch.is_tensor(tensor) else 0 for tensor in cosine_similarities])
m_cos_sim = np.nanmean(cosine_similarities_to_cpu)

print('--- Metrics ---')
print(f'Mean Intersection over Union (mIoU): {mIoU}')
print(f'Mean Cosine Similarity: {m_cos_sim}')

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

output_idx = 0

# Loading the image
img = images[output_idx]

# Preparing the output
fig, ax = plt.subplots()

# Display the image
ax.imshow(img)

colors = ['r', 'b', 'g']

# Create a Rectangle patch
for bbox, color in zip(outputs_grouped_by_sample[output_idx][1:2], colors):
    bounding_box_coordinates = bbox.cpu()
    top_left_x, top_left_y = bounding_box_coordinates[0], bounding_box_coordinates[1]
    width, height = bounding_box_coordinates[2]- top_left_x, bounding_box_coordinates[3] - top_left_y

    # Parameters: (x, y), width, height
    rect = patches.Rectangle((top_left_x, top_left_y), width, height, linewidth=1, edgecolor=color, facecolor='none')

    # Add the patch to the Axes
    ax.add_patch(rect)

ax.set_title(prompts[output_idx][1])

In [None]:
available_gpus = [torch.cuda.device(i) for i in range(torch.cuda.device_count())]
available_gpus