In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import cv2
import seaborn as sns
from raw_utils import demosaic_bilinear
from raw_utils import unpack_raw
#!pip install rawpy
import rawpy
from raw_utils import pack_raw
from dataset_navigation import get_image_paths
from ipywidgets import interact, FloatSlider


In [17]:

def load_image(file_path):
    raw = rawpy.imread(file_path).raw_image_visible
    return pack_raw(raw)


def kernel_blur(img, kernel):
    return cv2.filter2D(img, -1, kernel)


def sum_of_gaussians_kernel(kernel_size=3, sigma1=0.01, sigma2=10, ratio=20):
    radius = kernel_size // 2
    x = np.arange(-radius, radius + 1)
    xx, yy = np.meshgrid(x, x, indexing='ij')

    g1 = np.exp(-(xx**2 + yy**2) / (2 * sigma1**2))
    g2 = np.exp(-(xx**2 + yy**2) / (2 * sigma2**2))

    g_sum = g1 + ratio * g2
    g_sum /= np.sum(g_sum)  # normalize to sum to 1

    # Create 1D kernel for visualization
    d1_kernel = g_sum[radius, :]  # take the middle row as 1D kernel
    return g_sum.astype(np.float32), d1_kernel.astype(np.float32)

def get_gray_world_constants(image):
    """
    Calculate gray world constants for white balancing.
    
    Args:
        image (numpy.ndarray): Input image in RGB format.
        
    Returns:
        tuple: Mean values for each channel (R, G, B).
    """
    demosaiced = demosaic_bilinear(unpack_raw(image))
    mu = demosaiced.mean(axis=(0, 1))  # [μ_R, μ_G, μ_B]
    mu_gray = mu.mean()           # gray reference
    scales = mu_gray / mu          # [s_R, s_G, s_B]
    return scales

class PostProcessor:
    def __init__(self, gray_world_constants, gamma=2.2):
        self.gray_world_constants = gray_world_constants
        self.gamma = gamma

    def postprocess(self, image, gray_world_constants=None):
        rgb = demosaic_bilinear(unpack_raw(image))

        scales = self.gray_world_constants if gray_world_constants is None else gray_world_constants
        # apply scales
        rgb = rgb * scales[None,None,:]

        rgb = np.clip(rgb, 0, 1)
        # Apply gamma correction
        rgb = np.power(rgb, 1 / self.gamma)
        # Normalize to [0, 1]
        # rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min())
        return rgb

def extend(image, limit=1, factor=10):
    """
    Extend values above a certain limit by a factor.
    """
    extended_image = image.copy()
    # Build a mask for pixels where any channel exceeds the limit
    pixel_mask = (extended_image >= limit).any(axis=-1)  # shape (H, W)
    # Expand mask to all channels and apply
    extended_image[pixel_mask] *= factor
    return extended_image



paths = get_image_paths()
long_exp_paths = paths["long_exp"]
filtered_exp_paths = paths["filter_long_exp"]
idx = 22 # Handbag
idx = 0 # House
idx = 13
image_diff = load_image(filtered_exp_paths[idx])
image_org = load_image(long_exp_paths[idx])
# Downsample the images
factor = 4

image_diff = cv2.resize(image_diff, (image_diff.shape[1] // factor, image_diff.shape[0] // factor), interpolation=cv2.INTER_LINEAR)
image_org = cv2.resize(image_org, (image_org.shape[1] // factor, image_org.shape[0] // factor), interpolation=cv2.INTER_LINEAR)
image_diff = image_diff / image_diff.max()
image_org = image_org / image_org.max()


In [18]:
pp = PostProcessor(get_gray_world_constants(image_diff), gamma=2.2)
image_org_rgb = pp.postprocess(image_org)
image_diff_rgb = pp.postprocess(image_diff, get_gray_world_constants(image_diff))


In [19]:
def blur_image(image, kernel_size=1000, sigma1=0.115, sigma2=10, ratio=0.00007, extend_factor=20):
    extended = extend(image, limit=1, factor=extend_factor)
    kernel, _ = sum_of_gaussians_kernel(kernel_size=kernel_size, sigma1=sigma1, sigma2=sigma2, ratio=ratio)
    blurred_image = kernel_blur(extended, kernel)
    clipped_image = np.clip(blurred_image, 0, 1)
    return clipped_image



def update(factor=1.0):
    # Replace with your image logic
    fig, ax = plt.subplots(1,3, figsize=(20, 8))
    blurred_image = blur_image(image_org, kernel_size=1000, sigma1=0.115, sigma2=10, ratio=0.00007, extend_factor=factor)
    pp = PostProcessor(get_gray_world_constants(blurred_image), gamma=2.2)
    blurred_image_rgb = pp.postprocess(blurred_image)

    ax[0].imshow(image_org_rgb)
    ax[0].set_title("Original Image")
    ax[0].axis('off')
    ax[1].imshow(blurred_image_rgb)
    ax[1].set_title("Blurred Image")
    ax[1].axis('off')
    print("Image Diff Max:", image_diff_rgb.max())
    ax[2].imshow(image_diff_rgb)
    ax[2].set_title("Filtered Image")
    ax[2].axis('off')

    plt.show()
    

interact(update, factor=FloatSlider(min=1.0, max=200.0, step=1, value=20.0));

interactive(children=(FloatSlider(value=20.0, description='factor', max=200.0, min=1.0, step=1.0), Output()), …

In [23]:
import cv2
import numpy as np
def nothing(x): pass

cv2.namedWindow('Image', cv2.WINDOW_NORMAL)
cv2.resizeWindow('Image', 1500, 400)
cv2.createTrackbar('Factor', 'Image', 10, 100, nothing)
# cv2.createTrackbar('Threshold', 'Image', 5, 100, nothing)

last_factor = -1
# last_threshold = 5

while True:
    factor = cv2.getTrackbarPos('Factor', 'Image')
    # threshold = cv2.getTrackbarPos('Threshold', 'Image') / 100.0

    if factor != last_factor:
        # Update only if something changed
        blurred_image = blur_image(
            image_org, kernel_size=1000, sigma1=0.115, sigma2=10,
            ratio=0.00007, extend_factor=factor
        )
        pp = PostProcessor(get_gray_world_constants(blurred_image), gamma=2.2)
        blurred_image_rgb = pp.postprocess(blurred_image)

        montage = np.hstack((image_org_rgb, blurred_image_rgb, image_diff_rgb))

        cv2.imshow('Image', montage)

        last_factor = factor
        # last_threshold = threshold
    # Press 'Esc' to exit !!!
    if cv2.waitKey(1) & 0xFF == 27:
        break

cv2.destroyAllWindows()

KeyboardInterrupt: 