[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jayhansuh/COLAB-FILES/blob/main/12lbs/adCLIP-visualizer.ipynb)

# Adapting CLIP model

In [1]:
######## Mount the drive ########

from google.colab import drive
drive.mount('/content/drive/')
%cd /content/drive/MyDrive/adapting-CLIP
#%cd ../../adapting-CLIP

######## Install the dependencies ########
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
/content/drive/MyDrive/adapting-CLIP
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-qpq27men
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-qpq27men
  Resolved https://github.com/openai/CLIP.git to commit a9b1bf5920416aaeaec965c25dd9e8f98c864f16
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [27]:
#import argparse
#import os.path as osp
import os, json
from tqdm import tqdm
import numpy as np
import torch
from models.slic_vit import SLICViT
from models.ss_baseline import SSBaseline
from models.resnet_high_res import ResNetHighRes
from utils.zsg_data import FlickrDataset, VGDataset
from utils.grounding_evaluator import GroundingEvaluator

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2
import random
%matplotlib inline

In [3]:
# ! python eval.py --model vit14 --dataset flickr_s1_val --iou_thr 0.5 --num_samples 500

model = SLICViT
args = {
    'model': 'vit14',
    'alpha': 0.75,
    'aggregation': 'mean',
    'n_segments': list(range(100, 601, 50)),
    'temperature': 0.02,
    'upsample': 2,
    'start_block': 0,
    'compactness': 50,
    'sigma': 0,
}
dataset_full = FlickrDataset(data_type='flickr30k_c1/val')
iou_thr = 0.5
model = model(**args).cuda()



In [5]:
######### Evaluate the model #########
# Randomly select images
num_samples = 16
idxs = random.sample(range(len(dataset_full)), num_samples)

# Create a random subset of the dataset
dataset = FlickrDataset(data_type=dataset_full.data_type)
dataset.image_paths = [dataset_full.image_paths[idx] for idx in idxs]
dataset.bboxes = [dataset_full.bboxes[idx] for idx in idxs]
dataset.phrases = [dataset_full.phrases[idx] for idx in idxs]

# Lists to hold loaded data
imgs = []
texts = []
bbox_gts = []
bbox_preds = []

# Predict the bounding boxes
for idx in tqdm(range(len(dataset))):

    # Data loading - do not call __getitem__ repeatedly
    data = dataset[idx] 
    #print(data['edge_box'])
    im = data['image']
    text = data['phrases'][0]
    bbox_gts.append(data['bbox'])

    # Predict
    bbox_pred, _ = model(im, text)

    # Hold loaded data
    imgs.append(im)
    texts.append(text)
    bbox_preds.append(bbox_pred[0])

# Evaluate the model
evaluator = GroundingEvaluator(gt_dataset=dataset, iou_thresh=iou_thr)
acc = evaluator(torch.from_numpy(np.stack(bbox_preds, axis=0)))
print('\nAcc: {}'.format(acc))

100%|██████████| 16/16 [01:24<00:00,  5.28s/it]


Acc: 0.1875





In [6]:
######## Visualize multiple images in Grid ########

# Set the number of rows and columns in the grid
row_num = num_samples // 4 + (1 if num_samples % 4 != 0 else 0)
row_num = max(row_num, 2) # at least 2 rows
col_num = 4

# Red: predicted bounding box
# Blue: ground truth bounding box
fig, axs = plt.subplots(row_num, col_num, figsize=(20, 20))
for i in range(row_num):
    for j in range(col_num):

        idx = i * col_num + j

        # check if the index is out of range
        if(idx<num_samples):
          im = imgs[idx]
          bbox_pred = bbox_preds[idx]
          bbox_gt = bbox_gts[idx]

          im = cv2.rectangle(im, (int(bbox_pred[0]), int(bbox_pred[1])), (int(bbox_pred[2]), int(bbox_pred[3])), (255, 50, 50), 2)
          im = cv2.rectangle(im, (int(bbox_gt[0]), int(bbox_gt[1])), (int(bbox_gt[2]), int(bbox_gt[3])), (50, 50, 255), 2)
          
          axs[i, j].imshow(im)
          axs[i, j].set_title(texts[idx])

        # remove the axis
        axs[i, j].axis('off')

plt.show()

# Show the text
# print(texts)


Output hidden; open in https://colab.research.google.com to view.

# Export the sampled images and a json file

In [31]:
# Make a dictionary of image file name -> (text, bbox_gt)
flickrExamples = dict(zip(dataset.image_paths,list(zip(dataset.phrases,bbox_gts))))
flickrExamples

{'7173663269.jpg': (['two older folks'], [152, 100, 273, 228]),
 '226115048.jpg': (['a beer'], [112, 220, 152, 299]),
 '3782318456.jpg': (['equipment'], [3, 122, 491, 500]),
 '3517370470.jpg': (['one'], [152, 113, 267, 428]),
 '3329777647.jpg': (['a ball'], [363, 217, 401, 245]),
 '543603259.jpg': (['a wall'], [203, 130, 339, 425]),
 '4776990069.jpg': (['a carousel'], [146, 167, 394, 375]),
 '4756254503.jpg': (['two ice cream cones'], [147, 138, 206, 182]),
 '217108448.jpg': (['a telescope'], [113, 109, 287, 478]),
 '2534137886.jpg': (['a heavy bucket'], [188, 201, 364, 311]),
 '2620517927.jpg': (['motorcycles'], [46, 238, 288, 307]),
 '1045521051.jpg': (['the TV'], [44, 11, 171, 153]),
 '186487635.jpg': (['a wall'], [248, 11, 352, 475]),
 '3407528957.jpg': (['a table'], [155, 243, 472, 334]),
 '7573421864.jpg': (['leopard print paper'], [64, 18, 278, 487]),
 '3393035454.jpg': (['a tan rock'], [268, 291, 355, 383])}

In [32]:
# Export the dictionary to a json file
with open("/content/drive/MyDrive/COLAB-FILES/flickr-examples.json", 'w') as f:
    json.dump(flickrExamples, f,indent=2)

In [33]:
# Copy example images in the flickr-examples directory
example_dir = "/content/drive/MyDrive/COLAB-FILES/flickr-examples"
if(not os.path.exists(example_dir)):
    os.mkdir(example_dir)
for impath in dataset.image_paths:
    imname = impath.split("/")[-1]
    source =os.path.join(dataset.image_dir,impath)
    target = os.path.join(example_dir, imname)
    os.system("cp {} {}".format(source,target))