# Alignment and variance difference
**This notebook shows how to align two images and analyze the local variance difference between them**

In [110]:
%reload_ext autoreload
%autoreload 2

import numpy as np
import rawpy
import matplotlib.pyplot as plt
from raw_utils import *
from alignment import *

## Load images and pack raw images

In [None]:
clear = rawpy.imread('first-filter-dataset/IMG_7957.CR2').postprocess()
diffused = rawpy.imread('first-filter-dataset/IMG_7955_diff.CR2').postprocess()

clear_results, diffused_results = align_and_crop_raw_images(
    'first-filter-dataset/IMG_7957.CR2',
    'first-filter-dataset/IMG_7955_diff.CR2'
)

# _raw is the rawpy object with the aligned and cropped data
# _mosaic is an np array with the aligned and cropped data, still as mosaic
# _channels is the packed version of the _mosaic so it has the 4 color channels seperated. 
clear_raw, clear_mosaic, clear_channels = clear_results.values()
diffused_raw, diffused_mosaic, diffused_channels = diffused_results.values()

img_clear_aligned_cropped = clear_raw.postprocess()
img_diffused_aligned_cropped = diffused_raw.postprocess()

# Plot images
fig, axs = plt.subplots(2,2, figsize=(15,10), sharey=True, sharex=True)
axs[0,0].imshow(clear)
axs[0,1].imshow(diffused)
axs[1,0].imshow(img_clear_aligned_cropped)
axs[1,1].imshow(img_diffused_aligned_cropped)

# Column titles
axs[0, 0].set_title("clear")
axs[0, 1].set_title("Diffused")

# Row labels (use fig.text to center vertically)
axs[0,0].set_ylabel("Original", size='large')
axs[1,0].set_ylabel("Aligned & Cropped", rotation=90, size='large')

fig.suptitle("Postprocessed images before and after alignment", size=20)
plt.show()

## Calculate local variances and the variance difference between the images

In [84]:
def compute_local_variance_single_channel(image, kernel_size=5):
    image = image.astype(np.float32)
    # cv2 blur calculates local mean using a box filter
    mean = cv2.blur(image, (kernel_size, kernel_size))
    mean_sq = cv2.blur(image**2, (kernel_size, kernel_size))
    variance = mean_sq - mean**2
    return mean_sq - mean**2

def compute_local_variance(image, kernel_size=5):
    R_variance = compute_local_variance_single_channel(image[:, :, 0], kernel_size)
    G1_variance = compute_local_variance_single_channel(image[:, :, 1], kernel_size)
    B_variance = compute_local_variance_single_channel(image[:, :, 2], kernel_size)
    G2_variance = compute_local_variance_single_channel(image[:, :, 3], kernel_size)

    return np.stack((R_variance, G1_variance, B_variance, G2_variance), axis=2)


In [None]:
kernel_size = 9
clear_variance = compute_local_variance(clear_channels, kernel_size=kernel_size)
diffused_variance = compute_local_variance(diffused_channels, kernel_size=kernel_size)


variance_difference = np.abs(diffused_variance - clear_variance)
channels = ['R', 'G1', 'B', 'G2']
fig, axs = plt.subplots(len(channels),3, figsize=(15,len(channels)*3.5), sharey=True)
for channel_index, channel_name in enumerate(channels):
    axs[channel_index, 0].set_ylabel(channel_name, rotation=np.pi, size='large')
    if channel_index==0:
        axs[channel_index, 0].set_title('Clear image variance')
        axs[channel_index, 1].set_title('Diffused image variance')
        axs[channel_index, 2].set_title('Variance difference')
    axs[channel_index, 0].imshow(clear_variance[:,:,channel_index], cmap="coolwarm")
    axs[channel_index, 1].imshow(diffused_variance[:,:,channel_index], cmap="coolwarm")
    axs[channel_index, 2].imshow(variance_difference[:, :,channel_index], cmap="coolwarm")


fig.tight_layout(rect=[0, 0, 1, 0.95])
fig.suptitle('Variances in Clear and Diffused Photos', fontsize=16)

plt.show()


## Visualize variance per intensity

In [88]:
def create_array_per_pair(image, variance_difference):
    image = image.flatten()
    variance_difference = variance_difference.flatten()
    if image.shape != variance_difference.shape:
        raise ValueError("Inputs must have the same shape after flattening.")
    paired_array = np.stack((image, variance_difference), axis=1)
    return paired_array

def average_y_per_x(paired_array, threshold=0):
    x = paired_array[:, 0]
    y = paired_array[:, 1]
    unique_x, inverse_indices = np.unique(x, return_inverse=True)
    sum_y = np.bincount(inverse_indices, weights=y)
    count_y = np.bincount(inverse_indices)
    avg_y = sum_y / count_y
    # Apply threshold
    mask = count_y >= threshold
    filtered_x = unique_x[mask]
    filtered_avg_y = avg_y[mask]

    return np.stack((filtered_x, filtered_avg_y), axis=1)

def average_y_per_x_binned(paired_array, num_bins=100, threshold=0):
    x = paired_array[:, 0]
    y = paired_array[:, 1]

    # Create bins
    x_min, x_max = x.min(), x.max()
    bins = np.linspace(x_min, x_max, num_bins + 1)

    # Assign each x to a bin
    bin_indices = np.digitize(x, bins) - 1  # shift to 0-based index

    # Remove out-of-range values
    valid_mask = (bin_indices >= 0) & (bin_indices < num_bins)
    bin_indices = bin_indices[valid_mask]
    y = y[valid_mask]

    # Compute average y per bin
    sum_y = np.bincount(bin_indices, weights=y, minlength=num_bins)
    count_y = np.bincount(bin_indices, minlength=num_bins)
    avg_y = np.divide(sum_y, count_y, out=np.zeros_like(sum_y), where=count_y > 0)

    # Apply threshold
    mask = count_y >= threshold
    bin_centers = (bins[:-1] + bins[1:]) / 2
    filtered_x = bin_centers[mask]
    filtered_avg_y = avg_y[mask]

    return np.stack((filtered_x, filtered_avg_y), axis=1)

In [107]:
def plot_scatter(paired_array, ax, xlabel="", ylabel="", title=""):
    x = paired_array[:, 0]
    y = paired_array[:, 1]
    ax.scatter(x, y, s=1, alpha=0.8)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_title(title)
    ax.grid(True)

def plot_avg_variance_difference_by_pixel_1_channel(original, variance_difference, channel, threshold=0, ax=None):
    variance_difference_by_pixel_compared_to_original = create_array_per_pair(original[:, :, 0], variance_difference[:,:,channel])
    # avg_variance_by_pixel_compared_to_original = average_y_per_x(variance_difference_by_pixel_compared_to_original, threshold=threshold)
    avg_variance_by_pixel_compared_to_original = average_y_per_x_binned(variance_difference_by_pixel_compared_to_original, num_bins=500, threshold=threshold)
    if ax is None:
        fig, ax = plt.subplots()
    plot_scatter(avg_variance_by_pixel_compared_to_original, ax, "Pixel intensity", "Avg variance difference", "Avg variance difference by pixel intensity")
    

In [None]:
fig, ax = plt.subplots(1,4,figsize=(20,5), sharey=True)
channels = ['R', 'G1', 'B', 'G2']
for channel_index, channel_name in enumerate(channels):
    plot_avg_variance_difference_by_pixel_1_channel(original=clear_channels, variance_difference=variance_difference, channel=channel_index, threshold=0, ax=ax[channel_index])
    ax[channel_index].set_title(channel_name)
# plt.figure(figsize=(20,5))

plt.show()

## The variance differences of two completely unrelated images

In [None]:
def readjpg(path):
    img_bgr =cv2.imread(path)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    return img_rgb
img1 = readjpg("first-filter-dataset/IMG_7957.JPG")
img2 = readjpg("first-dataset/first-dataset-JPG/IMG_7782_DxO.jpg")


height = min(img1.shape[0], img2.shape[0])
width = min(img1.shape[1], img2.shape[1])

print(img2.shape)
print(img1.shape)
print(height, width)
img1 = img1[:height, :width, :]
img2 = img2[:height, :width, :]


def compute_local_variance_rgb(image, kernel_size=5):
    R_variance = compute_local_variance_single_channel(image[:, :, 0], kernel_size)
    G_variance = compute_local_variance_single_channel(image[:, :, 1], kernel_size)
    B_variance = compute_local_variance_single_channel(image[:, :, 2], kernel_size)

    return np.stack((R_variance, G_variance, B_variance), axis=2)

variances1 = compute_local_variance_rgb(img1)
variances2 = compute_local_variance_rgb(img2)
diff_variances = np.abs(variances1 - variances2)
plot_avg_variance_difference_by_pixel_1_channel(original=img1, variance_difference=diff_variances, channel=0)