In [None]:
#!/usr/env/bin python
# -*-coding:utf-8 -*-
"""
  ████
██    ██   Datature
  ██  ██   Powering Breakthrough AI
    ██

@File    :   bitmask_encoding_demo.py
@Author  :   Yong Jun Thong
@Version :   1.0
@Contact :   hello@datature.io
@License :   Apache License 2.0
@Desc    :   BitMask encoding using RLE and REE
"""

## BitMask Encoding Demo

This script demonstrates encoding a bitmask using Run Length Encoding (RLE) and Run End Encoding (REE) techniques. The various statistics and compression ratios are displayed in an interactive widget.

### Install Prerequisites

In [None]:
%pip install numpy==1.26.4
%pip install Pillow==10.4.0
%pip install ipywidgets
%pip install ipython

### Import Libraries

In [1]:
import numpy as np
from PIL import Image as PImage
from ipywidgets import FileUpload, Output, Button, HBox
from IPython.display import display, Image, HTML
import io

### BitMask Encoding Functions

These functions are used to ingest the image data and encode it into a bitmask using a predefined encoding.

In [2]:
def ingest_image(data_or_path):
    """
    Ingest an image and convert it to a binary array
    """
    # Convert image data to PIL Image
    if isinstance(data_or_path, (bytes, memoryview)):
        # Convert memoryview to bytes if needed
        img_data = data_or_path.tobytes() if isinstance(data_or_path, memoryview) else data_or_path
        img = PImage.open(io.BytesIO(img_data))
    elif isinstance(data_or_path, str):
        img = PImage.open(data_or_path)
    else:
        img = data_or_path

    # Convert to grayscale
    if img.mode == '1':
        img_gray = img
    elif img.mode == 'RGBA' or img.mode == 'LA':
        # Handle transparent images
        img_gray = PImage.new("L", img.size, 255)
        img_gray.paste(img.convert("L"), mask=img.split()[-1])
    else:
        img_gray = img.convert('L')

    # Get dimensions
    width, height = img_gray.size

    # Convert to numpy array and apply thresholding
    img_array = np.array(img_gray)
    binary = (img_array < 128).astype(int)

    # Flatten the array to 1D
    return binary.flatten(), width, height

def bit_to_index(iterable):
    """
    Convert a list of bits to a list of indices
    """
    class Iter:
        def __iter__(self):
            index = 0
            curr = 0
            for b in iterable:
                if b != curr:
                    yield index
                    curr = b
                index += 1
            yield index
    return Iter()

def write_bit(arr):
    """
    Write bits to a bytearray
    """
    index = 0
    position = 0

    def writer(v):
        nonlocal index, position
        if index >= len(arr):
            arr.append(0)
        mask = 1 << position
        arr[index] += mask * v
        position += 1
        if position >= 8:
            index += 1
            position = 0

    return writer

def encode_bitmask(iterable, max_index):
    """
    Encode a bitmask using a binary tree REE
    """
    n = max_index + 1
    input = iter(iterable)
    output = bytearray()
    writer = write_bit(output)

    depth = (n - 1).bit_length()
    stack = []
    stack.append(depth)

    curr_index = 0
    next_value = next(input, None)

    while stack:
        if curr_index >= n or next_value is None:
            break

        level = stack.pop()
        leaf_count = 1 << level

        if level == 0:
            if next_value == curr_index:
                writer(1)
                next_value = next(input, None)
            else:
                writer(0)
            curr_index += 1
        elif curr_index + leaf_count > next_value:
            writer(1)
            stack.extend([level - 1, level - 1])
        else:
            writer(0)
            curr_index += leaf_count

    return bytes(output)

def get_stat(width, height, bits, indices):
    """
    Get statistics of the bitmask encoding
    """
    run_lengths = np.array(indices + [len(bits)]) - np.array([0] + indices)
    btree_encoded = encode_bitmask(indices, len(bits) - 1)
    varint_encoded_size = sum(1 if length < 64 else 2 if length < 64 * 256 else 3 for length in run_lengths)

    return {
        "width": width,
        "height": height,
        "total_pixels": len(bits),
        "black_pixels": bits.sum(),
        "white_pixels": len(bits) - bits.sum(),
        "black_ratio": bits.sum() / len(bits),
        "runs": len(run_lengths),
        "run_to_pixel_ratio": len(run_lengths) / len(bits),
        "max_run_length": max(run_lengths),
        "avg_run_length": sum(run_lengths) / len(run_lengths),
        "btree_encoded_size_ree": len(btree_encoded),
        "varint_encoded_size_rle": varint_encoded_size,
        "ree_bytes_per_run": len(btree_encoded) / len(run_lengths),
        "rle_bytes_per_run": varint_encoded_size / len(run_lengths),
    }

def format_stat(stat, HTML):
    """Format the statistics as an HTML table"""
    formatted = {}
    for key, value in stat.items():
        if isinstance(value, (float, np.floating)):
            if value >= 1000:
                formatted[key] = f"{value:,.3f}"  # Add commas for thousands
            else:
                formatted[key] = f"{value:.3f}"  # Don't add commas for small decimals
        elif isinstance(value, (int, np.integer)):
            formatted[key] = f"{int(value):,}"  # Convert numpy int to Python int and add commas
        else:
            formatted[key] = str(value)

    # Convert to HTML table for better presentation
    html = "<table style='width:30%; border-collapse:collapse;'>"
    html += "<style>td,th{padding:8px; border-bottom:1px solid #ddd;}</style>"
    for key, value in formatted.items():
        pretty_key = " ".join(word.capitalize() for word in key.split('_'))
        pretty_key = pretty_key.replace("Rle", "RLE").replace("Ree", "REE")
        html += f"<tr><td><b>{pretty_key}</b></td><td>{value}</td></tr>"
    html += "</table>"
    return HTML(html)

### Bitmask Rendering

This snippet is used to render the bitmask and various encoding statistics.

In [None]:
# Create widgets
upload_widget = FileUpload(accept='image/*', multiple=False)
sample_button = Button(description='Use Sample Image')
output_widget = Output()


# Display handler for uploaded image
def on_upload_change(e):
    """
    Handle image upload and display
    """
    with output_widget:
        output_widget.clear_output()
        if upload_widget.value:
            if len(upload_widget.value) > 0:
                first_file = upload_widget.value[0]
                img_data = first_file['content']
                display(Image(data=img_data, height=300))

                # Convert image to bits and process
                bits, width, height = ingest_image(img_data)
                indices = list(bit_to_index(bits))
                display(format_stat(get_stat(width, height, bits, indices), HTML))

def on_sample_button_click(e):
    """
    Handle sample image button click
    """
    with output_widget:
        output_widget.clear_output()
        display(PImage.open('assets/bitmask_sample.png').convert("L"))

        # Convert image to bits and process
        bits, width, height = ingest_image('assets/bitmask_sample.png')
        indices = list(bit_to_index(bits))
        display(format_stat(get_stat(width, height, bits, indices), HTML))

# Connect the handlers
upload_widget.observe(on_upload_change, names='value')
sample_button.on_click(on_sample_button_click)

# Display widgets in a horizontal layout
display(HBox([upload_widget, sample_button]), output_widget)

HBox(children=(FileUpload(value=(), accept='image/*', description='Upload'), Button(description='Use Sample Im…

Output()