In [1]:
# Import Libraries
import os
import torch, cv2
import open_clip
from PIL import Image
from utils.config import *
from utils.common_utils import read_json_data
from utils.dataset import CaptionInContext
from utils.text_utils import get_text_metadata
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.optim as optim

from torch.utils.data import DataLoader
from utils.dataset_utils import PadCollate1
import torch.nn.functional as F
import matplotlib.pyplot as pltc

import torchvision
import numpy as np
from tqdm import tqdm

import os.path
import torch
import itertools as it
from torch.utils.data import Dataset
from utils.dataset_utils import modify_caption_replace_entities
from utils.common_utils import read_json_data
from utils.config import num_boxes, DATA_DIR
from utils.custom_transforms.data_aug import *

2023-07-26 20:20:09.910125: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-07-26 20:20:20.936981: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1956] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [2]:
transform3 = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
transform2 = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), transforms.ColorJitter(hue=.2, saturation=.2), transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
transform1 = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), transforms.RandomHorizontalFlip(p=0.5), transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])

In [3]:

def load_images(imgPath, bboxes):
    images = []
    image = Image.open(imgPath)
    for j, bbox in enumerate(bboxes):
        img1 = image.crop((int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])))
        if img1.size[0] > 0 and img1.size[1] > 0:
            images.append(img1)
    return images

In [4]:

class AugDataset(Dataset):
    """Custom dataset class for Out-of-Context Detection"""                                         ################################

    def __init__(self, metadata_file, mode, transforms, text_field=None, max_samples=None, slice_start=None, slice_end=None):
        """
            Initializes the dataset object
            Args:
                metadata_file (string): Path to the json file with annotations.
                mode (string): train, val, test.
                transforms (callable): Transform to be applied on a sample.
                text_field (torchtext.Field): Vocab object for applying text transformations (used for Glove and Fasttext embeddings only)
            Returns:
                None
        """
        self.data = read_json_data(metadata_file)
        self.mode = mode
        self.transforms = transforms
        self.text_field = text_field
        self.max_samples = max_samples

        self.flip_rotate_transform = Sequence(
            [RandomHorizontalFlip(0.8), RandomScale(0.2, diff=True), RandomRotate(10)])

        if max_samples is not None:
            self.data = self.data[:max_samples]
        #########################################
        if slice_start is not None or slice_end is not None:
        ###########################################
            self.data = self.data[slice_start:slice_end]    
    
    
    def __getitem__(self, index):
        """
            Returns sample corresponding to the index `index`
        """
        img_data = self.data[index]
        img_path = os.path.join(DATA_DIR, img_data['img_local_path'])
        bboxes = img_data['maskrcnn_bboxes'][:10]

        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)

        try:
            img_aug, bboxes_aug = self.flip_rotate_transform(img, np.array(bboxes))
            bboxes_aug = bboxes_aug.tolist()
            bboxes = list(it.islice(it.cycle(bboxes_aug), num_boxes - 1))
            img = img_aug
        except:
            pass

        if self.mode == 'test':
            idx1 = random.randint(0, 1)
            cap_key1 = 'caption1' if idx1 == 0 else 'caption2'
            caption1 = img_data[cap_key1]
            caption1 = modify_caption_replace_entities(caption1)

            while True:
                idx2 = random.randint(0, 1)
                cap_key2 = 'caption1' if idx2 == 0 else 'caption2'
                tgt_index = random.randint(0, len(self.data) - 1)
                caption2 = self.data[tgt_index][cap_key2]
                caption2 = modify_caption_replace_entities(caption2)
                if caption1 != caption2:
                    break
        else:
            src_captions = img_data['articles']
            caption1 = src_captions[random.randint(0, len(src_captions) - 1)]['caption_modified']

            while True:
                tgt_index = random.randint(0, len(self.data) - 1)
                tgt_captions = self.data[tgt_index]['articles']
                caption2 = tgt_captions[random.randint(0, len(tgt_captions) - 1)]['caption_modified']
                if caption1 != caption2:
                    break
        text = caption1
        
        return text, bboxes, img_path
        # return img_path, img_aug1, img_aug2, images[0], images[1], images[2], images[3], images[4], images[5], images[6], images[7], images[8], images[9], bboxes, text

    def __len__(self):
        """
            Returns length of the dataset
        """
        return len(self.data)

In [5]:
# Word Embeddings
text_field, word_embeddings, vocab_size = get_text_metadata()
train_dataset = AugDataset(metadata_file=os.path.join(DATA_DIR, 'annotations', 'train_data.json'),
                                 transforms=transform3, mode='train', text_field=text_field, max_samples=5000)

val_dataset  = AugDataset(metadata_file=os.path.join(DATA_DIR, 'annotations', 'val_data.json'),
                               transforms=transform3, mode='val', text_field=text_field, max_samples=1000)

In [6]:
# Creating data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=6, shuffle=True, collate_fn=PadCollate1())

val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, num_workers=6, shuffle=False, collate_fn=PadCollate1())

In [8]:
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k', device=device, jit=False)
tokenizer = open_clip.get_tokenizer('ViT-B-32')

In [9]:

class CLIPModel(nn.Module):
    def __init__(self, model):
        super(CLIPModel, self).__init__()

        # Preprocessing layers
        self.model = model
        self.fc = nn.Linear(512, 512)

        # Set requires_grad to True for the parameters
        for param in self.model.parameters():
            param.requires_grad = False

        for param in self.model.transformer.resblocks[-1].mlp.parameters():
            param.requires_grad_(True)

        for param in self.model.visual.transformer.resblocks[-1].mlp.parameters():
            param.requires_grad_(True)

        for param in self.model.ln_final.parameters():
           param.requires_grad = True

        for param in self.fc.parameters():
            param.requires_grad = True

    def forward(self, orgImg, img_aug1, img_aug2, images, tokenisedCaption):

        # Create a list to store the encoded features for each image in the batch
        feature_list = []
        
        encoded_orgImg = self.fc(self.model.encode_image(orgImg.to(device)))
        encoded_orgImg = encoded_orgImg / encoded_orgImg.norm(dim=-1, keepdim=True)
        feature_list.append(encoded_orgImg)

        encoded_img_aug1 = self.fc(self.model.encode_image(img_aug1.to(device)))
        encoded_img_aug1 = encoded_img_aug1 / encoded_img_aug1.norm(dim=-1, keepdim=True)
        feature_list.append(encoded_img_aug1)

        encoded_img_aug2 = self.fc(self.model.encode_image(img_aug2.to(device)))
        encoded_img_aug2 = encoded_img_aug2 / encoded_img_aug2.norm(dim=-1, keepdim=True)
        feature_list.append(encoded_img_aug2)

        for image in images:
            encoded_image = self.fc(self.model.encode_image(image.to(device)))
            encoded_image = encoded_image / encoded_image.norm(dim=-1, keepdim=True)
            feature_list.append(encoded_image)

        encoded_text = self.fc(self.model.encode_text(tokenisedCaption.to(device)))
        encoded_text = encoded_text / encoded_text.norm(dim=-1, keepdim=True)
        feature_list.append(encoded_text)

        return feature_list

clipModel = CLIPModel(model)
model_name = "ClipModelResUnfreezeSCL_5k"

In [10]:
class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        # print(self.temperature)
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature
        # print(self.base_temperature)

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        # device = (torch.device('cuda:1')
        #           if features.is_cuda
        #           else torch.device('cuda:0'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        # print(logits_max)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        # print(mean_log_prob_pos)
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

In [11]:
optimizer = optim.Adam([p for p in clipModel.parameters() if p.requires_grad],lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=True)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=5, verbose=True)

print("Total Params", sum(p.numel() for p in clipModel.parameters() if p.requires_grad))
print("Img Model", sum(p.numel() for p in clipModel.model.parameters() if p.requires_grad))

Total Params 7085824
Img Model 6823168


In [12]:
temp = 0.05
criterion = SupConLoss(temperature=temp)

In [13]:
def train_model(epoch):

    train_loss = 0.
    clipModel.to(device)
    optimizer.zero_grad()
    clipModel.train()
    clipModel.zero_grad()

    # Training loop
    for batch_idx, (caption, bboxes, img_path) in enumerate(tqdm(train_loader)):
        batch = len(bboxes)
        encoded_images_lists = [[] for _ in range(14)]
        with torch.set_grad_enabled(True):      
            for i, imgPath in enumerate(img_path):
                tokenisedCaption = tokenizer([caption[i][:77]])
                
                imgPath = os.path.join(DATA_DIR, imgPath)
                img = Image.open(imgPath)

                img_aug1 = transform1(preprocess(img)).unsqueeze(0)
                img_aug2 = transform2(preprocess(img)).unsqueeze(0)
                orgImg = transform3(preprocess(img)).unsqueeze(0)
                images = load_images(imgPath, bboxes[i])
                images = [transform3(preprocess(imgs)).unsqueeze(0) for imgs in images]

                features = clipModel(orgImg, img_aug1, img_aug2, images, tokenisedCaption)
                for i, feature in enumerate(features):
                    encoded_images_lists[i].append(feature)
             
            encoded_images_lists = [torch.stack(encoded_images_list) for encoded_images_list in encoded_images_lists]
            encoded_images_lists = torch.stack(encoded_images_lists, dim=1)
            loss = criterion(encoded_images_lists)
            train_loss += float(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            clipModel.to(device)

            torch.cuda.empty_cache()
            torch.cuda.synchronize() 
            del encoded_images_lists

            print('For Batch: {}, Total Loss: {:.4f}, Current Loss: {:.4f}'.format(int(batch_idx), train_loss / len(train_loader), loss))

    print(' Train Epoch: {} Loss: {:.4f}'.format(epoch, train_loss / len(train_loader)))


In [14]:
train_model(5)

  2%|▎         | 1/40 [00:32<21:18, 32.77s/it]

For Batch: 0.0000, Total Loss: 0.1641, Current Loss: 6.5625


  5%|▌         | 2/40 [01:01<19:09, 30.25s/it]

For Batch: 1.0000, Total Loss: 0.3048, Current Loss: 5.6300


  8%|▊         | 3/40 [01:30<18:22, 29.81s/it]

For Batch: 2.0000, Total Loss: 0.4387, Current Loss: 5.3543


 10%|█         | 4/40 [02:02<18:26, 30.72s/it]

For Batch: 3.0000, Total Loss: 0.5874, Current Loss: 5.9484


 12%|█▎        | 5/40 [02:32<17:38, 30.24s/it]

For Batch: 4.0000, Total Loss: 0.7208, Current Loss: 5.3386


 15%|█▌        | 6/40 [03:00<16:45, 29.56s/it]

For Batch: 5.0000, Total Loss: 0.8546, Current Loss: 5.3497


 18%|█▊        | 7/40 [03:28<16:03, 29.19s/it]

For Batch: 6.0000, Total Loss: 0.9882, Current Loss: 5.3455


 20%|██        | 8/40 [03:56<15:24, 28.89s/it]

For Batch: 7.0000, Total Loss: 1.1218, Current Loss: 5.3421


 22%|██▎       | 9/40 [04:29<15:28, 29.97s/it]

For Batch: 8.0000, Total Loss: 1.2550, Current Loss: 5.3272


 25%|██▌       | 10/40 [05:01<15:22, 30.75s/it]

For Batch: 9.0000, Total Loss: 1.3917, Current Loss: 5.4684


 28%|██▊       | 11/40 [05:33<15:01, 31.10s/it]

For Batch: 10.0000, Total Loss: 1.5251, Current Loss: 5.3380


 30%|███       | 12/40 [06:04<14:24, 30.89s/it]

For Batch: 11.0000, Total Loss: 1.6587, Current Loss: 5.3452


 32%|███▎      | 13/40 [06:36<14:05, 31.30s/it]

For Batch: 12.0000, Total Loss: 1.7924, Current Loss: 5.3474


 35%|███▌      | 14/40 [07:05<13:14, 30.57s/it]

For Batch: 13.0000, Total Loss: 1.9261, Current Loss: 5.3469


 38%|███▊      | 15/40 [07:33<12:30, 30.01s/it]

For Batch: 14.0000, Total Loss: 2.0598, Current Loss: 5.3484


 40%|████      | 16/40 [08:02<11:51, 29.65s/it]

For Batch: 15.0000, Total Loss: 2.1935, Current Loss: 5.3489


 42%|████▎     | 17/40 [08:34<11:39, 30.41s/it]

For Batch: 16.0000, Total Loss: 2.3272, Current Loss: 5.3481


 45%|████▌     | 18/40 [09:07<11:20, 30.93s/it]

For Batch: 17.0000, Total Loss: 2.4610, Current Loss: 5.3487


 48%|████▊     | 19/40 [09:35<10:32, 30.14s/it]

For Batch: 18.0000, Total Loss: 2.5946, Current Loss: 5.3472


 50%|█████     | 20/40 [10:03<09:50, 29.53s/it]

For Batch: 19.0000, Total Loss: 2.7283, Current Loss: 5.3480


 52%|█████▎    | 21/40 [10:31<09:14, 29.18s/it]

For Batch: 20.0000, Total Loss: 2.8620, Current Loss: 5.3478


 55%|█████▌    | 22/40 [10:59<08:39, 28.87s/it]

For Batch: 21.0000, Total Loss: 2.9957, Current Loss: 5.3477


 57%|█████▊    | 23/40 [11:28<08:07, 28.68s/it]

For Batch: 22.0000, Total Loss: 3.1294, Current Loss: 5.3460


 60%|██████    | 24/40 [11:56<07:37, 28.58s/it]

For Batch: 23.0000, Total Loss: 3.2630, Current Loss: 5.3456


 62%|██████▎   | 25/40 [12:25<07:08, 28.55s/it]

For Batch: 24.0000, Total Loss: 3.3966, Current Loss: 5.3448


 65%|██████▌   | 26/40 [12:53<06:38, 28.47s/it]

For Batch: 25.0000, Total Loss: 3.5302, Current Loss: 5.3442


 68%|██████▊   | 27/40 [13:21<06:09, 28.42s/it]

For Batch: 26.0000, Total Loss: 3.6638, Current Loss: 5.3412


 70%|███████   | 28/40 [13:49<05:40, 28.36s/it]

For Batch: 27.0000, Total Loss: 3.7972, Current Loss: 5.3355


 72%|███████▎  | 29/40 [14:17<05:10, 28.27s/it]

For Batch: 28.0000, Total Loss: 3.9306, Current Loss: 5.3375


 75%|███████▌  | 30/40 [14:46<04:42, 28.29s/it]

For Batch: 29.0000, Total Loss: 4.0641, Current Loss: 5.3416


 78%|███████▊  | 31/40 [15:14<04:14, 28.30s/it]

For Batch: 30.0000, Total Loss: 4.1973, Current Loss: 5.3266


 80%|████████  | 32/40 [15:42<03:46, 28.31s/it]

For Batch: 31.0000, Total Loss: 4.3304, Current Loss: 5.3250


 82%|████████▎ | 33/40 [16:15<03:26, 29.55s/it]

For Batch: 32.0000, Total Loss: 4.4633, Current Loss: 5.3132


 85%|████████▌ | 34/40 [16:43<02:55, 29.27s/it]

For Batch: 33.0000, Total Loss: 4.5963, Current Loss: 5.3199


 88%|████████▊ | 35/40 [17:12<02:25, 29.12s/it]

For Batch: 34.0000, Total Loss: 4.7290, Current Loss: 5.3118


 90%|█████████ | 36/40 [17:41<01:55, 28.87s/it]

For Batch: 35.0000, Total Loss: 4.8615, Current Loss: 5.2989


 92%|█████████▎| 37/40 [18:09<01:25, 28.67s/it]

For Batch: 36.0000, Total Loss: 4.9939, Current Loss: 5.2963


 95%|█████████▌| 38/40 [18:38<00:57, 28.93s/it]

For Batch: 37.0000, Total Loss: 5.1260, Current Loss: 5.2826


 98%|█████████▊| 39/40 [19:06<00:28, 28.72s/it]

For Batch: 38.0000, Total Loss: 5.2584, Current Loss: 5.2958


100%|██████████| 40/40 [19:08<00:00, 28.72s/it]

For Batch: 39.0000, Total Loss: 5.3403, Current Loss: 3.2746
 Train Epoch: 5 Loss: 5.3403



