# Picture Picker

Use this to browse all upscaled images and pick the best model.

## Setup

In [None]:
import sqlite3
from functools import lru_cache, partial
from io import BytesIO

import ipywidgets
from PIL import Image
from traitlets import dlink, link

In [None]:
db_path = "textures.db"

In [None]:
def get_image_data(image_hash, model_hash=None):
    if model_hash is None:
        args = ("SELECT input_data FROM inputs WHERE image_hash = ?", (image_hash,))
    else:
        args = (
            """SELECT
                output_data.*
            FROM
                outputs
            WHERE
                model_hash = ?
                AND image_hash = ?
            """,
            (image_hash, model_hash),
        )
    with sqlite3.connect(db_path) as conn:
        cursor = conn.cursor()
        data = cursor.execute(*args).fetchone()[0]
        return data


def get_image_from_data(image_data):
    stream = BytesIO(image_data)
    return Image.open(stream)


# @lru_cache(None)
def get_matching_image_data(image_hash):
    with sqlite3.connect(db_path) as conn:
        cursor = conn.cursor()
        return cursor.execute(
            """
            SELECT 
                output_data,
                model_name,
                outputs.model_hash,
                image_hash
            FROM
                outputs
            JOIN
                models
            ON
                outputs.model_hash = models.model_hash
            WHERE
                outputs.image_hash = ?
            ORDER BY
                model_name
            """,
            (image_hash,),
        ).fetchall()


# @lru_cache(None)
def get_matching_images(image_hash):
    return [get_image_from_data(d[0]) for d in get_matching_image_data(image_hash)]


@lru_cache(None)
def get_image_data_by_index(index, model_hash=None):
    image_hash, _ = hash_images[index]
    return get_image_data(image_hash, model_hash=model_hash)


@lru_cache(None)
def get_image_by_index(index, model_hash=None):
    return get_image_from_data(get_image_data_by_index(index, model_hash=model_hash))


@lru_cache(None)
def get_width_by_index(index, model_hash=None):
    image = get_image_by_index(index, model_hash=model_hash)
    return image.width


@lru_cache(None)
def get_height_by_index(index, model_hash=None):
    image = get_image_by_index(index, model_hash=None)
    return image.height


def get_active_preference(image_hash):
    with sqlite3.connect(db_path) as conn:
        cursor = conn.cursor()
        return cursor.execute(
            "SELECT model_hash FROM preferences WHERE image_hash = ?", (image_hash,)
        ).fetchone()


def pick_model(image_hash, model_hash):
    with sqlite3.connect(db_path) as conn:
        cursor = conn.cursor()
        args = dict(image_hash=image_hash, model_hash=model_hash)
        cursor.execute(
            """INSERT OR IGNORE INTO preferences(
                image_hash, model_hash
            )
            VALUES (:image_hash, :model_hash)
            """,
            args,
        )
        cursor.execute(
            """UPDATE preferences SET model_hash=:model_hash WHERE image_hash=:image_hash""",
            args,
        )

In [None]:
with sqlite3.connect(db_path) as conn:
    cursor = conn.cursor()
    hash_images = cursor.execute(
        """SELECT 
            inputs.image_hash,
            image_name
        FROM 
            inputs
        JOIN
            (SELECT 
                image_hash, 
                COUNT(image_hash) as image_count 
            FROM 
                outputs 
            GROUP BY image_hash    
            ) AS counted_outputs
        ON
            inputs.image_hash = counted_outputs.image_hash
        LEFT JOIN
            preferences
        ON
            inputs.image_hash = preferences.image_hash
        WHERE
            counted_outputs.image_count > 1
            AND preferences.model_hash IS null
        ORDER BY
            image_name
        """
    ).fetchall()
image_picker = ipywidgets.Select(
    options=[f"{i:4d} {name}" for i, (hash, name) in enumerate(hash_images)],
    layout=dict(height="256px"),
)


def alter_index(x, target, val):
    target.index = (target.index + val) % len(target.options)


nextbutton = ipywidgets.Button(description="Next")
nextbutton.on_click(lambda x: alter_index(x, image_picker, 1))
prevbutton = ipywidgets.Button(description="Previous")
prevbutton.on_click(lambda x: alter_index(x, image_picker, -1))
unscaled_image_display = ipywidgets.Image()
carousel_layout = ipywidgets.Layout(
    overflow="scroll", width="1400px", height="", flex_flow="row wrap", display="flex"
)
output_carousel = ipywidgets.Box(children=[], layout=carousel_layout)
model_choices = ipywidgets.RadioButtons(layout=dict(width="600px"))


def get_image_options(index):
    hash, _ = hash_images[index]
    matching_data = get_matching_image_data(hash)
    w = get_width_by_index(index) * 4
    max_width = 600
    max_height = 400
    h = get_height_by_index(index) * 4
    width = "{}px".format(min(max_width, w))
    height = "{}px".format(min(max_height, h))
    results = []
    for index, (image_data, model_name, model_hash, image_hash) in enumerate(
        matching_data
    ):
        image = ipywidgets.Image(value=image_data, layout=dict(max_height=height))
        button = ipywidgets.Button(description=model_name)

        def set_index(sender, index=index):
            print(index)
            model_choices.index = index

        button.on_click(set_index)
        results.append(ipywidgets.VBox([image, button]))
    return results


def get_model_options(index):
    hash, _ = hash_images[index]
    matching_data = get_matching_image_data(hash)
    return [(d[1], d[2]) for d in matching_data]


def click_radio(change):
    sender = change["owner"]
    image_hash, _ = hash_images[image_picker.index]
    model_hash = sender.value
    pick_model(image_hash, model_hash)


def update_model_selection(index):
    hash, _ = hash_images[index]
    model_preference = get_active_preference(hash)
    available_hashes = [d[2] for d in get_matching_image_data(hash)]
    try:
        return available_hashes.index(model_preference[0])
    except (ValueError, TypeError):
        return 0


def pick_carousel_image(change):
    index = change["new"]
    for child in output_carousel.children:
        child.layout.border = "none"
    layout = output_carousel.children[index].layout
    layout.border = "2px solid black"


model_choices.observe(click_radio, names="index")
model_choices.observe(pick_carousel_image, names="index")

imlinks = [
    dlink(
        (image_picker, "index"),
        (unscaled_image_display, "value"),
        partial(get_image_data_by_index, model_hash=None),
    ),
    dlink(
        (image_picker, "index"),
        (unscaled_image_display, "width"),
        partial(get_width_by_index, model_hash=None),
    ),
    dlink(
        (image_picker, "index"),
        (unscaled_image_display, "height"),
        partial(get_height_by_index, model_hash=None),
    ),
    dlink(
        (image_picker, "index"),
        (unscaled_image_display.layout, "max_width"),
        lambda ix: "{}px".format(get_width_by_index(ix, model_hash=None)),
    ),
    dlink(
        (image_picker, "index"),
        (unscaled_image_display.layout, "max_height"),
        lambda ix: "{}px".format(get_height_by_index(ix, model_hash=None)),
    ),
    dlink((image_picker, "index"), (output_carousel, "children"), get_image_options),
    dlink((image_picker, "index"), (model_choices, "options"), get_model_options),
    dlink((image_picker, "index"), (model_choices, "index"), update_model_selection),
]

## UI

Go through all the images and pick the best-looking model for each!

In [None]:
ipywidgets.VBox(
    [
        ipywidgets.HBox(
            [
                ipywidgets.VBox(
                    [
                        ipywidgets.HBox([prevbutton, nextbutton]),
                        image_picker,
                    ]
                ),
                unscaled_image_display,
                model_choices,
            ],
        ),
        output_carousel,
        model_choices,
        ipywidgets.HBox([prevbutton, nextbutton]),
    ]
)