In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import visual_genome.local as vg

VG_DATA_PATH = './data/visual-genome'

# load vg image info and region description
all_image_data = vg.get_all_image_data(data_dir=VG_DATA_PATH)
all_region_descriptions = vg.get_all_region_descriptions(data_dir=VG_DATA_PATH)

In [None]:
print('vg data size {}'.format(len(all_image_data)))
print('-----------------')
print('[all_image_data] type {}'.format(type(all_image_data[0])))
print('example')
print(all_image_data[0])
print('-----------------')
print('[all_region_descriptions] list of type {}'.format(type(all_region_descriptions[0][0])))
print('example')
print(all_region_descriptions[0])

### Visualizing ground truth regions

In [None]:
import os
import re

def vg_url_to_file_path(vg_data_path, url):
    
    res = re.search('(VG.*)/(.*.jpg)$', url)
    return os.path.join(vg_data_path, res.group(1), res.group(2))

In [None]:
vg_url_to_file_path(VG_DATA_PATH, all_image_data[2].url)

In [None]:
# Ref: https://github.com/ranjaykrishna/visual_genome_python_driver/blob/master/region_visualization_demo.ipynb

import matplotlib.pyplot as plt
from PIL import Image
from matplotlib.patches import Rectangle

def visualize_regions(image_file_path, regions):
    fig = plt.gcf()
    fig.set_size_inches(18.5, 10.5)

    img = Image.open(image_file_path)
    plt.imshow(img)
    ax = plt.gca()
    for region in regions:
        ax.add_patch(Rectangle((region.x, region.y),
                               region.width,
                               region.height,
                               fill=False,
                               edgecolor='red',
                               linewidth=3))
        ax.text(region.x, region.y, region.phrase, style='italic', bbox={'facecolor':'white', 'alpha':0.7, 'pad':10})
    fig = plt.gcf()
    plt.tick_params(labelbottom='off', labelleft='off')
    plt.show()

IMG_NAME = 51
img_idx = IMG_NAME - 1

image_file_path = vg_url_to_file_path(VG_DATA_PATH, all_image_data[img_idx].url)
regions = all_region_descriptions[img_idx]

visualize_regions(image_file_path, regions[:20])

### Visualizing regions

In [None]:
! python describe.py --help

In [None]:
# 使用脚本
! python describe.py --config_json './model_params/train_all_val_all_bz_2_epoch_10_inject_init/config.json' \
  --model_checkpoint './model_params/train_all_val_all_bz_2_epoch_10_inject_init.pth.tar' \
  --img_path './image_to_describe' \
  --result_dir '.' \
  --batch_size 2 --verbose

In [None]:
import json

RESULT_JSON_PATH = './res/result.json'
with open(RESULT_JSON_PATH, 'r') as f:
    results = json.load(f)

for file_path in results.keys():
    print(file_path)

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
from matplotlib.patches import Rectangle

def visualize_result(image_file_path, result, idx_to_token=None):

    fig = plt.gcf()
    fig.set_size_inches(18.5, 10.5)

    assert isinstance(result, list)

    img = Image.open(image_file_path)
    plt.imshow(img)
    ax = plt.gca()
    N = 0
    for r in result:        
        if N > 5:
            break
        
        if idx_to_token is not None:
            r['cap'] = ' '.join(idx_to_token[idx] for idx in r['cap'].tolist() if idx_to_token[idx] not in ['<pad>', '<bos>', '<eos>'])        
        
        if "car" not in r['cap']:
            continue

        N += 1

        ax.add_patch(Rectangle((r['box'][0], r['box'][1]),
                               r['box'][2]-r['box'][0],
                               r['box'][3]-r['box'][1],
                               fill=False,
                               edgecolor='red',
                               linewidth=3))
        ax.text(r['box'][0], r['box'][1], r['cap'] + (r['view'] if 'view' in r else ""), style='italic', bbox={'facecolor':'white', 'alpha':0.7, 'pad':10})
    fig = plt.gcf()
    plt.tick_params(labelbottom='off', labelleft='off')
    plt.show()

In [None]:
paths = list(results.keys())

for path in paths:
    visualize_result(path, results[path][:15])

In [None]:
from pathlib import Path
import pickle

TO_K = 10

lut_path = Path("./data/VG-regions-dicts-lite.pkl")

with open(lut_path, 'rb') as f:
    look_up_tables = pickle.load(f)

idx_to_token = look_up_tables['idx_to_token']


with open('filtered_car_data.pkl', "rb") as file:            
    img_info = pickle.load(file)            

paths = list(img_info.keys())

for r in img_info[paths[150]]:
    cap = ' '.join(idx_to_token[idx] for idx in r['cap'].tolist() if idx_to_token[idx] not in ['<pad>', '<bos>', '<eos>'])
    print(cap)

visualize_result(paths[155], img_info[paths[155]][:TO_K], idx_to_token)