In [2]:
from config import general
%matplotlib inline
import sys
sys.path.append("./coco-caption")
import matplotlib.pyplot as plt
import skimage.io as io
import pylab
import os
pylab.rcParams['figure.figsize'] = (10.0, 8.0)
import json
from json import encoder
encoder.FLOAT_REPR = lambda o: format(o, '.3f')
from dataloader import get_dataset_configuration, load_all_captions_flickr, load_all_captions_coco
import glob
import ipywidgets as widgets
from IPython.display import display, clear_output
from collections import deque    
from ipywidgets import HBox, Output, Button
from prettytable import PrettyTable


In [3]:
def load_images_flickr(images_dir):
    """Method to map images ids to pictures

    Parameters
    ----------
    images_dir: str
        Path to the directory with all images from  Flickr type dataset
    train_images_file_path
        Path to the file with image names of images from train split
    test_images_file_path
        Path to the file with image names of images from test split
    Returns
    -------
    train_images_mapping: dict->{image_filename: global path to the image}
        train split of images
    test_images_mapping: dict->{image_filename: global path to the image}
        test split of images

    """
    # add global paths to the all images in images_dir directory
    all_images = glob.glob(images_dir + '*.jpg')
    all_images_mapping=dict()
    for i in all_images:  # img is list of full path names of all images
        image_name = i.split("/")[-1]
        image_id = image_name.split(".")[0]
        all_images_mapping[image_id] = i  # Add it to the dict of train images
    return all_images_mapping

def load_images_coco(configuration):
    file_with_images_def = configuration["images_names_file_path"]
    images_folder = configuration["images_dir"]
    info = json.load(open(file_with_images_def))
    all_images_mapping = dict()
    for ix in range(len(info['images'])):
        img = info['images'][ix]
        image_filename = img['file_path'].rsplit(".", 1)[0]
        file_path = images_folder + "/" + img['file_path']

        if image_filename.find("/") != -1:
            image_filename = img['file_path'].rsplit("/", 1)[1].rsplit(".", 1)[0]
        if img['split'] in ['train','val', 'test', 'restval']:
            all_images_mapping[image_filename] = file_path

    return all_images_mapping

def get_data_for_split(dataset_name):
    train_dataset_configuration = get_dataset_configuration(dataset_name)
    if train_dataset_configuration["data_name"] in ["flickr30k", "coco17", "coco14"]:
        all_images = load_images_coco(train_dataset_configuration)
    if train_dataset_configuration["data_name"] in ["flickr30k_polish", "flickr8k_polish", "aide", "flickr8k"]:
        all_images = load_images_flickr(train_dataset_configuration["images_dir"])
    return all_images

In [4]:
list_of_results = [ x  for x in os.listdir(general["results_directory"]) if x.endswith(".json") ]
print(list_of_results)
selectbox = widgets.Select(
    options=list_of_results,
    value=list_of_results[0],
    description='Name of the dataset:',
    disabled=False
)

info = json.load(open("./" + general["results_directory"] + "/" +selectbox.value))
dataset_name=info["dataset_name"]
images_ids = list(info['imgToEval'].keys())
selectbox

['mixed_flickr8k_8k_n.json', 'mixed_coco2014_coco2014.json']


Select(description='Name of the dataset:', options=('mixed_flickr8k_8k_n.json', 'mixed_coco2014_coco2014.json'…

In [5]:
all_images = get_data_for_split(dataset_name)


In [8]:
def show_image_and_captions(image_id):
    image_results = info['imgToEval'][image_id]
    print('Dataset name: {}'.format(dataset_name))
    I = io.imread(all_images[image_id])
    plt.imshow(I)
    plt.axis('off')
    plt.show()
    print("Ground truth captions")
    print(image_results['ground_truth_captions'])
    print("Predicted captions")
    print(image_results['caption'])
    metrics= ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4", "METEOR", "ROUGE_L", "CIDEr", "WMD"]
    print( f'\n===== Results =====' )
    t = PrettyTable(("Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"))
    t.add_row((image_results["Bleu_1"], image_results["Bleu_2"], image_results["Bleu_3"], image_results["Bleu_4"])) 
    t2 = PrettyTable(("METEOR", "ROUGE_L", "CIDEr", "WMD"))
    t2.add_row((image_results["METEOR"],image_results["ROUGE_L"], image_results["CIDEr"], image_results["WMD"]))             
    print(t)
    print(t2)
    print()


In [9]:
d=deque(images_ids)
left = Button(description="<")
right = Button(description=">")

switch = [left, right]

combined = HBox([items for items in switch])
out = Output()
def on_button_left(ex):
    with out:
        clear_output()
        d.rotate(1)
        show_image_and_captions(d[0])
def on_button_right(ex):
    with out:
        clear_output()
        d.rotate(-1)
        show_image_and_captions(d[0])
l=switch[0].on_click(on_button_left)
r=switch[1].on_click(on_button_right)
display(combined)
display(out)

HBox(children=(Button(description='<', style=ButtonStyle()), Button(description='>', style=ButtonStyle())))

Output()