---
layout: post
title: Parsing PNG images in Mojo
categories: [mojo]
excerpt: There's currently no direct way of reading image files from Mojo. In this post I go through what's needed to parse a PNG file directly in Mojo without having to go through Python. 
---

So for the past while I've been trying to follow along with the development of Mojo, but so far I've mostly just followed along with the changelog and written some pretty trivial pieces of code. 
In my last post I said I wanted to try something a bit more substantial, so here goes. 

I was looking at the [Basalt](https://github.com/basalt-org/basalt) project, which tries to build a Machine Learning framework in pure Mojo, and realized that the only images used so far were MNIST, which come in a weird binary format anyway. Why no other though? As Mojo does not yet support accelerators (like GPUs) Imagenet is probably impractical, but it should be fairly quick to train a CNN on CIFAR-10 on a CPU these days. The CIFAR-10 dataset is available from the [original source](https://www.cs.toronto.edu/~kriz/cifar.html) as either a pickle archive or some custom binary format. I though about writing datasets for these, but it might be more useful to write a PNG parser in Mojo, and then use the version of the dataset hosted on [Kaggle](https://www.kaggle.com/c/cifar-10/). That way the code can be used to open PNG images in general.    


# A PNG parser in (pure-ish) Mojo

Don't mistake this post for a tutorial: read it as someone discovering the gory details of the PNG standard while learning a new language. If you want to read more about the PNG format, the [wikipedia page](https://en.wikipedia.org/wiki/PNG) is pretty helpful as an overview, and the [W3C page](https://www.w3.org/TR/png/) provides a lot of detail. 

For reference, this was written with Mojo `24.2.1`, and as Mojo is still changing pretty fast a lot of what is done below might be outdated.

The goal here is not to build a tool to display PNGs, but just to read them into an array (or tensor) that can be used for ML purposes, so I will skip over a lot of the more display oriented details. 

## Reading in the data

To start let's take a test image. This is the test image from the PIL library, which is an image of the OG [Grace Hopper](https://en.wikipedia.org/wiki/Grace_Hopper): 
This is a relatively simple PNG, so it should be a good place to start. 

![hopper](../images/hopper.png "Hopper")

Now that Mojo has implemented it's version of pathlib in the stdlib, we can actually check if the file exists: 

In [1]:
from pathlib import Path
test_image = Path('../images/hopper.png')
print(test_image.exists())

True


We'll also import the image via Python so we can compare if the outputs we get match the Python case. 

In [2]:
from python import Python
var Image = Python.import_module('PIL.Image')
var np = Python.import_module('numpy')
py_array = np.array(Image.open("../images/hopper.png"))

We're going to read the raw bytes. I would have expected the data to be unsigned 8-bit integers, but Mojo reads them as **signed** 8-bit integers. There is however a [proposal to change this](https://github.com/modularml/mojo/pull/2099), so this might change soon. 

In [3]:
with open(test_image, "r") as f:
    file_contents = f.read_bytes()

print(len(file_contents))


30605


## Checking the file header

PNG files have a signature defined in the first 8-bytes, part of which is the letters PNG in ASCII. Well define a little helper function to convert from bytes to String: 

In [4]:
fn bytes_to_string(owned list: List[Int8]) -> String:
    var word = String("")
    for letter in list:
        word += chr(int(letter[]))

    return word


To make sure we are actually dealing with a PNG, we can check the bits 1 to 3:

In [5]:
png_signature = file_contents[0:8]
print(bytes_to_string(png_signature[1:4]))


PNG


Yup, it's telling us it is a PNG file. 

## Reading chunks

So now we read the first "chunk", which should be the header. 
Each chunk consists of four parts, the chunk length (4 bytes), the chunk type (4 bytes), the chunk data (however long the first 4 bytes said it was), and a checksum (called the CRC) computed from the data (4 bytes). 

| **Length** | **Chunk type** | **Chunk data** | **CRC** |
|------------|----------------|----------------|---------|
|   4 bytes  |     4 bytes    | _Length_ bytes | 4 bytes |

When reading in data with `read_bytes`, the data comes as a list of signed 8-bit integers, but we would like to interpret the data as 32-bit unsigned integers. Below is a helper function to do so (thanks to [Michael Kowalski
](https://github.com/mikowals)) for the help.



In [6]:
from math.bit import bswap, bitreverse
from testing import assert_true


fn bytes_to_uint32_be(owned list: List[Int8]) raises -> List[UInt32]:
  assert_true(len(list) % 4 == 0, "List[Int8] length must be a multiple of 4 to convert to List[Int32]")
  var result_length = len(list) // 4
  
  # get the data pointer with ownership.
  # This avoids copying and makes sure only one List owns a pointer to the underlying address.
  var ptr_to_int8 = list.steal_data() 
  var ptr_to_uint32 = ptr_to_int8.bitcast[UInt32]()

  var result = List[UInt32]()
  result.data = ptr_to_uint32
  result.capacity = result_length
  result.size = result_length

  # swap the bytes in each UInt32 to convert from big-endian to little-endian
  for i in range(result_length):
    result[i] = bswap(result[i])

  return result

## Reading the image header

The firs chunk after the file header should always be the image header, so let's have a look at it: 

Let's see how long the first chunk is: 

In [7]:
read_head = 8
chunk_length = bytes_to_uint32_be(file_contents[read_head:read_head+4])[0]
print(chunk_length)

13


So the first chunk is 13 bytes long. Let's see what type it is: 


In [8]:

chunk_type = file_contents[read_head+4:read_head+8]
print(bytes_to_string(chunk_type))

IHDR


IHDR, which confirms that this chunk is the image header. We can now parse the next 13 bytes of header data to get information about the image: 

In [9]:
start_header = int(read_head+8)
end_header = int(read_head+8+chunk_length)
header_data = file_contents[start_header:end_header]

The first two chunks tell us the width and height of the image respectively:

In [10]:
print("Image width: ", bytes_to_uint32_be(header_data[0:4])[0])
print("Image height: ", bytes_to_uint32_be(header_data[4:8])[0])

Image width:  128
Image height:  128


So our image is 128x128 pixels in size. 

The next bytes tell is  the bit depth of each pixel, color type, compression method, filter method, and whether the image is interlaced or not. 

In [11]:
print("Bit depth: ", int(header_data[8]))
print("Color type: ", int(header_data[9]))
print("Compression method: ", int(header_data[10]))
print("Filter method: ", int(header_data[11]))
print("Interlaced: ", int(header_data[12]))


Bit depth:  8
Color type:  2
Compression method:  0
Filter method:  0
Interlaced:  0


So the color type is `Truecolor`, so RGB, with a bit depth of 8.

Interesting side note: in the [PIL PngImagePlugin](https://github.com/python-pillow/Pillow/blob/main/src/PIL/PngImagePlugin.py) there is a changelog:
```
# history:
# 1996-05-06 fl   Created (couldn't resist it)
# 1996-12-14 fl   Upgraded, added read and verify support (0.2)
# 1996-12-15 fl   Separate PNG stream parser
# 1996-12-29 fl   Added write support, added getchunks
# 1996-12-30 fl   Eliminated circular references in decoder (0.3)
# 1998-07-12 fl   Read/write 16-bit images as mode I (0.4)
# 2001-02-08 fl   Added transparency support (from Zircon) (0.5)
# 2001-04-16 fl   Don't close data source in "open" method (0.6)
# 2004-02-24 fl   Don't even pretend to support interlaced files (0.7)
# 2004-08-31 fl   Do basic sanity check on chunk identifiers (0.8)
# 2004-09-20 fl   Added PngInfo chunk container
# 2004-12-18 fl   Added DPI read support (based on code by Niki Spahiev)
# 2008-08-13 fl   Added tRNS support for RGB images
# 2009-03-06 fl   Support for preserving ICC profiles (by Florian Hoech)
# 2009-03-08 fl   Added zTXT support (from Lowell Alleman)
# 2009-03-29 fl   Read interlaced PNG files (from Conrado Porto Lopes Gouvua)
```

I like the comment from 2004: `Don't even pretend to support interlaced files` and then interlaced PNG being supported about 13 years after PNG reading was added to PIL. 
I have a feeling I won't be dealing with interlaced files in this post...

The final part of this chunk is the CRC32 value, which is the 32-bit [cyclic redundancy check](https://en.wikipedia.org/wiki/Cyclic_redundancy_check). I don't go into too much details, but it's basically an error-detecting code that's added to detect if the chunk data is corrupt. By checking the provided CRC32 value against one we calculate ourselves we can ensure that the data we are reading is not corrupt. 

In [12]:
start_crc = int(read_head+8+chunk_length)
end_crc = int(start_crc+4)
header_crc = bytes_to_uint32_be(file_contents[start_crc:end_crc])[0]
print("CRC: ", hex(header_crc))

CRC:  0x4c5cf69c


We need a little bit of code to calculate the CRC32 value.   
This is not the most efficient implementation, but it is simple.   
I'll probably do a follow up post where I explain what this does in more detail. 

In [13]:
fn CRC32(owned data: List[SIMD[DType.int8, 1]], value: SIMD[DType.uint32, 1] = 0xffffffff) -> SIMD[DType.uint32, 1]:
    var crc32 = value
    for byte in data:
        crc32 = (bitreverse(byte[]).cast[DType.uint32]() << 24) ^ crc32
        for i in range(8):
            
            if crc32 & 0x80000000 != 0:
                crc32 = (crc32 << 1) ^ 0x04c11db7
            else:
                crc32 = crc32 << 1

    return bitreverse(crc32^0xffffffff)

In [14]:
print(hex(CRC32(file_contents[read_head+4:end_header])))

0x4c5cf69c


Great, so the CRC hexes match, so we know that the data in our IHDR chunk is good. 

## Reading more chunks

Now, reading parts of each chunk will get repetitive, so let's define a struct called `Chunk` to hold the information contained in a chunk, and a function that will parse chunks for us and return the constituent parts: 

In [15]:
struct Chunk:
    var length: UInt32
    var type: String
    var data: List[Int8]
    var crc: UInt32
    var end: Int

    fn __init__(inout self, length: UInt32, chunk_type: String, data : List[Int8], crc: UInt32, end: Int):
        self.length = length
        self.type = chunk_type
        self.data = data
        self.crc = crc
        self.end = end


def parse_next_chunk(owned data: List[Int8], read_head: Int) -> Chunk:
    chunk_length = bytes_to_uint32_be(data[read_head:read_head+4])[0]
    chunk_type = bytes_to_string(data[read_head+4:read_head+8])
    start_data = int(read_head+8)
    end_data = int(start_data+chunk_length)
    chunk_data = data[start_data:end_data]
    start_crc = int(end_data)
    end_crc = int(start_crc+4)
    chunk_crc = bytes_to_uint32_be(data[start_crc:end_crc])[0]

    # Check CRC
    assert_true(CRC32(data[read_head+4:end_data]) == chunk_crc, "CRC32 does not match")
    return Chunk(length=chunk_length, chunk_type=chunk_type, data=chunk_data, crc=chunk_crc, end=end_crc)


During chunk creation the CRC32 value for the chunk data is computed, and an issue will be raised if it is different to what is expected. 

Let's test this to see if it parses the IHDR chunk: 

In [16]:
var header_chunk = parse_next_chunk(file_contents, 8)
print(header_chunk.type)
read_head = header_chunk.end

IHDR


The next few chunks are called "Ancillary chunks", and are not strictly necessary. They contain image attributes (like [gamma](https://en.wikipedia.org/wiki/Gamma_correction)) that may be used in rendering the image: 

In [17]:
var gamma_chunk = parse_next_chunk(file_contents, read_head)
print(gamma_chunk.type)
read_head = gamma_chunk.end

gAMA


In [18]:
var chromacity_chunk = parse_next_chunk(file_contents, read_head)
print(chromacity_chunk.type)
read_head = chromacity_chunk.end

cHRM


In [19]:
var background_chunk = parse_next_chunk(file_contents, read_head)
print(background_chunk.type)
read_head = background_chunk.end

bKGD


In [20]:
var pixel_size_chunk = parse_next_chunk(file_contents, read_head)
print(pixel_size_chunk.type)
read_head = pixel_size_chunk.end

pHYs


## The image data chunk

The IDAT chunk (there can actually be several of them per image) contains the actual image data. 



In [21]:
var image_data_chunk = parse_next_chunk(file_contents, read_head)
print(image_data_chunk.type)
read_head = image_data_chunk.end

IDAT


### Decompression

PNGs are compressed (losslessly) with the [DEFLATE](https://en.wikipedia.org/wiki/Deflate) compression algorithm. 

PNGs are first filtered, then compressed, but as we are decoding, we need to first uncompress the data and the undo the filter.

This next section is why I said in "pure-ish" Mojo: I considered implementing it, but that would be quite a lot of work, so I am hoping that either someone else does this, or that I might dig into this in the future. 

So for the moment, I am using the [zlib](https://en.wikipedia.org/wiki/Zlib) version of the algorithm through Mojo's foreign function interface (FFI).

The following I lightly adapted from the Mojo discord from a thread between Ilya Lubenets and Jack Clayton:

In [22]:
from sys import ffi
alias Bytef = Scalar[DType.int8]
alias uLong = UInt64
alias zlib_type = fn(
    _out: Pointer[Bytef], 
    _out_len: Pointer[UInt64], 
    _in: Pointer[Bytef], 
    _in_len: uLong
) -> Int
fn log_zlib_result(Z_RES: Int, compressing: Bool = True) raises -> NoneType:
    var prefix: String = ''
    if not compressing:
        prefix = "un"

    if Z_RES == 0:
        print('OK ' + prefix.upper() + 'COMPRESSING: Everything ' + prefix + 'compressed fine')
    elif Z_RES == -4:
        raise Error('ERROR ' + prefix.upper() + 'COMPRESSING: Not enought memory')
    elif Z_RES == -5:
        raise Error('ERROR ' + prefix.upper() + 'COMPRESSING: Buffer have not enough memory')
    else:
        raise Error('ERROR ' + prefix.upper() + 'COMPRESSING: Unhandled exception')

fn uncompress(data: List[Int8]) raises -> List[UInt8]:
    var data_memory_amount: Int = len(data)*4
    var handle = ffi.DLHandle('')
    var zlib_uncompress = handle.get_function[zlib_type]('uncompress')

    var uncompressed = Pointer[Bytef].alloc(data_memory_amount)
    var compressed = Pointer[Bytef].alloc(len(data))
    var uncompressed_len = Pointer[uLong].alloc(1)
    memset_zero(uncompressed, data_memory_amount)
    memset_zero(uncompressed_len, 1)
    uncompressed_len[0] = data_memory_amount
    for i in range(len(data)):
        compressed.store(i, data[i])

    var Z_RES = zlib_uncompress(
        uncompressed,
        uncompressed_len,
        compressed,
        len(data),
    )

    log_zlib_result(Z_RES, compressing=False)
    print('Uncompressed length: ' + str(uncompressed_len[0]))
    # Can probably do something more efficient here with pointers, but eh. 
    var res = List[UInt8]()
    for i in range(uncompressed_len[0]):
        res.append(uncompressed[i].to_int())
    return res

Drumroll... let's see if this worked:

In [23]:
uncompressed_data = uncompress(image_data_chunk.data)

OK UNCOMPRESSING: Everything uncompressed fine
Uncompressed length: 49280


Now we have a list of uncompressed bytes. However, these are not pixel values yet. 
The uncompressed data has a length of 49280 bytes. We know we have an RGB image with 8-bit colour depth, so expect $$128 * 128 * 3 = 49152$$ bytes worth of pixel data. Notice that $49280 - 49152 = 128$, and that our image has a shape of `(128, 128)`.  
These extra 128 bytes are to let us know what filter was used to transform the byte values in that line of pixels to something that can be efficiently compressed.  

### Unfilter

The possible filter types specified by the [PNG specification](https://www.w3.org/TR/PNG-Filters.html) are: 

```
   Type    Name
   
   0       None
   1       Sub
   2       Up
   3       Average
   4       Paeth
```

There is some subtle points to pay attention to in the specification, such as the fact that these filter are applied per byte, and not per pixel value. For 8-bit colour depth this is unimportant, but at 16-bits, this means the first byte of a pixel (the MSB, or most significant byte) will be computed separately from the second byte (the LSB, or least significant byte). I won't go too deep into all the details here, but you can read the details of the specification [here](https://www.w3.org/TR/PNG-Filters.html). 

I'll briefly explain the basic idea behind each filter: 

* 0: None
   * No filter is applied and each byte value is just the raw pixel value. 
* 1: Sub
   * Each byte has the preceding byte value subtracted from it
* 2: Up
   * Each byte has the value of the byte above it subtracted from it. 
* 3: Average: 
   * Each byte has the (floor) of the average of the bytes above and to the left of it subtracted from it. 
* 4: Paeth:
   * The three neighbouring pixels (left, above and upper left) are used to calculate a value that is subtracted from the pixel. It's a little more involved than the other three. 

So when decoding the filtered data, we need to reverse the above operations to regain the pixel values. 



Now that we know that, let's look at our first byte value: 





In [24]:
print(uncompressed_data[0])

1


So we are dealing with filter type 1 here. 
Let's decode the first row:

In [25]:
var filter_type = uncompressed_data[0]
var scanline = uncompressed_data[1:128*3+1]

# Decoded image data
var result = List[UInt8](capacity=128*3)
# take the first pixels as 0
var left: UInt8 = 0
var pixel_size: Int = 3
var offset: Int = 1

for i in range(len(scanline)):
    if i >= pixel_size:
        left = result[i-pixel_size] 

    # The specification specifies that the result is modulo 256
    # Silimar to the C implementation, we can just add the left pixel to the current pixel,
    # and the result will be modulo 256 due to overflow
    result.append((uncompressed_data[i + offset] + left))


And let's confirm that the row we decoded is the same as PIL would do: 

In [26]:
for i in range(128):
    for j in range(3):
        assert_true(result[i*3+j] == py_array[0][i][j].__int__(), "Pixel values do not match")

Now that we have the general idea of things, let's write this more generally, and do the other filters as well. 

For an idea of how filters are chosen, read this stackoverflow post and the resources it points to: [How do PNG encoders pick which filter to use?](https://stackoverflow.com/questions/59492926/how-do-png-encoders-pick-which-filter-to-use)

I've done these as functions that take 16-bit signed integers. This is important mostly for the case of the Paeth filter, where the standard states: 
> The calculations within the PaethPredictor function must be performed exactly, without overflow. Arithmetic modulo 256 is to be used only for the final step of subtracting the function result from the target byte value.
So basically we need to keep a higher level of precision and then cast back to bytes at the end. 


I based the implementation on the [iPXE](https://ipxe.org/) implementation of a [png decoder](https://dox.ipxe.org/png_8c_source.html) written in C. 

In [27]:
from math import abs

fn undo_trivial(current: Int16, left: Int16 = 0, above: Int16 = 0, above_left: Int16 = 0) -> Int16:
    return current

fn undo_sub(current: Int16, left: Int16 = 0, above: Int16 = 0, above_left: Int16 = 0) -> Int16:
    return current + left

fn undo_up(current: Int16, left: Int16 = 0, above: Int16 = 0, above_left: Int16 = 0) -> Int16:
    return current + above

fn undo_average(current: Int16, left: Int16 = 0, above: Int16 = 0, above_left: Int16 = 0) -> Int16:
    return current + ((above + left) >> 1) # Bitshift is equivalent to division by 2

fn undo_paeth(current: Int16, left: Int16 = 0, above: Int16 = 0, above_left: Int16 = 0) -> Int16:

    var peath: Int16 = left + above - above_left
    var peath_a: Int16 = abs(peath - left)
    var peath_b: Int16 = abs(peath - above)
    var peath_c: Int16 = abs(peath - above_left)
    if ( peath_a <= peath_b ) and ( peath_a <= peath_c ):
        return (current + left)
    elif ( peath_b <= peath_c ): 
        return (current + above)
    else:
        return (current + above_left)

fn undo_filter(filter_type: UInt8, current: UInt8, left: UInt8 = 0, above: UInt8 = 0, above_left: UInt8 = 0) raises -> UInt8:

    var current_int = current.cast[DType.int16]()
    var left_int = left.cast[DType.int16]()
    var above_int = above.cast[DType.int16]()
    var above_left_int = above_left.cast[DType.int16]()
    var result_int: Int16 = 0

    if filter_type == 0:
        result_int= undo_trivial(current_int, left_int, above_int, above_left_int)
    elif filter_type == 1:
        result_int = undo_sub(current_int, left_int, above_int, above_left_int)
    elif filter_type == 2:
        result_int = undo_up(current_int, left_int, above_int, above_left_int)
    elif filter_type == 3:
        result_int = undo_average(current_int, left_int, above_int, above_left_int)
    elif filter_type == 4:
        result_int = undo_paeth(current_int, left_int, above_int, above_left_int)
    else:
        raise Error("Unknown filter type")
    return result_int.cast[DType.uint8]()


For the `undo_filter` function, I was trying to add the separate filters to some kind of Tuple or List so I could just index them (hence the uniform signatures), but wasn't able to figure out how to do this in Mojo yet. 

So let's apply these to the whole image and confirm that we have the same results as we would get from Python: 

In [28]:

# Decoded image data
# take the first pixels as 0
var pixel_size: Int = 3

# Initialize the previous scanline to 0
var previous_result = List[UInt8](0*128)

for line in range(128):
    var offset =  1 + 1*line + line * 128 * 3
    var left: UInt8 = 0
    var above_left: UInt8 = 0

    #var left: UInt8 = 0
    var result = List[UInt8](capacity=128*3)
    var scanline = uncompressed_data[offset:offset+128*3]

    var filter_type = uncompressed_data[offset - 1]

    for i in range(len(scanline)):
        if i >= pixel_size:
            left = result[i-pixel_size] 
            above_left = previous_result[i-pixel_size] 

        result.append(undo_filter(filter_type, uncompressed_data[i + offset], left, previous_result[i], above_left))


    previous_result = result
    for i in range(128):
        for j in range(3):
            assert_true(result[i*3+j] == py_array[line][i][j].__int__(), "Pixel values do not match")


    

And that's it. If the above runs it means we've sucessfully parsed a PNG file, and at least get the same data out as you would by using Pillow. 


## Creating a tensor

Now ideally we want the above into a Tensor. 

Lets write a function that will parse the image data and return a Tensor for us. 

In [29]:
from tensor import Tensor, TensorSpec, TensorShape
from utils.index import Index
from random import rand

var height = 128
var width = 128
var channels = 3

# Declare the grayscale image.
var spec = TensorSpec(DType.uint8, height, width, channels)
var tensor_image = Tensor[DType.uint8](spec)

In [30]:
# Decoded image data
# take the first pixels as 0
var pixel_size: Int = 3

# Initialize the previous scanline to 0
var previous_result = List[UInt8](0*128)

for line in range(128):
    var offset =  1 + 1*line + line * 128 * 3
    var left: UInt8 = 0
    var above_left: UInt8 = 0

    #var left: UInt8 = 0
    var result = List[UInt8](capacity=128*3)
    var scanline = uncompressed_data[offset:offset+128*3]

    var filter_type = uncompressed_data[offset - 1]

    for i in range(len(scanline)):
        if i >= pixel_size:
            left = result[i-pixel_size] 
            above_left = previous_result[i-pixel_size] 

        result.append(undo_filter(filter_type, uncompressed_data[i + offset], left, previous_result[i], above_left))


    previous_result = result
    for i in range(128):
        for j in range(3):
            tensor_image[Index(line, i, j)] = result[i*3+j]

I'm not entirely sure why I need to use `Index` while setting items, but when getting I can just provide indices: 

In [31]:
print(tensor_image[0,1,2])
print(py_array[0][1][2])

62
62


And there we have it. I will put it all together soon but let's finish parsing the file quickly. 

## Final chunks
There are a few more chunks at this point: text chunks which hold some comments, and an end chunk, which denotes the end of the file: 

In [32]:
var text_chunk_1 = parse_next_chunk(file_contents, read_head)
print(text_chunk_1.type)
read_head = text_chunk_1.end
print(bytes_to_string(text_chunk_1.data))

tEXt
comment


In [33]:
var text_chunk_2 = parse_next_chunk(file_contents, read_head)
print(text_chunk_2.type)
read_head = text_chunk_2.end
print(bytes_to_string(text_chunk_2.data))

tEXt
date:create


In [34]:
var text_chunk_3 = parse_next_chunk(file_contents, read_head)
print(text_chunk_3.type)
read_head = text_chunk_3.end
print(bytes_to_string(text_chunk_3.data))

tEXt
date:modify


In [35]:
var end_chunk = parse_next_chunk(file_contents, read_head)
print(end_chunk.type)
read_head = end_chunk.end

IEND


## Putting it all together. 