In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import cv2
import seaborn as sns
#!pip install rawpy
import rawpy
from utils.raw_utils import pack_raw
from utils.dataset_navigation import get_image_paths

In [None]:
def load_image(file_path):
    raw = rawpy.imread(file_path).raw_image_visible
    return pack_raw(raw)

In [44]:
paths = get_image_paths()
long_exp_paths = paths["long_exp"]
filtered_exp_paths = paths["filter_long_exp"]
idx = 22
image_diff = load_image(filtered_exp_paths[idx])
image_org = load_image(long_exp_paths[idx])
# Downsample the images
factor = 4

image_diff = cv2.resize(image_diff, (image_diff.shape[1] // factor, image_diff.shape[0] // factor), interpolation=cv2.INTER_LINEAR)
image_org = cv2.resize(image_org, (image_org.shape[1] // factor, image_org.shape[0] // factor), interpolation=cv2.INTER_LINEAR)
image_diff = image_diff / image_diff.max()
image_org = image_org / image_org.max()
print(image_diff.shape)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

(357, 536, 4)
cpu


In [95]:
import torch
import torch.nn.functional as F

from variance_brightness_analysis import get_brightness_to_std_difference_splines

def fitted_func_multi_channel(x, blur_intensity):
    """
    This function takes in the pixel brightness and returns the sigma value for each channel.
    Args:
        x (torch.Tensor): Input tensor with shape (B, C, H, W)
    """
    # Calculate the sigma values for each channel using the provided splines
    splines = get_brightness_to_std_difference_splines()
    sigma = []
    for i in range(x.shape[1]):
        # Get the spline for the current channel
        spline = splines[i]
        # Calculate the sigma values using the spline
        sigma_channel = torch.tensor(spline(x[:, i, :, :].cpu().numpy()), device=x.device)
        sigma.append(sigma_channel)
    # Stack the sigma values to create a tensor with shape (B, C, H, W) 
    sigma = torch.stack(sigma, dim=1)
    return sigma * blur_intensity


def fitted_func_multi_channel_avg(x, blur_intensity):
    """
    This function takes in the pixel brightness and returns the sigma value for each channel.
    Args:
        x (torch.Tensor): Input tensor with shape (B, C, H, W)
    """
    # Calculate the sigma values for each channel using the provided splines
    splines = get_brightness_to_std_difference_splines()
    sigma = []
    for i in range(x.shape[1]):
        # Get the spline for the current channel
        spline = splines[i]
        # Calculate the sigma values using the spline
        sigma_channel = torch.tensor(spline(x[:, i, :, :].cpu().numpy()), device=x.device)
        sigma.append(sigma_channel)
    sigma = torch.stack(sigma, dim=1)

    # If use_avg is True, average the sigma values across channels
    sigma = sigma.mean(dim=1, keepdim=True)
    # Make into 4 channels again
    sigma = sigma.repeat(1, x.shape[1], 1, 1)
    return sigma * blur_intensity


def gaussian_kernel(sigma, k_size=3):
    """
    Generate Gaussian kernels for each pixel based on sigma values.

    Args:
        sigma (torch.Tensor): Sigma values with shape (B, C, H, W)
        k_size (int): Kernel size (default: 3)

    Returns:
        torch.Tensor: Gaussian kernels with shape (B, C, H, W, k_size, k_size)
    """
    device = sigma.device
    B, C, H, W = sigma.shape
    radius = k_size // 2

    # Create kernel grid
    y, x = torch.meshgrid(
        torch.arange(-radius, radius + 1, dtype=torch.float32, device=device),
        torch.arange(-radius, radius + 1, dtype=torch.float32, device=device),
        indexing='ij'
    )

    # Reshape for broadcasting
    x = x.view(1, 1, 1, 1, k_size, k_size)  # (1, 1, 1, 1, k, k)
    y = y.view(1, 1, 1, 1, k_size, k_size)
    sigma = sigma.unsqueeze(-1).unsqueeze(-1)  # (B, C, H, W, 1, 1)

    # Compute Gaussian weights
    coeff = 1.0 / (2 * torch.pi * sigma**2)
    exponent = -(x**2 + y**2) / (2 * sigma**2)
    weights = coeff * torch.exp(exponent)

    # Normalize kernels
    weights_sum = weights.sum(dim=(-2, -1), keepdim=True)
    weights = weights / weights_sum
    print(weights.shape)
    return weights

def adaptive_gaussian_conv2d(img, k_size=3, blur_intensity=10):
    """
    Perform spatially adaptive Gaussian convolution.

    Args:
        img (torch.Tensor): Input image with shape (B, C, H, W)
        k_size (int): Kernel size (default: 3)

    Returns:
        torch.Tensor: Convolved output with shape (B, C, H, W)
    """
    img = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)
    B, C, H, W = img.shape


    # sigma = fitted_func_multi_channel_avg(img, blur_intensity)
    brightness = 0.8
    print("Image mean", img.mean())
    img[img>brightness] = img[img>brightness]*1000000
    sigma = torch.where(img>brightness, 1000, 0.001)
    print("Image mean after thresholding", img.mean())
    print(sigma.mean())
    print("Sigma shape", sigma.shape)

    # Generate Gaussian kernels
    kernels = gaussian_kernel(sigma, k_size).to(img.device)  # (B, C, H, W, k, k)
    print(kernels.shape)
    # print(kernels[0, 0, 0, 0])

    # Unfold input image into patches
    pad = k_size // 2
    unfolded = F.unfold(img, kernel_size=k_size, padding=pad)  # (B, C*k*k, H*W)
    unfolded = unfolded.view(B, C, k_size*k_size, H, W)      # (B, C, k*k, H, W)
    unfolded = unfolded.permute(0, 1, 3, 4, 2)               # (B, C, H, W, k*k)

    # Reshape kernels and multiply with patches
    kernels_flat = kernels.view(B, C, H, W, -1)              # (B, C, H, W, k*k)
    output = (unfolded * kernels_flat).sum(dim=-1)           # (B, C, H, W)
    
    # Clip output to [0, 1]
    output = torch.clamp(output, 0, 1)
    return output

In [96]:
torch.cuda.empty_cache()

In [97]:

# print(blurred_image.min(), blurred_image.max())
#torch.cuda.empty_cache()

In [98]:
def postprocess(image):
    rgb = demosaic_bilinear(unpack_raw(image))
    # White balance gray world
    # rgb = rgb / np.mean(rgb, axis=(0, 1), keepdims=True)
    # channel means
    mu = rgb.mean(axis=(0,1))              # [μ_R, μ_G, μ_B]
    mu_gray = mu.mean()                  # gray reference
    scales = mu_gray / mu                # [s_R, s_G, s_B]
    # apply scales
    rgb = rgb * scales[None,None,:]

    rgb = np.clip(rgb, 0, 1)
    # Apply gamma correction
    gamma = 2.2
    rgb = np.power(rgb, 1 / gamma)
    # Normalize to [0, 1]
    rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min())
    return rgb

In [None]:
from utils.raw_utils import demosaic_bilinear
from utils.raw_utils import unpack_raw
torch.cuda.empty_cache()


blurred_image = adaptive_gaussian_conv2d(image_org, k_size=31, blur_intensity=100)[0].permute(1,2,0).cpu().numpy()
blurred_image_rgb = postprocess(blurred_image)

image_org_rgb = postprocess(image_org)
image_diff_rgb = postprocess(image_diff)
x_start = 0
x_end = image_org_rgb.shape[1]
y_start = 0
y_end = image_org_rgb.shape[0]

plt.figure(figsize=(18, 6))

img_max = image_org_rgb[y_start:y_end, x_start:x_end ].max()
plt.subplot(1, 4, 1)
plt.imshow(image_org_rgb[y_start:y_end, x_start:x_end ]/img_max)
plt.title("Original Image")
plt.axis("off")


plt.subplot(1, 4, 2)
plt.imshow(np.clip(blurred_image_rgb/img_max, 0, 1))
plt.title("Blurred Image")
plt.axis("off")

plt.subplot(1, 4, 3)
plt.imshow(image_diff_rgb[y_start:y_end, x_start:x_end ]/img_max)
plt.title("Filtered Image")
plt.axis("off")


plt.subplot(1, 4, 4)
plt.imshow(image_org_rgb[y_start:y_end, x_start:x_end ] - blurred_image_rgb[y_start:y_end, x_start:x_end], cmap="coolwarm")
plt.title("Difference orig, blur")
plt.axis("off")

print(image_diff.shape)
plt.tight_layout()
plt.show()

SyntaxError: closing parenthesis ')' does not match opening parenthesis '[' (2215299239.py, line 37)

In [73]:
diff = ((blurred_image - image_org[1200:2100, 1500:2400] / 255.0)[:, :, 0])
plt.imshow(diff, cmap="coolwarm")
plt.colorbar()
diff

ValueError: operands could not be broadcast together with shapes (357,536,4) (0,0,4) 