In [2]:
import torch
import json
import pickle
import os
import numpy as np

# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

from tqdm import tqdm
from typing import List
from PIL import Image
from pathlib import Path
from torch.utils.data import DataLoader
import torchvision.transforms as tf
from sentence_transformers import SentenceTransformer, util

from model.densecap import densecap_resnet50_fpn, DenseCapModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 1

In [3]:
def load_model(model_config_path: Path, checkpoint_path: Path, return_features=False, box_per_img=50, verbose=False, token_to_idx=None):
    with open(model_config_path, 'r') as f:
        model_args = json.load(f)

    model = densecap_resnet50_fpn(backbone_pretrained=model_args['backbone_pretrained'],
                                  return_features=return_features,
                                  feat_size=model_args['feat_size'],
                                  hidden_size=model_args['hidden_size'],
                                  max_len=model_args['max_len'],
                                  emb_size=model_args['emb_size'],
                                  rnn_num_layers=model_args['rnn_num_layers'],
                                  vocab_size=model_args['vocab_size'],
                                  fusion_type=model_args['fusion_type'],
                                  box_detections_per_img=box_per_img,
                                  token_to_idx=token_to_idx)

    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model'], strict=False)

    if verbose and 'results_on_val' in checkpoint.keys():
        print('[INFO]: checkpoint {} loaded'.format(checkpoint_path))
        print('[INFO]: correspond performance on val set:')
        for k, v in checkpoint['results_on_val'].items():
            if not isinstance(v, dict):
                print('        {}: {:.3f}'.format(k, v))

    return model

def get_image_paths(parent_folder: Path) -> List[str]:
    image_paths = []

    for child in parent_folder.iterdir():
        if child.is_dir():
            image_paths.extend(get_image_paths(child))
            continue
        image_paths.append(str(child))

    return image_paths


def img_to_tensor(img_list, device):
    img_tensors = []

    for img_path in img_list:
        img = Image.open(img_path).convert("RGB")
        img_tensors.append(tf.ToTensor()(img).to(device))

    return img_tensors


def describe_images(model: DenseCapModel, img_list: List[str], device: torch.device):
    all_results = []

    with torch.no_grad():
        model.to(device)
        model.eval()

        for i in tqdm(range(0, len(img_list), BATCH_SIZE)):
            image_tensors = img_to_tensor(img_list[i:i+BATCH_SIZE], device=device)

            results = model(image_tensors)

            all_results.extend([{k:v.cpu() for k,v in r.items()} for r in results])

    return all_results

In [4]:
def postprocess_results(results, img_paths: List[str], idx_to_token):
    results_dict = {}

    for img_path, result in zip(img_paths, results):
        results_dict[img_path] = []        

        for box, cap, score in zip(result['boxes'], result['caps'], result['scores']):            
            r = {
                'box': [round(c, 2) for c in box.tolist()],
                'score': round(score.item(), 2),
                'cap': ' '.join(idx_to_token[idx] for idx in cap.tolist()
                                if idx_to_token[idx] not in ['<pad>', '<bos>', '<eos>'])
            }            

            results_dict[img_path].append(r)


    return results_dict

In [5]:
import matplotlib.pyplot as plt
from PIL import Image
from matplotlib.patches import Rectangle

def visualize_result(image_file_path, result, idx_to_token=None):

    fig = plt.gcf()
    fig.set_size_inches(18.5, 10.5)

    assert isinstance(result, list)

    img = Image.open(image_file_path)
    plt.imshow(img)
    ax = plt.gca()
    N = 0
    for r in result:        
        if N > 5:
            break
        
        if idx_to_token is not None:
            r['cap'] = ' '.join(idx_to_token[idx] for idx in r['cap'].tolist() if idx_to_token[idx] not in ['<pad>', '<bos>', '<eos>'])        
        
        if "car" not in r['cap']:
            continue

        N += 1

        ax.add_patch(Rectangle((r['box'][0], r['box'][1]),
                               r['box'][2]-r['box'][0],
                               r['box'][3]-r['box'][1],
                               fill=False,
                               edgecolor='red',
                               linewidth=3))
        ax.text(r['box'][0], r['box'][1], r['cap'] + (r['view'] if 'view' in r else ""), style='italic', bbox={'facecolor':'white', 'alpha':0.7, 'pad':10})
    fig = plt.gcf()
    plt.tick_params(labelbottom='off', labelleft='off')
    plt.show()

In [8]:
from utils.snare_dataset import SnareDataset


lut_path = Path("./data/VG-regions-dicts-lite.pkl")

with open(lut_path, 'rb') as f:
    look_up_tables = pickle.load(f)

idx_to_token = look_up_tables['idx_to_token']
token_to_idx = look_up_tables['token_to_idx']

params_path = Path("compute_model_params")
model_name = "without_aux"
model = load_model(
    params_path / model_name / "config.json", 
    params_path / (model_name + ".pth.tar"), 
    return_features=False, verbose=True, token_to_idx=token_to_idx)


dataset = SnareDataset(mode="train")
loader = DataLoader(dataset, batch_size=1)

view_ids = torch.arange(8)
model.rpn.training = False

with torch.no_grad():
    for batch in loader:    
        (key1_imgs, key2_imgs), gt_idx, (key1, key2), annotation, is_visual = batch    
        key1_imgs = [k.squeeze() for k in key1_imgs]
        key2_imgs = [k.squeeze() for k in key2_imgs]
        annotation = annotation[0]
        model.train()
        model.rpn.training = False
        
        if gt_idx > 0:
            key1_imgs, key2_imgs = key2_imgs, key1_imgs
        
        print("pos model")
        losses, results = model.query_caption(key1_imgs, [annotation], view_ids)
        print("neg model")
        losses, results = model.query_caption(key2_imgs, [annotation], view_ids)
        # decoded_results = postprocess_results(results, [f"{key1[0]}-{i}" for i in range(6, 14)], idx_to_token)        
        # print(decoded_results)
        # losses, results = model.query_caption(key2_imgs, [annotation], view_ids)
        # print(losses)
        # print(entry_idx)    
        print(annotation)
        break

[INFO]: checkpoint compute_model_params/without_aux.pth.tar loaded
[INFO]: correspond performance on val set:
        map: 0.108
        detmap: 0.264
../snare/amt/folds_adversarial/train.json
Loaded Entries. train: 39278 entries
pos model
caption loss: torch.Size([8000, 16])
box features: torch.Size([8000, 4096])
min loss: 10.089165687561035
box features at min index: torch.Size([4096])
tensor([[  1,   5,  56, 109,   2,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0]])
neg model
caption loss: torch.Size([8000, 16])
box features: torch.Size([8000, 4096])
min loss: 10.571321487426758
box features at min index: torch.Size([4096])
tensor([[ 1,  4, 26,  8, 10,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])
wooden chair with cushions


In [9]:
x = [ 1,  4, 26,  8, 10,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]

for i in x:
    if i < 1:
        break
    print(idx_to_token[i])

y = [  1,   5,  56, 109,   2,   0,   0,   0,   0,   0,   0,   0,   0,   0, 0,   0,   0]

for i in y:
    if i < 1:
        break
    print(idx_to_token[i])

<bos>
the
wall
is
white
<eos>
<bos>
a
wooden
bench
<eos>
