In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms as T
import torch.nn.functional as F
from torchvision.io import read_image, ImageReadMode
from PIL import Image
import os
import json
import pickle
import numpy as np

In [3]:
class RefCocoG_Dataset(Dataset):
    full_annotations = None

    def __init__(self, root_dir, annotations_f, instances_f, split='train', transform=None, target_transform=None) -> None:
        super().__init__()

        self.root_dir = root_dir
        self.annotations_f = annotations_f
        self.instances_f = instances_f

        self.split = split

        self.transform = transform
        self.target_transform = target_transform

        self.get_annotations()
        self.image_names = list([
            self.annotations[id]['image']['actual_file_name']
            for id in self.annotations
        ])

    def get_annotations(self):
        if RefCocoG_Dataset.full_annotations:
            self.annotations = dict(filter(lambda match: match[1]['image']['split'] == self.split, RefCocoG_Dataset.full_annotations.items()))
            return

        # Load pickle data
        with open(os.path.join(self.root_dir, 'annotations', self.annotations_f), 'rb') as file:
            self.data = pickle.load(file)

        # Load instances
        with open(os.path.join(self.root_dir, 'annotations', self.instances_f), 'rb') as file:
            self.instances = json.load(file)

        # Match data between the two files and build the actual dataset
        self.annotations = {}

        images_actual_file_names = {}
        for image in self.instances['images']:
            images_actual_file_names[image['id']] = image['file_name']

        for image in self.data:
            if image['ann_id'] not in self.annotations:
                self.annotations[image['ann_id']] = {}

            self.annotations[image['ann_id']]['image'] = image
            self.annotations[image['ann_id']]['image']['actual_file_name'] = images_actual_file_names[image['image_id']]

        for annotation in self.instances['annotations']:
            if annotation['id'] not in self.annotations:
                continue

            self.annotations[annotation['id']]['annotation'] = annotation

        # Keep only samples from the given split
        RefCocoG_Dataset.full_annotations = self.annotations
        self.annotations = dict(filter(lambda match: match[1]['image']['split'] == self.split, self.annotations.items()))

    def __len__(self):
        # Return the number of images
        return len(self.image_names)

    def corner_size_to_corners(self, bounding_box):
        """
        Transform (top_left_x, top_left_y, width, height) bounding box representation
        into (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
        """

        return [
            bounding_box[0],
            bounding_box[1],
            bounding_box[0] + bounding_box[2],
            bounding_box[1] + bounding_box[3]
        ]

    def __getitem__(self, idx):
        # Get the image name at the given index
        image_name = self.image_names[idx]

        # Load the image file as a PIL image
        image = Image.open(os.path.join(self.root_dir, 'images', image_name)).convert('RGB')
        
        image_id = list(self.annotations)[idx]

        # Get the caption for the image
        prompts = [
            prompt['sent'] for prompt in self.annotations[image_id]['image']['sentences']
        ]

        # Get the bounding box for the prompts for the image
        bounding_box = self.corner_size_to_corners(self.annotations[image_id]['annotation']['bbox'])

        # Apply the transform if given
        if self.transform:
            image = self.transform(image)

        sample = [
            image,
            bounding_box,
            prompts,
        ]

        # Return the sample as a list
        return sample

In [4]:
# Load the dataset with the three splits

dataset_train = RefCocoG_Dataset('refcocog', 'refs(umd).p', 'instances.json', split='train')
dataset_val = RefCocoG_Dataset('refcocog', 'refs(umd).p', 'instances.json', split='val')
dataset_test = RefCocoG_Dataset('refcocog', 'refs(umd).p', 'instances.json', split='test')

dataset_splits = [
    dataset_train,
    dataset_val,
    dataset_test
]

In [5]:
# Display how many samples there are per split to check whether
# it was loaded correctly
len(RefCocoG_Dataset.full_annotations), len(dataset_train.annotations), len(dataset_val.annotations), len(dataset_test.annotations)

(49820, 42224, 2573, 5023)

In [6]:
# In order to be able to move lists of objects around (especially lists of
# PIL Images, rather than tensors), we need a custom collation function.
# This ensures we can feed the original images to the pipeline, rather
# than tensor-transformed (with scaling, cropping, etc.) versions.

def collate_differently_sized_prompts(batch):
    images = [item[0] for item in batch]
    bboxes = [item[1] for item in batch]
    prompts = [item[2] for item in batch]
    
    return list(images), list(bboxes), list(prompts)

def get_data(dataset_splits, batch_size=64, test_batch_size=256, num_workers=0):
    training_data = dataset_splits[0]
    validation_data = dataset_splits[1]
    test_data = dataset_splits[2]

    train_loader = torch.utils.data.DataLoader(training_data, batch_size, shuffle=True, drop_last=True, collate_fn=collate_differently_sized_prompts, num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(validation_data, test_batch_size, shuffle=False, collate_fn=collate_differently_sized_prompts, num_workers=num_workers)
    test_loader = torch.utils.data.DataLoader(test_data, test_batch_size, shuffle=False, collate_fn=collate_differently_sized_prompts, num_workers=num_workers)

    return train_loader, val_loader, test_loader

In [7]:
train_loader, val_loader, test_loader = get_data(dataset_splits, batch_size=128, test_batch_size=64, num_workers=0)

In [8]:
if torch.cuda.is_available():
    device = torch.device("cuda:0") # First GPU
else:
    device = 'cpu'

In [9]:
import clip

clip_backbone = 'RN50'
# clip_backbone = 'RN50x16'
# clip_backbone = 'ViT-B/16'

model, preprocess = clip.load(clip_backbone, device=device, jit=False)

In [10]:
def disable_grad(num_layers=0):
    for idx, param in enumerate(model.parameters()):
        # param.to(f'cuda:{idx % 2}')
        if idx  < num_layers:
            param.requires_grad = False
        else:
            break

In [11]:
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()

        disable_grad(num_layers=250)

        if p.grad is not None:
            p.grad.data = p.grad.data.float()

In [12]:
if device == 'cpu':
    model.float()
else:
    clip.model.convert_weights(model)

In [13]:
len(list(enumerate(model.parameters())))

324

In [14]:
disable_grad(num_layers=320)

In [14]:
import torch.nn as nn
import torch.optim as optim

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2)

In [15]:
from PIL import Image, ImageDraw, ImageFilter

class ImageAugmenter():
    def __init__(self, device):#, yolo_model=None) -> None:
        super().__init__()

    def augment_images(self, images, bounding_boxes):
        augmented_images = []

        # bounding_boxes = self.get_bounding_boxes(images)
        for idx, image in enumerate(images):
            image_augmentations, bounding_boxes[idx] = self.get_visual_prompts(image, bounding_boxes[idx].unsqueeze(0))
            augmented_images += image_augmentations

        return augmented_images

    def get_image_with_marker_and_blur(self, image, bbox, stroke_color='red', stroke_width=1, blur_radius=1):
        """
        Add a visual marker to the image at the position specified by the
        bounding box (bbox), which is expected to be in the format
        (top_left_x, top_left_y, bottom_right_x, bottom_right_y).
        The background is then blurred.
        """
        
        result = image.filter(ImageFilter.GaussianBlur(radius=blur_radius))
        mask = Image.new('L', image.size, 0)
        draw = ImageDraw.Draw(mask)
        draw.rectangle(bbox, fill=255)
        result.paste(image, mask=mask)
        draw = ImageDraw.Draw(result)
        draw.ellipse(bbox, outline=stroke_color, width=stroke_width)
        
        return result

    def get_image_with_marker_and_blur_grayscale(self, image, bbox, stroke_color='red', stroke_width=1, blur_radius=1):
        """
        Add a visual marker to the image at the position specified by the
        bounding box (bbox), which is expected to be in the format
        (top_left_x, top_left_y, bottom_right_x, bottom_right_y).
        The background is both grayscaled and blurred.
        """
        
        result = image.filter(ImageFilter.GaussianBlur(radius=blur_radius)).convert('L').convert('RGB')
        mask = Image.new('L', image.size, 0)
        draw = ImageDraw.Draw(mask)
        draw.rectangle(bbox, fill=255)
        result.paste(image, mask=mask)
        draw = ImageDraw.Draw(result)
        draw.ellipse(bbox, outline=stroke_color, width=stroke_width)

        return result

    def get_visual_prompts(self, image, bounding_boxes):
        self.visual_augmentation = 2
        visual_prompts = []
        keep_bbox = []

        if bounding_boxes is None:
            return [image] * self.visual_augmentation, bounding_boxes

        # Setting the parameters for the visual markers
        stroke_color = 'red'
        stroke_width = 3
        blur_radius = 20

        for idx, bounding_box in enumerate(bounding_boxes if bounding_boxes is not None else []):
            bounding_box = (bounding_box[0].item(), bounding_box[1].item(), bounding_box[2].item(), bounding_box[3].item())

            # For the following line to work correctly bounding boxes should actually be removed from
            # YOLO's results, as that's what is actually used in the end
            # if (bounding_box[2] - bounding_box[0]) * (bounding_box[3] - bounding_box[1]) < (image.size[0] * image.size[1]) * 0.8:
            #     continue

            # If the previous condition was uncommented, this line would not execute
            # for boxes covering more than 80% of the area of the image, thus they
            # would be removed from the results
            keep_bbox += [idx]

            # Uncomment the following lines to add or remove visual markers.
            # Remember to update `self.visual_augmentation` to match the number
            # of visual prompts that are generated for each region proposal.
            bbox_visual_prompts = [
                # self.get_image_with_marker(image, bounding_box, stroke_color=stroke_color, stroke_width=stroke_width),
                self.get_image_with_marker_and_blur(image, bounding_box, stroke_color=stroke_color, stroke_width=stroke_width, blur_radius=blur_radius),
                # self.get_image_with_marker_and_grayscale(image, bounding_box, stroke_color=stroke_color, stroke_width=stroke_width),
                self.get_image_with_marker_and_blur_grayscale(image, bounding_box, stroke_color=stroke_color, stroke_width=stroke_width, blur_radius=blur_radius),
            ]

            for el in bbox_visual_prompts:
                visual_prompts.append(el)

        # if self.device_index == 0:
        #     print('From', len(bounding_boxes))
        bounding_boxes = bounding_boxes[keep_bbox]
        # if self.device.index == 0:
        #     print('To', len(bounding_boxes))

        if len(visual_prompts) == 0:
            # If no region proposal, return the whole image.
            # It is inserted as many times as each region would
            # be augmented to ensure consistency in the algorithm
            for _ in range(self.visual_augmentation):
                visual_prompts.append(image)
                
        return visual_prompts, bounding_boxes

In [16]:
image_augmenter = ImageAugmenter(device=device)

In [18]:
EPOCHS = 50

# torch.cuda.empty_cache()

# with torch.autocast(device_type='cuda'):
for epoch in range(EPOCHS):
    print(f'Epoch: {epoch}')

    for batch_idx, (images, gt_bounding_boxes, prompts) in enumerate(train_loader):
        print(f'-- Batch: {batch_idx}')
        optimizer.zero_grad()

        # For each image, get a random prompt from the list of its prompts
        prompts = [
            prompt_list[np.random.randint(len(prompt_list))]
            for prompt_list in prompts
            # for i in range(2)
        ]

        images = image_augmenter.augment_images(images, torch.tensor(gt_bounding_boxes))
        sampled_images = [
            images[i + np.random.randint(2)]
            for i in range(0, len(images), 2)
        ]
        images = sampled_images

        images = torch.stack([preprocess(image) for image in images]).to(device)
        prompts = clip.tokenize(prompts).to(device)

        logits_per_image, logits_per_text = model(images, prompts)
        
        ground_truth = torch.arange(len(images), dtype=torch.long, device=device)

        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
        total_loss.backward()

        if device == 'cpu':
            optimizer.step()
        else:
            convert_models_to_fp32(model)
            optimizer.step()
            clip.model.convert_weights(model)

        torch.cuda.empty_cache()

    torch.save(model, f'CLIP_{clip_backbone}_fine_tuned_{10 + epoch}.pt')

Epoch: 0
-- Batch: 0
-- Batch: 1
-- Batch: 2
-- Batch: 3
-- Batch: 4
-- Batch: 5
-- Batch: 6
-- Batch: 7
-- Batch: 8
-- Batch: 9
-- Batch: 10
-- Batch: 11
-- Batch: 12
-- Batch: 13
-- Batch: 14
-- Batch: 15
-- Batch: 16
-- Batch: 17
-- Batch: 18
-- Batch: 19
-- Batch: 20
-- Batch: 21
-- Batch: 22
-- Batch: 23
-- Batch: 24
-- Batch: 25
-- Batch: 26
-- Batch: 27
-- Batch: 28
-- Batch: 29
-- Batch: 30
-- Batch: 31
-- Batch: 32
-- Batch: 33
-- Batch: 34
-- Batch: 35
-- Batch: 36
-- Batch: 37
-- Batch: 38
-- Batch: 39
-- Batch: 40
-- Batch: 41
-- Batch: 42
-- Batch: 43
-- Batch: 44
-- Batch: 45
-- Batch: 46
-- Batch: 47
-- Batch: 48
-- Batch: 49
-- Batch: 50
-- Batch: 51
-- Batch: 52
-- Batch: 53
-- Batch: 54
-- Batch: 55
-- Batch: 56
-- Batch: 57
-- Batch: 58
-- Batch: 59
-- Batch: 60
-- Batch: 61
-- Batch: 62
-- Batch: 63
-- Batch: 64
-- Batch: 65
-- Batch: 66
-- Batch: 67
-- Batch: 68
-- Batch: 69
-- Batch: 70
-- Batch: 71
-- Batch: 72
-- Batch: 73
-- Batch: 74
-- Batch: 75
-- Batch: 76
