In [None]:
from pynq import Overlay, allocate
from time import perf_counter_ns
import matplotlib.pyplot as plt
import numpy as np
import cv2 as cv
from typing import Final
from enum import IntEnum

In [None]:
VERSION: Final[int] = 4

overlay = Overlay(F'Version{VERSION}.bit')
lif_ip = overlay.LinearImageFilter

In [None]:
def get_median_filter(size: int) -> np.ndarray:
    
    if size % 2 == 0:
        raise ValueError(F'Kernel size must be odd. {size} is even.')
    
    kernel: Final[np.ndarray] = np.ones((size, size), dtype=np.float32) / (size * size)
    return kernel.flatten()

In [None]:
class Padding(IntEnum):
    EDGE = 0x01
    REFLECT = 0x02
    ZEROS = 0x04

In [None]:
CONTROL: Final[int] = 0x00

IMAGE_OUT: Final[int] = 0x10
IMAGE_IN: Final[int] = 0x18
KERNEL: Final[int] = 0x30

ROWS: Final[int] = 0x20
COLS: Final[int] = 0x28
KERNEL_SIZE: Final[int] = 0x38

STRIDE_ROW: Final[int] = 0x40
STRIDE_COL: Final[int] = 0x48

PADDING: Final[int] = 0x50

In [None]:
image_path: Final[str] = 'lena.tif'
image: Final = cv.cvtColor(cv.imread(image_path), cv.COLOR_BGR2RGB)

In [None]:
grey_image: Final = cv.cvtColor(image, cv.COLOR_RGB2GRAY)
rows, cols = grey_image.shape

In [None]:
flat_image: Final = grey_image.flatten().astype(np.float32) / 255.0

In [None]:
KERNEL_DIM: Final[int] = 5

median_kernel: Final[np.ndarray] = get_median_filter(KERNEL_DIM)

In [None]:
image_in: Final = allocate(shape=flat_image.shape, dtype=np.float32, cacheable=0)
image_out: Final = allocate(shape=flat_image.shape, dtype=np.float32, cacheable=0)
kernel: Final = allocate(shape=median_kernel.shape, dtype=np.float32, cacheable=0)

In [None]:
image_in[:] = flat_image
kernel[:] = median_kernel

lif_ip.write(IMAGE_IN, image_in.physical_address)
lif_ip.write(IMAGE_OUT, image_out.physical_address)
lif_ip.write(KERNEL, kernel.physical_address)

lif_ip.write(ROWS, rows)
lif_ip.write(COLS, cols)
lif_ip.write(KERNEL_SIZE, KERNEL_DIM)

lif_ip.write(STRIDE_ROW, 1)
lif_ip.write(STRIDE_COL, 1)

lif_ip.write(PADDING, Padding.ZEROS.value)

In [None]:
start = perf_counter_ns()

lif_ip.write(CONTROL, (1 << 0))
while lif_ip.read(CONTROL) != (1 << 2):
    pass

stop = perf_counter_ns()

print(F'Elapsed time: {(stop - start) / 1e6} ms')

In [None]:
filtered_flat_image: Final[np.ndarray] = image_out

filtered_image: Final[np.ndarray] = (filtered_flat_image.reshape((rows, cols)) * 255.0).astype(np.uint8)

cv.imwrite('result.png', filtered_image)