# Noise Statistics Estimation (Blind-spot Size Determination)

This jupyter notebook provides a step-by-step guide for determining the appropriate blind-spot size for a given dataset.

In [None]:
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from ipywidgets import interactive, FloatLogSlider

## Load an image from the dataset for estimation

In [None]:
# Path to an image in the dataset you want to denoise
image = cv.imread("./Datasets/Au/Tif_convert0081.tif", cv.IMREAD_UNCHANGED)

## Define functions for noise statistics estimation

In [None]:
def numpy_normalize_zscore_normalize(image):
    image = np.array(image,dtype=np.float32)
    vec = image.flatten()
    std = np.std(vec)
    mean = np.mean(vec)
    image = (image-mean)/std
    return image

def compute_correlation_for_distance(noise_img, distance):
    """Compute average correlation for a given relative distance"""
    correlations = []
    
    # Shift the image by the distance in x and y directions and compute correlation
    for dx in range(distance+1):
        for dy in range(distance+1):
            if dx == 0 and dy == 0:
                correlations.append(1.0)
                continue
            shifted = np.roll(noise_img, shift=(dx, dy), axis=(0, 1))
            corr_values = [np.corrcoef(noise_img[:, col], shifted[:, col])[0, 1] for col in range(noise_img.shape[1])]
            correlations.append(np.mean(corr_values))

    return np.mean(correlations)

def relative_distance_correlation(image_tensor, max_distance):
    """Compute correlations for various relative distances"""
    correlations = []
    for distance in range(1, max_distance + 1):
        correlation = compute_correlation_for_distance(image_tensor, distance)
        correlations.append(correlation)
    return correlations

def spatial_correlation_heatmap(image_tensor, max_distance):
    """Compute 2D spatial correlation heatmap"""
    size = 2 * max_distance + 1
    heatmap = np.zeros((size, size))

    def corr(x, y):
        x = x - np.mean(x)
        y = y - np.mean(y)
        return np.sum(x * y) / (np.sqrt(np.sum(x**2)) * np.sqrt(np.sum(y**2)))

    for dx in range(-max_distance, max_distance + 1):
        for dy in range(-max_distance, max_distance + 1):
            if dx == 0 and dy == 0:
                heatmap[dx + max_distance, dy + max_distance] = 1.0
                continue
            shifted = np.roll(image_tensor, shift=(dx, dy), axis=(0, 1))
            correlation = corr(image_tensor.flatten(), shifted.flatten())
            heatmap[dx + max_distance, dy + max_distance] = correlation

    return heatmap

def compute_random_pixel_correlation(img, num_pairs=int(1e8)):
    """Compute correlation between random pixel pairs in an image."""
    
    # Get random pixel coordinates
    ys, xs = img.shape[:2]
    random_x1 = np.random.randint(0, xs, num_pairs)
    random_y1 = np.random.randint(0, ys, num_pairs)
    random_x2 = np.random.randint(0, xs, num_pairs)
    random_y2 = np.random.randint(0, ys, num_pairs)
    
    # Extract pixel values for these coordinates
    values1 = img[random_y1, random_x1]
    values2 = img[random_y2, random_x2]
    
    # Compute correlation
    correlation = np.corrcoef(values1, values2)[0, 1]
    
    return correlation

## Compute noise correlation and visualize the results

In [None]:
image = numpy_normalize_zscore_normalize(image)
max_distance_to_check = 15
max_distance = 15
heatmap = spatial_correlation_heatmap(image, max_distance_to_check)
correlations = relative_distance_correlation(image, max_distance_to_check)

def plot_correlation(vmin, vmax):
    fig, ax = plt.subplots(1, 2, figsize=(15, 6))

    # 1D correlation plot
    ax[1].semilogy(range(1, max_distance_to_check + 1), correlations, '-o', label='Noise Correlation')
    ax[1].set_xticks(np.arange(0, max_distance_to_check, max_distance_to_check//5))
    ax[1].set_xlabel('Relative Distance')
    ax[1].set_ylabel('Correlation (Log Scale)')
    ax[1].set_title('Correlation of Noise vs. Relative Distance (Log Scale)')
    ax[1].grid(True)
    ax[1].legend()

    size = 2 * max_distance + 1
    # 2D correlation heatmap
    cax = ax[0].imshow(np.abs(heatmap), cmap='Blues', origin='lower', norm=mcolors.LogNorm(vmin=vmin, vmax=vmax))
    cbar = fig.colorbar(cax, ax=ax[0], orientation='vertical')
    ax[0].set_title('2D Spatial Correlation Heatmap')
    ax[0].set_xlabel('dx')
    ax[0].set_ylabel('dy')
    ax[0].set_xticks(np.arange(0, size, size // 5))
    ax[0].set_yticks(np.arange(0, size, size // 5))
    ax[0].set_xticklabels(np.arange(-max_distance, max_distance + 1, size // 5))
    ax[0].set_yticklabels(np.arange(-max_distance, max_distance + 1, size // 5))

    plt.tight_layout()
    plt.show()

# Define sliders for vmin and vmax
vmin_slider = FloatLogSlider(value=1e-1, base=10, min=-3, max=0, step=0.05, description='vmin')
vmax_slider = FloatLogSlider(value=1, base=10, min=-1, max=0, step=0.05, description='vmax')

# Create interactive widget
interactive_plot = interactive(plot_correlation, vmin=vmin_slider, vmax=vmax_slider)
interactive_plot