In [1]:
import functools
import glob
import json
import math
import os

import matplotlib.pyplot as plt
import matplotlib as mpl

import PIL.Image
import PIL.ImageFilter

from tqdm import tqdm

from bokeh.io import output_notebook, show
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, TapTool, OpenURL
from bokeh.transform import jitter

output_notebook()


In [2]:
import db

In [40]:
MODEL_NAME = "front-vs-back"
MODEL_NAME = "with-address-or-blank"

CONFIDENCE_THRESHOLD = 0.9

images = []

with db.Session() as session:
    model = session.query(db.Model).filter_by(name=MODEL_NAME).one()
    labels = json.loads(model.labels)
    q = session.query(db.Label, db.Image).join(db.Image).filter(db.Label.model_name==MODEL_NAME, db.Label.label!="", db.Label.prediction!="")
    sha_set = set()
    for label, image in q:
        if label.sha256 in sha_set:
            continue
        sha_set.add(label.sha256)
        try:
            img = json.loads(label.prediction)
        except:
            continue
        img["label"] = label.label
        img["absolute_path"] = db.make_path(image.origin, image.path)
        img["sha256"] = label.sha256
        images.append(img)

print(f"{len(images)} images to analyze.")


RATINGS = (
    ("good", "green"),
    ("average", "blue"),
    ("critical", "red"),
    ("bad", "orange"),
    ("recoverable", "brown"),
    ("unsure", "black"),
)


def analyze_rppc(img):
    img["file_name1"] = img["file_name"][0]
    img["file_name2"] = img["file_name"][1]
    if img["correctness"] > .9:
        img["rating"] = "good"
    elif img["correctness"] > .5:
        img["rating"] = "average"
    elif img["correctness"] < 0.1:
        img["rating"] = "critical"

def clean_bokeh_label(label):
    return "label_" + label.replace("-", "_")

for i, img in enumerate(images):
    img["i"] = i
    img["predicted"] = "unsure"
    for label in labels:
        img[clean_bokeh_label(label)] = img[label]
        if img[label] > CONFIDENCE_THRESHOLD:
            img["predicted"] = label
    img["correctness"] = img[img["label"]]
    img["rating"] = "unsure"
    if img["correctness"] > CONFIDENCE_THRESHOLD:
        img["rating"] = "good"
    elif img["correctness"] < 1-CONFIDENCE_THRESHOLD:
        img["rating"] = "critical"
    if MODEL_NAME == "with-address-or-blank":
        if img["rating"] == "critical" and img["label"] != "with-address":
            img["rating"] = "recoverable"
    #analyze(img)
            
#print(data[0])
    
PALETTE = zip(labels, ("red", "green", "blue"))
        
TOOLTIPS = """
<div>
<img width="256" src="/files/root/@absolute_path">
<pre>
CATEGORY:  @label
PREDICTED: @predicted
RATING:    @rating

PREDICTIONS:
{}
</pre>
</div>
""".format("\n".join(["{label}=@{cleanlabel}".format(label=label, cleanlabel=clean_bokeh_label(label)) for label in labels]))

def make_cds(list_of_dicts):
    if not list_of_dicts:
        return ColumnDataSource()
    return ColumnDataSource({
        k: [ i[k] for i in list_of_dicts ]
        for k in list_of_dicts[0]
    })

f = figure(
    title="Model Performance (model: CHEESECAKE)",
    tooltips=TOOLTIPS,
    sizing_mode="stretch_width",
    x_range=list(labels),
)
f.add_tools(TapTool(callback=OpenURL(url=f"http://localhost:5000/label/{MODEL_NAME}/@sha256")))
for t in RATINGS:
    imgs = [ img for img in images if img["rating"]==t[0]]
    if len(imgs) == 0:
        continue
    f.circle(
        source=make_cds(imgs),
        color=t[1], size=10,
        y="correctness", x=jitter("label", width=0.5, range=f.x_range),
        #x="i", y=label,
        legend_label=f"{t[0]} ({len(imgs)}, {100*len(imgs)/len(images):.2f}%",
    )
f.legend.location="top"
show(f)


14696 images to analyze.


In [36]:
r

<db.Label at 0x7fc1fd32e050>