In [None]:
import functools
import glob
import json
import math
import os

import matplotlib.pyplot as plt
import matplotlib as mpl

import PIL.Image
import PIL.ImageFilter

from tqdm import tqdm

from bokeh.io import output_notebook, show
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, TapTool, OpenURL
from bokeh.transform import jitter

output_notebook()

In [None]:
with open("_brownie/inference.json") as f:
    data = json.load(f)

THRESHOLD = 0.9
LABELS = set([ img["label"] for img in data ])
MAXVAL = max([ img[list(LABELS)[0]] for img in data ])

i = 0
for img in data:
    img["i"] = i
    i += 1
    # specific code comes here
    if img["blank"] > THRESHOLD*MAXVAL:
        if img["label"] == "blank":
            img["guess"] = "right"
        else:
            img["guess"] = "wrong"
    elif img["with_address"] > THRESHOLD*MAXVAL:
        if img["label"] == "with_address":
            img["guess"] = "right"
        else:
            img["guess"] = "wrong"
    else:
        img["guess"] = "unsure"
    #
    img["max"] = max([img[label] for label in LABELS])
    #img["x"] = sum([img[label] * math.sin(2*i*math.pi/len(LABELS)) for (i, label) in enumerate(LABELS) ])
    #img["y"] = sum([img[label] * math.cos(2*i*math.pi/len(LABELS)) for (i, label) in enumerate(LABELS) ])
    #if img["label"] == "blank":
    #    img["x"] -= 300
    #if img["label"] == "with_address":
    #    img["x"] += 300
            
#print(data[0])
    
PALETTE = zip(LABELS, ("red", "green", "blue"))
        
TOOLTIPS = """
<div>
<img width="256" src="/files/@file_name">
<p>@filename</p>
<p>@label ({})</p>
</div>
""".format(", ".join(["{label}=@{label}/{MAXVAL}".format(label=label, MAXVAL=MAXVAL) for label in LABELS]))

def make_cds(list_of_dicts):
    if not list_of_dicts:
        return ColumnDataSource()
    return ColumnDataSource({
        k: [ i[k] for i in list_of_dicts ]
        for k in list_of_dicts[0]
    })

f = figure(
    title="Model Performance (model: BROWNIE)",
    tooltips=TOOLTIPS,
    sizing_mode="stretch_width",
    x_range=list(LABELS),
)
f.add_tools(TapTool(callback=OpenURL(url="/files/@file_name")))
for guess, color, func, size in (
    ("right", "green", f.star, 6),
    ("wrong", "red", f.circle, 10),
    ("unsure", "black", f.square, 4),
    ):
    func(
        source=make_cds([ img for img in data if img["guess"]==guess]),
        color=color, size=size,
        y="max", x=jitter("label", width=0.6, range=f.x_range),
        #x="i", y=label,
        legend_label=guess,
    )
f.legend.location="left"
show(f)
