In [32]:
import os
# import cv2
import torch
import lpips
import time

import pandas as pd
# import numpy as np
# import torch.optim as optim
import matplotlib.pyplot as plt

from torch import manual_seed
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
# from torchvision.io import decode_image
from PIL import Image
from torchvision import transforms
from torch import nn
# from tqdm import tqdm

from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr

In [33]:
CSV_FILE_PATH = './sardata_small.csv'
IMAGE_DIR_SAR = './sardata_small/s1'
IMAGE_DIR_COL = './sardata_small/s2'
IMAGE_SIZE = (256, 256)
BATCH_SIZE = 1
SEED = 42
manual_seed(SEED)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {DEVICE}')

device: cpu


In [34]:
data_df = pd.read_csv(CSV_FILE_PATH)

_, test_df = train_test_split(data_df, test_size=0.25, random_state=SEED, shuffle=True, stratify=data_df['type'])
# print(train_df.groupby('type').count())
print(test_df.groupby('type').count())

       s1_image  s1_image_path  s2_image  s2_image_path
type                                                   
agri        250            250       250            250
urban       250            250       250            250


In [35]:
class SarColorDataset(Dataset):
    def __init__(self, data_df, image_dir_sar, image_dir_col, transform_sar=None, transform_col=None):
        self.data_df = data_df
        self.image_dir_sar = image_dir_sar
        self.image_dir_col = image_dir_col
        self.transform_sar = transform_sar
        self.transform_col = transform_col

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, index):
        row = self.data_df.iloc[index]
        label = row['type']

        image_path_sar = os.path.join(self.image_dir_sar, row['s1_image'])
        image_path_col = os.path.join(self.image_dir_col, row['s2_image'])

        # image_sar = decode_image(image_path_sar, mode='GRAY')
        # image_col = decode_image(image_path_col, mode='RGB')

        # image_sar = cv2.imread(image_path_sar, cv2.IMREAD_GRAYSCALE)
        # image_col = cv2.imread(image_path_col, cv2.IMREAD_COLOR_RGB)

        image_sar = Image.open(image_path_sar).convert('L')
        image_col = Image.open(image_path_col).convert('RGB')

        if self.transform_sar:
            image_sar = self.transform_sar(image_sar)

        if self.transform_col:
            image_col = self.transform_col(image_col)

        return image_sar, image_col, label

In [36]:
transform_sar = transforms.Compose([
    transforms.Resize(size=IMAGE_SIZE, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,), inplace=False),
])

transform_col = transforms.Compose([
    transforms.Resize(size=IMAGE_SIZE, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), inplace=False),
])

In [37]:
test_dataset = SarColorDataset(
    data_df=test_df,
    image_dir_sar=IMAGE_DIR_SAR,
    image_dir_col=IMAGE_DIR_COL,
    transform_sar=transform_sar,
    transform_col=transform_col,
)

test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

In [38]:
class DownSample(nn.Module):
    def __init__(self, inp_c, out_c, kernel_size=4, stride=2, padding=1, use_bias=True, normalization='batch'):
        super(DownSample, self).__init__()

        self.down = nn.Sequential(
            nn.Conv2d(in_channels=inp_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, bias=(not normalization) and use_bias),
        )

        if (normalization == 'batch'):
            self.down.append(nn.BatchNorm2d(num_features=out_c, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True))
        elif (normalization == 'instance'):
            self.down.append(nn.InstanceNorm2d(num_features=out_c, eps=1e-5, momentum=0.1, affine=False, track_running_stats=False))

        self.down.append(nn.LeakyReLU(negative_slope=0.2, inplace=False))

    def forward(self, x):
        x = self.down(x)
        return x

class UpSample(nn.Module):
    # def __init__(self, inp_c, out_c, kernel_size=4, stride=1, padding=0, use_bias = True, normalization='batch', apply_dropout=False, dropout_rate=0.5):
    def __init__(self, inp_c, out_c, kernel_size=4, stride=2, padding=1, use_bias = True, normalization='batch', apply_dropout=False, dropout_rate=0.5):
        super(UpSample, self).__init__()

        # self.up = nn.Sequential(
        #     nn.Upsample(scale_factor=2, mode='bilinear'),
        #     nn.ZeroPad2d((2,1,2,1)),
        #     nn.Conv2d(in_channels=inp_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, bias=(not normalization) and use_bias),
        # )

        self.up = nn.Sequential(
            nn.ConvTranspose2d(in_channels=inp_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, bias=(not normalization) and use_bias),
        )

        if (normalization == 'batch'):
            self.up.append(nn.BatchNorm2d(num_features=out_c, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True))
        elif (normalization == 'isinstance'):
            self.up.append(nn.InstanceNorm2d(num_features=out_c, eps=1e-5, momentum=0.1, affine=False, track_running_stats=False))

        if apply_dropout:
            self.up.append(nn.Dropout(p=dropout_rate, inplace=False))

        self.up.append(nn.ReLU(inplace=False))

    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        return x

class Generator(nn.Module):
    def __init__(self, in_channels=1, out_channels=3):
        super(Generator, self).__init__()

        self.down_stack = nn.ModuleList([
            DownSample(inp_c=in_channels, out_c=64, normalization=None),
            DownSample(inp_c= 64, out_c=128),
            DownSample(inp_c=128, out_c=256),
            DownSample(inp_c=256, out_c=512),
            DownSample(inp_c=512, out_c=512),
            DownSample(inp_c=512, out_c=512),
            DownSample(inp_c=512, out_c=512),
            DownSample(inp_c=512, out_c=512, normalization=None),
            ])

        self.up_stack = nn.ModuleList([
            UpSample(inp_c= 512, out_c=512), # removed dropout layers
            UpSample(inp_c=1024, out_c=512),
            UpSample(inp_c=1024, out_c=512),
            UpSample(inp_c=1024, out_c=512),
            UpSample(inp_c=1024, out_c=256),
            UpSample(inp_c= 512, out_c=128),
            UpSample(inp_c= 256, out_c= 64),
            ])

        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.ZeroPad2d((2,1,2,1)),
            nn.Conv2d(in_channels=128, out_channels=out_channels, kernel_size=4, stride=1, padding=0, bias=True),
            nn.Tanh()
        )

    def forward(self, x):

        skips = []
        for layer in self.down_stack:
            x = layer(x)
            skips.append(x)

        skips.pop()
        skips = skips[::-1]

        for layer, skip in zip(self.up_stack, skips):
            x = layer(x, skip)

        x = self.final(x)
        return x

In [39]:
generator = Generator(in_channels=1, out_channels=3).to(DEVICE)

In [40]:
CHECKPOINT_PATH = './checkpoints/checkpoint_epoch_100.pth'

checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
generator.load_state_dict(checkpoint['generator_state_dict'])
generator.eval()

Generator(
  (down_stack): ModuleList(
    (0): DownSample(
      (down): Sequential(
        (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (1): LeakyReLU(negative_slope=0.2)
      )
    )
    (1): DownSample(
      (down): Sequential(
        (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (2): DownSample(
      (down): Sequential(
        (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.2)
      )
    )
    (3): DownSample(
      (down): Sequential(
        (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, aff

In [41]:
lpips_metric = lpips.LPIPS(net='alex').to(DEVICE)

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: C:\Users\nafis\AppData\Roaming\Python\Python311\site-packages\lpips\weights\v0.1\alex.pth


In [42]:
OUTPUT_DIR = "./test_results"
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
results = []
time_taken = 0.0

for idx, (sar_image, col_image, _) in enumerate(test_dataloader):

    sar_image = sar_image.to(DEVICE)
    col_image = col_image.to(DEVICE)

    with torch.no_grad():
        start = time.time()
        pred_image = generator(sar_image)
        time_taken += time.time() - start
    
    def denormalize(tensor):
        return (tensor * 0.5 + 0.5).clamp(0, 1)
    
    for batch in range(BATCH_SIZE):

        sar_image_np = denormalize(sar_image[batch]).permute(1, 2, 0).cpu().numpy().squeeze(-1)
        pred_image_np = denormalize(pred_image[batch]).permute(1, 2, 0).cpu().numpy()
        col_image_np = denormalize(col_image[batch]).permute(1, 2, 0).cpu().numpy()

        psnr_value = psnr(col_image_np, pred_image_np, data_range=1.0)
        ssim_value = ssim(col_image_np, pred_image_np, win_size=7, data_range=1.0, channel_axis=-1)
        lpips_value = lpips_metric(pred_image[batch], col_image[batch]).item()

        results.append([idx * BATCH_SIZE + batch, psnr_value, ssim_value, lpips_value])

        # Save Images
        sar_img = Image.fromarray((sar_image_np * 255).astype('uint8'))
        pred_img = Image.fromarray((pred_image_np * 255).astype('uint8'))
        col_img = Image.fromarray((col_image_np * 255).astype('uint8'))

        sar_img.save(os.path.join(OUTPUT_DIR, f"{idx * BATCH_SIZE + batch:03}_sar.png"))
        pred_img.save(os.path.join(OUTPUT_DIR, f"{idx * BATCH_SIZE + batch:03}_pred.png"))
        col_img.save(os.path.join(OUTPUT_DIR, f"{idx * BATCH_SIZE + batch:03}_color.png"))

        fig, axes = plt.subplots(1, 3, figsize=(12, 4.5))
        fig.suptitle(f"Image ID: {idx * BATCH_SIZE + batch:03} | PSNR: {psnr_value:.2f} | SSIM: {ssim_value:.2f} | LPIPS: {lpips_value:.4f}")  # Adjust y to position title above

        axes[0].imshow(sar_image_np, cmap="gray")
        axes[0].set_title("SAR Image")
        # axes[0].axis("off")
        
        axes[1].imshow(pred_image_np)
        axes[1].set_title("Predicted Image")
        # axes[1].axis("off")

        axes[2].imshow(col_image_np)
        axes[2].set_title("Ground Truth")
        # axes[2].axis("off")

        plt.tight_layout()

        plt.savefig(os.path.join(OUTPUT_DIR, f"{idx * BATCH_SIZE + batch:03}_comparison.png"))

        if idx % 5 == 0:
            plt.show()
        else:
            plt.close()

        results_df = pd.DataFrame(results, columns=["Index", "PSNR", "SSIM", "LPIPS"])
        results_df.to_csv(os.path.join(OUTPUT_DIR, "metrics.csv"), index=False)

In [44]:
avg_psnr = results_df["PSNR"].mean()
avg_ssim = results_df["SSIM"].mean()
avg_lpips = results_df["LPIPS"].mean()
avg_time = (time_taken * 1000) / (BATCH_SIZE * len(test_dataloader))

print(f"Average PSNR: {avg_psnr:.2f}")
print(f"Average SSIM: {avg_ssim:.2f}")
print(f"Average LPIPS: {avg_lpips:.4f}")
print(f"Average Prediction Time: {avg_time:.2f} ms")

Average PSNR: 21.99
Average SSIM: 0.63
Average LPIPS: 0.3059
Average Prediction Time: 311.13 ms
