In [None]:
import yaml
import numpy as np
import plotly.offline as py
import plotly.graph_objs as go
py.init_notebook_mode(connected=True)

In [None]:
data = [('vgg16', 0.08083438105769455, 'rgba(31, 119, 180, 1)'), 
        ('resnet18', 0.010670146435228262, 'rgba(255, 127, 14, 1)'),
        ('resnet34', 0.013862044609998438, 'rgba(44, 160, 44, 1)')]
names, speeds, colors = zip(*data)
data = [go.Bar(y=[speed], x=[name], name=name) for name, speed, color in data]

layout = go.Layout(title=('Average forward speed for each model (lower is better)'),
                   font=dict(family='Roboto'),
                   xaxis=dict(title='model', ticks='outside'),
                   yaxis=dict(title='Forward speed per image in seconds'),
                   showlegend=True
                  )
fig = go.Figure(data=data, layout=layout)
py.iplot(fig)
    

In [None]:
def get_metric(file_path, metric, parts):
    with open(file_path, 'r') as stream:
        model_eval = yaml.load(stream)
    keys = sorted(model_eval['sequence'].keys(), reverse=True)
    stats = [[model_eval['sequence'][k][metric][p][0] for p in parts]
             for k in keys]
    part_0, part_1 = list(zip(*stats))
    return part_0, part_1, keys

vgg16_J_mean_online, vgg16_J_decay_online, keys = get_metric('/home/klaus/dev/davis-2017/python/tools/metrics_mine_vgg16_online.yml',
                                         'J', ['mean', 'decay'])
                                         
resnet18_J_mean_online, resnet18_J_decay_online, _ = get_metric('/home/klaus/dev/davis-2017/python/tools/metrics_mine_resnet18_0_1.yml',
                                               'J', ['mean', 'decay'])

file_path_resnet34 = '/home/klaus/dev/davis-2017/python/tools/metrics_mine_resnet34_0_1.yml'
resnet34_J_mean_online, resnet34_J_decay_online, _ = get_metric(file_path_resnet34, 'J', ['mean', 'decay'])


vgg16_J_mean_offline, vgg16_J_decay_offline, _ = get_metric('/home/klaus/dev/davis-2017/python/tools/metrics_mine_vgg16_offline.yml',
                                         'J', ['mean', 'decay'])
                                         
resnet18_J_mean_offline, resnet18_J_decay_offline, _ = get_metric('/home/klaus/dev/davis-2017/python/tools/metrics_mine_resnet18_0_offline.yml',
                                               'J', ['mean', 'decay'])

file_path_resnet34 = '/home/klaus/dev/davis-2017/python/tools/metrics_mine_resnet34_0_offline.yml'
resnet34_J_mean_offline, resnet34_J_decay_offline, _ = get_metric(file_path_resnet34, 'J', ['mean', 'decay'])

In [None]:
import plotly.offline as py
import plotly.graph_objs as go
py.init_notebook_mode(connected=True)


def plot(keys, mode, metric, lower_higher, data):
    data = [go.Bar(y=keys, x=x, name=name, orientation = 'h') for name, x in data]

    title = '{metric} per object for each {mode} model ({lower_higher} is better)'.format(metric=metric,
                                                                                          mode=mode,
                                                                                          lower_higher=lower_higher)
    layout = go.Layout(title=title, font=dict(family='Roboto'),
                       xaxis=dict(title=metric, ticks='outside'),
                       yaxis=dict(title='Object'),
                       showlegend=True,
                       bargap=0.5,
                       autosize=False, height=1200,
                       margin=go.Margin(l=120, r=0, b=80, t=100, pad=10)
                      )
    fig = go.Figure(data=data, layout=layout)
    py.iplot(fig)
    

plot(keys, 'offline', 'J_mean', 'higher', [('vgg16', vgg16_J_mean_offline),
                                           ('resnet18', resnet18_J_mean_offline),
                                           ('resnet34', resnet34_J_mean_offline)])

plot(keys, 'online', 'J_mean', 'higher', [('vgg16', vgg16_J_mean_online), 
                                          ('resnet18', resnet18_J_mean_online), 
                                          ('resnet34', resnet34_J_mean_online)])

plot(keys, 'offline', 'J_decay', 'lower', [('vgg16', vgg16_J_decay_offline),
                                           ('resnet18', resnet18_J_decay_offline), 
                                           ('resnet34', resnet34_J_decay_offline)])

plot(keys, 'online', 'J_decay', 'lower', [('vgg16', vgg16_J_decay_online), 
                                          ('resnet18', resnet18_J_decay_online), 
                                          ('resnet34', resnet34_J_decay_online)])