In [13]:
import numpy as np
import random
from typing import Sequence
from itertools import product

In [17]:
class ImageNew:
    """Class that represents a png-based image."""

    SUPPORTED_DEPTHS = {1, 2, 4, 8}

    def __init__(
        self, data: Sequence[int] | None, width: int, height: int, depth: int = 4
    ):
        """Initializes an image from binary data.

        Parameters:
        -----------
        data : bytes-like or None
            Binary raw pixel data in GBA format.
            If None the image will be empty.
        width : int
            The width of the picture.
        height : int
            The height of the picture.
        bpp : 4 or 8
            Bits per pixel (depth).
        """
        assert (
            depth in self.SUPPORTED_DEPTHS
        ), f'Invalid depth {depth}, must be in {self.SUPPORTED_DEPTHS}'
        assert width % 8 == 0, 'Width is not a multiple of 8'
        assert height % 8 == 0, 'Height is not a multiple of 8'
        self.depth = depth
        self.width = width
        self.height = height
        self.data = np.zeros((width, height), dtype=int)
        if data is not None:
            for x, y in product(range(width), range(height)):
                byte_index, bit_index = self._position_to_data_idx(x, y)
                self.data[x, y] = (data[byte_index] >> bit_index) & (2**depth - 1)

    def to_binary(self) -> bytearray:
        """Returns the raw binary data of the image in GBA tile format.

        Note that this has linear complexity in the number of pixels.

        Returns:
        --------
        binary : bytearray
            The binary representation of the picture in GBA tile format.
        """
        # Pack the data to a dense binary format
        binary = bytearray([0] * (self.width * self.height * self.depth // 8))
        for x, y in product(range(self.width), range(self.height)):
            byte_index, bit_index = self._position_to_data_idx(x, y)
            binary[byte_index] |= (self.data[x, y] & (2**self.depth - 1)) << bit_index
        return binary

    def _position_to_data_idx(self, x: int, y: int) -> tuple[int, int]:
        """Converts a position to the data index and bit index.

        Parameters:
        -----------
        x : int
            The x position.
        y : int
            The y position.

        Returns:
        --------
        data_idx : int
            Which index in the byte array the position is located.
        bit_idx : int
            At which bit in the byte the position is located.
        """
        tile_x, tile_y = x >> 3, y >> 3
        tile_idx = tile_y * (self.width >> 3) + tile_x
        tile_off_x, tile_off_y = x & 7, y & 7
        data_idx = (
            tile_idx * 8 * self.depth
            + tile_off_y * self.depth
            + tile_off_x * self.depth // 8
        )
        return data_idx, 8 - self.depth - tile_off_x * self.depth % 8

In [29]:
for depth in ImageNew.SUPPORTED_DEPTHS:
    w, h = 16, 24
    data = bytearray([np.random.randint(2**8) for _ in range(w * h * depth // 8)])

    img = ImageNew(data, w, h, depth)
    assert img.to_binary() == data

In [46]:
class ImageOld:
    """Class that represents a png-based image."""

    def __init__(
        self, data: Sequence[int] | None, width: int, height: int, depth: int = 4
    ):
        """Initializes an image from binary data.

        Parameters:
        -----------
        data : bytes-like or None
            Binary raw pixel data in GBA format.
            If None the image will be empty.
        width : int
            The width of the picture.
        height : int
            The height of the picture.
        bpp : 4 or 8
            Bits per pixel (depth).
        """
        self.data = np.zeros((width, height), dtype=int)
        if data is not None:
            # Unpack the data
            tile_width = width // 8
            for idx in range(len(data)):
                if depth == 4:
                    tile_idx = idx // 32  # (8 * 8 / 2) = 32 bytes per tile
                    tile_pos = idx % 32
                    x = 8 * (tile_idx % tile_width)
                    y = 8 * (tile_idx // tile_width)
                    y += tile_pos // 4
                    x += 2 * (tile_pos % 4)
                    # print(f'Tile {idx} -> {x}, {y}')
                    if x < width and y < height:
                        self.data[x, y] = data[idx] & 0xF
                        self.data[x + 1, y] = data[idx] >> 4
                elif depth == 8:
                    tile_idx = idx // 64  # (8 * 8) = 64 bytes per tile
                    tile_pos = idx % 64
                    x = 8 * (tile_idx % tile_width)
                    y = 8 * (tile_idx // tile_width)
                    y += tile_pos // 8
                    x += tile_pos % 8
                    if x < width and y < height:
                        self.data[x, y] = data[idx]
                else:
                    raise RuntimeError(
                        'Invalid image depth. Only depths 4 and 8 are ' 'supported!'
                    )

        self.depth = depth
        self.width = width
        self.height = height

    def to_binary(self) -> bytearray:
        """Returns the raw binary data of the image in GBA tile format.

        Note that this has linear complexity in the number of pixels.

        Returns:
        --------
        binary : bytearray
            The binary representation of the picture in GBA tile format.
        """
        # Pack the data to a dense binary format
        binary = bytearray([0] * (self.width * self.height * self.depth // 8))
        for x, y in product(range(self.width), range(self.height)):
            tile_x, tile_y = x >> 3, y >> 3
            tile_idx = tile_y * (self.width >> 3) + tile_x
            tile_off_x, tile_off_y = x & 7, y & 7
            if self.depth == 4:
                idx = tile_idx << 5  # (8 * 8 / 2) = 32 bytes per tile
                idx += ((tile_off_y << 3) + tile_off_x) >> 1
                # print(f'{x}, {y} -> {idx} [upper={tile_off_x & 1 > 0}]')
                if tile_off_x & 1 > 0:
                    binary[idx] |= self.data[x, y] << 4
                else:
                    binary[idx] |= self.data[x, y]
            elif self.depth == 8:
                idx = tile_idx << 6  # (8 * 8) = 64 bytes per tile
                idx += (tile_off_y << 3) + tile_off_x
                binary[idx] = self.data[x, y]
            else:
                raise RuntimeError(
                    'Invalid image depth. Only depths 4 and 8 are ' 'supported!'
                )
        return binary

In [49]:
for depth in [4, 8]:
    w, h = 16, 32
    data = bytearray([np.random.randint(2**8) for _ in range(w * h * depth // 8)])
    img_old = ImageOld(data, w, h, depth)
    img_new  = ImageNew(data, w, h, depth)
    
    assert img_old.to_binary() == data
    assert img_new.to_binary() == data

In [43]:
6 * 64 + 2 * 8 + 4

404