In [32]:
import functools
import glob
import json
import math
import os

import matplotlib.pyplot as plt
import matplotlib as mpl

import PIL.Image
import PIL.ImageFilter

from tqdm import tqdm

from bokeh.io import output_notebook, show
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, TapTool, OpenURL
from bokeh.transform import jitter

output_notebook()


In [35]:
with open("../inference-analysis.json") as f:
    data = json.load(f)

print(f"{len(data)} images to analyze.")

LABELS = set([ img["label"] for img in data ])
MAXVAL = max([ img[list(LABELS)[0]] for img in data ])

RATINGS = (
    ("good", "green"),
    ("average", "blue"),
    ("critical", "red"),
    ("bad", "orange"),
    ("recoverable", "brown"),
    ("black", "unsure"),
)

for i, img in enumerate(data):
    img["i"] = i
    # specific code comes here
    img["correctness"] = img[img["label"]]
    if img["correctness"] > .9:
        img["rating"] = "good"
    elif img["correctness"] > .5:
        img["rating"] = "average"
    elif any([img[label] >= .95 for label in LABELS]):
        img["rating"] = "critical" if img["label"]=="with_address" else "recoverable"
    elif any([img[label] > .5 for label in LABELS]):
        img["rating"] = "bad"
    else:
        img["rating"] = "unsure"
            
    #img["x"] = sum([img[label] * math.sin(2*i*math.pi/len(LABELS)) for (i, label) in enumerate(LABELS) ])
    #img["y"] = sum([img[label] * math.cos(2*i*math.pi/len(LABELS)) for (i, label) in enumerate(LABELS) ])
    #if img["label"] == "blank":
    #    img["x"] -= 300
    #if img["label"] == "with_address":
    #    img["x"] += 300
            
#print(data[0])
    
PALETTE = zip(LABELS, ("red", "green", "blue"))
        
TOOLTIPS = """
<div>
<img width="256" src="/files/imgroot/@file_name">
<p>@file_name</p>
<p>CATEGORY: @label, RATING: @rating</p>
<p>PREDICTIONS: {}</p>
</div>
""".format(", ".join(["{label}=@{label}/{MAXVAL}".format(label=label, MAXVAL=MAXVAL) for label in LABELS]))

def make_cds(list_of_dicts):
    if not list_of_dicts:
        return ColumnDataSource()
    return ColumnDataSource({
        k: [ i[k] for i in list_of_dicts ]
        for k in list_of_dicts[0]
    })

f = figure(
    title="Model Performance (model: BROWNIE)",
    tooltips=TOOLTIPS,
    sizing_mode="stretch_width",
    x_range=list(LABELS),
)
f.add_tools(TapTool(callback=OpenURL(url="/files/imgroot/@file_name")))
for t in RATINGS:
    f.circle(
        source=make_cds([ img for img in data if img["rating"]==t[0]]),
        color=t[1], size=10,
        y="correctness", x=jitter("label", width=0.5, range=f.x_range),
        #x="i", y=label,
        legend_label=t[0],
    )
f.legend.location="left"
show(f)


1182 images to analyze.


In [36]:
len([i for i in data if i["rating"]=="critical"])

21