## Example Data Labeler

A not uncommon task you need to perform is to go through a bunch of images and group them into some categories. Sometimes you even need to assess whether the labels of some sort of image processing pipeline are correct.

This Jupyter notebook shows you how to use ipywidgets to write a binary labeler.

In [1]:
from io import BytesIO
from typing import Optional

import numpy as np
import ipywidgets as widgets
from IPython.display import display
from PIL import Image

In [2]:
def get_image(filename: str, width: Optional[int] = None, height: Optional[int] = None) -> Image:
    """
    Open the image at `filename` and resize it according to width/height.
    If only width is provided, scale the image to that width, and proportionally
    scale the height.
    
    Args:
        filename: The file to open
        width: The new width of the image
        height: The new height of the image
    
    Returns:
        The PIL Image representation
    """
    with open(filename, 'rb') as opened_file:
        img = Image.open(opened_file)
        img.load()
    
    if width:
        if height:
            img = img.resize((width, height))
        ratio = width / img.width
        height = int(ratio * img.height)
        img = img.resize((width, height))
    elif height:
        ratio = height / img.height
        width = int(ratio * img.width)
        img = img.resize((width, height))
    return img


def img_to_bytes(img: Image) -> BytesIO:
    """
    Convert the Image to a its PNG representation as bytes
    """
    bytes_io = BytesIO()
    img.save(bytes_io, 'png')
    bytes_io.seek(0)
    return bytes_io


class DataCapture:
    def __init__(self, images, outcomes, seed=12381238):
        if len(images) != len(outcomes):
            raise ValueError('images and outcomes must have the same length')

        rand = np.random.RandomState(seed)
        self.images = images
        self.outcomes = outcomes
        self.idxs = np.arange(len(images))
        self.current_idx_idx = 0
        
        # Shuffle the indices so we get a random sample (when we likely get bored)
        rand.shuffle(self.idxs)
        
        self.image = None
        self.outcome = None
        self.yes_button = None
        self.no_button = None
        
        self.answers = {}
        
    @property
    def current_idx(self):
        return self.idxs[self.current_idx_idx]
        
    def start(self):
        """ Start the capture routine of the widget """
        val = img_to_bytes(get_image(self.images[self.current_idx], 300)).read()
        self.image = widgets.Image(value=val, format='png')
        
        self.outcome = widgets.Label(self.outcomes[self.current_idx])
        
        self.yes_button = widgets.Button(description='Correct')
        self.no_button = widgets.Button(description='Fail')
        
        def on_yes_clicked(b):
            self.answers[self.current_idx] = 'Correct'
            self.next()
        
        def on_no_clicked(b):
            self.answers[self.current_idx] = 'Incorrect'
            self.next()
        
        self.yes_button.on_click(on_yes_clicked)
        self.no_button.on_click(on_no_clicked)
        
        display(self.outcome)
        display(widgets.HBox([self.yes_button, self.no_button]))
        display(self.image)
        
    def next(self):
        """ Advance the capture routine of the widget """
        self.current_idx_idx += 1
        if self.current_idx_idx == len(self.images):
            # We're done!
            self.outcome.close()
            self.yes_button.close()
            self.no_button.close()
            self.image.close()
            print('All images labeled!!')
            return
        self.outcome.value = self.outcomes[self.current_idx]
        val = img_to_bytes(get_image(self.images[self.current_idx], 300)).read()
        self.image.value = val

In [3]:
images = [
    'images/green1.jpg',
    'images/green2.jpg',
    'images/purple1.jpg',
    'images/purple2.jpg'
]

outcomes = [
    'green',
    'purple',
    'purple',
    'green'
]

In [4]:
cap = DataCapture(images, outcomes)
cap.start()

Label(value='purple')

HBox(children=(Button(description='Correct', style=ButtonStyle()), Button(description='Fail', style=ButtonStyl…

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01,\x00\x00\x00\xc8\x08\x02\x00\x00\x00\xdd\xbdK\x0…

All images labeled!!


In [5]:
correct_images = [images[idx] for idx, val in cap.answers.items() if val == 'Correct']
incorrect_images = [images[idx] for idx, val in cap.answers.items() if val == 'Incorrect']

In [6]:
len(correct_images)

2

In [7]:
len(incorrect_images)

2