In [3]:
import os, io
import configparser
import PIL
import IPython
import numpy as np

In [4]:
class Config:
    DEFAULTS = dict(
        kernel_size=32,
        kernel_initial_complexity=0.1,
        kernel_score_threshold=0.5,
        grid_size=16,
        grid_score_threshold=0.1,
        dataset_source='dataset_source/101_ObjectCategories',
        dataset_path='dataset',
        dataset_sample_count=1000
    )
    
    def __init__(self):
        super()
        self.reload()

    def reload(self):
        config = configparser.ConfigParser(defaults=type(self).DEFAULTS)

        if os.path.exists('config.txt'):
            config.read('config.txt')

        def_section = config['DEFAULT']
        
        for k in def_section:
            typ = type(type(self).DEFAULTS[k]) # enforce proper type (e.g. kernel_size must by int, not string)
            setattr(self, k, typ(def_section[k]))

In [5]:
def pil_image_to_2d_bytes(image):
    assert image.width == image.height

    if image.mode != 'L' and image.mode != '1':
        assert False, f'Unsupported mode {image.mode}'

    l_image = image if image.mode == 'L' else image.convert('L')
    arr = np.frombuffer(l_image.tobytes(), dtype=np.uint8)
    rv = arr.reshape(l_image.width, l_image.width)
    return rv if image.mode == 'L' else rv / 255

In [6]:
# from https://gist.github.com/parente/691d150c934b89ce744b5d54103d7f1e
def _src_from_data(data):
    """Base64 encodes image bytes for inclusion in an HTML img element"""
    img_obj = IPython.display.Image(data=data)
    for bundle in img_obj._repr_mimebundle_():
        for mimetype, b64value in bundle.items():
            if mimetype.startswith('image/'):
                return f'data:{mimetype};base64,{b64value}'

def display_images(images, captions=None, row_height='auto'):
    """Shows a set of images in a gallery that flexes with the width of the notebook.
    
    Parameters
    ----------
    images: list of str or bytes
        URLs or bytes of images to display

    row_height: str
        CSS height value to assign to all images. Set to 'auto' by default to show images
        with their native dimensions. Set to a value like '250px' to make all rows
        in the gallery equal height.
    """
    figures = []
    
    for image in images:
        if isinstance(image, bytes) or isinstance(image, PIL.Image.Image):
            if isinstance(image, bytes):
                bts = image
            else:
                b = io.BytesIO()
                image.save(b, format='PNG')
                bts = b.getvalue()
            
            src = _src_from_data(bts)
        else:
            src = image
            #caption = f'<figcaption style="font-size: 0.6em">{image}</figcaption>'

        caption = ''
        
        if captions:
            caption = captions[id(image)]

            if caption:
                caption = f'<figcaption style="font-size: 0.6em">{caption}</figcaption>'
        
        figures.append(f'''
            <figure style="margin: 5px !important;">
              <img src="{src}" style="height: {row_height}">
              {caption}
            </figure>
        ''')
    return IPython.display.HTML(data=f'''
        <div style="display: flex; flex-flow: row wrap; text-align: center;">
        {''.join(figures)}
        </div>
    ''')