<a href="https://colab.research.google.com/github/le-pigeon/fantastic-winner/blob/main/SAR_DeepSpeck.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Colab test

In [None]:
# empty for debug/admin

In [None]:
# This script trains a deep unrolling network for SAR despeckling.
# Steps:
# 0. Pay for premium GPU :)
# 1. Set up dataset (simulated speckle noise on optical images)
# 2. Define model
# 3. Train with Charbonnier + Total Variation loss
# 4. Test on real SAR images/test noisy images

In [None]:
import glob
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Subset, DataLoader
import os
import random
import torch.nn.functional as F
import gc
from tqdm import tqdm  # For progress bar
import torch.cuda.memory as cuda_mem



In [None]:
# If you wanna clear cache.. idk sometimes still high RAM :[
gc.collect()
torch.cuda.empty_cache()

In [None]:
# This block loads and unzip learning dataset from G Drive
# Allow all permissions
# You should see /content/drive folder

# Step 1: Mount Google Drive
from google.colab import drive
import zipfile
import os

# drive.mount('/content/drive')  # Connect to your Google Drive

# Define paths for zip files
zip_path = "/content/drive/MyDrive/SAR_Project/Dataset/SAR_paired.zip"
extract_path = "/content/test_data/SAR_Dataset"  # Where to extract

# Make folder to put your own test real SAR images
folder_name = "/content/imported_images"

# Create if doesn't exist
os.makedirs(folder_name, exist_ok=True)

# Step 3: Extract ZIP
print("Extracting dataset... this may take a moment!")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print("Extraction complete!")

# Make model folder if you want to load your own model
os.makedirs("/content/model", exist_ok=True)  # Make a folder exist

# Select CUDA device (GPU is recommended)
# Automatically select CPU if no GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# Load training data and validation data

In [None]:
###############################################################################
############################## PAIRED DATASET #################################
###############################################################################

# Load training
train_noisy_tensors = []
train_clean_tensors = []

# Path to folders
noisy_folder = "/content/test_data/SAR_Dataset/SAR despeckling filters dataset/Main folder/Noisy"
clean_folder = "/content/test_data/SAR_Dataset/SAR despeckling filters dataset/Main folder/GTruth"

# Remove too black images
def is_black_image(filepath, brightness_threshold=0.1, black_ratio_threshold=0.8):
    """
    Checks if an image is mostly black.

    Parameters:
        filepath (str): Path to the image file.
        brightness_threshold (float): Pixel intensity below which a pixel is considered black (0-1 scale).
        black_ratio_threshold (float): Percentage of pixels that must be black for removal.

    Returns:
        bool: True if the image is mostly black, False otherwise.
    """
    img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
    if img is None:
        return True  # Treat unreadable images as black

    img = img / 255.0  # Normalize to [0,1]
    black_pixels = np.sum(img < brightness_threshold)
    total_pixels = img.size
    black_ratio = black_pixels / total_pixels

    return black_ratio > black_ratio_threshold  # Remove if too much is black

def is_mostly_black_edge_case(filepath, brightness_threshold=0.1, black_ratio_threshold=0.9, center_crop_ratio=0.5):
    """
    Detects if an image is mostly black, even if there are small bright stripes at the edges.

    Args:
        filepath: Path to image.
        brightness_threshold: Pixel value below which it's considered black (0-1 scale).
        black_ratio_threshold: Proportion of black pixels to consider it "too black".
        center_crop_ratio: Proportion of image to focus on the center (e.g., 0.5 = center 50%).

    Returns:
        True if mostly black, False otherwise.
    """
    img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)
    if img is None:
        return True

    img = img / 255.0
    h, w = img.shape

    # Focus on central region
    ch, cw = int(h * center_crop_ratio), int(w * center_crop_ratio)
    y1, x1 = (h - ch) // 2, (w - cw) // 2
    center_crop = img[y1:y1+ch, x1:x1+cw]

    black_pixels = np.sum(center_crop < brightness_threshold)
    total_pixels = center_crop.size
    black_ratio = black_pixels / total_pixels

    return black_ratio > black_ratio_threshold

#Remove pair from training if clean image is too blurry
def is_blurry(image, threshold=100):
    """Detect if an image is blurry using Laplacian variance."""
    gray = image if len(image.shape) == 2 else cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    laplacian = cv2.Laplacian(gray, cv2.CV_64F)
    variance = laplacian.var()
    return variance < threshold  # Low variance = blurry image

# Load all image paths
noisy_paths = sorted(glob.glob(os.path.join(noisy_folder, "*.tiff")))
clean_paths = sorted(glob.glob(os.path.join(clean_folder, "*.tiff")))

# Load the images to CPU RAM
for noisy_path, clean_path in zip(noisy_paths, clean_paths):
    #if is_black_image(noisy_path) or is_black_image(clean_path):
    # if is_mostly_black_edge_case(noisy_path) or is_mostly_black_edge_case(clean_path):
    #       continue  # Skip black images

    noisy = cv2.imread(noisy_path, cv2.IMREAD_GRAYSCALE)
    clean = cv2.imread(clean_path, cv2.IMREAD_GRAYSCALE)

    # if is_blurry(clean):  # Only check CLEAN SAR image
    #     # print(f"Skipping blurry clean image: {clean_path}")
    #     continue  # Skip both noisy and clean images if clean is blurry

    # Resize to save memory
    noisy = cv2.resize(noisy, (256, 256), interpolation=cv2.INTER_AREA)
    clean = cv2.resize(clean, (256, 256), interpolation=cv2.INTER_AREA)
    noisy = noisy / 255.0
    clean = clean / 255.0

    # Store as torch tensors (but stay in CPU RAM)
    train_noisy_tensors.append(torch.tensor(noisy).unsqueeze(0))  # (1, H, W)
    train_clean_tensors.append(torch.tensor(clean).unsqueeze(0))  # (1, H, W)

print(f"Loaded {len(train_clean_tensors)} noisy-clean SAR image pairs to CPU RAM!")



class PairedSARDataset(Dataset):
    def __init__(self, noisy_images, clean_images):
        self.noisy_images = noisy_images
        self.clean_images = clean_images

    def __getitem__(self, index):
        noisy = self.noisy_images[index]
        clean = self.clean_images[index]

        # Move to GPU only when accessed
        return noisy.float(), clean.float()


    def __len__(self):
        return len(self.noisy_images)



In [None]:
# This loads training data

SEED = 42  # Set a fixed seed for reproducibility on random

torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True  # Enable for speed & efficiency

# Ensure reproducibility across NumPy & PyTorch
np.random.seed(SEED)
torch.manual_seed(SEED)

# Create Paired Dataset
train_dataset = PairedSARDataset(train_noisy_tensors, train_clean_tensors)
subset_indices = np.random.choice(len(train_dataset), 800, replace=True)
subset = Subset(train_dataset, subset_indices)

# To load all dataset, uncomment if needed
# train_loader = DataLoader(train_dataset,
#                           batch_size=4,
#                           shuffle=True,
#                           num_workers=2,
#                           pin_memory=False)  # Pin to CPU RAM

# To load only a subset (more GPU efficient)
train_loader = DataLoader(subset,
                          batch_size=16,
                          shuffle=True,
                          num_workers=8,
                          pin_memory=False)  # Pin to CPU RAM

# print(f"Ready to train on {len(train_dataset)} images!") # uncomment if using all images
print(f"Ready to train on {len(subset)} images!")



In [None]:
# This loads validation dataset
val_noisy_tensors = []
val_clean_tensors = []

# Load validationn data set
val_noisy_dir = "/content/test_data/SAR_Dataset/SAR despeckling filters dataset/Main folder/Noisy_val"
val_clean_dir = "/content/test_data/SAR_Dataset/SAR despeckling filters dataset/Main folder/GTruth_val"

# List of filenames to exclude
exclude_files = {'5120_0.tiff', '5632_0.tiff'}

val_noisy_files = sorted(
    [f for f in glob.glob(os.path.join(val_noisy_dir, "*.tiff")) if os.path.basename(f) not in exclude_files]
)

val_clean_files = sorted(
    [f for f in glob.glob(os.path.join(val_clean_dir, "*.tiff")) if os.path.basename(f) not in exclude_files]
)


# Sanity check
assert len(val_noisy_files) == len(val_clean_files), "Mismatch in validation pair count!"

for noisy_path, clean_path in zip(val_noisy_files, val_clean_files):
    noisy = cv2.imread(noisy_path, cv2.IMREAD_GRAYSCALE)
    clean = cv2.imread(clean_path, cv2.IMREAD_GRAYSCALE)

    # Normalize to [0,1]
    noisy = torch.tensor(noisy / 255.0, dtype=torch.float32).unsqueeze(0)
    clean = torch.tensor(clean / 255.0, dtype=torch.float32).unsqueeze(0)

    noisy = noisy / 255.0
    clean = clean / 255.0

    val_noisy_tensors.append(noisy)
    val_clean_tensors.append(clean)

# Pass the tensors
val_dataset = PairedSARDataset(val_noisy_tensors, val_clean_tensors)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=8)

print(f"Ready to validate on {len(val_dataset)} images!")

In [None]:
# Peek at training images OwO

import matplotlib.pyplot as plt
import random


# Function to show images
def show_random_images(noisy_images, clean_images, num_samples=15):
    """Display random noisy and clean SAR images."""
    indices = random.sample(range(len(noisy_images)), num_samples)

    for idx in indices:
        noisy = cv2.imread(noisy_paths[idx], cv2.IMREAD_GRAYSCALE)
        clean = cv2.imread(clean_paths[idx], cv2.IMREAD_GRAYSCALE)

        noisy = noisy / 255.0
        clean = clean / 255.0

        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(noisy, cmap='gray')
        plt.title("Noisy SAR Image")
        plt.axis("off")

        plt.subplot(1, 2, 2)
        plt.imshow(clean, cmap='gray')
        plt.title("Clean SAR Image")
        plt.axis("off")

        plt.show()

# Show 10 random images from the dataset
show_random_images(train_noisy_tensors, train_clean_tensors)

# Load autoencoder and model

In [None]:
class ResUNet_512(nn.Module):
    def __init__(self, channels=1):
        super(ResUNet_512, self).__init__()
        # Encoder Path with BatchNorm
        self.enc1 = self.conv_block(channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 256)

        # Decoder Path with optimized skip connections
        self.dec4 = self.upconv_block(256, 256)
        self.dec3 = self.upconv_block(256, 128)
        self.dec2 = self.upconv_block(128, 64)
        self.dec1 = self.conv_block(64, 32)

        # Final Output Layer
        self.final_conv = nn.Conv2d(32, 1, kernel_size=3, padding=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1)
        )

    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3,
                               stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1)
        )

    def crop(self, enc_feat, dec_feat):
        """Crop encoder features to match decoder features spatially."""
        _, _, h, w = dec_feat.size()
        return enc_feat[:, :, :h, :w]

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(F.max_pool2d(e1, 2))
        e3 = self.enc3(F.max_pool2d(e2, 2))
        e4 = self.enc4(F.max_pool2d(e3, 2))

        # Decoder with optimized skip connections
        d4_temp = self.dec4(e4)
        d4 = d4_temp + self.crop(e3, d4_temp)

        d3_temp = self.dec3(d4)
        d3 = d3_temp + self.crop(e2, d3_temp)

        d2_temp = self.dec2(d3)
        d2 = d2_temp + self.crop(e1, d2_temp)

        d1 = self.dec1(d2)

        # Residual subtraction (with scaling)
        out = self.final_conv(d1)
        return F.relu(out)


In [None]:
# Debug to check resUnet output (was broken a few times before hand)
test_input = torch.randn(1, 1, 512, 512)  # Simulated SAR patch
model = ResUNet_512()
output = model(test_input)

print("ResUNet Output Shape:", output.shape)  # Should match input e.g. (1, 1, 512, 512)



In [None]:
# ========== SAR DeepSpeck model ==========

class SAR_DeepSpeck(nn.Module):
    def __init__(self, num_layers=8):
        super(SAR_DeepSpeck, self).__init__()
        self.num_layers = num_layers
        self.resunet = ResUNet_512()

        self.gradient_steps = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(1, 8, kernel_size=3, padding=1),
                nn.BatchNorm2d(8),
                nn.ReLU(),
                nn.Conv2d(8, 1, kernel_size=3, padding=1)
            ) for _ in range(num_layers)
        ])

        # Make δ and η trainable parameters
        self.delta = nn.Parameter(torch.tensor(0.01, dtype=torch.float32))  # Init to 0.01
        self.eta = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))    # Init to 1.0

    def forward(self, x):
        v = x  # Initial guess
        for i in range(self.num_layers):
            noise_est = self.resunet(v)     # Estimate noise at current step
            x = x - self.delta * (self.eta * self.gradient_steps[i](noise_est))  # Update x
            v = torch.relu(x) # Clamp vlaues to avoid negative value

        return x - self.resunet(v)
        # return x + v


# ========== Loss Function & Optimizer ==========
def charbonnier_loss(x, y, epsilon=1e-5):  #  Increase epsilon to prevent division issues
    diff = x - y
    return torch.mean(torch.sqrt(diff ** 2 + epsilon))

def tv_loss(img):
    h_variance = torch.mean(torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :]))
    w_variance = torch.mean(torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1]))
    return h_variance + w_variance


def laplacian_loss(img):
    img = F.pad(img, (1, 1, 1, 1), mode='reflect')  # Avoid size mismatch
    laplacian_kernel = torch.tensor([[0, 1, 0],
                                 [1, -4, 1],
                                 [0, 1, 0]], dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
    laplacian_img = F.conv2d(img, laplacian_kernel, padding=0)
    return torch.mean(torch.abs(laplacian_img))  # L1 loss on Laplacian


# To check validation loss but..uh doens't really workas intended it keeps going up :( )

def gaussian_kernel(window_size=11, sigma=1.5, channels=1):
    x = torch.arange(window_size).float() - window_size // 2
    gauss = torch.exp(-x**2 / (2 * sigma**2))
    gauss = gauss / gauss.sum()
    kernel = gauss[:, None] @ gauss[None, :]
    kernel = kernel.expand(channels, 1, window_size, window_size)
    return kernel


def ssim_torch(img1, img2, window_size=11, window=None, C1=0.01**2, C2=0.03**2):
    """Simplified SSIM for grayscale images."""
    assert img1.size() == img2.size(), "Input images must have the same size"
    B, C, H, W = img1.shape

    if window is None:
        window = gaussian_kernel(window_size=window_size, sigma=1.5, channels=C).to(img1.device)

    mu1 = F.conv2d(img1, window, padding=window_size//2, groups=C)
    mu2 = F.conv2d(img2, window, padding=window_size//2, groups=C)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size//2, groups=C) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size//2, groups=C) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size//2, groups=C) - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
               ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()

# ========== Early stopping ==========
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', min_loss=None):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.path = path
        self.min_loss = min_loss  # User-defined absolute threshold

    def __call__(self, val_loss, model):
        # Early exit if minimum threshold is reached
        if self.min_loss is not None and val_loss < self.min_loss:
            if self.verbose:
                print(f"Validation loss {val_loss:.6f} < threshold {self.min_loss:.6f} → Early stopping now.")
            self.save_checkpoint(val_loss, model)
            self.early_stop = True
            return

        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"No improvement. EarlyStopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f"Saving model ... (val_loss: {val_loss:.6f})")
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


# Training loop

In [None]:
# =================== Training loop ===========================

from torch.amp import autocast, GradScaler
scaler = GradScaler()
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'
early_stopping = EarlyStopping(patience=5,
                               verbose=True,
                               delta=0.00001,
                               path="/content/model/SAR_DeepSpeck.pth",
                               min_loss=0.005
                               )

# Clear GPU mem
gc.collect()
torch.cuda.empty_cache()

model = SAR_DeepSpeck().to(device)  # Move model to CPU or GPU

num_epochs = 10  # Set the number of epochs (14 is the magic number?)

# Define weights
w1 = torch.tensor(0.9)    # Charbonnier
# w2 = torch.tensor(0.1)  # Gradient Loss not used
w3 = torch.tensor(0.05) # Lap loss
w4 = torch.tensor(0.08) # tv loss

# Add parameters to optimizer
# optimizer = torch.optim.Adam([w1, w2, w3] + list(model.parameters()), lr=1e-5) # for learnable
optimizer = torch.optim.AdamW(model.parameters(), lr=20e-6, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.75, patience=3)


# ========= Training Loop ==========
torch.set_default_tensor_type('torch.FloatTensor')  # Ensure float32 tensors

train_losses = []
val_losses = []

for epoch in range(num_epochs):
    torch.cuda.empty_cache()  # Free up unused GPU memory

    total_loss = 0.0

    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")

    for batch_idx, (noisy, clean) in progress_bar:
        noisy, clean = noisy.float(), clean.float()
        noisy, clean = noisy.to(device, non_blocking=True), clean.to(device, non_blocking=True)

        # For float 16
        # noisy, clean = noisy.to(device, dtype=torch.float16), clean.to(device, dtype=torch.float16)

        # noisy = force_even_dim(noisy)
        # clean = force_even_dim(clean)

        optimizer.zero_grad()

        with autocast(device_type=device_type):
            output = model(noisy)


            # Compute simplified losses
            charbonnier = charbonnier_loss(output, clean)

            # Ensure the tensor has enough pixels for gradient calculation
            # if output.shape[-1] > 1 and output.shape[-2] > 1:
            #     grad_loss = torch.mean(torch.abs(torch.gradient(output, dim=(-2, -1))[0] - torch.gradient(clean, dim=(-2, -1))[0]))
            # else:
            #     grad_loss = torch.tensor(0.0, device=device)  # Fallback if too small

            laplacian = laplacian_loss(output)
            tv = tv_loss(output)

            loss = (
                w1 * charbonnier +
                # w2 * grad_loss + # not used
                w3 * laplacian +
                w4 * tv

            )

            loss = loss.mean()  # Ensure it's a scalar

        # Backprop with scaled loss
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # GPU memory tracking
        gpu_mem = torch.cuda.memory_allocated(device) / 1024**3

        # Learning rate tracking
        lr = optimizer.param_groups[0]['lr']

        # # Backpropagation
        # loss.backward()
        # optimizer.step()

        # Track loss for logging
        total_loss += loss.item()

        progress_bar.set_postfix({
            "charbonnier": f"{charbonnier.item():.6f}",
            # "grad_loss": f"{grad_loss.item():.6f}", # not used
            "lap_loss": f"{laplacian.item():.6f}",
            "tv_loss": f"{tv.item():.6f}",
            "loss": f"{loss.item():.4f}",
            "gpu": f"{gpu_mem:.2f} GB",
            "lr": f"{lr:.2e}"
        })

    # Validation
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for noisy, clean in val_loader:
            noisy, clean = noisy.to(device), clean.to(device)
            output = model(noisy)

            charbonnier = charbonnier_loss(output, clean)
            laplacian = laplacian_loss(output)
            tv = tv_loss(output)
            # ssim_score = ssim_torch(output, clean)
            # ssim_loss = 1 - ssim_score
            # loss = ssim_loss

            loss = (
                w1 * charbonnier +
                w3 * laplacian +
                w4 * tv
            )
            loss = loss.mean()  # or your total loss
            val_loss += loss.item()

    val_loss /= len(val_loader)

    print(f"Validation Loss: {val_loss:.6f}")
    train_losses.append(total_loss / len(train_loader))
    val_losses.append(val_loss)

    # Use validation loss for learning rate scheduling
    scheduler.step(val_loss)  # if you're using ReduceLROnPlateau

    model.train()  # Switch back to training mode

    # Early stop
    # early_stopping(val_loss, model)

    # Early stopping
    # if early_stopping.early_stop:
    #     print("Ran out of patience, early stopping!!")
    #     # Make sure folder exists (for extra safety)
    #     os.makedirs("/content/model", exist_ok=True)

    #     print("Model already saved to /content/model/sar_durnet.pth")
    #     break
    # Update LR Scheduler at the END of the epoch
    # scheduler.step(total_loss / len(train_loader))



# ========== 5. Save Model ==========
os.makedirs("/content/model", exist_ok=True)  # Make a folder exist
torch.save(model.state_dict(), "/content/model/SAR_DeepSpeck.pth")
print("Model saved!")

Took about 16GB of VRAM for 800 pairs, batch size 8, num worker = 4. \

25 second per epoch \

Took about 32GB of VRAM for 800 pairs, batch size 16, num worker = 8. \

25 second per epoch

In [None]:
# Try to plot the training loss
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label="Training Loss", marker='o')
plt.plot(val_losses, label="Validation Loss", marker='s')
plt.title("Training vs Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss (log scale)")
plt.yscale("log")  # log scale
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
# ========== 5. Save Model (but manually) ==========
# os.makedirs("/content/model", exist_ok=True)  # Create folder if it doesn't exist
# torch.save(model.state_dict(), "/content/model/SAR_DeepSpeck.pth")
# print("Model saved!")

In [None]:
# Check what is eating up my gpu
# !nvidia-smi


# Load previous model (if any)

In [None]:
##########################
# To load other models downloaded before
#########################

# Load the trained SAR-DURNet model
model = SAR_DeepSpeck().to(device)

# Load the saved model weights
checkpoint_path = "/content/model/SAR_DeepSpeck.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))

# Set the model to evaluation mode
model.eval()

print("Model successfully loaded!")

# Test on validation dataset

In [None]:
# ============= To test on validation images ==============================

import glob
import cv2
import torch
import os
import matplotlib.pyplot as plt
import numpy as np

# Define paths
noisy_folder = "/content/test_data/SAR_Dataset/SAR despeckling filters dataset/Main folder/Noisy_val"
clean_folder = "/content/test_data/SAR_Dataset/SAR despeckling filters dataset/Main folder/GTruth_val"
output_folder = "/content/despeckled_val_results"  # Where despeckled images will be saved

# Make sure output folder exists
os.makedirs(output_folder, exist_ok=True)

# Load model in eval mode
model.eval()

# Load noisy & clean image pairs
noisy_images = sorted(glob.glob(noisy_folder + "/*.tiff"))
clean_images = sorted(glob.glob(clean_folder + "/*.tiff"))

assert len(noisy_images) == len(clean_images), "Mismatch in number of noisy and clean images!"

# Store outputs for later visualization
results = []

# Resize to multiple of 32 for U-Net compatibility
def resize_to_multiple_of_32(image):
    h, w = image.shape
    new_h = (h // 32) * 32
    new_w = (w // 32) * 32
    return cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)

# Process each image
for noisy_path, clean_path in zip(noisy_images, clean_images):
    # print(f"Processing {noisy_path}...")

    # Load images
    noisy = cv2.imread(noisy_path, cv2.IMREAD_GRAYSCALE) / 255.0
    clean = cv2.imread(clean_path, cv2.IMREAD_GRAYSCALE) / 255.0

    # Resize
    noisy_resized = resize_to_multiple_of_32(noisy)
    clean_resized = resize_to_multiple_of_32(clean)

    # Convert to tensor
    noisy_tensor = torch.tensor(noisy_resized, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)

    # Run through model
    with torch.no_grad():
        despeckled = model(noisy_tensor).squeeze().cpu().numpy()
    despeckled = np.clip(despeckled, 0, 1)

    # Save despeckled image
    output_path = os.path.join(output_folder, os.path.basename(noisy_path))
    cv2.imwrite(output_path, (despeckled * 255).astype(np.uint8))

    # Store for visualization
    results.append((noisy_resized, despeckled, clean_resized))

print("ALL VALIDATION IMAGES PROCESSED!")




In [None]:
# Show 15 random samples at the end
num_samples = min(15, len(results))
sample_indices = np.random.choice(len(results), num_samples, replace=False)

plt.figure(figsize=(5, 30))
for i, idx in enumerate(sample_indices):
    noisy, despeckled, clean = results[idx]

    plt.subplot(num_samples, 3, i * 3 + 1)
    plt.imshow(noisy, cmap="gray")
    plt.title("Noisy Input")
    plt.axis("off")

    plt.subplot(num_samples, 3, i * 3 + 2)
    plt.imshow(despeckled, cmap="gray")
    plt.title("Despeckled Output")
    plt.axis("off")

    plt.subplot(num_samples, 3, i * 3 + 3)
    plt.imshow(clean, cmap="gray")
    plt.title("Ground Truth")
    plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
# zip output folder to download ezpzggwp

import shutil
shutil.make_archive("/content/despeckled_val_results", 'zip', "/content/despeckled_val_results")

# To use model on real SAR images

To use model on real SAR images

In [None]:
# =============== To test on real SAR images ================================

import os
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt

def extract_overlapping_patches_with_padding(image, patch_size=256, stride=128):
    h, w = image.shape

    pad_h = (np.ceil((h - patch_size) / stride) * stride + patch_size - h).astype(int)
    pad_w = (np.ceil((w - patch_size) / stride) * stride + patch_size - w).astype(int)

    padded_image = np.pad(image, ((0, pad_h), (0, pad_w)), mode='reflect')
    padded_h, padded_w = padded_image.shape

    patches = []
    positions = []

    for y in range(0, padded_h - patch_size + 1, stride):
        for x in range(0, padded_w - patch_size + 1, stride):
            patch = padded_image[y:y+patch_size, x:x+patch_size]
            patches.append(patch)
            positions.append((y, x))

    return patches, positions, padded_image.shape, (h, w)


def merge_overlapping_patches(patches, positions, padded_shape, original_shape, patch_size=256):
    h_pad, w_pad = padded_shape
    h_orig, w_orig = original_shape

    recon = np.zeros((h_pad, w_pad), dtype=np.float32)
    weight = np.zeros((h_pad, w_pad), dtype=np.float32)

    for patch, (y, x) in zip(patches, positions):
        recon[y:y+patch_size, x:x+patch_size] += patch
        weight[y:y+patch_size, x:x+patch_size] += 1.0

    recon /= weight
    return recon[:h_orig, :w_orig]  # Crop back to original

def create_gaussian_weight(patch_size=256, sigma=0.125):
    """Create a smooth 2D Gaussian weight mask for blending overlapping patches."""
    ax = np.linspace(-1, 1, patch_size)
    gauss = np.exp(-0.5 * (ax / sigma) ** 2)
    weight = np.outer(gauss, gauss)
    return weight / weight.max()


def merge_patches_soft(patches, positions, padded_shape, original_shape, patch_size=256):
    h_pad, w_pad = padded_shape
    h_orig, w_orig = original_shape

    recon = np.zeros((h_pad, w_pad), dtype=np.float32)
    weight_sum = np.zeros((h_pad, w_pad), dtype=np.float32)

    weight_mask = create_trimmed_weight(patch_size)  # Soft blending mask

    for patch, (y, x) in zip(patches, positions):
        weighted_patch = patch * weight_mask
        recon[y:y+patch_size, x:x+patch_size] += weighted_patch
        weight_sum[y:y+patch_size, x:x+patch_size] += weight_mask

    final = recon / (weight_sum + 1e-8)
    return final[:h_orig, :w_orig]  # Crop to original size

def create_trimmed_weight(patch_size=256, inner_ratio=0.7, sigma=0.3):
    ax = np.linspace(-1, 1, patch_size)
    gauss = np.exp(-0.5 * (ax / sigma) ** 2)
    outer = np.outer(gauss, gauss)
    inner_mask = np.zeros_like(outer)

    start = int(patch_size * ((1 - inner_ratio) / 2))
    end = patch_size - start
    inner_mask[start:end, start:end] = 1.0

    trimmed = np.maximum(inner_mask, outer)
    return trimmed / trimmed.max()


# Process a single image through the model using patching
def process_image_with_overlap_padded(image_path, model, device, patch_size=256, stride=128):
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) / 255.0
    image = image.astype(np.float32)

    patches, positions, padded_shape, original_shape = extract_overlapping_patches_with_padding(
        image, patch_size, stride)

    output_patches = []

    model.eval()
    with torch.no_grad():
        for patch in patches:
            tensor = torch.tensor(patch).unsqueeze(0).unsqueeze(0).to(device)
            out = model(tensor).squeeze().cpu().numpy()
            out = np.clip(out, 0, 1)
            output_patches.append(out)

    despeckled = merge_patches_soft(output_patches, positions, padded_shape, original_shape, patch_size)
    return image, despeckled


# ========== Run it on all images in your folder ==========

input_folder = "/content/imported_images/"
output_folder = "/content/despeckled_results_patched/"
os.makedirs(output_folder, exist_ok=True)

image_paths = sorted([os.path.join(input_folder, f) for f in os.listdir(input_folder) if f.lower().endswith((".png", ".jpg", ".tif", ".tiff"))])

for img_path in image_paths:
    print(f"Processing {img_path}")
    original, despeckled = process_image_with_overlap_padded(img_path, model, device)


    # Save the result
    out_name = os.path.basename(img_path)
    output_path = os.path.join(output_folder, out_name)
    cv2.imwrite(output_path, (despeckled * 255).astype(np.uint8))
    print(f"Saved to {output_path}")

    # Optional: Show comparison
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(original, cmap="gray")
    plt.title("Original SAR")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(despeckled, cmap="gray")
    plt.title("Despeckled")
    plt.axis("off")
    plt.show()


In [None]:
# zip output folder to download ezpzggwp

import shutil
shutil.make_archive("/content/despeckled_results_patched", 'zip', "/content/despeckled_results_patched")
