# BM3D

Python implementation of ["An Analysis and Implementation of the BM3D Image Denoising Method" by Marc Lebrun](https://www.ipol.im/pub/art/2012/l-bm3d/?utm_source=doi).

## Modules

In [1]:
import numpy as np
from skimage import io as skio
import heapq

In [2]:
# local
import matplotlib.pyplot as plt

## Parameters

In [3]:
# 1st step
kHard = 8 #patch size
nHard = 39 #search window size --! era pra ser 39 mas nao entendi como centralizar P
NHard = 16 #max number of similar patches kept 

sigma = 30
tauHard = 5000 if sigma > 40 else 2500

gammaHard = 0 #hard thresholding for grouping --! ??? where

## Initialization

In [4]:
im = skio.imread('./lena.tif') # original image

In [5]:
def noise(im,br):
    """ Cette fonction ajoute un bruit blanc gaussier d'ecart type br
       a l'image im et renvoie le resultat"""
    imt=np.float32(im.copy())
    sh=imt.shape
    bruit=br*np.random.randn(*sh)
    imt=imt+bruit
    return imt

In [6]:
u = noise(im, sigma) # create noisy image

## The First Denoising Step


### Grouping

In [7]:
# x,y is the top-left corner of the reference patch
# doesnt work well for even window_size --change
def get_search_window(image, x, y, patch_size=kHard, window_size=nHard):
    img_h, img_w = image.shape  # image dimensions
    
    # padded image (to handle borders)
    padded_image = np.pad(image, window_size//2, mode='reflect')
    
    # adjust coordinates
    x_padded = x + window_size//2
    y_padded = y + window_size//2

    # ensure the patch defined by (x, y) fits within the image bounds
    if x < 0 or y < 0 or x + patch_size > img_w or y + patch_size > img_h:
        raise ValueError("The specified patch defined by (x, y) exceeds image boundaries.")
    
    search_window = padded_image[
        y_padded - (window_size//2 - patch_size//2):y_padded + window_size//2 + patch_size//2 +1,
        x_padded - (window_size//2 - patch_size//2):x_padded + window_size//2 + patch_size//2 +1  
    ]
    return search_window

In [8]:
def hard_thresholding(img, threshold):
    return (img < threshold) * img

In [9]:
def distance(p,q):
    return (np.linalg.norm(p-q) ** 2) / (kHard ** 2)

In [14]:
def build_3d_group(p, window, sigma, gammaHard, tauHard, N=NHard):
    closer_N_dists = []

    # assumming square patch and window
    k = p.shape[0]
    n = window.shape[0] 

    if sigma > 40:
        p = hard_thresholding(p, gammaHard * sigma)
    
    for i in range(n-k+1):
        for j in range(n-k+1):
            # get patch Q and calculate distance to ref P
            q = window[i:k+i, j:k+j]
            if sigma > 40:
                q = hard_thresholding(q, gammaHard * sigma)
            
            dist = distance(p, q)
            if dist <= tauHard:
                dist_tuple = (-dist, (i, j))  # negate distance to use max-heap
        
                if len(closer_N_dists) < N+1: # because after we will take out the first one
                    heapq.heappush(closer_N_dists, dist_tuple)
                else:
                    if dist_tuple > closer_N_dists[0]:
                        heapq.heappushpop(closer_N_dists, dist_tuple)
                        
    closer_N_dists = [(-d, idx) for d, idx in closer_N_dists]
    closer_N_dists = sorted(closer_N_dists, key=lambda x: x[0])[1:] # take out the first one (distance 0.0)

    
    group_3d = []
    for _, (i, j) in closer_N_dists:
        patch = window[i:k+i, j:k+j]
        group_3d.append(patch)
    group_3d=np.array(group_3d)
    
    return group_3d

In [16]:
def grouping_1st_step(image, sigma, kHard, nHard, gammaHard, tauHard, NHard):
    height, width = image.shape
    all_groups = []

    # iterate through all patches in the image
    for x in range(0, height - kHard + 1, kHard):
        for y in range(0, width - kHard + 1, kHard):
            
            patch = image[x:x+kHard, y:y+kHard]
            search_window = get_search_window(image, x, y, patch_size=kHard, window_size=nHard)

            group_3d = build_3d_group(patch, search_window, sigma, gammaHard, tauHard, NHard)

            if len(group_3d) > 0:
                all_groups.append((group_3d, (x,y)))

    return all_groups

In [17]:
%%time
a=grouping_1st_step(u, sigma, kHard, nHard, gammaHard, tauHard, NHard)

CPU times: total: 2min 28s
Wall time: 2min 49s


In [22]:
print(f'quantidade de 3d groups: {len(a)}')
print(a[100][1])
print(f'exemplo da estrutura de um: {a[0][0].shape}')
a[0]

quantidade de 3d groups: 10974
(4, 16)
exemplo da estrutura de um: (16, 8, 8)


(array([[[109.52249483, 161.20902822, 161.44063125, ..., 129.84449276,
          135.51302547, 103.92265316],
         [118.31419494, 148.49671213, 103.83842967, ..., 118.42754583,
          124.81407541, 112.49848665],
         [ 77.89256804, 141.11644975, 100.76731756, ..., 179.64053832,
          149.7571895 , 171.05869221],
         ...,
         [132.66082444, 143.91198748,  69.07423657, ..., 135.68488739,
          157.1794654 ,  99.62012458],
         [107.24306215, 128.77838677, 137.92471691, ..., 106.56738519,
          108.27680172, 114.4768356 ],
         [160.86511802,  86.11862739, 152.79765814, ...,  74.8854051 ,
          122.85427623,  93.29719872]],
 
        [[157.1794654 , 135.68488739, 151.9681494 , ..., 143.91198748,
          132.66082444,  93.96422565],
         [ 62.21054684, 129.24309576, 125.38090599, ..., 112.85020935,
          134.16143946,  94.93569401],
         [135.72262543, 140.64605085, 137.92849369, ..., 128.97083054,
          160.73703906, 134.1658

### --rascunho

DÚVIDA: como lidar com patches que não formam grupos? acho que nenhum Q teve distancia abaixo da threshold. por enquanto vou só ignorar eles 

In [None]:
print(a[1].shape)
print(a[1000].shape)

rever passo em grouping_1st_step

### Collaborative Filtering

### Aggregation

## The Second Denoising Step

### Grouping

### Collaborative Filtering


### Aggregation

## Results

### PSNR: Peak Signal to Noise Ratio