### Comparison to Addition of Blue Tint and Noise

In this notebook, we get evaluation metrics for baseline methods. This being

- No alteration of the digital at all
- applying a blue filter
- applying Gaussian noise


In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import autorootcwd

In [None]:
import os
from tqdm import tqdm
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset

from torchmetrics import MetricCollection
from torchmetrics.image import (
    StructuralSimilarityIndexMeasure as SSIM,
    PeakSignalNoiseRatio as PSNR,
)

from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
from src.eval import PieAPP
from src.data.components import PairedDataset
import torchvision.transforms.v2 as T
from src.models import transforms as CT
from src.utils.utils import process_pair

In [None]:
metrics = MetricCollection(
    {
        "ssim": SSIM(),
        "psnr": PSNR(),
        "lpips": LPIPS(),
        "pieapp": PieAPP(),
    }
)

In [None]:
# Constants
RAW_DIR = os.getcwd()
DATA_DIR = os.path.join(RAW_DIR, 'data')

## No Alteration

We compute metrics on simply predicting the film image as the digital image. I.e. we compute metrics over the (digital, film) pairs.

In [None]:
film_paired_dir = os.path.join(DATA_DIR, "paired", "processed", "film")
digital_paired_dir = os.path.join(DATA_DIR, "paired", "processed", "digital")
digital_film_data = PairedDataset(image_dirs=(film_paired_dir, digital_paired_dir))
film_0, digital_0 = digital_film_data[0]
digital_film_subset = Subset(digital_film_data, range(6))

Let's look at the first example and compute its metrics

In [None]:
film_0_processed, digital_0_processed = process_pair(film_0, digital_0)
print(film_0_processed.shape, digital_0_processed.shape)

In [None]:
metrics_0 = metrics(film_0_processed.unsqueeze(0), digital_0_processed.unsqueeze(0))
metrics_0

Let's now iterate over all the images the dataset

In [None]:
all_metrics = {}
for film, digital in tqdm(digital_film_subset):
    film, digital = process_pair(film, digital)
    film, digital = film.unsqueeze(0), digital.unsqueeze(0)

    for metric in metrics:
        if metric not in all_metrics:
            all_metrics[metric] = []
        
        score = metrics[metric](film, digital)

        if isinstance(score, torch.Tensor):
            score = score.item()

        all_metrics[metric].append(score)

Let's now add a blue filter, Guassian blur, and a combination of the two and see what the test metrics look like

In [None]:
class Grain(nn.Module):
    def __init__(self, grain_sizes, intensities):
        """
        Initialize the GrainTransform with specified grain sizes and their respective intensities.
        
        Args:
            grain_sizes (list of int): List of grain sizes to be applied.
            intensities (list of float): List of intensities corresponding to each grain size.
                                         Should be between 0 (no grain) and 1 (full grain).
        """
        super(Grain, self).__init__()
        assert len(grain_sizes) == len(intensities), "Grain sizes and intensities lists must be of the same length"
        assert all(0 <= intensity <= 1 for intensity in intensities), "Intensities must be between 0 and 1"
        self.grain_sizes = grain_sizes
        self.intensities = intensities

    def forward(self, img):
        """
        Apply natural grain effect to the input image.
        
        Args:
            img (Tensor): The input image tensor to be transformed. Should be of shape (B, C, H, W) or (C, H, W).
            
        Returns:
            Tensor: The transformed image tensor with grain added.
        """
        if img.ndimension() == 3:
            img = img.unsqueeze(0)  # Add batch dimension if not present
        
        assert img.ndimension() == 4 and img.size(1) == 3, "Input tensor must be of shape (B, C, H, W) with 3 channels"
        
        batch_size, channels, height, width = img.size()
        grain = torch.zeros_like(img)
        
        for grain_size, intensity in zip(self.grain_sizes, self.intensities):
            grain_noise = torch.randn(batch_size, 1, height // grain_size, width // grain_size, device=img.device)
            grain_noise = F.interpolate(grain_noise, size=(height, width), mode='nearest')
            grain += grain_noise * intensity
        
        img = img + grain
        img = torch.clamp(img, 0.0, 1.0)
        
        return img


class BlueTint(nn.Module):
    def __init__(self, blue_intensity=0.5, red_shift=0.1, green_shift=0.1):
        """
        Initialize the BlueTint transform with specified intensity for blue tint and color shifts.
        
        Args:
            blue_intensity (float): The intensity of the blue filter. Should be between 0 (no blue) and 1 (full blue).
            red_shift (float): The intensity of the red shift. Should be between 0 and 1.
            green_shift (float): The intensity of the green shift. Should be between 0 and 1.
        """
        super(BlueTint, self).__init__()
        assert 0 <= blue_intensity <= 1, "Blue intensity must be between 0 and 1"
        assert 0 <= red_shift <= 1, "Red shift intensity must be between 0 and 1"
        assert 0 <= green_shift <= 1, "Green shift intensity must be between 0 and 1"
        
        self.blue_intensity = blue_intensity
        self.red_shift = red_shift
        self.green_shift = green_shift

    def forward(self, img):
        """
        Apply the BlueTint effect to the input image.
        
        Args:
            img (Tensor): The input image tensor to be transformed. Should be of shape (B, C, H, W) or (C, H, W).
            
        Returns:
            Tensor: The transformed image tensor with the BlueTint effect applied.
        """
        if img.ndimension() == 3:
            img = img.unsqueeze(0)  # Add batch dimension if not present
        
        assert img.ndimension() == 4 and img.size(1) == 3, "Input tensor must be of shape (B, C, H, W) with 3 channels"
        
        # Apply blue filter
        blue_filter = torch.tensor([1 - self.blue_intensity, 1 - self.blue_intensity, 1]).view(1, 3, 1, 1).to(img.device)
        img = img * blue_filter

        # Apply red shift
        red_channel = img[:, 0, :, :] + self.red_shift
        img[:, 0, :, :] = torch.clamp(red_channel, 0.0, 1.0)

        # Apply green shift
        green_channel = img[:, 1, :, :] + self.green_shift
        img[:, 1, :, :] = torch.clamp(green_channel, 0.0, 1.0)

        return img
    

class ColourFilter(nn.Module):
    def __init__(self, red_intensity=1.0, green_intensity=1.0, blue_intensity=1.0):
        """
        Initialize the ColourFilter with specified intensities for red, green, and blue channels.
        
        Args:
            red_intensity (float): The intensity of the red filter. Should be between 0 (no red) and 1 (full red).
            green_intensity (float): The intensity of the green filter. Should be between 0 (no green) and 1 (full green).
            blue_intensity (float): The intensity of the blue filter. Should be between 0 (no blue) and 1 (full blue).
        """
        super(ColourFilter, self).__init__()
        assert 0 <= red_intensity <= 1, "Red intensity must be between 0 and 1"
        assert 0 <= green_intensity <= 1, "Green intensity must be between 0 and 1"
        assert 0 <= blue_intensity <= 1, "Blue intensity must be between 0 and 1"
        
        self.red_intensity = red_intensity
        self.green_intensity = green_intensity
        self.blue_intensity = blue_intensity

    def forward(self, img):
        """
        Apply the color filter to the input image.
        
        Args:
            img (Tensor): The input image tensor to be transformed. Should be of shape (B, C, H, W) or (C, H, W).
            
        Returns:
            Tensor: The transformed image tensor with the color filter applied.
        """
        if img.ndimension() == 3:
            img = img.unsqueeze(0)  # Add batch dimension if not present
        
        assert img.ndimension() == 4 and img.size(1) == 3, "Input tensor must be of shape (B, C, H, W) with 3 channels"
        
        # Create the color filter mask
        color_filter = torch.tensor([self.red_intensity, self.green_intensity, self.blue_intensity]).view(1, 3, 1, 1).to(img.device)
        
        # Apply the color filter
        img = img * color_filter
        
        return img


In [None]:
grain = Grain(intensities=[0.02, 0.01, 0.005, 0.0025], grain_sizes=[1,2,4, 8])
blue_tint = BlueTint(blue_intensity=0.8, red_shift=0.12, green_shift=0.12)
blue_filter = ColourFilter(red_intensity=0.5, green_intensity=0.9, blue_intensity=1.0)
noise_and_filter = T.Compose([grain, blue_filter])
noise_and_tint = T.Compose([grain, blue_tint])

In [None]:
digital_0_noisy = grain(digital_0_processed)
CT.FromModelInput()(digital_0_noisy.squeeze())

In [None]:
digital_0_filter = blue_filter(digital_0_processed)
CT.FromModelInput()(digital_0_filter.squeeze(0))

In [None]:
digital_0_noise_and_filter = noise_and_filter(digital_0_processed)
CT.FromModelInput()(digital_0_noise_and_filter.squeeze(0))

In [None]:
CT.FromModelInput()(digital_0_processed.squeeze(0))

In [None]:
CT.FromModelInput()(film_0_processed.squeeze(0))

Let's now compute the metrics over these transformed versions.

In [None]:
noisy_metrics = {}
filter_metrics = {}
noise_and_filter_metrics = {}

for digital, film in tqdm(digital_film_subset):
    digital, film = process_pair(digital, film)
    digital, film = digital.unsqueeze(0), film.unsqueeze(0)
    digital_noisy = grain(digital)
    digital_filtered = blue_filter(digital)
    digital_noise_and_filter = noise_and_filter(digital)
    
    for metric in metrics:
        if metric not in noisy_metrics:
            noisy_metrics[metric] = []
            filter_metrics[metric] = []
            noise_and_filter_metrics[metric] = []

        
        noisy_score = metrics[metric](film, digital_noisy)
        filter_score = metrics[metric](film, digital_filtered)
        noise_and_filter_score = metrics[metric](film, digital_noise_and_filter)

        if isinstance(noisy_score, torch.Tensor):
            noisy_score = noisy_score.item()
            filter_score = filter_score.item()
            noise_and_filter_score = noise_and_filter_score.item()

        noisy_metrics[metric].append(noisy_score)
        filter_metrics[metric].append(filter_score)
        noise_and_filter_metrics[metric].append(noise_and_filter_score)

In [None]:
# Print means for each metric for each filter/noise combination
for metric in metrics:
    print(f"{metric.upper()}:")
    print(f"  Noisy: {sum(noisy_metrics[metric]) / len(noisy_metrics[metric])}")
    print(f"  Filtered: {sum(filter_metrics[metric]) / len(filter_metrics[metric])}")
    print(f"  Noise and Filtered: {sum(noise_and_filter_metrics[metric]) / len(noise_and_filter_metrics[metric])}")
    print(f" Baseline: {sum(all_metrics[metric]) / len(all_metrics[metric])}")
    print()

In [None]:
# Plot histograms of the metrics in the same plot
import matplotlib.pyplot as plt

fig, axs = plt.subplots(2, 2, figsize=(12, 12))
axs = axs.ravel()
metric_names = list(noisy_metrics.keys())
for i, metric in enumerate(metric_names):
    axs[i].hist(noisy_metrics[metric], bins=20, alpha=0.5, label='Noisy')
    axs[i].hist(filter_metrics[metric], bins=20, alpha=0.5, label='Filtered')
    axs[i].hist(noise_and_filter_metrics[metric], bins=20, alpha=0.5, label='Noisy and Filtered')
    axs[i].hist(all_metrics[metric], bins=20, alpha=0.5, label='Original')
    axs[i].set_title(metric)
    axs[i].legend()
plt.tight_layout()
plt.show()