---
layout: post
title: CRC calculations and compile time metaprogramming in Mojo
categories: [mojo]
date: "2024-05-99"
author: "Ferdinand Schenck"
draft: true
description: . 
---

In my [last post](https://fnands.com/blog/2024/mojo-png-parsing/) on parsing PNG images in I very briefly mentioned cyclic redundancy checks, and posted a rather cryptic looking function which I claimed was a bit inefficient. 

In this post I want to follow up on that a bit, and delve into the compile time metaprogramming side of Mojo to see how we can speed up these calculations.   

But first, let's go through a bit of background so we know what we're dealing with. 

## Cyclic redundancy checks

CRCs are error detecting codes that are often used to detect corruption of data in digital files, an example of which is PNG files. In the case of PNGs for example the CRC32 is calculated for the data of each chunk and appended to the end of the chunk, so that the person reading the file can verify whether the data they read was the same as the data that was written.  

A CRC check technically does "long division in the ring of polynomials of binary coefficients ($\Bbb{F}_2[x]$)" 😳.   

It's not as complicated as it sounds. I found the [Wikipedia article on Polynomial long division](https://en.wikipedia.org/wiki/Polynomial_long_division) to be helpful, and if you want an in depth explanation then
[this post](https://github.com/komrad36/CRC) by [Kareem Omar](https://github.com/komrad36) does a really great job of explaining both the concept and implementation considerations. 

But what you need to know is that XOR is equivalent to polynomial long division (over a finite field) for binary numbers, and XOR is a very efficient operation to calculate in hardware. 

The simplest example of a cyclic redundancy check is the [parity bit](https://en.wikipedia.org/wiki/Parity_bit), AKA CRC-1. The parity bit is used to detect whether an error has occurred while transmitting a byte-long message (it can be used for longer messages, but probably shouldn't be). 

In the formalism of CRC checks, it can be calculated by successively applying XOR between your message and the relevant *generator polynomial*. For larger cases the choice of generator polynomial can get quite involved, but for the CRC-1 case it is $x + 1$, expressed in binary as 11. Notice that the Generator polynomial is always 1 order (or has one more bit) than the CRC. The way it is applied is by bitshifting 

```
1+0+0+1 (mod 2) = 0
1+0+1+1 (mod 2) = 1

1001/1100 = 0101
0101/0110 = 0011
0011/0011 = 0000

1011/1100 = 0111
0111/0110 = 0001
```


In [1]:
from math.bit import bitreverse
import benchmark

fn CRC32(owned data: List[SIMD[DType.uint8, 1]]) -> SIMD[DType.uint32, 1]:
    var crc32: UInt32 = 0xffffffff
    for byte in data:
        crc32 = (bitreverse(byte[]).cast[DType.uint32]() << 24) ^ crc32
        for i in range(8):
            
            if crc32 & 0x80000000 != 0:
                crc32 = (crc32 << 1) ^ 0x04c11db7
            else:
                crc32 = crc32 << 1

    return bitreverse(crc32^0xffffffff)

In [2]:
fn CRC32_inv(owned data: List[SIMD[DType.uint8, 1]]) -> SIMD[DType.uint32, 1]:
    var crc32: UInt32 = 0xffffffff
    for byte in data:
        crc32 = (byte[].cast[DType.uint32]() ) ^ crc32
        for i in range(8):
            
            if crc32 & 1 != 0:
                crc32 = (crc32 >> 1) ^ 0xedb88320
            else:
                crc32 = crc32 >> 1

    return crc32^0xffffffff

In [3]:
var test_list = List[SIMD[DType.uint8, 1]](5, 78, 138, 1, 54, 17, 104)

In [4]:
print(hex(CRC32(test_list)))
print(hex(CRC32_inv(test_list)))

0x89ba07cb
0x89ba07cb


In [5]:

from time import sleep




In [6]:
from random import rand

fn run_32[data: List[SIMD[DType.uint8, 1]] ]():
    var a =  CRC32(data)
    benchmark.keep(a)


fn run_32_inv[data: List[SIMD[DType.uint8, 1]] ]():
    var a = CRC32_inv(data)
    benchmark.keep(a)


fn bench():

    
    alias fill_size = 2**16
    alias g = UnsafePointer[SIMD[DType.uint8, 1]].alloc(fill_size)
    rand[DType.uint8](ptr =  g, size = fill_size)


    alias rand_list = List[SIMD[DType.uint8,1]](data = g, size = fill_size, capacity = fill_size)


    print(len(rand_list))

    var report = benchmark.run[run_32[rand_list]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report)

    var report_2 = benchmark.run[run_32_inv[rand_list]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report_2)

    print(100 * (report/report_2 -1))
    #report.print_full()
    #report_2.print_full()


bench()

65536
0.43777475492831541
0.41176855658277695
6.3157319639365106


In [7]:
var little_endian_table = List[UInt32](capacity=256)

for i in range(256):

    var key = UInt8(i)
    var crc32 = key.cast[DType.uint32]()
    for i in range(8):
        if crc32 & 1 != 0:
            crc32 = (crc32 >> 1) ^ 0xedb88320
        else:
            crc32 = crc32 >> 1

    little_endian_table[i] = crc32

fn CRC32_table(owned data: List[SIMD[DType.uint8, 1]], table: List[UInt32]) -> SIMD[DType.uint32, 1]:
    var crc32: UInt32 = 0xffffffff
    for byte in data:
        var index = (crc32 ^ byte[].cast[DType.uint32]()) & 0xff
        crc32 = table[int(index)] ^ (crc32 >> 8)


    return crc32^0xffffffff

In [8]:
print(hex(CRC32(test_list)))
print(hex(CRC32_inv(test_list)))
print(hex(CRC32_table(test_list, little_endian_table)))

0x89ba07cb
0x89ba07cb
0x89ba07cb


In [9]:
from random import rand

fn run_32[data: List[SIMD[DType.uint8, 1]] ]():
    var a =  CRC32(data)
    benchmark.keep(a)


fn run_32_inv[data: List[SIMD[DType.uint8, 1]] ]():
    var a = CRC32_inv(data)
    benchmark.keep(a)


fn run_32_table[data: List[SIMD[DType.uint8, 1]], table: List[UInt32]]():
    var a = CRC32_table(data, table)
    benchmark.keep(a)


fn fill_table() -> List[UInt32]:

    var table = List[UInt32](capacity=256)

    for i in range(256):

        var key = UInt8(i)
        var crc32 = key.cast[DType.uint32]()
        for i in range(8):
            if crc32 & 1 != 0:
                crc32 = (crc32 >> 1) ^ 0xedb88320
            else:
                crc32 = crc32 >> 1

        table[i] = crc32
    return table

fn bench():

    
    alias fill_size = 2**16
    alias g = UnsafePointer[SIMD[DType.uint8, 1]].alloc(fill_size)
    rand[DType.uint8](ptr =  g, size = fill_size)


    alias rand_list = List[SIMD[DType.uint8,1]](data = g, size = fill_size, capacity = fill_size)


    print(len(rand_list))

    var report = benchmark.run[run_32[rand_list]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report)

    var report_2 = benchmark.run[run_32_inv[rand_list]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report_2)

 
    
    alias little_endian_table = fill_table()

    var report_3 = benchmark.run[run_32_table[rand_list, little_endian_table]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report_3)

    print(100 * (report/report_2 -1))
    print(100 * (report_2/report_3 -1))
    print(100 * (report/report_3 -1))
    #report.print_full()
    #report_2.print_full()


bench()

65536
0.4380599642356241
0.41109755614424881
0.11433471179999999
6.5586398382564193
259.55620972164712
283.13820653337592


In [10]:
var little_endian_table_2_byte = List[UInt32](capacity=512)

for i in range(256):

    var key = UInt8(i)
    var crc32 = key.cast[DType.uint32]()
    for i in range(8):
        if crc32 & 1 != 0:
            crc32 = (crc32 >> 1) ^ 0xedb88320
        else:
            crc32 = crc32 >> 1

    little_endian_table_2_byte[i] = crc32

for i in range(256, 512):
    var crc32 = little_endian_table_2_byte[i-256]
    little_endian_table_2_byte[i] = (crc32 >> 8) ^ little_endian_table_2_byte[int(crc32.cast[DType.uint8]())]






In [11]:
from testing import assert_true


fn CRC32_table_2_byte(owned data: List[SIMD[DType.uint8, 1]], table: List[UInt32]) -> SIMD[DType.uint32, 1]:
    var crc32: UInt32 = 0xffffffff

    #assert_true(len(data) % 2 == 0, "List must be divisible by two for 16-bit optimization.")

    var extra = len(data) % 2
    var leftover = List[SIMD[DType.uint8, 1]](capacity = extra)
    for i in range(extra):
        leftover.append(data[-(i + 1)])

    var result_length = len(data)//2
    var ptr_to_int8 = data.steal_data() 
    var ptr_to_uint16 = ptr_to_int8.bitcast[UInt16]()

    var result = List[UInt16]()
    result.data = ptr_to_uint16
    result.capacity = result_length
    result.size = result_length

    for byte in result:
        var index = (crc32 ^ byte[].cast[DType.uint32]()) #& 0xff
        crc32 =  table[int((index >> 8).cast[DType.uint8]())] ^ table[256 + int(index.cast[DType.uint8]())] ^ (crc32 >> 16)
    
    for byte in leftover:
        var index = (crc32 ^ byte[].cast[DType.uint32]()) & 0xff
        crc32 = table[int(index)] ^ (crc32 >> 8)


    return crc32^0xffffffff

In [12]:
print(hex(CRC32(test_list)))
print(hex(CRC32_inv(test_list)))
print(hex(CRC32_table_2_byte(test_list, little_endian_table_2_byte)))

0x89ba07cb
0x89ba07cb
0x89ba07cb


In [13]:
var f: UInt32 = (0xff << 8) | 0xff
print(hex(f))

0xffff


In [14]:
fn run_32[data: List[SIMD[DType.uint8, 1]] ]():
    var a =  CRC32(data)
    benchmark.keep(a)


fn run_32_inv[data: List[SIMD[DType.uint8, 1]] ]():
    var a = CRC32_inv(data)
    benchmark.keep(a)


fn run_32_table[data: List[SIMD[DType.uint8, 1]], table: List[UInt32]]():
    var a = CRC32_table(data, table)
    benchmark.keep(a)


fn run_32_table_2_byte[data: List[SIMD[DType.uint8, 1]], table: List[UInt32]]():
    var a = CRC32_table_2_byte(data, table)
    benchmark.keep(a)


fn fill_table() -> List[UInt32]:

    var table = List[UInt32](capacity=256)

    for i in range(256):

        var key = UInt8(i)
        var crc32 = key.cast[DType.uint32]()
        for i in range(8):
            if crc32 & 1 != 0:
                crc32 = (crc32 >> 1) ^ 0xedb88320
            else:
                crc32 = crc32 >> 1

        table[i] = crc32
    return table

fn fill_table_2_byte() -> List[UInt32]:

    var table = List[UInt32](capacity=512)

    for i in range(256):

        var key = UInt8(i)
        var crc32 = key.cast[DType.uint32]()
        for i in range(8):
            if crc32 & 1 != 0:
                crc32 = (crc32 >> 1) ^ 0xedb88320
            else:
                crc32 = crc32 >> 1

        table[i] = crc32

    for i in range(256, 512):
        var crc32 = table[i-256]
        table[i] = (crc32 >> 8) ^ table[int(crc32.cast[DType.uint8]())]
    return table


fn bench():

    
    alias fill_size = 2**16
    alias g = UnsafePointer[SIMD[DType.uint8, 1]].alloc(fill_size)
    rand[DType.uint8](ptr =  g, size = fill_size)


    alias rand_list = List[SIMD[DType.uint8,1]](data = g, size = fill_size, capacity = fill_size)


    print(len(rand_list))

    var report = benchmark.run[run_32[rand_list]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report)

    var report_2 = benchmark.run[run_32_inv[rand_list]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report_2)

 
    
    alias little_endian_table = fill_table()

    var report_3 = benchmark.run[run_32_table[rand_list, little_endian_table]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report_3)

    alias little_endian_table_2_byte = fill_table()

    var report_4 = benchmark.run[run_32_table_2_byte[rand_list, little_endian_table_2_byte]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report_4)

    print(100 * (report/report_2 -1))
    print(100 * (report_2/report_3 -1))
    print(100 * (report/report_3 -1))
    print(100 * (report/report_4 -1))
    #report.print_full()
    #report_2.print_full()


bench()

65536
0.43897641339648175
0.41363761281480216
0.11423727145
0.079083381461675567
6.125845376886585
262.08639051383977
284.26724292746729
455.0804799731701


In [15]:
fn CRC32_table_2_byte_2(owned data: List[SIMD[DType.uint8, 1]], table: List[UInt32]) -> SIMD[DType.uint32, 1]:
    var crc32: UInt32 = 0xffffffff

    var length = len(data)//2
    var extra = len(data) % 2

    for i in range(start = 0, end = length *2 , step = 2):
        
        var val: UInt32 = ((data[i + 1].cast[DType.uint32]() << 8) | data[i].cast[DType.uint32]())
        var index = crc32 ^ val
        crc32 =  table[int((index >> 8).cast[DType.uint8]())] ^ table[256 + int(index.cast[DType.uint8]())] ^ (crc32 >> 16)
    

    for i in range(2*length, 2*length + extra ):
        var index = (crc32 ^ data[i].cast[DType.uint32]()) & 0xff
        crc32 = table[int(index)] ^ (crc32 >> 8)


    return crc32^0xffffffff

In [16]:
print(hex(CRC32(test_list)))
print(hex(CRC32_inv(test_list)))
print(hex(CRC32_table_2_byte(test_list, little_endian_table_2_byte)))
print(hex(CRC32_table_2_byte_2(test_list, little_endian_table_2_byte)))

0x89ba07cb
0x89ba07cb
0x89ba07cb
0x89ba07cb


In [17]:
fn run_32[data: List[SIMD[DType.uint8, 1]] ]():
    var a =  CRC32(data)
    benchmark.keep(a)


fn run_32_inv[data: List[SIMD[DType.uint8, 1]] ]():
    var a = CRC32_inv(data)
    benchmark.keep(a)


fn run_32_table[data: List[SIMD[DType.uint8, 1]], table: List[UInt32]]():
    var a = CRC32_table(data, table)
    benchmark.keep(a)


fn run_32_table_2_byte_2[data: List[SIMD[DType.uint8, 1]], table: List[UInt32]]():
    var a = CRC32_table_2_byte_2(data, table)
    benchmark.keep(a)


fn fill_table() -> List[UInt32]:

    var table = List[UInt32](capacity=256)

    for i in range(256):

        var key = UInt8(i)
        var crc32 = key.cast[DType.uint32]()
        for i in range(8):
            if crc32 & 1 != 0:
                crc32 = (crc32 >> 1) ^ 0xedb88320
            else:
                crc32 = crc32 >> 1

        table[i] = crc32
    return table

fn fill_table_2_byte() -> List[UInt32]:

    var table = List[UInt32](capacity=512)
    table.size = 512

    for i in range(256):

        var key = UInt8(i)
        var crc32 = key.cast[DType.uint32]()
        for i in range(8):
            if crc32 & 1 != 0:
                crc32 = (crc32 >> 1) ^ 0xedb88320
            else:
                crc32 = crc32 >> 1

        table[i] = crc32

    for i in range(256, 512):
        var crc32 = table[i-256]
        table[i] = (crc32 >> 8) ^ table[int(crc32.cast[DType.uint8]())]
    return table


fn bench():

    
    alias fill_size = 2**16
    alias g = UnsafePointer[SIMD[DType.uint8, 1]].alloc(fill_size)
    rand[DType.uint8](ptr =  g, size = fill_size)


    alias rand_list = List[SIMD[DType.uint8,1]](data = g, size = fill_size, capacity = fill_size)


    print(len(rand_list))

    var report = benchmark.run[run_32[rand_list]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report)

    var report_2 = benchmark.run[run_32_inv[rand_list]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report_2)

 
    
    alias little_endian_table = fill_table()

    var report_3 = benchmark.run[run_32_table[rand_list, little_endian_table]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report_3)

    alias little_endian_table_2_byte = fill_table_2_byte()

    var report_4 = benchmark.run[run_32_table_2_byte_2[rand_list, little_endian_table_2_byte]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report_4)

    print(100 * (report/report_2 -1))
    print(100 * (report_2/report_3 -1))
    print(100 * (report/report_3 -1))
    print(100 * (report/report_4 -1))
    #report.print_full()
    #report_2.print_full()


bench()

65536
0.43900643368373965
0.41288061071614351
0.11452483935
0.079099824589347931
6.3276943236159067
260.5162103343684
283.3285741114114
455.00304325941448


In [18]:
fn fill_table_n_byte[n: Int]() -> List[UInt32]:

    var table = List[UInt32](capacity=256*n)
    table.size = 256*n

    for i in range(256*n):

        if i < 256: 
            var key = UInt8(i)
            var crc32 = key.cast[DType.uint32]()
            for i in range(8):
                if crc32 & 1 != 0:
                    crc32 = (crc32 >> 1) ^ 0xedb88320
                else:
                    crc32 = crc32 >> 1

            table[i] = crc32
        else:
            var crc32 = table[i-256]
            var index = int(crc32.cast[DType.uint8]())
            table[i] = (crc32 >> 8) ^ table[index]
            
    return table

In [19]:
#var t = fill_table_n_byte[1]()
alias t2 = fill_table_n_byte[2]()

In [20]:
fn CRC32_table_4_byte(owned data: List[SIMD[DType.uint8, 1]], table: List[UInt32]) -> SIMD[DType.uint32, 1]:
    var crc32: UInt32 = 0xffffffff

    var size = 4

    #assert_true(len(data) % 2 == 0, "List must be divisible by two for 16-bit optimization.")
    var length = len(data)//size
    var extra = len(data) % size



    for i in range(start = 0, end = length*size, step = size):
        
        var val: UInt32 =  (data[i + 3].cast[DType.uint32]() << 24) | (data[i + 2].cast[DType.uint32]() << 16) | (data[i + 1].cast[DType.uint32]() << 8) | data[i].cast[DType.uint32]()
        var index = crc32 ^ val.cast[DType.uint32]()
        crc32 = table[0*256 + int((index >> 24).cast[DType.uint8]())] ^
                table[1*256 + int((index >> 16).cast[DType.uint8]())] ^
                table[2*256 + int((index >> 8).cast[DType.uint8]())] ^
                table[3*256 + int((index >> 0).cast[DType.uint8]())] 
    
    for i in range(size*length, size*length + extra ):
        var index = (crc32 ^ data[i].cast[DType.uint32]()) & 0xff
        crc32 = table[int(index)] ^ (crc32 >> 8)


    return crc32^0xffffffff

In [21]:



var little_endian_table_4_byte  = fill_table_n_byte[4]()

In [22]:
print(hex(CRC32(test_list)))
print(hex(CRC32_inv(test_list)))
print(hex(CRC32_table_2_byte(test_list, little_endian_table_2_byte)))
print(hex(CRC32_table_2_byte_2(test_list, little_endian_table_2_byte)))
print(hex(CRC32_table_4_byte(test_list, little_endian_table_4_byte)))

0x89ba07cb
0x89ba07cb
0x89ba07cb
0x89ba07cb
0x89ba07cb


In [23]:
fn run_32[data: List[SIMD[DType.uint8, 1]] ]():
    var a =  CRC32(data)
    benchmark.keep(a)


fn run_32_inv[data: List[SIMD[DType.uint8, 1]] ]():
    var a = CRC32_inv(data)
    benchmark.keep(a)


fn run_32_table[data: List[SIMD[DType.uint8, 1]], table: List[UInt32]]():
    var a = CRC32_table(data, table)
    benchmark.keep(a)


fn run_32_table_2_byte_2[data: List[SIMD[DType.uint8, 1]], table: List[UInt32]]():
    var a = CRC32_table_2_byte_2(data, table)
    benchmark.keep(a)

fn run_32_table_4_byte[data: List[SIMD[DType.uint8, 1]], table: List[UInt32]]():
    var a = CRC32_table_4_byte(data, table)
    benchmark.keep(a)



fn bench():

    
    alias fill_size = 2**16
    alias g = UnsafePointer[SIMD[DType.uint8, 1]].alloc(fill_size)
    rand[DType.uint8](ptr =  g, size = fill_size)


    alias rand_list = List[SIMD[DType.uint8,1]](data = g, size = fill_size, capacity = fill_size)


    print(len(rand_list))

    var report = benchmark.run[run_32[rand_list]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report)

    var report_2 = benchmark.run[run_32_inv[rand_list]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report_2)

 
    
    alias little_endian_table = fill_table_n_byte[1]()

    var report_3 = benchmark.run[run_32_table[rand_list, little_endian_table]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report_3)

    alias little_endian_table_2_byte = fill_table_n_byte[2]()

    var report_4 = benchmark.run[run_32_table_2_byte_2[rand_list, little_endian_table_2_byte]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report_4)

    alias little_endian_table_4_byte = fill_table_n_byte[4]()

    var report_5 = benchmark.run[run_32_table_4_byte[rand_list, little_endian_table_4_byte]](max_runtime_secs=5
    ).mean(benchmark.Unit.ms)
    print(report_5)

    print(100 * (report/report_2 -1))
    print(100 * (report_2/report_3 -1))
    print(100 * (report/report_3 -1))
    print(100 * (report/report_4 -1))
    print(100 * (report/report_5 -1))
    #report.print_full()
    #report_2.print_full()


bench()

65536
0.43836233177172057
0.41147309900820794
0.11464432815
0.079255872110735801
0.041411089979022667
6.5348701599995174
258.91273964276616
282.36722116611799
453.09760664704254
958.56265071452776
