Colab test

In [None]:
# empty for debug/admin

In [None]:
# This script trains a deep unrolling network for SAR despeckling.
# Steps:
# 0. Pay for premium GPU :)
# 1. Set up dataset (simulated speckle noise on optical images)
# 2. Define model
# 3. Train with Charbonnier + Total Variation loss
# 4. Test on real SAR images/test noisy images

In [None]:
import glob
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Subset, DataLoader
import os
import random
import torch.nn.functional as F
import gc
from tqdm import tqdm  # For progress bar
import torch.cuda.memory as cuda_mem

In [None]:
# If you wanna clear cache.. idk sometimes still high RAM :[
# gc.collect()
# torch.cuda.empty_cache()

In [None]:
# This block loads and unzip learning dataset from G Drive
# Allow all permissions
# You should see /content/drive folder in the file browser

# Step 1: Mount Google Drive
from google.colab import drive
import zipfile
import os

# Select CUDA device (GPU is recommended)
# Automatically select CPU if no GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Make folder to put your own test real SAR images
folder_name = "/content/imported_images"

# Create if doesn't exist
os.makedirs(folder_name, exist_ok=True)

# Make model folder if you want to load your own model
os.makedirs("/content/model", exist_ok=True)  # Make a folder exist

# drive.mount('/content/drive')  # Connect to your Google Drive


# Uncomment below if you want to use validation dataset
# # Define paths for zip files
# zip_path = "/content/drive/MyDrive/SAR_Project/Dataset/SAR_paired.zip"
# extract_path = "/content/test_data/SAR_Dataset"  # Where to extract


# # Step 3: Extract ZIP
# print("Extracting dataset... this may take a moment!")
# with zipfile.ZipFile(zip_path, 'r') as zip_ref:
#     zip_ref.extractall(extract_path)

# print("Extraction complete!")






# Load autoencoder and model

Auto encoder

In [None]:
 # It was initially 512 but I ran out of GPU ram too fast...

class ResUNet_256(nn.Module):
    def __init__(self, channels=1):
        super(ResUNet_256, self).__init__()
        # Encoder Path with BatchNorm
        self.enc1 = self.conv_block(channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 256)

        # Decoder Path with optimized skip connections
        self.dec4 = self.upconv_block(256, 256)
        self.dec3 = self.upconv_block(256, 128)
        self.dec2 = self.upconv_block(128, 64)
        self.dec1 = self.conv_block(64, 32)

        # Final Output Layer
        self.final_conv = nn.Conv2d(32, 1, kernel_size=3, padding=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1)
        )

    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3,
                               stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.1)
        )

    def crop(self, enc_feat, dec_feat):
        """Crop encoder features to match decoder features spatially."""
        _, _, h, w = dec_feat.size()
        return enc_feat[:, :, :h, :w]

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(F.max_pool2d(e1, 2))
        e3 = self.enc3(F.max_pool2d(e2, 2))
        e4 = self.enc4(F.max_pool2d(e3, 2))

        # Decoder with optimized skip connections
        d4_temp = self.dec4(e4)
        d4 = d4_temp + self.crop(e3, d4_temp)

        d3_temp = self.dec3(d4)
        d3 = d3_temp + self.crop(e2, d3_temp)

        d2_temp = self.dec2(d3)
        d2 = d2_temp + self.crop(e1, d2_temp)

        d1 = self.dec1(d2)

        # Residual subtraction (with scaling)
        out = self.final_conv(d1)
        return F.relu(out)


In [None]:
# Debug to check resUnet output (was broken a few times before hand)
test_input = torch.randn(1, 1, 256, 256)  # Simulated SAR patch
model = ResUNet_256()
output = model(test_input)

print("ResUNet Output Shape:", output.shape)  # Should match input e.g. (1, 1, 512, 512)



Model

In [None]:
# ========== SAR DeepSpeck model ==========

class SAR_DeepSpeck(nn.Module):
    def __init__(self, num_layers=8):
        super(SAR_DeepSpeck, self).__init__()
        self.num_layers = num_layers
        self.resunet = ResUNet_256()

        self.gradient_steps = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(1, 8, kernel_size=3, padding=1),
                nn.BatchNorm2d(8),
                nn.ReLU(),
                nn.Conv2d(8, 1, kernel_size=3, padding=1)
            ) for _ in range(num_layers)
        ])

        # Make δ and η trainable parameters
        self.delta = nn.Parameter(torch.tensor(0.01, dtype=torch.float32))  # Init to 0.01
        self.eta = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))    # Init to 1.0

    def forward(self, x):
        v = x  # Initial guess
        for i in range(self.num_layers):
            noise_est = self.resunet(v)     # Estimate noise at current step
            x = x - self.delta * (self.eta * self.gradient_steps[i](noise_est))  # Update x
            v = torch.relu(x) # Clamp vlaues to avoid negative value

        return x - self.resunet(v)
        # return x + v

# Load previous model (if any)

In [None]:
##########################
# To load other models created before
#########################

# Put model in this directory: /content/model/

# Load the trained SAR-DURNet model
model = SAR_DeepSpeck().to(device)

# Load the saved model weights
checkpoint_path = "/content/model/SAR_DeepSpeck.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))

# Set the model to evaluation mode
model.eval()

print("Model successfully loaded!")

## Test on validation dataset
Only if you loaded the validation dataset, else skip to real SAR images

In [None]:
# ============= To test on validation images ==============================

import glob
import cv2
import torch
import os
import matplotlib.pyplot as plt
import numpy as np

# Define paths
noisy_folder = "/content/test_data/SAR_Dataset/SAR despeckling filters dataset/Main folder/Noisy_val"
clean_folder = "/content/test_data/SAR_Dataset/SAR despeckling filters dataset/Main folder/GTruth_val"
output_folder = "/content/despeckled_val_results"  # Where despeckled images will be saved

# Make sure output folder exists
os.makedirs(output_folder, exist_ok=True)

# Load model in eval mode
model.eval()

# Load noisy & clean image pairs
noisy_images = sorted(glob.glob(noisy_folder + "/*.tiff"))
clean_images = sorted(glob.glob(clean_folder + "/*.tiff"))

assert len(noisy_images) == len(clean_images), "Mismatch in number of noisy and clean images!"

# Store outputs for later visualization
results = []

# Resize to multiple of 32 for U-Net compatibility
def resize_to_multiple_of_32(image):
    h, w = image.shape
    new_h = (h // 32) * 32
    new_w = (w // 32) * 32
    return cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)

# Process each image
for noisy_path, clean_path in zip(noisy_images, clean_images):
    # print(f"Processing {noisy_path}...")

    # Load images
    noisy = cv2.imread(noisy_path, cv2.IMREAD_GRAYSCALE) / 255.0
    clean = cv2.imread(clean_path, cv2.IMREAD_GRAYSCALE) / 255.0

    # Resize
    noisy_resized = resize_to_multiple_of_32(noisy)
    clean_resized = resize_to_multiple_of_32(clean)

    # Convert to tensor
    noisy_tensor = torch.tensor(noisy_resized, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)

    # Run through model
    with torch.no_grad():
        despeckled = model(noisy_tensor).squeeze().cpu().numpy()
    despeckled = np.clip(despeckled, 0, 1)

    # Save despeckled image
    output_path = os.path.join(output_folder, os.path.basename(noisy_path))
    cv2.imwrite(output_path, (despeckled * 255).astype(np.uint8))

    # Store for visualization
    results.append((noisy_resized, despeckled, clean_resized))

print("ALL VALIDATION IMAGES PROCESSED!")




In [None]:
# Show 15 random samples at the end
num_samples = min(15, len(results))
sample_indices = np.random.choice(len(results), num_samples, replace=False)

plt.figure(figsize=(5, 30))
for i, idx in enumerate(sample_indices):
    noisy, despeckled, clean = results[idx]

    plt.subplot(num_samples, 3, i * 3 + 1)
    plt.imshow(noisy, cmap="gray")
    plt.title("Noisy Input")
    plt.axis("off")

    plt.subplot(num_samples, 3, i * 3 + 2)
    plt.imshow(despeckled, cmap="gray")
    plt.title("Despeckled Output")
    plt.axis("off")

    plt.subplot(num_samples, 3, i * 3 + 3)
    plt.imshow(clean, cmap="gray")
    plt.title("Ground Truth")
    plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
# zip output folder to download ezpzggwp

import shutil
shutil.make_archive("/content/despeckled_val_results", 'zip', "/content/despeckled_val_results")

# To use model on real SAR images

To use model on real SAR images

In [None]:
# =============== To test on real SAR images ================================

import os
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt

def extract_overlapping_patches_with_padding(image, patch_size=256, stride=128):
    h, w = image.shape

    pad_h = (np.ceil((h - patch_size) / stride) * stride + patch_size - h).astype(int)
    pad_w = (np.ceil((w - patch_size) / stride) * stride + patch_size - w).astype(int)

    padded_image = np.pad(image, ((0, pad_h), (0, pad_w)), mode='reflect')
    padded_h, padded_w = padded_image.shape

    patches = []
    positions = []

    for y in range(0, padded_h - patch_size + 1, stride):
        for x in range(0, padded_w - patch_size + 1, stride):
            patch = padded_image[y:y+patch_size, x:x+patch_size]
            patches.append(patch)
            positions.append((y, x))

    return patches, positions, padded_image.shape, (h, w)


def merge_overlapping_patches(patches, positions, padded_shape, original_shape, patch_size=256):
    h_pad, w_pad = padded_shape
    h_orig, w_orig = original_shape

    recon = np.zeros((h_pad, w_pad), dtype=np.float32)
    weight = np.zeros((h_pad, w_pad), dtype=np.float32)

    for patch, (y, x) in zip(patches, positions):
        recon[y:y+patch_size, x:x+patch_size] += patch
        weight[y:y+patch_size, x:x+patch_size] += 1.0

    recon /= weight
    return recon[:h_orig, :w_orig]  # Crop back to original

def create_gaussian_weight(patch_size=256, sigma=0.125):
    """Create a smooth 2D Gaussian weight mask for blending overlapping patches."""
    ax = np.linspace(-1, 1, patch_size)
    gauss = np.exp(-0.5 * (ax / sigma) ** 2)
    weight = np.outer(gauss, gauss)
    return weight / weight.max()


def merge_patches_soft(patches, positions, padded_shape, original_shape, patch_size=256):
    h_pad, w_pad = padded_shape
    h_orig, w_orig = original_shape

    recon = np.zeros((h_pad, w_pad), dtype=np.float32)
    weight_sum = np.zeros((h_pad, w_pad), dtype=np.float32)

    weight_mask = create_trimmed_weight(patch_size)  # Soft blending mask

    for patch, (y, x) in zip(patches, positions):
        weighted_patch = patch * weight_mask
        recon[y:y+patch_size, x:x+patch_size] += weighted_patch
        weight_sum[y:y+patch_size, x:x+patch_size] += weight_mask

    final = recon / (weight_sum + 1e-8)
    return final[:h_orig, :w_orig]  # Crop to original size

def create_trimmed_weight(patch_size=256, inner_ratio=0.7, sigma=0.3):
    ax = np.linspace(-1, 1, patch_size)
    gauss = np.exp(-0.5 * (ax / sigma) ** 2)
    outer = np.outer(gauss, gauss)
    inner_mask = np.zeros_like(outer)

    start = int(patch_size * ((1 - inner_ratio) / 2))
    end = patch_size - start
    inner_mask[start:end, start:end] = 1.0

    trimmed = np.maximum(inner_mask, outer)
    return trimmed / trimmed.max()


# Process a single image through the model using patching
def process_image_with_overlap_padded(image_path, model, device, patch_size=256, stride=128):
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) / 255.0
    image = image.astype(np.float32)

    patches, positions, padded_shape, original_shape = extract_overlapping_patches_with_padding(
        image, patch_size, stride)

    output_patches = []

    model.eval()
    with torch.no_grad():
        for patch in patches:
            tensor = torch.tensor(patch).unsqueeze(0).unsqueeze(0).to(device)
            out = model(tensor).squeeze().cpu().numpy()
            out = np.clip(out, 0, 1)
            output_patches.append(out)

    despeckled = merge_patches_soft(output_patches, positions, padded_shape, original_shape, patch_size)
    return image, despeckled


# ========== Run it on all images in your folder ==========

input_folder = "/content/imported_images/"
output_folder = "/content/despeckled_results_patched/"
os.makedirs(output_folder, exist_ok=True)

image_paths = sorted([os.path.join(input_folder, f) for f in os.listdir(input_folder) if f.lower().endswith((".png", ".jpg", ".tif", ".tiff"))])

for img_path in image_paths:
    print(f"Processing {img_path}")
    original, despeckled = process_image_with_overlap_padded(img_path, model, device)


    # Save the result
    out_name = os.path.basename(img_path)
    output_path = os.path.join(output_folder, out_name)
    cv2.imwrite(output_path, (despeckled * 255).astype(np.uint8))
    print(f"Saved to {output_path}")

    # Optional: Show comparison
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(original, cmap="gray")
    plt.title("Original SAR")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(despeckled, cmap="gray")
    plt.title("Despeckled")
    plt.axis("off")
    plt.show()


In [None]:
# zip output folder to download ezpzggwp

import shutil
shutil.make_archive("/content/despeckled_results_patched", 'zip', "/content/despeckled_results_patched")
