In [None]:
from pathlib import Path

import yaml
import numpy as np
import imageio
from moviepy.editor import ImageSequenceClip
from moviepy.video.io.bindings import mplfig_to_npimage
from IPython.display import display
from tqdm import tqdm

import matplotlib
matplotlib.use('webagg')
import matplotlib.pyplot as plt
import plotly.offline as py
import plotly.graph_objs as go
py.init_notebook_mode(connected=True)

In [None]:
data = [('vgg16', 0.08083438105769455, 'rgba(31, 119, 180, 1)'), 
        ('resnet18', 0.010670146435228262, 'rgba(255, 127, 14, 1)'),
        ('resnet34', 0.013862044609998438, 'rgba(44, 160, 44, 1)')]
names, speeds, colors = zip(*data)
data = [go.Bar(y=[speed], x=[name], name=name) for name, speed, color in data]

layout = go.Layout(title=('Average forward speed for each model (lower is better)'),
                   font=dict(family='Roboto'),
                   xaxis=dict(title='model', ticks='outside'),
                   yaxis=dict(title='Forward speed per image in seconds'),
                   showlegend=True
                  )
fig = go.Figure(data=data, layout=layout)
py.iplot(fig)
    

In [None]:
def get_metric(file_path, metric, parts):
    with open(file_path, 'r') as stream:
        model_eval = yaml.load(stream)
    keys = sorted(model_eval['sequence'].keys(), reverse=True)
    stats = [[model_eval['sequence'][k][metric][p][0] for p in parts]
             for k in keys]
    part_0, part_1 = list(zip(*stats))
    return part_0, part_1, keys

vgg16_J_mean_online, vgg16_J_decay_online, keys = get_metric('/home/klaus/dev/davis-2017/python/tools/metrics_mine_vgg16_online.yml',
                                         'J', ['mean', 'decay'])
                                         
resnet18_J_mean_online, resnet18_J_decay_online, _ = get_metric('/home/klaus/dev/davis-2017/python/tools/metrics_mine_resnet18_0_1.yml',
                                               'J', ['mean', 'decay'])

file_path_resnet34 = '/home/klaus/dev/davis-2017/python/tools/metrics_mine_resnet34_0_1.yml'
resnet34_J_mean_online, resnet34_J_decay_online, _ = get_metric(file_path_resnet34, 'J', ['mean', 'decay'])


vgg16_J_mean_offline, vgg16_J_decay_offline, _ = get_metric('/home/klaus/dev/davis-2017/python/tools/metrics_mine_vgg16_offline.yml',
                                         'J', ['mean', 'decay'])
                                         
resnet18_J_mean_offline, resnet18_J_decay_offline, _ = get_metric('/home/klaus/dev/davis-2017/python/tools/metrics_mine_resnet18_0_offline.yml',
                                               'J', ['mean', 'decay'])

file_path_resnet34 = '/home/klaus/dev/davis-2017/python/tools/metrics_mine_resnet34_0_offline.yml'
resnet34_J_mean_offline, resnet34_J_decay_offline, _ = get_metric(file_path_resnet34, 'J', ['mean', 'decay'])

In [None]:
def plot(keys, metric, lower_higher, data):
    data = [go.Bar(y=keys, x=x, name=name, orientation = 'h') for name, x in data]

    title = '{metric} per object for each model ({lower_higher} is better)'.format(metric=metric,
                                                                                   lower_higher=lower_higher)
    layout = go.Layout(title=title,
                       font=dict(family='Roboto'),
                       xaxis=dict(title=metric, ticks='outside'),
                       yaxis=dict(title='Object'),
                       showlegend=True,
                       bargap=0.5,
                       autosize=False, height=2400,
                       margin=go.Margin(l=120, r=0, b=80, t=100, pad=10)
                      )
    fig = go.Figure(data=data, layout=layout)
    py.iplot(fig)
    

plot(keys, 'J_mean', 'higher', [('vgg16_offline', vgg16_J_mean_offline),
                                ('resnet18_offline', resnet18_J_mean_offline),
                                ('resnet34_offline', resnet34_J_mean_offline),
                                ('vgg16_online', vgg16_J_mean_online), 
                                ('resnet18_online', resnet18_J_mean_online), 
                                ('resnet34_online', resnet34_J_mean_online)])

plot(keys, 'J_decay', 'lower', [('vgg16_offline', vgg16_J_decay_offline),
                                ('resnet18_offline', resnet18_J_decay_offline), 
                                ('resnet34_offline', resnet34_J_decay_offline),
                                ('vgg16_online', vgg16_J_decay_online), 
                                ('resnet18_online', resnet18_J_decay_online), 
                                ('resnet34_online', resnet34_J_decay_online)])

In [None]:
def convert_to_rgb(image):
    if len(image.shape) == 2:
        width, height = image.shape
        rgb =  np.empty((width, height, 3), dtype=np.uint8)
        rgb[:, :, :] = image[:, :, None]
        return rgb
    else:
        return image

def dir_to_images(path):
    files = Path(path).iterdir()
    files = map(str, files)
    files = sorted(files)
    files = map(imageio.imread, files)
    files = map(convert_to_rgb, files)
    files = list(files)
    return files
    
sources = [('Original Image', '/home/klaus/dev/datasets/DAVIS/JPEGImages/480p/car-shadow/'),
           ('Ground Truth', '/home/klaus/dev/datasets/DAVIS/Annotations/480p/car-shadow/'),
           ('vgg16', '/home/klaus/dev/fast-osvos/src/results/vgg16/offline/car-shadow/'),
           ('resnet18', '/home/klaus/dev/fast-osvos/src/results/resnet18/0/1/car-shadow/'), 
           ('resnet34', '/home/klaus/dev/fast-osvos/src/results/resnet18/0/1/car-shadow/')]
descriptions, sources = list(zip(*sources))
sources = list(map(dir_to_images, sources))

n_columns = 2
n_rows = len(sources) / n_columns + 1
sources = list(zip(*sources))

def generate_image(frame_sources, n_rows, n_columns, descriptions):
    width = 8
    height = 8
    fig = plt.figure(figsize=(width, height))
    for image_index, image_source in enumerate(frame_sources):
        ax = plt.subplot(n_rows, n_columns, image_index + 1)
        plt.imshow(image_source)
        plt.axis('off')
        ax.set_title(descriptions[image_index], fontsize=24, fontname='Roboto')

    image = mplfig_to_npimage(fig)
    plt.close()
    return image

def display_video(path, loop = True, fps = 25):
    files = Path(path).iterdir()
    files = map(str, files)
    files = sorted(files)
    files = map(imread, files)
    files = map(convert_to_rgb, files)
    files = list(files)

frames = [generate_image(frame_sources, n_rows, n_columns, descriptions) 
          for frame_sources in tqdm(sources)]
clip = ImageSequenceClip(frames, fps=5)
clip.ipython_display(loop=True)