In [None]:
import cv2
import numpy as np
from skimage import io, color
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from skimage.color import rgb2lab, lab2rgb



## Step 1: Load and Slice Images into Patches

In [None]:
def load_and_slice_image(image_path, patch_size=224):
    img = load_img(image_path)
    img_array = img_to_array(img)

    patches = []
    img_height, img_width, _ = img_array.shape

    for i in range(0, img_height, patch_size):
        for j in range(0, img_width, patch_size):
            patch = img_array[i:i+patch_size, j:j+patch_size]
            if patch.shape[0] == patch_size and patch.shape[1] == patch_size:
                patches.append(patch)

    return patches


## Step 2: Convert Patches to Lab Color Space

In [None]:
def convert_to_lab(patches):
    lab_patches = []
    for patch in patches:
        lab_patch = color.rgb2lab(patch / 255.0)  # Convert RGB to Lab
        lab_patches.append(lab_patch)

    return lab_patches

def prepare_data_for_training(lab_patches):
    L = []
    ab = []

    for lab_patch in lab_patches:
        L.append(lab_patch[:,:,0])  # L channel
        ab.append(lab_patch[:,:,1:])  # ab channels

    L = np.array(L)
    ab = np.array(ab)

    # Normalize the data to [-1, 1]
    L = (L - 50) / 50.0  # L channel normalization
    ab = ab / 128.0  # ab channels normalization

    return L[..., np.newaxis], ab


## Encoder

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import ResNet50_Weights, DenseNet121_Weights

class EnsembleEncoder(nn.Module):
    def __init__(self):
        super(EnsembleEncoder, self).__init__()

        # Load pre-trained ResNet50 and DenseNet121
        self.resnet50 = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.densenet121 = models.densenet121(weights=DenseNet121_Weights.DEFAULT)

        # Remove the fully connected layers
        self.resnet50 = nn.Sequential(*list(self.resnet50.children())[:-2])
        # self.densenet121 = nn.Sequential(*list(self.densenet121.children())[:-1])
        self.densenet121.classifier = nn.Identity()  # Remove the fully connected layer


        # Custom layers for fusion
        self.conv1x1_resnet50 = nn.ModuleList([
            nn.Conv2d(256, 128, kernel_size=1),
            nn.Conv2d(512, 256, kernel_size=1),
            nn.Conv2d(1024, 512, kernel_size=1),
            nn.Conv2d(2048, 1024, kernel_size=1)
        ])

        self.conv1x1_densenet121 = nn.ModuleList([
            nn.Conv2d(256, 128, kernel_size=1),
            nn.Conv2d(512, 256, kernel_size=1),
            nn.Conv2d(1024, 512, kernel_size=1),
            nn.Conv2d(1024, 1024, kernel_size=1)
        ])

        # Fusion blocks
        self.fusion_blocks = nn.ModuleList([
            self.fusion_block(128, 128),
            self.fusion_block(256, 256),
            self.fusion_block(512, 512),
            self.fusion_block(1024, 1024)
        ])

    # Fusion block
    def fusion_block(self, in_channels_resnet, in_channels_densenet):
        return nn.Sequential(
            nn.Conv2d(in_channels_resnet + in_channels_densenet, in_channels_resnet, kernel_size=1),
            nn.BatchNorm2d(in_channels_resnet),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Forward pass through ResNet50
        resnet_features = []
        resnet_input = x  # The input grayscale image, repeated 3 times
        for i, layer in enumerate(self.resnet50.children()):
            resnet_input = layer(resnet_input)
            # print('resnet input', resnet_input.shape)
            if i in [4, 5, 6, 7]:  # Extract features after specific layers
                resnet_features.append(self.conv1x1_resnet50[i-4](resnet_input))

        # Forward pass through DenseNet121
        densenet_features = []
        idx = 0
        densenet_input = x  # The same input grayscale image
        for i, layer in enumerate(self.densenet121.features):
            # print(layer)
            densenet_input = layer(densenet_input)
            # print('densenet input', densenet_input.shape)
            if i in [ 4, 6, 8, 11]:  # After each dense block
                densenet_features.append(self.conv1x1_densenet121[idx](densenet_input))
                idx += 1

        # Fusion of features from both networks
        # print(f"ResNet features: {[f.shape for f in resnet_features]}")
        # print(f"DenseNet features: {[f.shape for f in densenet_features]}")
        fused_features = []
        for i in range(4):
            # fused = (resnet_features[i] + densenet_features[i]) / 2 # average fusion
            # fused, _ = torch.max(torch.stack([resnet_features[i], densenet_features[i]]), dim=0)  # Max Fusion
            fused = torch.cat((resnet_features[i], densenet_features[i]), dim=1)
            fused = self.fusion_blocks[i](fused)
            fused_features.append(fused)

        return fused_features


## Decoder

In [None]:
import torch
import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        # Decoder block 1: Takes input from Fusion Block 4
        self.decode1 = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 7x7 -> 14x14
        )

        # Decoder block 2: Takes input from Decoder Block 1 + Fusion Block 3 (512 + 512 channels)
        self.decode2 = nn.Sequential(
            nn.Conv2d(512 + 512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 14x14 -> 28x28
        )

        # Decoder block 3: Takes input from Decoder Block 2 + Fusion Block 2 (256 + 256 channels)
        self.decode3 = nn.Sequential(
            nn.Conv2d(256 + 256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 28x28 -> 56x56
        )

        # Decoder block 4: Takes input from Decoder Block 3 + Fusion Block 1 (128 + 128 channels)
        self.decode4 = nn.Sequential(
            nn.Conv2d(128 + 128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 56x56 -> 112x112
        )

        # Final decoder block: Reduce to 2 channels (ab channels)
        self.decode5 = nn.Sequential(
            nn.Conv2d(64, 2, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),  # Output in the range [-1, 1]
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)  # 112x112 -> 224x224
        )

    def forward(self, features_7x7, features_14x14, features_28x28, features_56x56):
        x = self.decode1(features_7x7)  # Output of Fusion Block 4
        x = torch.cat([x, features_14x14], dim=1)  # Skip connection with Fusion Block 3
        x = self.decode2(x)  # Output of Decoder Block 1

        x = torch.cat([x, features_28x28], dim=1)  # Skip connection with Fusion Block 2
        x = self.decode3(x)  # Output of Decoder Block 2

        x = torch.cat([x, features_56x56], dim=1)  # Skip connection with Fusion Block 1
        x = self.decode4(x)  # Output of Decoder Block 3

        output = self.decode5(x)  # Final output

        return output
