In [2]:
import numpy as np
from numba import njit

In [5]:
@njit
def bilateral_filter(image, window, sigma_spatial, sigma_spectral):
    out_image = np.zeros_like(image, dtype=np.uint8)
    half_window_size = window // 2
    height, width = image.shape

    w = lambda i, j, k, l: np.exp(
        -((i - k) ** 2 + (j - l) ** 2 / (2 * sigma_spatial ** 2))
        / -((image[i, j] - image[k, l]) / (2 * sigma_spectral ** 2))
    )

    for h in range(height):
        for w in range(width):
            sum_ = 0
            weight_sum = 0

            for i in range(-half_window_size, half_window_size + 1):
                for j in range(-half_window_size, half_window_size + 1):
                    k = h + i
                    l = w + j

                    if k >= height:
                        k -= height

                    if l >= width:
                        l -= width

                    weight = w(i, j, k, l)
                    sum_ += image[k, l] * weight
                    weight_sum += weight

            out_image[h, w] = int(np.round(sum // weight_sum))
    return out_image

In [7]:
@njit
def joint_bilateral_filter(image, guide, window, sigma_spatial, sigma_spectral):
    out_image = np.zeros_like(image, dtype=np.uint8)
    half_window_size = window // 2
    height, width = image.shape

    w = lambda i, j, k, l: np.exp(
        -((i - k) ** 2 + (j - l) ** 2 / (2 * sigma_spatial ** 2))
        / -((guide[i, j] - guide[k, l]) / (2 * sigma_spectral ** 2))
    )

    for h in range(height):
        for w in range(width):
            sum_ = 0
            weight_sum = 0

            for i in range(-half_window_size, half_window_size + 1):
                for j in range(-half_window_size, half_window_size + 1):
                    k = h + i
                    l = w + j

                    if k >= height:
                        k -= height
                    if l >= width:
                        l -= width

                    weight = w(i, j, k, l)
                    sum_ += image[k, l] * weight
                    weight_sum += weight

            out_image[h, w] = int(np.round(sum // weight_sum))
    return out_image