In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms


class MultiFolderSAROpticalDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Root directory containing all the subfolders for each terrain type.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_pairs = []

        # Traverse all subfolders and collect image pairs
        for terrain_type in os.listdir(root_dir):
            sar_dir = os.path.join(root_dir, terrain_type, "s1")
            optical_dir = os.path.join(root_dir, terrain_type, "s2")

            sar_images = sorted(os.listdir(sar_dir))
            optical_images = sorted(os.listdir(optical_dir))

            for sar_img in sar_images:
                optical_img = sar_img.replace("s1", "s2")
                sar_path = os.path.join(sar_dir, sar_img)
                optical_path = os.path.join(optical_dir, optical_img)
                if os.path.exists(optical_path):
                    self.image_pairs.append((sar_path, optical_path))

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        sar_path, optical_path = self.image_pairs[idx]

        # Load the images
        sar_image = Image.open(sar_path).convert("L")  # SAR images are grayscale
        optical_image = Image.open(optical_path).convert(
            "RGB"
        )  # Optical images are RGB

        if self.transform:
            sar_image = self.transform(sar_image)
            optical_image = self.transform(optical_image)

        return sar_image, optical_image

In [None]:
# Define image transformations
transform = transforms.Compose(
    [
        transforms.Resize((256, 256)),  # Resize images to 256x256
        transforms.ToTensor(),  # Convert images to PyTorch tensors
        transforms.Normalize(
            (0.5,), (0.5,)
        ),  # Normalize SAR images to [-1, 1] and Optical images to [-1, 1]
    ]
)

In [None]:
# Define the path to the root directory containing the subfolders
root_dir = "data/v_2"

# Create the dataset
dataset = MultiFolderSAROpticalDataset(root_dir=root_dir, transform=transform)

# Create DataLoader for batching
train_loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)

In [None]:
import torch.nn as nn


class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x_pooled = self.pool(x)
        return x, x_pooled


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.upconv = nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size=2, stride=2
        )
        self.conv = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, skip_x):
        x = self.upconv(x)
        x = torch.cat((x, skip_x), dim=1)  # Skip connection
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class ColorizationNet(nn.Module):
    def __init__(self):
        super(ColorizationNet, self).__init__()
        # Encoder
        self.enc1 = EncoderBlock(1, 64)
        self.enc2 = EncoderBlock(64, 128)
        self.enc3 = EncoderBlock(128, 256)
        self.enc4 = EncoderBlock(256, 512)

        # Decoder
        self.dec1 = DecoderBlock(512, 256)
        self.dec2 = DecoderBlock(512, 128)
        self.dec3 = DecoderBlock(256, 64)
        self.dec4 = nn.Conv2d(
            128, 3, kernel_size=3, padding=1
        )  # Output 3 channels for RGB

    def forward(self, x):
        # Encoding
        skip1, x = self.enc1(x)
        skip2, x = self.enc2(x)
        skip3, x = self.enc3(x)
        _, x = self.enc4(x)

        # Decoding
        x = self.dec1(x, skip3)
        x = self.dec2(x, skip2)
        x = self.dec3(x, skip1)
        x = self.dec4(x)  # No skip connection for the final layer

        return torch.sigmoid(x)  # Ensure output is in [0, 1] range

In [None]:
class ColorizationLoss(nn.Module):
    def __init__(self):
        super(ColorizationLoss, self).__init__()
        self.l1_loss = nn.L1Loss()

    def forward(self, output, target):
        return self.l1_loss(output, target)


# Initialize the model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ColorizationNet().to(device)
criterion = ColorizationLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 50  # Set the number of epochs

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for i, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

In [None]:
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
import numpy as np

model.eval()
psnr_list = []
ssim_list = []

with torch.no_grad():
    for i, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs)

        outputs_np = (
            outputs.cpu().numpy().transpose(0, 2, 3, 1)
        )  # Convert to HWC format
        targets_np = targets.cpu().numpy().transpose(0, 2, 3, 1)

        for j in range(outputs_np.shape[0]):
            output_img = outputs_np[j]
            target_img = targets_np[j]

            # Calculate PSNR
            psnr_value = peak_signal_noise_ratio(
                target_img, output_img, data_range=target_img.max() - target_img.min()
            )
            psnr_list.append(psnr_value)

            # Calculate SSIM
            ssim_value = structural_similarity(
                target_img,
                output_img,
                multichannel=True,
                data_range=target_img.max() - target_img.min(),
            )
            ssim_list.append(ssim_value)

    print(f"Validation PSNR: {np.mean(psnr_list):.4f}")
    print(f"Validation SSIM: {np.mean(ssim_list):.4f}")

In [None]:
torch.save(model.state_dict(), "models/colorization_model.pth")