In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1)  # 512 -> 256
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1)  # 256 -> 128
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1)  # 128 -> 64
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)  # 64 -> 32
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(128 * 32 * 32, 1024)  # Fully connected layer
        self.fc2 = nn.Linear(1024, 256)  # Fully connected layer

    def forward(self, x):
        x = torch.tanh(self.conv1(x))  # 512 -> 256
        x = torch.tanh(self.conv2(x))  # 256 -> 128
        x = torch.tanh(self.conv3(x))  # 128 -> 64
        x = torch.tanh(self.conv4(x))  # 64 -> 32
        x = self.flatten(x)
        x = torch.tanh(self.fc1(x))  # Fully connected layer
        x = torch.tanh(self.fc2(x))  # Fully connected layer
        return x  # Final compressed representation


class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(256, 1024)
        self.fc2 = nn.Linear(1024, 128 * 32 * 32)
        self.unflatten = nn.Unflatten(1, (128, 32, 32))
        self.convt1 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)  # 32 -> 64
        self.convt2 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1)  # 64 -> 128
        self.convt3 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, padding=1, output_padding=1)  # 128 -> 256
        self.convt4 = nn.ConvTranspose2d(in_channels=16, out_channels=3, kernel_size=3, stride=2, padding=1, output_padding=1)  # 256 -> 512

        # Additional outputs for intermediate resolutions
        self.out_128 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=1)  # Output at 128x128
        self.out_256 = nn.Conv2d(in_channels=16, out_channels=3, kernel_size=1)  # Output at 256x256

    def forward(self, x):
        x = torch.tanh(self.fc1(x))  # Fully connected layer
        x = torch.tanh(self.fc2(x))  # Fully connected layer
        x = self.unflatten(x)
        x = torch.tanh(self.convt1(x))  # 32 -> 64
        x = torch.tanh(self.convt2(x))  # 64 -> 128

        out_128 = torch.tanh(self.out_128(x))  # 128x128 output

        x = torch.tanh(self.convt3(x))  # 128 -> 256
        out_256 = torch.tanh(self.out_256(x))  # 256x256 output

        x = torch.tanh(self.convt4(x))  # 256 -> 512
        out_512 = torch.tanh(x)  # 512x512 output

        return out_128, out_256, out_512


class AutoEncoderMRL(nn.Module):
    def __init__(self):
        super(AutoEncoderMRL, self).__init__()
        self.encoder = Encoder()  # Downsampling: 512 -> 32
        self.decoder = Decoder()  # Upsampling: 32 -> 512

    def forward(self, x):
        emb = self.encoder(x)  # Get compressed representation
        out_128, out_256, out_512 = self.decoder(emb)  # Reconstruct at multiple resolutions
        return out_128, out_256, out_512, emb


In [None]:
from torchmetrics.functional import peak_signal_noise_ratio as psnr

def mrl_loss(outputs, target_128, target_256, target_512, weights):
    """
    Compute the MRL loss as a weighted sum of PSNR losses for 128x128, 256x256, and 512x512 resolutions.

    :param out_128: Reconstructed output at 128x128.
    :param out_256: Reconstructed output at 256x256.
    :param out_512: Reconstructed output at 512x512.
    :param target_128: Ground truth image at 128x128.
    :param target_256: Ground truth image at 256x256.
    :param target_512: Ground truth image at 512x512.
    :param weights: List of weights for each resolution [w_128, w_256, w_512].
    :return: Combined loss value.
    """
    # Compute PSNR for each resolution
    out_128, out_256, out_512 = outputs
    loss_128 = -psnr(out_128, target_128, data_range=1.0)
    loss_256 = -psnr(out_256, target_256, data_range=1.0)
    loss_512 = -psnr(out_512, target_512, data_range=1.0)

    print(f"PSNR Loss at 128x128: {-loss_128:.4f}, 256x256: {-loss_256:.4f}, 512x512: {-loss_512:.4f}")

    # Weighted sum of losses
    total_loss = weights[0] * loss_128 + weights[1] * loss_256 + weights[2] * loss_512
    return total_loss / 3  # Average the weighted loss


In [None]:
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

# Custom Dataset for loading images from a single folder
class ImageDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        # Sort image files to ensure consistent order
        self.image_files = sorted([f for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))])
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.folder_path, self.image_files[idx])
        image = Image.open(img_path)  # Open the image
        if self.transform:
            image = self.transform(image)
        return image  # Return the transformed image


# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),          # Convert image to PyTorch tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])

# Folder paths for different resolutions
folder_path_512 = 'data/train/512'
folder_path_256 = 'data/train/256'
folder_path_128 = 'data/train/128'

# Create datasets for each resolution
dataset_512 = ImageDataset(folder_path_512, transform=transform)
dataset_256 = ImageDataset(folder_path_256, transform=transform)
dataset_128 = ImageDataset(folder_path_128, transform=transform)

# Define dataloaders
batch_size = 32
dataloader_512 = DataLoader(dataset_512, batch_size=batch_size, shuffle=False)
dataloader_256 = DataLoader(dataset_256, batch_size=batch_size, shuffle=False)
dataloader_128 = DataLoader(dataset_128, batch_size=batch_size, shuffle=False)
