In [None]:
import argparse
import datetime
import json
import math
import os
import random
import time
from pathlib import Path

import numpy as np
import ruamel.yaml as yaml
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torchvision import transforms
import torchvision.utils as vutils
from PIL import Image
from torchvision.io import read_image

import matplotlib.pyplot as plt


import utils
from dataset import create_dataset, create_sampler, create_loader
from dataset.utils import collect_tensor_result, grounding_eval_bbox, grounding_eval_bbox_vlue
from models.model_bbox import XVLM
from models.tokenization_bert import BertTokenizer
from models.tokenization_roberta import RobertaTokenizer
from optim import create_optimizer
from refTools.refer_python3 import REFER
from scheduler import create_scheduler
from utils.hdfs_io import hmkdir, hcopy, hexists



In [None]:
dict_part_include = {'cat':['head', 'lbleg', 'lbpa', 'lear', 'lfleg', 'lfpa', 'nose', 'rbleg', 'rbpa', 'rear', 'rfleg', 'rfpa', 'tail', 'torso'],
                     'person':['hair', 'head', 'lfoot', 'lhand', 'llarm', 'llleg', 'luarm', 'luleg', 'rfoot', 'rhand', 'rlarm', 'rlleg', 'ruarm', 'ruleg', 'torso'],
                     'motorbike':['bwheel', 'fwheel', 'handlebar', 'headlight', 'saddle'],
                     'car':['backside', 'bliplate', 'door', 'fliplate', 'frontside', 'headlight', 'leftmirror', 'leftside', 'rightmirror', 'rightside', 'roofside', 'wheel', 'window'],
                     'aeroplane':['body', 'engine', 'lwing', 'rwing', 'stern', 'tail', 'wheel'],
                     'dog':['head', 'lbleg', 'lbpa', 'lear', 'lfleg', 'lfpa', 'muzzle', 'rbleg', 'rbpa', 'rear', 'rfleg', 'rfpa', 'tail', 'torso'],
                     'bus':['backside', 'bliplate', 'door', 'fliplate', 'frontside', 'headlight', 'leftmirror', 'leftside', 'rightmirror', 'rightside', 'roofside', 'wheel', 'window'],
                     'bird':['beak', 'head', 'lfoot', 'lleg', 'lwing', 'neck', 'reye', 'rfoot', 'rleg', 'rwing', 'tail', 'torso'],
                     'horse':['head', 'lbho', 'lblleg', 'lbuleg', 'lfho', 'lflleg', 'lfuleg', 'muzzle', 'neck', 'rbho', 'rblleg', 'rbuleg', 'rfho', 'rflleg', 'rfuleg', 'tail', 'torso'],
                     'pottedplant':['plant', 'pot'],
                     'cow':['head', 'lblleg', 'lbuleg', 'lflleg', 'lfuleg', 'lhorn', 'muzzle', 'rblleg', 'rbuleg', 'rflleg', 'rfuleg', 'rhorn', 'tail', 'torso'],
                     'bicycle':['bwheel', 'chainwheel', 'fwheel', 'handlebar', 'headlight', 'saddle'],
                     'sheep':['head', 'lblleg', 'lbuleg', 'lflleg', 'lfuleg', 'lhorn', 'muzzle', 'rblleg', 'rbuleg', 'rflleg', 'rfuleg', 'rhorn', 'tail', 'torso']}

# This is a list of the classes we are not using to create our custom dataset
dict_class_exclude = {'bottle', 'chair', 'boat', 'sofa', 'tvmonitor', 'table', 'train'}

# Expanded english text description of part abbreviations in PascalParts
expanded_abbrev = {'cat': 'cat',
                   'background': 'background',
                   'head': 'head',
                   'lbleg': 'left back leg',
                   'lbpa': 'left back paw',
                   'lear': 'left ear',
                   'leye': 'left eye',
                   'lfleg': 'left front leg',
                   'lfpa': 'left front paw',
                   'neck': 'neck',
                   'nose': 'nose',
                   'rbleg': 'right back leg',
                   'rbpa': 'right back paw',
                   'rear': 'right ear',
                   'reye': 'right eye',
                   'rfleg': 'right front leg',                
                   'rfpa': 'right front paw',
                   'tail': 'tail',
                   'torso': 'torso',
                   'person': 'person',
                   'hair': 'hair',
                   'lebrow': 'left eyebrow',
                   'lfoot': 'left foot',
                   'lhand': 'left hand',
                   'llarm': 'left lower arm',
                   'llleg': 'left lower leg',
                   'luarm': 'left upper arm',
                   'luleg': 'left upper leg',
                   'mouth': 'mouth',
                   'rebrow': 'right eyebrow',
                   'rfoot': 'right foot',
                   'rhand': 'right hand',
                   'rlarm': 'right lower arm',
                   'rlleg': 'right lower leg',
                   'ruarm': 'right upper arm',
                   'ruleg': 'right upper leg',
                   'muzzle': 'muzzle',
                   'dog': 'dog',
                   'lblleg': 'left back lower leg',
                   'lbuleg': 'left back upper leg',
                   'lflleg': 'left front lower leg',
                   'lfuleg': 'left front upper leg',
                   'lhorn': 'left horn',
                   'rblleg': 'right back lower leg',
                   'rbuleg': 'right back upper leg',
                   'rflleg': 'right front lower leg',
                   'rfuleg': 'right front upper leg',
                   'rhorn': 'right horn',
                   'cow': 'cow',
                   'beak': 'beak',
                   'lleg': 'left leg',
                   'lwing': 'left wing',
                   'rleg': 'right leg',
                   'rwing': 'right wing',
                   'bird': 'bird',
                   'body': 'body',
                   'engine': 'engine',
                   'stern': 'stern',
                   'wheel': 'wheel',
                   'aeroplane': 'aeroplane',
                   'bwheel': 'back wheel',
                   'fwheel': 'front wheel',
                   'handlebar': 'handlebar',
                   'headlight': 'headlight',
                   'saddle': 'saddle',
                   'motorbike': 'motorbike',
                   'sheep': 'sheep',
                   'pot': 'pot',
                   'plant': 'plant',
                   'lbho': 'left back hoof',
                   'lfho': 'left front hoof',
                   'rbho': 'right back hoof',
                   'rfho': 'right front hoof',
                   'backside': 'back side',
                   'bliplate': 'back license plate',
                   'door': 'door',
                   'fliplate': 'front license plate',
                   'frontside': 'front side',
                   'leftmirror': 'left mirror',
                   'leftside': 'left side',
                   'rightmirror': 'right mirror',
                   'rightside': 'right side',
                   'roofside': 'roof side',
                   'window': 'window',
                   'chainwheel': 'chain wheel',
                   'cbackside': 'coach back side',
                   'cfrontside': 'coach front side',
                   'cleftside': 'coach left side',
                   'coach': 'coach',
                   'crightside': 'coach right side',
                   'croofside': 'coach roof side',
                   'hbackside': 'head back side',
                   'hfrontside': 'head front side',
                   'hleftside': 'head left side',
                   'hrightside': 'head right side',
                   'hroofside': 'head roof side',
                   'cap': 'cap',
                   'screen': 'screen',
                   'tvmonitor': 'tv monitor',
                   'car': 'car',
                   'bus': 'bus',
                   'train': 'train',
                   'horse': 'horse',
                   'pottedplant': 'potted plant',
                   'bicycle': 'bicycle',
                   'bottle': 'bottle', 
                   'table': 'table',
                   'sofa':'sofa',
                   'chair':'chair', 
                   'boat':'boat'}

## Define BBox Functions

In [None]:
def xvlm_get_bbox(image, prompts, model, config):
    # Evaluate the model on the image & prompts
    results = torch.empty((len(prompts),4))
    model.eval()

    for i, text in enumerate(prompts):

        image = image.to(device)
        text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)

        with torch.no_grad():
            outputs_coord = model(image, text_input.input_ids, text_input.attention_mask, target_bbox=None)

        results[i] = outputs_coord.cpu()

    # convert predicted coordinates from normalized center coordinates to xmin, ymin, xmax, ymax
    cx, cy, w, h = results[:,0], results[:,1], results[:,2], results[:,3]
    coords = torch.stack((cx - w/2, cy-h/2, cx + w/2, cy + h/2)).T * config['image_res']
    
    return coords
    
    
    
def plot_bbox_prompts(image_int, coords, prompts):
    # Plot the resulting predictions on top of the original image
    plt.figure(figsize=(10,20))
    bbox_image = vutils.draw_bounding_boxes(image_int,coords, 
                                            width=3, 
                                            labels=[f"Prompt {i}" for i in range(len(prompts))],
                                            colors="red")

    plt.imshow(bbox_image.permute(1,2,0))
    plt.xlabel('\n' + '\n'.join([f'Prompt {i}: ' + prompts[i] for i in range(len(prompts))]), fontsize=14)
    plt.show()
    
    

## Define Parameters & Transforms

In [None]:
config_file = "configs/Grounding_bbox.yaml"
config = yaml.load(open(config_file, 'r'), Loader=yaml.Loader)
checkpoint = "model_checkpoints/16m_base_finetune/refcoco_bbox/checkpoint_best.pth"
evaluate = False
device = 7


# Transformations
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

test_transform = transforms.Compose([
        transforms.Resize((config['image_res'], config['image_res']), interpolation=Image.BICUBIC),
        transforms.ToTensor(),
        normalize,
    ])

plot_transform = transforms.Compose([
    transforms.Resize((config['image_res'], config['image_res']), interpolation=Image.BICUBIC),
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.uint8),
])

## Load and Evaluate Model

In [None]:
print("Creating model")
model = XVLM(config=config)
model.load_pretrained(checkpoint, config, is_eval=evaluate)
model = model.to(device)
print("### Total Params: ", sum(p.numel() for p in model.parameters() if p.requires_grad))

if config['use_roberta']:
    tokenizer = RobertaTokenizer.from_pretrained(config['text_encoder'])
else:
    tokenizer = BertTokenizer.from_pretrained(config['text_encoder'])

# Qualitative Test of X-VLM (16M) part identification for pascal parts classes

In [None]:
# Load Image

img_path = 'images/pascalparts/VOCdevkit_2010/VOC2010/JPEGImages/2007_000549.jpg'

image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)

p_class = 'cat'

for part in dict_part_include[p_class]:
    # Define Text Prompt
    prompts = [expanded_abbrev[part]]
    # prompts = ["left lower leg partially occludes right upper leg", "crossed legs", "left leg on top of right leg"]
    coords = xvlm_get_bbox(image, prompts, model, config)
    plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = 'images/pascalparts/VOCdevkit_2010/VOC2010/JPEGImages/2007_000584.jpg'

image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)

p_class = 'bicycle'

for part in dict_part_include[p_class]:
    # Define Text Prompt
    prompts = [expanded_abbrev[part]]
    # prompts = ["left lower leg partially occludes right upper leg", "crossed legs", "left leg on top of right leg"]
    coords = xvlm_get_bbox(image, prompts, model, config)
    plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = 'images/pascalparts/VOCdevkit_2010/VOC2010/JPEGImages/2007_000783.jpg'

image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)

p_class = 'horse'

for part in dict_part_include[p_class]:
    # Define Text Prompt
    prompts = [expanded_abbrev[part]]
    coords = xvlm_get_bbox(image, prompts, model, config)
    plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = 'images/pascalparts/VOCdevkit_2010/VOC2010/JPEGImages/2007_001288.jpg'

image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)

p_class = 'aeroplane'

for part in dict_part_include[p_class]:
    # Define Text Prompt
    prompts = [expanded_abbrev[part]]
    # prompts = ["left lower leg partially occludes right upper leg", "crossed legs", "left leg on top of right leg"]
    coords = xvlm_get_bbox(image, prompts, model, config)
    plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = 'images/pascalparts/VOCdevkit_2010/VOC2010/JPEGImages/2007_001423.jpg'

image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)

p_class = 'person'

for part in dict_part_include[p_class]:
    # Define Text Prompt
    prompts = [expanded_abbrev[part]]
    # prompts = ["left lower leg partially occludes right upper leg", "crossed legs", "left leg on top of right leg"]
    coords = xvlm_get_bbox(image, prompts, model, config)
    plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = 'images/pascalparts/VOCdevkit_2010/VOC2010/JPEGImages/2007_002088.jpg'

image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)

p_class = 'cow'

for part in dict_part_include[p_class]:
    # Define Text Prompt
    prompts = [expanded_abbrev[part]]
    # prompts = ["left lower leg partially occludes right upper leg", "crossed legs", "left leg on top of right leg"]
    coords = xvlm_get_bbox(image, prompts, model, config)
    plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = 'images/pascalparts/VOCdevkit_2010/VOC2010/JPEGImages/2007_002094.jpg'

image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)

p_class = 'bird'

for part in dict_part_include[p_class]:
    # Define Text Prompt
    prompts = [expanded_abbrev[part]]
    # prompts = ["left lower leg partially occludes right upper leg", "crossed legs", "left leg on top of right leg"]
    coords = xvlm_get_bbox(image, prompts, model, config)
    plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = 'images/pascalparts/VOCdevkit_2010/VOC2010/JPEGImages/2007_002260.jpg'

image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)

p_class = 'motorbike'

for part in dict_part_include[p_class]:
    # Define Text Prompt
    prompts = [expanded_abbrev[part]]
    # prompts = ["left lower leg partially occludes right upper leg", "crossed legs", "left leg on top of right leg"]
    coords = xvlm_get_bbox(image, prompts, model, config)
    plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = 'images/pascalparts/VOCdevkit_2010/VOC2010/JPEGImages/2007_002281.jpg'

image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)

p_class = 'car'

for part in dict_part_include[p_class]:
    # Define Text Prompt
    prompts = [expanded_abbrev[part]]
    # prompts = ["left lower leg partially occludes right upper leg", "crossed legs", "left leg on top of right leg"]
    coords = xvlm_get_bbox(image, prompts, model, config)
    plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = 'images/pascalparts/VOCdevkit_2010/VOC2010/JPEGImages/2007_002284.jpg'

image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)

p_class = 'pottedplant'

for part in dict_part_include[p_class]:
    # Define Text Prompt
    prompts = [expanded_abbrev[part]]
    # prompts = ["left lower leg partially occludes right upper leg", "crossed legs", "left leg on top of right leg"]
    coords = xvlm_get_bbox(image, prompts, model, config)
    plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = 'images/pascalparts/VOCdevkit_2010/VOC2010/JPEGImages/2007_003190.jpg'

image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)

p_class = 'sheep'

for part in dict_part_include[p_class]:
    # Define Text Prompt
    prompts = [expanded_abbrev[part]]
    # prompts = ["left lower leg partially occludes right upper leg", "crossed legs", "left leg on top of right leg"]
    coords = xvlm_get_bbox(image, prompts, model, config)
    plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = 'images/pascalparts/VOCdevkit_2010/VOC2010/JPEGImages/2007_003711.jpg'

image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)

p_class = 'bus'

for part in dict_part_include[p_class]:
    # Define Text Prompt
    prompts = [expanded_abbrev[part]]
    # prompts = ["left lower leg partially occludes right upper leg", "crossed legs", "left leg on top of right leg"]
    coords = xvlm_get_bbox(image, prompts, model, config)
    plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = 'images/pascalparts/VOCdevkit_2010/VOC2010/JPEGImages/2008_007291.jpg'

image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)

p_class = 'person'

for part in dict_part_include[p_class]:
    # Define Text Prompt
    prompts = [expanded_abbrev[part]]
    # prompts = ["left lower leg partially occludes right upper leg", "crossed legs", "left leg on top of right leg"]
    coords = xvlm_get_bbox(image, prompts, model, config)
#     plot_bbox_prompts(image_int, coords, prompts)
    plot_bbox_prompts(image_int, [87.0, 203.0, 17.0, 14.0], ['left hand'])

In [None]:
Image.open(img_path).convert('RGB').size

In [None]:
plot_transform = transforms.Compose([
    transforms.Resize((384,384), interpolation=Image.BICUBIC),
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.uint8),
])

# img_path = 'images/pascalparts/VOCdevkit_2010/VOC2010/JPEGImages/2008_007291.jpg'

image = Image.open(img_path).convert('RGB')

plot_bbox_prompts(plot_transform(image), torch.tensor([16.0, 152.0, 16 + 73.0, 152+98.0]).unsqueeze(0), ['left hand'])