In [None]:
!pip install transformers &> /dev/null
!pip install timm &> /dev/null

In [None]:
import torch
import torchvision
import torch.nn.functional as F
from torch import nn
from typing import Dict
import os,sys
import numpy as np
from pathlib import Path
import pycocotools
import torch.utils
from torch.utils.data import ConcatDataset, DataLoader, DistributedSampler
from functools import partial
import json
from transformers import RobertaTokenizerFast
import pickle
import time
from collections import defaultdict
import pandas as pd

In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

Mounted at /content/gdrive


In [None]:
os.chdir('/content/gdrive/MyDrive/DL_systems/final_project')
#!git clone https://github.com/ashkamath/mdetr.git

In [None]:
img_dir = '/content/gdrive/MyDrive/DL_systems/final_project/data/images'
annotations_dir = '/content/gdrive/MyDrive/DL_systems/final_project/output'
mdetr_git_dir = '/content/gdrive/MyDrive/DL_systems/final_project/mdetr'
output_dir = '/content/gdrive/MyDrive/DL_systems/final_project/output'
batch_size = 10
pretrained_model = 'mdetr_resnet101'

In [None]:
sys.path.append(mdetr_git_dir)
#os.chdir(mdetr_git_dir)
import datasets.transforms as T
from datasets.coco import *
from datasets.phrasecut_utils import data_transfer
import util.misc as utils
from util.misc import targets_to
from models.mdetr import MDETR
from datasets.refexp import RefExpDetection
from util.metrics import MetricLogger
from datasets import get_coco_api_from_dataset
from datasets.flickr_eval import FlickrEvaluator
from models.postprocessors import PostProcess, PostProcessFlickr
from util import box_ops

In [None]:
if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'
print(device)

cuda:0


In [None]:
def build_dataset(img_dir, ann_file, image_set, text_encoder_type):
    tokenizer = RobertaTokenizerFast.from_pretrained(text_encoder_type)
    dataset = RefExpDetection(
        img_dir,
        ann_file,
        transforms=make_coco_transforms(image_set, cautious=True),
        return_masks=False,
        return_tokens=True,
        tokenizer=tokenizer,
    )
    return dataset

In [None]:
test_dset = build_dataset(img_dir = img_dir, 
                              ann_file = os.path.join(annotations_dir, 'flickr_test_masked.json'), 
                              image_set = 'val', 
                              text_encoder_type= "roberta-base")

print(f"Creating test dataset with {len(test_dset)} instances")

loading annotations into memory...
Done (t=0.54s)
creating index...
index created!
Creating test dataset with 7893 instances


In [None]:
test_loader = DataLoader(test_dset,
                            batch_size,
                            sampler = torch.utils.data.SequentialSampler(test_dset),
                            drop_last = False,
                            collate_fn=partial(utils.collate_fn, False)
                            )

In [None]:
model, postprocessor = torch.hub.load('ashkamath/mdetr:main', 'mdetr_resnet101', pretrained=True, return_postprocessor=True)
model = model.to(device)

Downloading: "https://github.com/ashkamath/mdetr/archive/main.zip" to /root/.cache/torch/hub/main.zip
Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth


  0%|          | 0.00/171M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/478M [00:00<?, ?B/s]

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Downloading: "https://zenodo.org/record/4721981/files/pretrained_resnet101_checkpoint.pth" to /root/.cache/torch/hub/checkpoints/pretrained_resnet101_checkpoint.pth


  0%|          | 0.00/2.76G [00:00<?, ?B/s]

In [None]:
with torch.no_grad():
    model.eval()

In [None]:
base_ds = get_coco_api_from_dataset(test_dset)

In [None]:
metric_logger = MetricLogger(delimiter="  ")
header = "Test:"

In [None]:
flickr_anns = json.load(open(os.path.join(annotations_dir, 'flickr_test_masked.json')))

In [None]:
flickr_anns['images'][0]

{'caption': 'Several climbers in a row are climbing the rock while the man in red watches and holds the ropes.',
 'dataset_name': 'flickr',
 'file_name': '1016887272.jpg',
 'height': '500',
 'id': 0,
 'masked_caption': 'Several climbers in a row are climbing the rock while the man in red watches and holds the <mask> .',
 'masked_word': 'line',
 'orig_caption': 'Several climbers in a row are climbing the rock while the man in red watches and holds the line .',
 'orig_id': 153901,
 'orig_sentence_id': 0,
 'orig_tokens_positive_eval': [[[0, 16]],
  [[39, 47]],
  [[54, 61]],
  [[65, 68]],
  [[87, 95]]],
 'original_img_id': 1016887272,
 'pred_word': 'ropes',
 'score_k': 3,
 'score_raw': 0.06085050478577614,
 'score_scaled': 0.12291948446884929,
 'sentence_id': 0,
 'tokens_negative': [[0, 97]],
 'tokens_positive_eval': [[[0, 16]],
  [[39, 47]],
  [[54, 61]],
  [[65, 68]],
  [[91, 95]]],
 'tokens_positive_eval_idx': 4,
 'width': '333'}

In [None]:
len(flickr_anns['images'])

7893

In [None]:
[i[0][0] for i in flickr_anns['images'][0]['tokens_positive_eval']]

[0, 39, 54, 65, 87]

In [None]:
test_pos_tok_idx = [i['mask_pos_token_idx'] for i in flickr_anns['images']]
test_positive_tok_idx = [i['mask_positive_token_idx'] for i in flickr_anns['images']]
equal = np.array([test_pos_tok_idx[i] == test_positive_tok_idx[i] for i, idx in enumerate(test_pos_tok_idx)])

(array([], dtype=int64),)

In [None]:
flickr_anns['images'][62]

{'caption': 'Two blond girls are advertising in a public space ; one is handing out signs and the other is holding balloons.',
 'dataset_name': 'flickr',
 'file_name': '4859170265.jpg',
 'height': '499',
 'id': 62,
 'mask_positive_token_idx': 2,
 'masked_caption': 'Two blond girls are advertising in a public space ; one is handing out <mask> and the other is holding balloons .',
 'masked_word': 'literature',
 'orig_caption': 'Two blond girls are advertising in a public space ; one is handing out literature and the other is holding balloons .',
 'orig_id': 153936,
 'orig_sentence_id': 0,
 'orig_tokens_positive_eval': [[[0, 15]],
  [[52, 55]],
  [[71, 81]],
  [[86, 95]],
  [[107, 115]]],
 'original_img_id': 4859170265,
 'pred_word': 'signs',
 'score_k': 5,
 'score_raw': 0.017895670607686043,
 'score_scaled': 0.030106287800877842,
 'sentence_id': 4,
 'tokens_negative': [[0, 117]],
 'tokens_positive_eval': [[[0, 15]],
  [[52, 55]],
  [[71, 81]],
  [[81, 90]],
  [[102, 110]]],
 'tokens_posi

In [None]:
image_file_names = list(set([i['file_name'] for i in flickr_anns['images']]))


828

In [None]:
import shutil
for img_fn in image_file_names:
  shutil.copy(os.path.join('/content/gdrive/MyDrive/DL_systems/final_project/data/flickr30k-images', img_fn), os.path.join(img_dir, img_fn))

assert len(image_file_names) == len(os.listdir(img_dir))

In [None]:
for i, image in enumerate(flickr_anns['images']):
  mask_start = str.find(image['masked_caption'], '<mask>')
  diff_start_idx = [abs(i[0][0]-mask_start) for i in image['tokens_positive_eval']]
  min_diff = np.min(diff_start_idx)
  image['mask_positive_token_idx'] = diff_start_idx.index(min_diff)


In [168]:
class PostProcessFlickr(nn.Module):
    """This module converts the model's output for Flickr30k entities evaluation.
    This processor is intended for recall@k evaluation with respect to each phrase in the sentence.
    It requires a description of each phrase (as a binary mask), and returns a sorted list of boxes for each phrase.
    """

    @torch.no_grad()
    def forward(self, outputs, target_sizes, positive_map, items_per_batch_element):
        """Perform the computation.
        Args:
            outputs: raw outputs of the model
            target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
                          For evaluation, this must be the original image size (before any data augmentation)
                          For visualization, this should be the image size after data augment, but before padding
            positive_map: tensor [total_nbr_phrases x max_seq_len] for each phrase in the batch, contains a binary
                          mask of the tokens that correspond to that sentence. Note that this is a "collapsed" batch,
                          meaning that all the phrases of all the batch elements are stored sequentially.
            items_per_batch_element: list[int] number of phrases corresponding to each batch element.
            captions : list of captions for all elements in batch
            mask_token_idx : list of len(batch_size) where each element indicates the index of the positive_token_eval 
                             that is masked out/ replaced
        """
        out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]

        assert len(out_logits) == len(target_sizes)
        assert target_sizes.shape[1] == 2

        batch_size = target_sizes.shape[0]

        prob = F.softmax(out_logits, -1)

        # convert to [x0, y0, x1, y1] format
        boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
        img_h, img_w = target_sizes.unbind(1)
        scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
        # and from relative [0, 1] to absolute [0, height] coordinates
        boxes = boxes * scale_fct[:, None, :]
        cum_sum = np.cumsum(items_per_batch_element)

        curr_batch_index = 0
        # binarize the map if not already binary
        pos = positive_map > 1e-6
      
        predicted_boxes = [[] for _ in range(batch_size)]
        scores_output = [[] for _ in range(batch_size)]
        pos_tokens_pred = [[] for _ in range(batch_size)]

        # The collapsed batch dimension must match the number of items
        assert len(pos) == cum_sum[-1]

        if len(pos) == 0:
            return predicted_boxes

        # if the first batch elements don't contain elements, skip them.
        while cum_sum[curr_batch_index] == 0:
            curr_batch_index += 1
        phrase_ids = [list(range(i)) for i in items_per_batch_element]
        for i in range(len(pos)):

            # scores are computed by taking the max over the scores assigned to the positive tokens
            scores, _ = torch.max(pos[i].unsqueeze(0) * prob[curr_batch_index, :, :], dim=-1)
            _, indices = torch.sort(scores, descending=True)
            assert items_per_batch_element[curr_batch_index] > 0
            predicted_boxes[curr_batch_index].append(boxes[curr_batch_index][indices].to("cpu").tolist())
            scores_output[curr_batch_index].append(scores[indices].to("cpu").tolist())
            assert len(predicted_boxes[curr_batch_index]) == len(scores_output[curr_batch_index]), f"len(predicted_boxes[curr_batch_index]): {len(predicted_boxes[curr_batch_index])} and len(scores_output[curr_batch_index]): {len(scores_output[curr_batch_index])}"
            if i == len(pos) - 1:
                break

            # check if we need to move to the next batch element
            while i >= cum_sum[curr_batch_index] - 1:
                curr_batch_index += 1
                assert curr_batch_index < len(cum_sum)
        
        return predicted_boxes, phrase_ids, scores_output

In [191]:
i = 0
flickr_res_collector = []
mask_res_collector = []
time_df= pd.DataFrame()
with torch.no_grad():
    start_time = time.time()
    for batch_dict in metric_logger.log_every(test_loader, batch_size, header):
        start_batch = time.time()
        batch_start_time = time.time()
        samples = batch_dict['samples'].to(device)
        positive_map = batch_dict["positive_map"].to(device)
        targets = batch_dict["targets"]
        captions = [t["caption"] for t in targets]
        targets = targets_to(targets, device)
        orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
        orig_target_sizes = orig_target_sizes.to(device)

        scoring_start_time = time.time()
        memory_cache = model(samples, captions, encode_and_save=True)
        output = model(samples, captions,encode_and_save=False, memory_cache=memory_cache)
        scoring_end_time = time.time()

        post_process_start_time = time.time()
    
        #results = PostProcess()(output, orig_target_sizes)

        image_ids = [t["image_id"] for t in targets]
        sentence_ids = [t["sentence_id"] for t in targets]
        items_per_batch_element = [t["nb_eval"] for t in targets]
        #return list of indices for the masked positive token eval map
        image_anns = [image for image in flickr_anns['images'] if image['id'] in(image_ids)]
        #assert len(image_anns)==batch_size

        mask_token_idx = [image['tokens_positive_eval_idx'] for image in image_anns]
        positive_map_eval = batch_dict["positive_map_eval"].to(device)
        bboxes, phrase_ids, scores = PostProcessFlickr()(output, orig_target_sizes, positive_map_eval, 
                                                          items_per_batch_element)
        
        post_process_end_time = time.time()
        flickr_res = []
        mask_res = []
        
        for im_id, sent_id, boxes, phrase_id, score, mask_idx in zip(image_ids, sentence_ids, bboxes, phrase_ids, scores, mask_token_idx):
            flickr_res.append({"image_id": im_id, "sentence_id":sent_id,
                                "boxes": boxes,
                                "phrase_ids": phrase_id, 
                                "scores": score})
        
            mask_res.append({"image_id": im_id, "sentence_id":sent_id,
                              "boxes": boxes[mask_idx],
                              "scores": score[mask_idx]})
        post_process_end_time = time.time()
        batch_end_time = time.time()
        if i% 100 == 0:
            print(f"FINISHING BATCH {i}")
            print(f"batch processing time:{batch_end_time-batch_start_time}")
            print(f"total processing time:{batch_end_time-start_time}")
            print("---------------------------------------------")
            print("")

        for f in mask_res:
          mask_res_collector.append(f)
        pkl_file = open(os.path.join(output_dir, f'{pretrained_model}_{batch_size}_masked_token_results.pkl'), 'wb')
        pickle.dump(mask_res_collector, pkl_file)
        pkl_file.close()

        add_time = {'total_cumulative_time': time.time()-start_time,
                    'batch_time': batch_end_time - batch_start_time,
                    'model_scoring_time': scoring_end_time - scoring_start_time,
                    'post_processing_time': post_process_end_time - post_process_start_time}
        

        time_df = time_df.append(add_time, ignore_index = True)
        time_df.to_csv(os.path.join(output_dir, f'{pretrained_model}_{batch_size}_eval_time.csv'))
        i+=1


  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)


FINISHING BATCH 0
batch processing time:3.6183128356933594
total processing time:4.122673273086548
---------------------------------------------

Test:  [  0/790]  eta: 0:54:35    time: 4.1465  data: 0.5043  max mem: 5590
Test:  [ 10/790]  eta: 0:45:56    time: 3.5340  data: 0.3378  max mem: 5590
Test:  [ 20/790]  eta: 0:47:34    time: 3.6851  data: 0.3377  max mem: 5590
Test:  [ 30/790]  eta: 0:46:50    time: 3.7885  data: 0.3487  max mem: 5590
Test:  [ 40/790]  eta: 0:46:20    time: 3.7083  data: 0.3436  max mem: 5590
Test:  [ 50/790]  eta: 0:45:57    time: 3.7697  data: 0.3399  max mem: 5590
Test:  [ 60/790]  eta: 0:45:36    time: 3.8343  data: 0.3426  max mem: 5590
Test:  [ 70/790]  eta: 0:44:54    time: 3.7833  data: 0.3439  max mem: 5590
Test:  [ 80/790]  eta: 0:44:11    time: 3.6905  data: 0.3228  max mem: 5590
Test:  [ 90/790]  eta: 0:43:41    time: 3.7540  data: 0.3308  max mem: 5590
FINISHING BATCH 100
batch processing time:3.199345588684082
total processing time:378.98088836

In [177]:
for i in range(10):
  print(len(boxes[i]),len(scores[i]))

100 4
100 14
100 6
100 6
100 6
100 6
100 6
100 6
100 5
100 5


In [187]:
#boxes[mask_idx]

In [189]:
score[mask_idx]


[0.26629841327667236,
 0.23038122057914734,
 0.19664451479911804,
 0.19529412686824799,
 0.17925092577934265,
 0.10769129544496536,
 0.10717467218637466,
 0.09324860572814941,
 0.0784594938158989,
 0.05490195006132126,
 0.02750505320727825,
 0.01751176454126835,
 0.015931153669953346,
 0.01568690687417984,
 0.013197149150073528,
 0.012334312312304974,
 0.0096694091334939,
 0.008902233093976974,
 0.008729779161512852,
 0.007776372134685516,
 0.00600266270339489,
 0.005737207364290953,
 0.005049105267971754,
 0.0044263252057135105,
 0.004273982718586922,
 0.003466750495135784,
 0.003253511618822813,
 0.0029891962185502052,
 0.002961447462439537,
 0.002708910033106804,
 0.0026903930120170116,
 0.0019721733406186104,
 0.0018061596201732755,
 0.0017205264884978533,
 0.0016998446080833673,
 0.001629700418561697,
 0.0016285256715491414,
 0.0015725042903795838,
 0.0015720534138381481,
 0.001565179554745555,
 0.0015516512794420123,
 0.001426858827471733,
 0.0011987838661298156,
 0.0010806086938

In [None]:
output["pred_logits"].cpu().softmax(-1)[0,:,-1]

tensor([0.8962, 0.6352, 0.9472, 0.8629, 0.9460, 0.8039, 0.9776, 0.9688, 0.9075,
        0.9870, 0.8838, 0.8042, 0.9286, 0.9400, 0.9300, 0.2921, 0.8956, 0.9369,
        0.2263, 0.9144, 0.9266, 0.9453, 0.9396, 0.0476, 0.9756, 0.7077, 0.0018,
        0.8721, 0.9239, 0.8237, 0.2195, 0.0189, 0.9339, 0.7479, 0.6253, 0.8864,
        0.8779, 0.9536, 0.0235, 0.9800, 0.0370, 0.9739, 0.9444, 0.5922, 0.9129,
        0.8830, 0.9645, 0.8486, 0.9612, 0.9498, 0.9276, 0.9786, 0.8191, 0.7668,
        0.0160, 0.8726, 0.0753, 0.9626, 0.9282, 0.7204, 0.9459, 0.8864, 0.8719,
        0.0142, 0.9478, 0.9737, 0.9562, 0.8434, 0.9010, 0.9767, 0.8302, 0.9512,
        0.9502, 0.9509, 0.9006, 0.9416, 0.0205, 0.9082, 0.2123, 0.9308, 0.0623,
        0.9795, 0.9123, 0.9570, 0.9133, 0.8710, 0.9665, 0.9662, 0.8521, 0.8242,
        0.0204, 0.9593, 0.7649, 0.4784, 0.8136, 0.8930, 0.9394, 0.0631, 0.7506,
        0.9035])