In [None]:
import cv2
import numpy as np
from skimage import io, color
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb



## Step 1: Load and Slice Images into Patches

In [None]:
def load_and_slice_image(image_path, patch_size=224):
    img = load_img(image_path)
    img_array = img_to_array(img)

    patches = []
    img_height, img_width, _ = img_array.shape

    for i in range(0, img_height, patch_size):
        for j in range(0, img_width, patch_size):
            patch = img_array[i:i+patch_size, j:j+patch_size]
            if patch.shape[0] == patch_size and patch.shape[1] == patch_size:
                patches.append(patch)

    return patches


## Step 2: Convert Patches to Lab Color Space

In [None]:
def convert_to_lab(patches):
    lab_patches = []
    for patch in patches:
        lab_patch = color.rgb2lab(patch / 255.0)  # Convert RGB to Lab
        lab_patches.append(lab_patch)

    return lab_patches

def prepare_data_for_training(lab_patches):
    L = []
    ab = []

    for lab_patch in lab_patches:
        L.append(lab_patch[:,:,0])  # L channel
        ab.append(lab_patch[:,:,1:])  # ab channels

    L = np.array(L)
    ab = np.array(ab)

    # Normalize the data to [-1, 1]
    L = (L - 50) / 50.0  # L channel normalization
    ab = ab / 128.0  # ab channels normalization

    return L[..., np.newaxis], ab


## Step 3: Combine and Load Data

In [None]:
def load_and_preprocess_image(image_path):
    patches = load_and_slice_image(image_path)
    lab_patches = convert_to_lab(patches)
    L, ab = prepare_data_for_training(lab_patches)
    return L, ab

# Example usage
image_path = 'path_to_your_image.jpg'
L, ab = load_and_preprocess_image(image_path)

print("L shape:", L.shape)
print("ab shape:", ab.shape)


# Loading Data in Batches

In [None]:
import os
from tensorflow.keras.utils import Sequence

class ImageDataGenerator(Sequence):
    def __init__(self, image_paths, batch_size=32, patch_size=224):
        self.image_paths = image_paths
        self.batch_size = batch_size
        self.patch_size = patch_size

    def __len__(self):
        return int(np.ceil(len(self.image_paths) / float(self.batch_size)))

    def __getitem__(self, idx):
        batch_paths = self.image_paths[idx * self.batch_size:(idx + 1) * self.batch_size]
        L_batch = []
        ab_batch = []

        for image_path in batch_paths:
            patches = load_and_slice_image(image_path, patch_size=self.patch_size)
            lab_patches = convert_to_lab(patches)
            L, ab = prepare_data_for_training(lab_patches)
            L_batch.append(L)
            ab_batch.append(ab)

        return np.concatenate(L_batch, axis=0), np.concatenate(ab_batch, axis=0)

# Example usage
image_directory = 'path_to_your_image_directory'
image_paths = [os.path.join(image_directory, fname) for fname in os.listdir(image_directory)]

data_gen = ImageDataGenerator(image_paths, batch_size=32)

# Fetch a batch
L_batch, ab_batch = data_gen[0]
print("L_batch shape:", L_batch.shape)
print("ab_batch shape:", ab_batch.shape)


## Encoder

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import ResNet50_Weights, DenseNet121_Weights

class EnsembleEncoder(nn.Module):
    def __init__(self):
        super(EnsembleEncoder, self).__init__()

        # Load pre-trained ResNet50 and DenseNet121
        self.resnet50 = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.densenet121 = models.densenet121(weights=DenseNet121_Weights.DEFAULT)

        # Remove the fully connected layers
        self.resnet50 = nn.Sequential(*list(self.resnet50.children())[:-2])
        # self.densenet121 = nn.Sequential(*list(self.densenet121.children())[:-1])
        self.densenet121.classifier = nn.Identity()  # Remove the fully connected layer


        # Custom layers for fusion
        self.conv1x1_resnet50 = nn.ModuleList([
            nn.Conv2d(256, 128, kernel_size=1),
            nn.Conv2d(512, 256, kernel_size=1),
            nn.Conv2d(1024, 512, kernel_size=1),
            nn.Conv2d(2048, 1024, kernel_size=1)
        ])

        self.conv1x1_densenet121 = nn.ModuleList([
            nn.Conv2d(256, 128, kernel_size=1),
            nn.Conv2d(512, 256, kernel_size=1),
            nn.Conv2d(1024, 512, kernel_size=1),
            nn.Conv2d(1024, 1024, kernel_size=1)
        ])

        # Fusion blocks
        self.fusion_blocks = nn.ModuleList([
            self.fusion_block(128, 128),
            self.fusion_block(256, 256),
            self.fusion_block(512, 512),
            self.fusion_block(1024, 1024)
        ])

    # Fusion block
    def fusion_block(self, in_channels_resnet, in_channels_densenet):
        return nn.Sequential(
            nn.Conv2d(in_channels_resnet + in_channels_densenet, in_channels_resnet, kernel_size=1),
            nn.BatchNorm2d(in_channels_resnet),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Forward pass through ResNet50
        resnet_features = []
        resnet_input = x  # The input grayscale image, repeated 3 times
        for i, layer in enumerate(self.resnet50.children()):
            resnet_input = layer(resnet_input)
            # print('resnet input', resnet_input.shape)
            if i in [4, 5, 6, 7]:  # Extract features after specific layers
                resnet_features.append(self.conv1x1_resnet50[i-4](resnet_input))

        # Forward pass through DenseNet121
        densenet_features = []
        idx = 0
        densenet_input = x  # The same input grayscale image
        for i, layer in enumerate(self.densenet121.features):
            # print(layer)
            densenet_input = layer(densenet_input)
            # print('densenet input', densenet_input.shape)
            if i in [ 4, 6, 8, 11]:  # After each dense block
                densenet_features.append(self.conv1x1_densenet121[idx](densenet_input))
                idx += 1

        # Fusion of features from both networks
        # print(f"ResNet features: {[f.shape for f in resnet_features]}")
        # print(f"DenseNet features: {[f.shape for f in densenet_features]}")
        fused_features = []
        for i in range(4):
            # fused = (resnet_features[i] + densenet_features[i]) / 2 # average fusion
            # fused, _ = torch.max(torch.stack([resnet_features[i], densenet_features[i]]), dim=0)  # Max Fusion
            fused = torch.cat((resnet_features[i], densenet_features[i]), dim=1)
            fused = self.fusion_blocks[i](fused)
            fused_features.append(fused)

        return fused_features


## Decoder

In [None]:
import torch
import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        # Decoder block 1: Takes input from Fusion Block 4
        self.decode1 = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 7x7 -> 14x14
        )

        # Decoder block 2: Takes input from Decoder Block 1 + Fusion Block 3 (512 + 512 channels)
        self.decode2 = nn.Sequential(
            nn.Conv2d(512 + 512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 14x14 -> 28x28
        )

        # Decoder block 3: Takes input from Decoder Block 2 + Fusion Block 2 (256 + 256 channels)
        self.decode3 = nn.Sequential(
            nn.Conv2d(256 + 256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 28x28 -> 56x56
        )

        # Decoder block 4: Takes input from Decoder Block 3 + Fusion Block 1 (128 + 128 channels)
        self.decode4 = nn.Sequential(
            nn.Conv2d(128 + 128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 56x56 -> 112x112
        )

        # Final decoder block: Reduce to 2 channels (ab channels)
        self.decode5 = nn.Sequential(
            nn.Conv2d(64, 2, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),  # Output in the range [-1, 1]
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 112x112 -> 224x224
        )

    def forward(self, features_7x7, features_14x14, features_28x28, features_56x56):
        x = self.decode1(features_7x7)  # Output of Fusion Block 4
        x = torch.cat([x, features_14x14], dim=1)  # Skip connection with Fusion Block 3
        x = self.decode2(x)  # Output of Decoder Block 1

        x = torch.cat([x, features_28x28], dim=1)  # Skip connection with Fusion Block 2
        x = self.decode3(x)  # Output of Decoder Block 2

        x = torch.cat([x, features_56x56], dim=1)  # Skip connection with Fusion Block 1
        x = self.decode4(x)  # Output of Decoder Block 3

        output = self.decode5(x)  # Final output

        return output


## Training

In [None]:
%%time
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Freeze the encoder parameters
for param in encoder.parameters():
    param.requires_grad = False

# Define the model, loss function, and optimizer
model = ColorizationModel(encoder, decoder).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    # Progress bar for training
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Training)")

    for i, (L_batch, ab_batch) in enumerate(train_bar):
        L, ab = L_batch.to(device), ab_batch.to(device)
        L = L.repeat(1, 3, 1, 1)  # Repeat grayscale image to 3 channels

        # Forward pass
        optimizer.zero_grad()
        output = model(L)
        loss = criterion(output, ab)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Update progress bar
        train_bar.set_postfix(loss=f"{running_loss/(i+1):.4f}")

    # Validation phase
    model.eval()
    val_loss = 0.0

    # Progress bar for validation
    val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Validation)")

    with torch.no_grad():
        for i, (L_batch, ab_batch) in enumerate(val_bar):
            L, ab = L_batch.to(device), ab_batch.to(device)
            L = L.repeat(1, 3, 1, 1)  # Repeat grayscale image to 3 channels
            output = model(L)
            loss = criterion(output, ab)
            val_loss += loss.item()

            # Update progress bar
            val_bar.set_postfix(loss=f"{val_loss/(i+1):.4f}")

    # Print statistics and save the best model
    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {running_loss/len(train_loader):.4f}, Validation Loss: {val_loss/len(val_loader):.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_colorization_model.pth')

print("Training complete.")

In [None]:
torch.save(model, 'color_model.pth')

## Prediction

In [None]:
from torch.utils.data import Subset
# Create indices for the test subset
test_indices = list(range(17000, 18000))  # Select 1000 images for testing
test_dataset = Subset(dataset, test_indices)

# Create a DataLoader for the test dataset
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False,)

In [None]:
!kaggle datasets download -d requiemonk/sentinel12-image-pairs-segregated-by-terrain

In [None]:
!unzip sentinel12-image-pairs-segregated-by-terrain.zip
!rm sentinel12-image-pairs-segregated-by-terrain.zip

In [None]:
import os
import shutil
opt = []
sar = []
root_dir = './v_2'
for dir in os.listdir(root_dir):
  path = os.path.join(root_dir, dir)
  s1, s2  = os.listdir(path)
  for file in os.listdir(os.path.join(path, s1)):
    if file.endswith('.png'):
      sar.append(os.path.join(path, s1, file))
  for file in os.listdir(os.path.join(path, s2)):
    if file.endswith('.png'):
      opt.append(os.path.join(path, s2, file))
opt = sorted(opt)
sar = sorted(sar)
print(len(opt), len(sar))

In [None]:
import os
import cv2
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from skimage.color import rgb2lab

class ColorizationDatasetNew(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        color_img = cv2.imread(image_path)
        gray_img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE).astype("float32")  # Convert to grayscale
        gray_img = cv2.resize(gray_img, (224, 224))
        gray_img = gray_img / 255. - 1
        # change shape to channel shape
        gray_img = gray_img.reshape(224, 224, 1)
        gray_img = transforms.ToTensor()(gray_img)
        # print(gray_img.shape)

        # Apply desired transformations
        color_img = cv2.resize(color_img, (224, 224))
        # gray_img_3channel = cv2.merge([gray_img, gray_img, gray_img])

        # Convert to Lab color space
        L_channel, ab_channels = rgb_to_lab(color_img)  # Assuming rgb_to_lab is defined
        # print(L_channel.shape)

        # if self.transform:
        #     L_channel = self.transform(L_channel)
        #     ab_channels = self.transform(ab_channels)

        return gray_img, ab_channels


# Define transformations (if needed)
transform = transforms.Compose([
    transforms.ToTensor(),
    # Add other transformations as required
])

# Create the dataset
dataset_new = ColorizationDatasetNew(opt[:4000], transform=transform)

# Create the DataLoader
test_set = DataLoader(dataset_new, batch_size=32, shuffle=False, num_workers=2)

In [None]:
gray = 'despeckled_image.png'
sar = 'img.png'
opt = "img_opt.png"

gray_img = cv2.imread(gray)
sar_img = cv2.imread(sar)
opt_img = cv2.imread(opt)
# resize to 224
gray_img = cv2.resize(gray_img, (224, 224))
sar_img = cv2.resize(sar_img, (224, 224))
opt_img = cv2.resize(opt_img, (224, 224))
# convert to lab
L_channel, ab_channels = rgb_to_lab(gray_img)
L_channel_sar, ab_channels_sar = rgb_to_lab(sar_img)
L_channel_opt, ab_channels_opt = rgb_to_lab(opt_img)
# add batch dim
L_channel = L_channel.unsqueeze(0)
L_channel_sar = L_channel_sar.unsqueeze(0)
L_channel_opt = L_channel_opt.unsqueeze(0)
ab_channels_opt = ab_channels_opt.unsqueeze(0)
ab_channels_sar = ab_channels_sar.unsqueeze(0)
# repeat L channel to 3 channel
L_channel = L_channel.repeat(1, 3, 1, 1)
L_channel_sar = L_channel_sar.repeat(1, 3, 1, 1)
L_channel_opt = L_channel_opt.repeat(1, 3, 1, 1)



# Load your trained model
model = ColorizationModel(encoder, decoder)
model.load_state_dict(torch.load('best_colorization_model.pth', map_location=torch.device('cpu')))

# Set the model to evaluation mode
model.eval()

# Make predictions
with torch.no_grad():  # Disable gradient calculation
    predicted_ab = model(L_channel)

In [None]:
!pip install PyWavelets

In [None]:
from skimage import img_as_ubyte, io, img_as_float
from skimage.restoration import (
    denoise_nl_means,
    denoise_tv_chambolle,
    denoise_wavelet,
    denoise_bilateral
)
from skimage.restoration import estimate_sigma
import numpy as np
from skimage.util import random_noise
import cv2

def denoise_image(image):
    # Convert image to float
    image = img_as_float(image)

    # Estimate sigma for NLM
    sigma_est = np.mean(estimate_sigma(image))

    # Apply NLM and its combinations
    denoised_image_nlm = denoise_nl_means(image, h=1.0 * sigma_est, fast_mode=True, patch_size=7, patch_distance=11, channel_axis=None)

    # Combinations with NLM
    denoised_image_nlm_tv = denoise_tv_chambolle(denoise_nl_means(image, h=1.0 * sigma_est, fast_mode=True, patch_size=7, patch_distance=11, channel_axis=None), weight=0.1)

    denoised_image_nlm_wavelet = denoise_wavelet(denoise_nl_means(image, h=1.0 * sigma_est, fast_mode=True, patch_size=7, patch_distance=11, channel_axis=None), method='BayesShrink', mode='soft')

    denoised_image_nlm_bilateral = denoise_bilateral(denoise_nl_means(image, h=1.0 * sigma_est, fast_mode=True, patch_size=7, patch_distance=11, channel_axis=None), sigma_color=0.05, sigma_spatial=15)

    # Combinations of the above
    denoised_image_nlm_tv_wavelet = denoise_wavelet(denoise_tv_chambolle(denoise_nl_means(image, h=1.0 * sigma_est, fast_mode=True, patch_size=7, patch_distance=11, channel_axis=None), weight=0.1), method='BayesShrink', mode='soft')

    denoised_image_nlm_tv_bilateral = denoise_bilateral(denoise_tv_chambolle(denoise_nl_means(image, h=1.0 * sigma_est, fast_mode=True, patch_size=7, patch_distance=11, channel_axis=None), weight=0.1), sigma_color=0.05, sigma_spatial=15)

    # Return one of the denoised images as required
    return denoised_image_nlm_tv_bilateral

# Example usage:
image = io.imread('img.png', as_gray=True)
denoised_image = denoise_image(image)

In [None]:
denoised_image = denoised_image * 255
denoised_image = denoised_image.astype(np.uint8)
# save the image
cv2.imwrite('despeckled_image.png', denoised_image)

In [None]:
# Get a batch from the test loader
dataiter = iter(test_loader)
L_batch, ab_batch = next(dataiter)
L_batch, ab_batch = L_batch.to(device), ab_batch.to(device)
L_batch = L_batch.repeat(1, 3, 1, 1)

# for idx, (L_batch, ab_batch) in enumerate(test_set):
#     L_batch, ab_batch = L_batch.to(device), ab_batch.to(device)
    # L_batch = L_batch.repeat(1, 3, 1, 1)
# Get predictions
# Load your trained model
# model1 = ColorizationModel(encoder, decoder)
# model1.load_state_dict(torch.load('best_colorization_model.pth'))

# Set the model to evaluation mode
# model1.eval()
with torch.no_grad():
   predicted_ab = model(L_batch)
    # if idx == 0:
    #   break
    # break
# print(predicted_ab[6])
# print(ab_batch[6])
# De-normalize ab channels
# predicted_ab = predicted_ab * 128.0

In [None]:
L_batch = L_channel[:, 0, :, :]
L_batch = L_batch.unsqueeze(1)  # Add channel dimension

# Assuming L_batch is in the range [-1, 1] and ab_batch is in the range [-1, 1]
L_batch = (L_batch + 1) * 50  # Add 1 and multiply by 50
predicted_ab = predicted_ab * 110  # Multiply by 110
ab_batch = ab_channels_opt * 110

# Combine L and ab channels
predicted_lab = torch.cat([L_batch, predicted_ab], dim=1)
real_lab = torch.cat([L_batch, ab_batch], dim=1)

# Assuming predicted_lab is on the CPU and has shape (batch_size, 3, height, width)
predicted_lab = predicted_lab.cpu().numpy()
real_lab = real_lab.cpu().numpy()

In [None]:
predicted_ab[0][1]

In [None]:
import numpy as np
from skimage.color import lab2rgb
import matplotlib.pyplot as plt
import cv2


# Iterate over the batch
for i in range(predicted_lab.shape[0]):
    # Extract the Lab image for the current batch element
    lab_image = predicted_lab[i]
    real_img = real_lab[i]

    # Transpose to (height, width, 3) for skimage
    lab_image = lab_image.transpose(1, 2, 0)
    real_img = real_img.transpose(1, 2, 0)

    # Convert Lab to RGB using skimage.color.lab2rgb
    rgb_image = lab2rgb(lab_image)
    real_rgb = lab2rgb(real_img)

    # Do something with the rgb_image (e.g., display it, save it, etc.)
    plt.subplot(1, 2, 1)
    plt.imshow(real_rgb)
    plt.axis('off')
    plt.title('Real Image')
    plt.subplot(1, 2, 2)
    plt.imshow(rgb_image)
    plt.title('Predicted Image')
    plt.axis('off')
    plt.show()
    # break

In [None]:
import matplotlib.pyplot as plt

# Convert Lab to RGB
predicted_rgb_images = []
for lab_image in predicted_lab:
    lab_image_np = lab_image.transpose(1, 2, 0).astype('uint8')
    rgb_image_np = cv2.cvtColor(lab_image_np, cv2.COLOR_Lab2RGB)
    predicted_rgb_images.append(rgb_image_np)

# Display images
fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(20, 5))
for i, img in enumerate(predicted_rgb_images):
    axes[i].imshow(img)
    axes[i].axis('off')
    if i==3:
        break

plt.show()

## Evaluation Metrics:

Importing Libraries:

In [None]:
# Install the necessary libraries if not already installed
pip install scikit-image lpips opencv-python
import os
import numpy as np
import cv2  # for reading images
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import lpips
import torch

# Initialize LPIPS model
lpips_fn = lpips.LPIPS(net='alex')  # Can also use 'vgg' or 'squeeze'

Loading DIV2K images:

In [None]:
# Paths to the dataset (update these according to your dataset location)
HR_PATH = 'path_to_DIV2K_HR'  # Path to high-resolution images
SR_PATH = 'path_to_SR_images'  # Path to your super-resolved or reconstructed images

# List of image filenames (assuming .png or .jpg)
hr_images = sorted([f for f in os.listdir(HR_PATH) if f.endswith('.png') or f.endswith('.jpg')])
sr_images = sorted([f for f in os.listdir(SR_PATH) if f.endswith('.png') or f.endswith('.jpg')])

Calculate PSNR and SSIM:

In [None]:
def calculate_psnr_ssim(hr_img, sr_img):
    # PSNR
    psnr_value = psnr(hr_img, sr_img, data_range=hr_img.max() - hr_img.min())

    # SSIM
    ssim_value = ssim(hr_img, sr_img, multichannel=True)

    return psnr_value, ssim_value

Calculate LPIPS:

In [None]:
def calculate_lpips(hr_img, sr_img):
    # Convert to PyTorch tensors (assuming images are in range [0, 255])
    hr_tensor = torch.tensor(hr_img).permute(2, 0, 1).unsqueeze(0).float() / 255.0
    sr_tensor = torch.tensor(sr_img).permute(2, 0, 1).unsqueeze(0).float() / 255.0

    # LPIPS score
    lpips_value = lpips_fn(hr_tensor, sr_tensor)

    return lpips_value.item()

Full Process to compute metrics:

In [None]:
# Loop through each pair of images
psnr_scores = []
ssim_scores = []
lpips_scores = []

for hr_img_file, sr_img_file in zip(hr_images, sr_images):
    # Load images using OpenCV (BGR -> RGB conversion)
    hr_img = cv2.imread(os.path.join(HR_PATH, hr_img_file))
    sr_img = cv2.imread(os.path.join(SR_PATH, sr_img_file))

    hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB)
    sr_img = cv2.cvtColor(sr_img, cv2.COLOR_BGR2RGB)

    # Resize if necessary to ensure both images have the same dimensions
    if hr_img.shape != sr_img.shape:
        sr_img = cv2.resize(sr_img, (hr_img.shape[1], hr_img.shape[0]))

    # Calculate PSNR and SSIM
    psnr_value, ssim_value = calculate_psnr_ssim(hr_img, sr_img)

    # Calculate LPIPS
    lpips_value = calculate_lpips(hr_img, sr_img)

    # Store results
    psnr_scores.append(psnr_value)
    ssim_scores.append(ssim_value)
    lpips_scores.append(lpips_value)

# Calculate average scores
avg_psnr = np.mean(psnr_scores)
avg_ssim = np.mean(ssim_scores)
avg_lpips = np.mean(lpips_scores)

print(f"Average PSNR: {avg_psnr:.4f}")
print(f"Average SSIM: {avg_ssim:.4f}")
print(f"Average LPIPS: {avg_lpips:.4f}")

Load CIFAR-10 dataset:

In [None]:
# Install required libraries if not already installed
pip install torch torchvision
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
# Define transforms (normalization, convert to tensor)
transform = transforms.Compose([transforms.ToTensor()])

# Load the CIFAR-10 test set
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# DataLoader for testset
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# Classes in CIFAR-10
classes = testset.classes

Compute MAE for Images (Pixel-wise MAE):

In [None]:
def compute_mae_images(gt_img, pred_img):
    """
    Compute MAE between ground truth (gt_img) and predicted image (pred_img).
    Both inputs should be PyTorch tensors.
    """
    # Ensure they are on the same device (e.g., CPU) and convert to float
    gt_img = gt_img.float()
    pred_img = pred_img.float()
    
    # Compute the absolute difference and mean it
    mae_value = torch.mean(torch.abs(gt_img - pred_img))
    
    return mae_value.item()  # Return as a scalar

Example Usage for single batch:

In [None]:
# Example usage (assuming ground truth and predicted images are in testloader)
for i, data in enumerate(testloader):
    images, labels = data  # Ground truth images and labels
    pred_images = model(images)  # Assuming a model generates reconstructed images

    # Compute MAE for the batch
    batch_mae = compute_mae_images(images, pred_images)
    
    print(f"Batch {i+1} MAE: {batch_mae}")

Compute MAE for Labels (for classifications):

In [None]:
def compute_mae_labels(gt_labels, pred_labels):
    """
    Compute MAE between ground truth (gt_labels) and predicted labels (pred_labels).
    Both inputs should be PyTorch tensors.
    """
    # Ensure the predicted labels are of the same shape and type
    mae_value = torch.mean(torch.abs(gt_labels - pred_labels.float()))
    return mae_value.item()

Example Usage for Label Predictions:

In [None]:
# Assuming you have a trained classification model
correct_labels = []
predicted_labels = []

# Loop through the test dataset
for images, labels in testloader:
    # Predict using the model (assuming output is logits)
    outputs = model(images)
    
    # Get the predicted class (highest probability)
    _, predicted = torch.max(outputs, 1)
    
    # Store labels and predictions
    correct_labels.append(labels)
    predicted_labels.append(predicted)

# Convert lists to tensors
correct_labels = torch.cat(correct_labels)
predicted_labels = torch.cat(predicted_labels)

# Compute MAE for labels
mae_labels = compute_mae_labels(correct_labels, predicted_labels)
print(f"MAE for labels: {mae_labels}")