In [None]:
import argparse
import datetime
import json
import math
import os
import random
import time
from pathlib import Path

import numpy as np
import ruamel.yaml as yaml
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torchvision import transforms
import torchvision.utils as vutils
from PIL import Image
from torchvision.io import read_image

import matplotlib.pyplot as plt


import utils
from dataset import create_dataset, create_sampler, create_loader
from dataset.utils import collect_tensor_result, grounding_eval_bbox, grounding_eval_bbox_vlue
from models.model_bbox import XVLM
from models.tokenization_bert import BertTokenizer
from models.tokenization_roberta import RobertaTokenizer
from optim import create_optimizer
from refTools.refer_python3 import REFER
from scheduler import create_scheduler
from utils.hdfs_io import hmkdir, hcopy, hexists



## Define BBox Functions

In [None]:
def xvlm_get_bbox(image, prompts, model, config):
    # Evaluate the model on the image & prompts
    results = torch.empty((len(prompts),4))
    model.eval()

    for i, text in enumerate(prompts):

        image = image.to(device)
        text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)

        with torch.no_grad():
            outputs_coord = model(image, text_input.input_ids, text_input.attention_mask, target_bbox=None)

        results[i] = outputs_coord.cpu()

    # convert predicted coordinates from normalized center coordinates to xmin, ymin, xmax, ymax
    cx, cy, w, h = results[:,0], results[:,1], results[:,2], results[:,3]
    coords = torch.stack((cx - w/2, cy-h/2, cx + w/2, cy + h/2)).T * config['image_res']
    
    return coords
    
    
    
def plot_bbox_prompts(image_int, coords, prompts):
    # Plot the resulting predictions on top of the original image
    plt.figure(figsize=(10,20))
    bbox_image = vutils.draw_bounding_boxes(image_int,coords, 
                                            width=3, 
                                            labels=[f"Prompt {i}" for i in range(len(prompts))],
                                            colors="red")

    plt.imshow(bbox_image.permute(1,2,0))
    plt.xlabel('\n' + '\n'.join([f'Prompt {i}: ' + prompts[i] for i in range(len(prompts))]), fontsize=14)
    plt.show()
    
    

## Define Parameters & Transforms

In [None]:
config_file = "configs/Grounding_bbox.yaml"
config = yaml.load(open(config_file, 'r'), Loader=yaml.Loader)
checkpoint = "model_checkpoints/16m_base_finetune/refcoco_bbox/checkpoint_best.pth"
evaluate = False
device = 7


# Transformations
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

test_transform = transforms.Compose([
        transforms.Resize((config['image_res'], config['image_res']), interpolation=Image.BICUBIC),
        transforms.ToTensor(),
        normalize,
    ])

plot_transform = transforms.Compose([
    transforms.Resize((config['image_res'], config['image_res']), interpolation=Image.BICUBIC),
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.uint8),
])

## Load and Evaluate Model

In [None]:
print("Creating model")
model = XVLM(config=config)
model.load_pretrained(checkpoint, config, is_eval=evaluate)
model = model.to(device)
print("### Total Params: ", sum(p.numel() for p in model.parameters() if p.requires_grad))

if config['use_roberta']:
    tokenizer = RobertaTokenizer.from_pretrained(config['text_encoder'])
else:
    tokenizer = BertTokenizer.from_pretrained(config['text_encoder'])

In [None]:
# Load Image
img_path = '/mv_users/ericwtodd/datasets/PascalPart_People_subset/2008_002485.jpg'
# img_path = '/mv_users/ericwtodd/datasets/PascalPart_People_subset/2008_001802.jpg'
# img_path = '/mv_users/ericwtodd/datasets/PascalPart_People_subset/2008_001736.jpg'
# img_path = '/mv_users/ericwtodd/datasets/PascalPart_People_subset/2008_001808.jpg'
image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)


# Define Text Prompt
# prompts = ["The picture on the wall", "lamp", "door in the background"]
# prompts = ["face in the middle", "face on the right", "face on the left"]
prompts = ["left hand of the person on the left"]
# prompts = ["left lower leg partially occludes right upper leg", "crossed legs", "left leg on top of right leg"]
coords = xvlm_get_bbox(image, prompts, model, config)
plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = '/mv_users/ericwtodd/datasets/PascalPart_People_subset/2008_001802.jpg'
image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)


# Define Text Prompt
prompts = ["left lower leg partially occludes right upper leg", "crossed legs of woman on left", "left leg on top of right leg"]
coords = xvlm_get_bbox(image, prompts, model, config)
plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = '/mv_users/ericwtodd/datasets/PascalPart_People_subset/2008_001736.jpg'
image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)


# Define Text Prompt
prompts = ["Left hand", "Right hand", "head", "torso", "right leg", "left leg"]

coords = xvlm_get_bbox(image, prompts, model, config)
plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = '/mv_users/ericwtodd/datasets/PascalPart_People_subset/2008_001808.jpg'
image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)


# Define Text Prompt
# prompts = ["Left hand", "Right hand", "head", "torso", "left leg", "right leg"]
prompts = ["Left hand", "Right hand",]

coords = xvlm_get_bbox(image, prompts, model, config)
plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = '/mv_users/ericwtodd/datasets/PascalPart_People_subset/2008_002801.jpg'
image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)


# Define Text Prompt
# prompts = ["Left hand of girl", "left hand of boy", "hand", "person in the middle"]
prompts = ["boy with striped sweater", "boy with blue turtleneck","winnie the pooh logo"]

coords = xvlm_get_bbox(image, prompts, model, config)
plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = '/mv_users/ericwtodd/datasets/PascalPart_People_subset/2008_002829.jpg'
image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)


# Define Text Prompt
prompts = ["gray baseball hat", "baseball glove", "shoes", "black pants", "fence", "silver belt"]

coords = xvlm_get_bbox(image, prompts, model, config)
plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = '/mv_users/ericwtodd/datasets/PascalPart_People_subset/2008_003344.jpg'
image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)


# Define Text Prompt
prompts = ["man without a hat on", "man with a hat on", "shoes", "bare feet of human"]
# prompts = ["man without shoes on", "man with shoes on"]
# prompts = ["man not riding a horse", "man riding a horse"]

coords = xvlm_get_bbox(image, prompts, model, config)
plot_bbox_prompts(image_int, coords, prompts)

In [None]:
# Load Image
img_path = '/multiview/datasets/mscoco/images/val2017/000000020333.jpg'
image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)


# Define Text Prompt
prompts = ["left arm", "right arm"]
# prompts = ["right arm", "the boy's right arm resting on his knee", "right knee", "left knee", "cat in the background"]

coords = xvlm_get_bbox(image, prompts, model, config)
plot_bbox_prompts(image_int, coords, prompts)



In [None]:

# Load Image
img_path = './images/coco/train2014/COCO_train2014_000000581857.jpg'
image = Image.open(img_path).convert('RGB')
image_int = plot_transform(image)
image = test_transform(image).unsqueeze(0)


# Define Text Prompt
prompts = ["all of the green and yellow bananas in the front", 'white bird cages', " lady in gray in the back"]
# prompts = ["right arm", "the boy's right arm resting on his knee", "right knee", "left knee", "cat in the background"]

coords = xvlm_get_bbox(image, prompts, model, config)
plot_bbox_prompts(image_int, coords, prompts)

