In [None]:
!pip install -q rasterio

In [None]:
!pip install -q piq

In [None]:
import rasterio
import numpy as np
import matplotlib.pyplot as plt
import cv2, os
from glob import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import random
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import torch.nn.functional as F

In [None]:
sar_dir = '/kaggle/input/sar-images/ROIs2017_winter_s1/ROIs2017_winter'
eo_dir = '/kaggle/input/sar-images/ROIs2017_winter_s2/ROIs2017_winter'

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
sar_subdirs = sorted(os.listdir(sar_dir))
eo_subdirs = sorted(os.listdir(eo_dir))
assert len(sar_subdirs) == len(eo_subdirs), "SAR and EO folder counts do not match"

In [None]:
sar_paths = []
eo_paths = []

In [None]:
for s_sub, e_sub in zip(sar_subdirs, eo_subdirs):
    # Verify directory matching (remove s1/s2 prefix for comparison)
    assert s_sub.replace("s1_", "") == e_sub.replace("s2_", ""), f"Unmatched subdirs: {s_sub}, {e_sub}"

    sar_sub_path = os.path.join(sar_dir, s_sub)
    eo_sub_path = os.path.join(eo_dir, e_sub)

    sar_files = sorted(os.listdir(sar_sub_path))
    eo_files = sorted(os.listdir(eo_sub_path))

    # FIXED: Proper filename matching
    for sar_fname in sar_files:
        # Convert SAR filename to corresponding EO filename
        eo_fname = sar_fname.replace('_s1_', '_s2_')  # s1 → s2 conversion
        
        # Verify the EO file actually exists
        if eo_fname in eo_files:
            sar_paths.append(os.path.join(sar_sub_path, sar_fname))
            eo_paths.append(os.path.join(eo_sub_path, eo_fname))
        else:
            print(f"Warning: No matching EO file for {sar_fname}")

In [None]:
print(f"Length of whole dataset is {len(sar_paths)} pairs")

In [None]:
sar_paths = sar_paths[:5000]
eo_paths = eo_paths[:5000]

In [None]:
def normalize(img):
    img = img.astype(np.float32)
    img -= img.min()
    img /= (img.max() + 1e-6)
    return img

In [None]:
class SARToEODataset(Dataset):
    def __init__(self, sar_paths, eo_paths, patch_size=256, output_mode='rgb'):
        
        # Define band indices for Sentinel-2 (0-based)
        self.bands = {
            'RGB': [3, 2, 1],      # B4, B3, B2
            'NIR_SWIR': [7, 10, 4],     # B8, B11, B5
            'RGB_NIR': [3, 2, 1, 7] # B4, B3, B2, B8
        }
        self.sar_paths = sar_paths
        self.eo_paths = eo_paths
        self.patch_size = patch_size
        self.output_mode = output_mode

    def __len__(self):
        return len(self.sar_paths)  # Assuming sar_paths and eo_paths have same length

    def __getitem__(self, idx):
        sar = self.read_image(self.sar_paths[idx], bands=[0, 1])  # VV, VH
        eo_bands = self.bands[self.output_mode]
        eo = self.read_image(self.eo_paths[idx], bands=eo_bands)
        
        sar = torch.from_numpy(sar).float()
        eo = torch.from_numpy(eo).float()
        return sar, eo

    def read_image(self, path, bands):
        with rasterio.open(path) as src:
            img = []
            raw_band_data = []
            for b in bands:
                band_data = normalize(src.read(b + 1))  # rasterio bands start at 1
                raw_band_data.append(band_data)
            
            if bands == [0,1]:
                vv, vh = raw_band_data
                vv_vh_ratio = np.divide(vv, vh + 1e-6)
                img = [vv, vh, vv_vh_ratio]
            else:
                img = raw_band_data

            img = np.stack(img, axis=0)
            img = img[:, :self.patch_size, :self.patch_size]
            return img

In [None]:
class ResnetBlock(nn.Module):
    def __init__(self, dim, norm_layer=nn.InstanceNorm2d):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, padding=0),
            norm_layer(dim),
            nn.ReLU(True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, padding=0),
            norm_layer(dim)
        )
    def forward(self, x):
        return x + self.conv_block(x)

In [None]:
class ResnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, n_blocks=9, norm_layer=nn.InstanceNorm2d):
        super().__init__()

        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
            norm_layer(ngf),
            nn.ReLU(inplace=True)
        ]

        # Downsampling
        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            model += [
                nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
                norm_layer(ngf * mult * 2),
                nn.ReLU(inplace=True)
            ]

        # Residual blocks
        mult = 2 ** n_downsampling
        for _ in range(n_blocks):
            model += [ResnetBlock(ngf * mult, norm_layer=norm_layer)]

        # Upsampling
        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [
                nn.ConvTranspose2d(ngf * mult, ngf * mult // 2, kernel_size=3, stride=2, padding=1, output_padding=1),
                norm_layer(ngf * mult // 2),
                nn.ReLU(inplace=True)
            ]

        # Output layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [None]:
class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d):
        super().__init__()
        kw = 4
        padw = 1

        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        nf_mult = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, inplace=True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        self.model = nn.Sequential(*sequence)

    def forward(self, x):
        return self.model(x)

In [None]:
G_sar2rgb = ResnetGenerator(input_nc=3, output_nc=3)
G_rgb2sar = ResnetGenerator(input_nc=3, output_nc=3)

D_sar = NLayerDiscriminator(input_nc=3)
D_rgb = NLayerDiscriminator(input_nc=3)

In [None]:
def show_tensor_image(img_tensor, title='', cmap=None, bands=None):
    """
    Convert CHW tensor to HWC image and display/save with proper handling.
    """
    img = img_tensor.detach().cpu().numpy()
    
    if len(img.shape) == 3:
        if bands is not None:
            img = img[bands]
        if img.shape[0] == 1:
            img = img[0]
        else:
            img = img.transpose(1, 2, 0)

    # Normalize to [0, 1]
    img = (img - img.min()) / (img.max() - img.min() + 1e-5)

    plt.imshow(img, cmap=cmap)
    plt.title(title)
    plt.axis('off')

In [None]:
def save_sample_images(epoch, num_samples=3):
    self.G_AB.eval()
    with torch.no_grad():
        val_iter = iter(self.val_loader)
        for i in range(num_samples):
            try:
                real_A, real_B = next(val_iter)
            except StopIteration:
                break
            real_A = real_A.to(self.device)
            real_B = real_B.to(self.device)
            fake_B = self.G_AB(real_A)

            plt.figure(figsize=(12, 4))
            # SAR input (assume 2 or 3 channels: VV, VH, VV/VH)
            plt.subplot(1, 3, 1)
            show_tensor_image(real_A[0], 'Input SAR', cmap='gray')

            # Real EO (assume first 3 bands are RGB)
            plt.subplot(1, 3, 2)
            show_tensor_image(real_B[0], 'Real EO', bands=[0, 1, 2])

            # Fake EO
            plt.subplot(1, 3, 3)
            show_tensor_image(fake_B[0], 'Generated EO', bands=[0, 1, 2])

            plt.tight_layout()
            plt.savefig(f"{self.output_dir}/images/epoch_{epoch}_sample_{i}.png")
            plt.close()

In [None]:
import random
import torch

class ImagePool:
    """History buffer of generated images for discriminator training."""
    def __init__(self, pool_size: int):
        self.pool_size = pool_size
        self.images = []

    def query(self, images: torch.Tensor) -> torch.Tensor:
        """
        images: a batch of generated images (N,C,H,W)
        Returns a batch of images to use for D training: either
        images from pool or the current images, randomly replaced.
        """
        if self.pool_size == 0:
            return images
        return_images = []
        for img in images:
            img = torch.unsqueeze(img, 0)
            if len(self.images) < self.pool_size:
                # fill pool
                self.images.append(img)
                return_images.append(img)
            else:
                if random.random() > 0.5:
                    # use image from pool, replace it
                    idx = random.randint(0, self.pool_size - 1)
                    tmp = self.images[idx].clone()
                    self.images[idx] = img
                    return_images.append(tmp)
                else:
                    return_images.append(img)
        return torch.cat(return_images, dim=0)

In [None]:
import random

class ImagePool:
    """Image buffer that stores previously generated images to stabilize training.
    
    This buffer enables us to update discriminators using a history of generated images
    rather than only the most recently generated images.
    """
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []
    
    def query(self, images):
        """Return images from the pool.
        
        Parameters:
            images: the latest generated images from the generator
        Returns:
            images from the buffer.
        
        By 50/100, the buffer will return input images.
        By 50/100, the buffer will return images previously stored in the buffer,
        and insert the current images to the buffer.
        """
        if self.pool_size == 0:  # if buffer size is 0, do nothing
            return images
        
        return_images = []
        for image in images:
            image = torch.unsqueeze(image.data, 0)
            if self.num_imgs < self.pool_size:   # if buffer not full
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:  # 50% chance to return a previously stored image
                    random_id = random.randint(0, self.pool_size - 1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:       # 50% chance to return the current image
                    return_images.append(image)
        
        return_images = torch.cat(return_images, 0)
        return return_images


In [None]:
class CycleGANTrainer:
    def __init__(self, G_AB, G_BA, D_A, D_B,
                 dataloaders, optimizers,
                 pool_size=50, device='cuda',
                 output_dir='./outputs', img_save_epoch=5):
        self.G_AB, self.G_BA = G_AB, G_BA
        self.D_A, self.D_B = D_A, D_B
        self.train_loader, self.val_loader = dataloaders
        self.opt_G, self.opt_D = optimizers
        self.device = device
        self.best_ssim = -float('inf')
        self.output_dir = output_dir
        self.img_save_epoch = img_save_epoch
        self.fake_A_pool = ImagePool(pool_size)
        self.fake_B_pool = ImagePool(pool_size)

        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
        os.makedirs(f"{output_dir}/images", exist_ok=True)

        self.loss_history = {'G': [], 'D': [], 'cycle': [], 'ssim': []}
        self.G_AB = self.G_AB.to(self.device)
        self.G_BA = self.G_BA.to(self.device)
        self.D_A = self.D_A.to(self.device)
        self.D_B = self.D_B.to(self.device)

    def train(self, n_epochs, metrics_fn):
        self.G_AB.train()
        self.G_BA.train()
        self.D_A.train()
        self.D_B.train()
        for epoch in range(1, n_epochs+1):
            epoch_losses = {'G': 0., 'D': 0., 'cycle': 0., 'ssim': 0.}
            pbar = tqdm(self.train_loader, desc=f"Epoch {epoch}/{n_epochs}")
            for real_A, real_B in pbar:
                real_A, real_B = real_A.to(self.device), real_B.to(self.device)

                # ------------------
                #  Train Generators
                # ------------------
                self.opt_G.zero_grad()
                fake_B = self.G_AB(real_A)
                fake_A = self.G_BA(real_B)

                rec_A = self.G_BA(fake_B)
                rec_B = self.G_AB(fake_A)

                # GAN losses
                loss_GAN_AB = F.mse_loss(self.D_B(fake_B), torch.ones_like(self.D_B(fake_B)))
                loss_GAN_BA = F.mse_loss(self.D_A(fake_A), torch.ones_like(self.D_A(fake_A)))

                # Cycle consistency
                loss_cycle = F.l1_loss(rec_A, real_A) + F.l1_loss(rec_B, real_B)

                # Combined generator loss (no identity term)
                loss_G = loss_GAN_AB + loss_GAN_BA + 10.0 * loss_cycle
                loss_G.backward()
                self.opt_G.step()

                # -----------------------
                #  Train Discriminators
                # -----------------------
                self.opt_D.zero_grad()
                fake_B_ = self.fake_B_pool.query(fake_B.detach())
                fake_A_ = self.fake_A_pool.query(fake_A.detach())

                loss_D_B = (F.mse_loss(self.D_B(real_B), torch.ones_like(self.D_B(real_B))) +
                                  F.mse_loss(self.D_B(fake_B_), torch.zeros_like(self.D_B(fake_B_))))
                loss_D_A = (F.mse_loss(self.D_A(real_A), torch.ones_like(self.D_A(real_A))) +
                                  F.mse_loss(self.D_A(fake_A_), torch.zeros_like(self.D_A(fake_A_))))
                loss_D = loss_D_A + loss_D_B
                loss_D.backward()
                self.opt_D.step()

                # Logging
                epoch_losses['G'] += loss_G.item()
                epoch_losses['D'] += loss_D.item()
                epoch_losses['cycle'] += loss_cycle.item()
                pbar.set_postfix(G=loss_G.item(), D=loss_D.item())

            # Average losses
            for k in ['G', 'D', 'cycle']:
                epoch_losses[k] /= len(self.train_loader)

            # Validation: compute SSIM on a small batch
            val_real_A, val_real_B = next(iter(self.val_loader))
            val_real_A, val_real_B = val_real_A.to(self.device), val_real_B.to(self.device)
            val_fake_B = self.G_AB(val_real_A)
            ssim_val = metrics_fn(val_fake_B, val_real_B)
            epoch_losses['ssim'] = ssim_val

            # Save best checkpoint
            if ssim_val > self.best_ssim:
                self.best_ssim = ssim_val
                torch.save({
                    'epoch': epoch,
                    'G_AB': self.G_AB.state_dict(),
                    'G_BA': self.G_BA.state_dict(),
                    'D_A': self.D_A.state_dict(),
                    'D_B': self.D_B.state_dict(),
                    'opt_G': self.opt_G.state_dict(),
                    'opt_D': self.opt_D.state_dict(),
                    'best_ssim': self.best_ssim
                }, f"{self.output_dir}/checkpoints/best.pth")

            # Save epoch checkpoint
            torch.save({
                'epoch': epoch,
                'G_AB': self.G_AB.state_dict(),
                'G_BA': self.G_BA.state_dict(),
                'D_A': self.D_A.state_dict(),
                'D_B': self.D_B.state_dict(),
                'opt_G': self.opt_G.state_dict(),
                'opt_D': self.opt_D.state_dict(),
                'best_ssim': self.best_ssim
            }, f"{self.output_dir}/checkpoints/epoch_{epoch}.pth")

            # Save sample images periodically
            # if epoch % self.img_save_epoch == 0:
            #     self.save_sample_images(epoch, num_samples=3)

            # Record loss history
            for k in ['G', 'D', 'cycle', 'ssim']:
                self.loss_history[k].append(epoch_losses[k])

            print(f"Epoch {epoch} | SSIM: {ssim_val:.4f}")

        return self.loss_history
    


In [None]:
import piq
def ssim_metric(pred, target):
    # pred and target are in [–1,1] range; data_range=2.0 covers that span
    p = torch.clamp((pred + 1.0) / 2.0, 0.0, 1.0)
    t = torch.clamp((target + 1.0) / 2.0, 0.0, 1.0)
    return piq.ssim(p, t, data_range=1.0, reduction='mean').item()

In [None]:
import itertools

In [None]:
dataset = SARToEODataset(sar_paths, eo_paths, output_mode='RGB')

In [None]:
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

In [None]:
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)

In [None]:
gen_params  = itertools.chain(G_sar2rgb.parameters(), G_rgb2sar.parameters())
disc_params = itertools.chain(D_sar.parameters(),     D_rgb.parameters())

optimizer_G = torch.optim.AdamW(gen_params,  lr=2e-4,
                                betas=(0.5, 0.999), weight_decay=1e-4)
optimizer_D = torch.optim.AdamW(disc_params, lr=2e-4,
                                betas=(0.5, 0.999), weight_decay=1e-4)

In [None]:
trainer = CycleGANTrainer(
    G_sar2rgb, G_rgb2sar, D_sar, D_rgb,
    dataloaders=(train_loader, val_loader),
    optimizers=(optimizer_G, optimizer_D),
    pool_size=50,
    device=device,
    output_dir='./runs/exp1',
    img_save_epoch=5
)

In [None]:
loss_history = trainer.train(15,ssim_metric)