# Super-Resolution

## Data Prep

In [None]:
import os
import torch
import rasterio
from torch.utils.data import Dataset, DataLoader
from skimage.transform import resize
import matplotlib.pyplot as plt
from transformers import Swin2SRImageProcessor, Swin2SRForImageSuperResolution
from torch.optim import Adam
from tqdm import tqdm
from PIL import Image
import numpy as np

hr_size = 128 

class SRDataset(Dataset):
    def __init__(self, data_paths, lr_size=(64, 64), hr_size=(128, 128)):
        self.data_paths = data_paths
        self.lr_size = lr_size
        self.hr_size = hr_size

    def __len__(self):
        return len(self.data_paths)

    def __getitem__(self, idx):
        image_path = self.data_paths[idx]
        
        img = Image.open(image_path).convert('RGB')  # ensuring RGB format
        
        img = np.array(img) 
        
        # Resize low-res and high-res images
        lr_img = resize(img, self.lr_size, anti_aliasing=True, preserve_range=True)  # 64
        hr_img = resize(img, self.hr_size, anti_aliasing=True, preserve_range=True)  # 128

        lr_img = lr_img / 255.0 # normalisation
        hr_img = hr_img / 255.0

        # Convert to tensors and add RGB channels
        lr_tensor = torch.tensor(lr_img, dtype=torch.float32).permute(2, 0, 1)  # [3, 64, 64] # this is to change the order of the dimensions
        hr_tensor = torch.tensor(hr_img, dtype=torch.float32).permute(2, 0, 1)  # [3, 128, 128]
 

        return lr_tensor, hr_tensor, image_path.split("\\")[-1].split(".")[0]


def filter_120x120_images(datasetPath):
    imagePaths = []
    for root, dirs, files in os.walk(datasetPath):
        for file in files:
            image_path = os.path.join(root, file)
            with rasterio.open(image_path) as src:
                # Check if the image is 120x120
                if src.width == 120 and src.height == 120:
                    imagePaths.append(image_path)
                else:
                    print(f"Skipping image {file} with size {src.width}x{src.height}")

    return imagePaths

datasetPath = r"SampleDataset"
imagePaths = filter_120x120_images(datasetPath)

dataset = SRDataset(imagePaths, lr_size=(64, 64), hr_size=(128, 128)) 
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

print(len(dataset))


## Training

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from skimage import filters
from transformers import Swin2SRImageProcessor, Swin2SRForImageSuperResolution
from tqdm import tqdm

processor = Swin2SRImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
# Define Gaussian blur with a kernel size
blur = T.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0))

# Function to detect keypoints and descriptors using SIFT
def detect_edges(image_tensor):
    image_np = image_tensor.cpu().permute(1, 2, 0).numpy()
    gray_img = (image_np.mean(axis=2) * 255).astype('uint8')  # Convert to grayscale
    
    sift = cv2.SIFT_create()
    keypoints, descriptors = sift.detectAndCompute(gray_img, None)
    
    # Create a blank image and draw keypoints for visualization
    sift_image = cv2.drawKeypoints(gray_img, keypoints, None)
    
    # Convert back to tensor
    sift_tensor = torch.tensor(sift_image, dtype=torch.float32).unsqueeze(0)
    return sift_tensor


from skimage.feature import corner_harris

# Function to detect corners
def detect_edges(image_tensor):
    image_np = image_tensor.cpu().permute(1, 2, 0).numpy()
    gray_img = image_np.mean(axis=2)  # Convert to grayscale
    corners = corner_harris(gray_img)  # Harris corner detection
    corner_tensor = torch.tensor(corners, dtype=torch.float32).unsqueeze(0)  # Shape [1, H, W]
    return corner_tensor


# Super-resolution guided by edge detection
def feature_guided_super_resolution(lr_tensor):
    edge_tensor = detect_edges(lr_tensor[0])  # Detect edges in the 64x64 image
    edge_tensor = edge_tensor.to(device)
    
    # Combine the LR image (64x64) with edges for intermediate SR
    combined_tensor = combine_image_with_features(lr_tensor[0], edge_tensor)

    # Perform the first stage of super-resolution (64x64 -> 128x128)
    with torch.no_grad():
        inputs = {'pixel_values': combined_tensor.unsqueeze(0).to(device)}
        outputs = model(**inputs)
        sr_image = outputs.reconstruction  # Direct 128x128 super-resolved output

    # Apply Gaussian blur to smooth the final image and reduce the edge highlights
    sr_image = blur(sr_image)

    return sr_image, edge_tensor

# Visualization function
def visualize_images_with_features(lr_image, final_sr_image, edge_tensor):
    plt.figure(figsize=(12, 6))
    
    plt.subplot(1, 3, 1)
    lr_img = lr_image.permute(1, 2, 0).cpu().numpy()
    plt.imshow(lr_img)
    plt.title("Low-Resolution Input (64x64)")
    
    plt.subplot(1, 3, 2)
    plt.imshow(edge_tensor.cpu().squeeze(), cmap='gray')
    plt.title("Detected Edges")

    plt.subplot(1, 3, 3)
    final_sr_img = final_sr_image[0].permute(1, 2, 0).cpu().detach().numpy()
    plt.imshow(final_sr_img)
    plt.title("Super-Resolved Output (128x128) [Original]")

    plt.show()

# Apply 2x SR with edge detection
for batch_idx, (lr_tensor, hr_tensor, img_path) in enumerate(dataloader):
    lr_tensor = lr_tensor.to(device)
    
    # 64x64 to 128x128 super-resolution
    final_sr_image, edge_tensor = feature_guided_super_resolution(lr_tensor)

    # Visualize the 64x64 input and 128x128 output with edges
    visualize_images_with_features(lr_tensor[0], final_sr_image, edge_tensor)

    break  # Stop after one batch for demonstration

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.transforms as T
import cv2
import numpy as np

# Define Gaussian blur with a kernel size
blur = T.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0))

def extract_sift_features(image_tensor):
    # Convert tensor to numpy array and prepare for SIFT
    image_np = image_tensor.cpu().permute(1, 2, 0).numpy()
    gray_img = cv2.cvtColor(np.uint8(image_np * 255), cv2.COLOR_RGB2GRAY)
    
    # Initialize SIFT detector
    sift = cv2.SIFT_create()
    
    # Detect keypoints and compute descriptors
    keypoints, descriptors = sift.detectAndCompute(gray_img, None)
    
    # Create a feature map highlighting SIFT keypoints
    feature_map = np.zeros_like(gray_img, dtype=np.float32)
    
    # Draw keypoints with their scales
    for kp in keypoints:
        x, y = int(kp.pt[0]), int(kp.pt[1])
        size = int(kp.size)
        cv2.circle(feature_map, (x, y), size, 1.0, -1)
    
    # Normalize and convert to tensor
    feature_map = feature_map / (feature_map.max() + 1e-8)
    feature_tensor = torch.tensor(feature_map, dtype=torch.float32).unsqueeze(0)
    
    return feature_tensor, keypoints, descriptors

def combine_image_with_features(image_tensor, feature_tensor, weight=0.3):
    # Normalize image and feature tensors
    image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min())
    feature_tensor = (feature_tensor - feature_tensor.min()) / (feature_tensor.max() - feature_tensor.min())
    
    # Add SIFT features to the image with specified weight
    combined_tensor = image_tensor + (weight * feature_tensor)
    combined_tensor = combined_tensor / (combined_tensor.max() + 1e-8)
    
    return combined_tensor

def sift_guided_super_resolution(lr_tensor):
    # Extract SIFT features from the LR image
    sift_tensor, keypoints, descriptors = extract_sift_features(lr_tensor[0])
    sift_tensor = sift_tensor.to(device)
    
    # Combine the LR image with SIFT features
    combined_tensor = combine_image_with_features(lr_tensor[0], sift_tensor)

    # Perform super-resolution
    with torch.no_grad():
        inputs = {'pixel_values': combined_tensor.unsqueeze(0).to(device)}
        outputs = model(**inputs)
        sr_image = outputs.reconstruction
    
    # Apply Gaussian blur to smooth the final image
    sr_image = blur(sr_image)

    return sr_image, sift_tensor, keypoints

def visualize_images_with_sift(lr_image, final_sr_image, sift_tensor, keypoints):
    plt.figure(figsize=(15, 5))
    
    # Display LR input
    plt.subplot(1, 3, 1)
    lr_img = lr_image.permute(1, 2, 0).cpu().numpy()
    plt.imshow(lr_img)
    plt.title("Low-Resolution Input (64x64)")
    
    # Display SIFT features
    plt.subplot(1, 3, 2)
    lr_img_with_kp = lr_img.copy()
    plt.imshow(sift_tensor.cpu().squeeze(), cmap='hot')
    plt.title("SIFT Feature Map")
    
    # Display SR output
    plt.subplot(1, 3, 3)
    final_sr_img = final_sr_image[0].permute(1, 2, 0).cpu().detach().numpy()
    plt.imshow(final_sr_img)
    plt.title("Super-Resolved Output (128x128)")

    plt.tight_layout()
    plt.show()

# Apply 2x SR with SIFT features
for batch_idx, (lr_tensor, hr_tensor, img_path) in enumerate(dataloader):
    lr_tensor = lr_tensor.to(device)
    
    # 64x64 to 128x128 super-resolution with SIFT guidance
    final_sr_image, sift_tensor, keypoints = sift_guided_super_resolution(lr_tensor)

    # Visualize results
    visualize_images_with_sift(lr_tensor[0], final_sr_image, sift_tensor, keypoints)
    
    break  # Stop after one batch for demonstration

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.transforms as T
import cv2
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define Gaussian blur with a kernel size
blur = T.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0))

def detect_satellite_edges(image_tensor):
    # Convert tensor to numpy array
    image_np = image_tensor.cpu().permute(1, 2, 0).numpy()
    
    # Convert to grayscale
    gray_img = cv2.cvtColor(np.uint8(image_np * 255), cv2.COLOR_RGB2GRAY)
    
    # Apply Gaussian blur to reduce noise
    blurred = cv2.GaussianBlur(gray_img, (3, 3), 0)
    
    # Apply Canny edge detection with automatic threshold detection
    median = np.median(blurred)
    lower = int(max(0, (1.0 - 0.33) * median))
    upper = int(min(255, (1.0 + 0.33) * median))
    edges = cv2.Canny(blurred, lower, upper)
    
    # Dilate edges slightly to make them more prominent
    kernel = np.ones((2,2), np.uint8)
    edges = cv2.dilate(edges, kernel, iterations=1)
    
    # Convert to tensor
    edge_tensor = torch.tensor(edges, dtype=torch.float32).unsqueeze(0)
    edge_tensor = edge_tensor / 255.0  # Normalize to [0,1]
    
    return edge_tensor

def enhance_edges(edge_tensor):
    """Additional processing to enhance edges for satellite imagery"""
    # Apply morphological operations to strengthen significant edges

    edge_tensor = edge_tensor.to('cpu')
    
    edge_np = edge_tensor.squeeze().numpy()
    
    # Close small gaps
    kernel = np.ones((3,3), np.uint8)
    edge_np = cv2.morphologyEx(edge_np, cv2.MORPH_CLOSE, kernel)
    
    # Remove small noise
    edge_np = cv2.morphologyEx(edge_np, cv2.MORPH_OPEN, kernel)
    
    return torch.tensor(edge_np, dtype=torch.float32).unsqueeze(0)

def combine_image_with_features(image_tensor, feature_tensor):
    # Normalize image and feature tensors
    image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min())
    feature_tensor = (feature_tensor - feature_tensor.min()) / (feature_tensor.max() - feature_tensor.min())
    
    # Add edges to the image with an appropriate weight for satellite imagery
    weight = 0.4  # Increased weight for satellite imagery edges
    image_tensor = image_tensor.to(device)
    feature_tensor = feature_tensor.to(device)
    combined_tensor = image_tensor + (weight * feature_tensor)
    combined_tensor = combined_tensor / (combined_tensor.max() + 1e-8)
    
    return combined_tensor

def edge_guided_super_resolution(lr_tensor):
    # Detect edges in the LR image
    edge_tensor = detect_satellite_edges(lr_tensor[0])
    edge_tensor = edge_tensor.to(device)
    
    # Enhance edges
    edge_tensor = enhance_edges(edge_tensor)
    
    # Combine the LR image with edges
    combined_tensor = combine_image_with_features(lr_tensor[0], edge_tensor)

    # Perform super-resolution
    with torch.no_grad():
        inputs = {'pixel_values': combined_tensor.unsqueeze(0).to(device)}
        outputs = model(**inputs)
        sr_image = outputs.reconstruction
    
    # Apply subtle Gaussian blur to smooth any edge artifacts
    sr_image = blur(sr_image)

    return sr_image, edge_tensor

def visualize_satellite_images(lr_image, final_sr_image, edge_tensor):
    plt.figure(figsize=(15, 5))
    
    # Display LR input
    plt.subplot(1, 3, 1)
    lr_img = lr_image.permute(1, 2, 0).cpu().numpy()
    plt.imshow(lr_img)
    plt.title("Low-Resolution Satellite Image")
    
    # Display detected edges
    plt.subplot(1, 3, 2)
    plt.imshow(edge_tensor.cpu().squeeze(), cmap='gray')
    plt.title("Detected Edges")
    
    # Display SR output
    plt.subplot(1, 3, 3)
    final_sr_img = final_sr_image[0].permute(1, 2, 0).cpu().detach().numpy()
    plt.imshow(final_sr_img)
    plt.title("Super-Resolved Output")

    plt.tight_layout()
    plt.show()

# Apply SR with edge detection
for batch_idx, (lr_tensor, hr_tensor, img_path) in enumerate(dataloader):
    lr_tensor = lr_tensor.to(device)
    
    # Perform edge-guided super-resolution
    final_sr_image, edge_tensor = edge_guided_super_resolution(lr_tensor)

    # Visualize results
    visualize_satellite_images(lr_tensor[0], final_sr_image, edge_tensor)
    
    break  # Stop after one batch for demonstration

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.transforms as T
import cv2
import numpy as np

# Define Gaussian blur with a kernel size
blur = T.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0))

def detect_harris_corners(image_tensor):
    """
    Apply Harris corner detection and convert corners to edge-like features
    """
    # Convert tensor to numpy array
    image_np = image_tensor.cpu().permute(1, 2, 0).numpy()
    
    # Convert to grayscale and normalize
    gray_img = cv2.cvtColor(np.uint8(image_np * 255), cv2.COLOR_RGB2GRAY)
    
    # Apply Harris corner detection
    harris_response = cv2.cornerHarris(gray_img, blockSize=2, ksize=3, k=0.04)
    
    # Normalize the response
    harris_response = cv2.normalize(harris_response, None, 0, 1, cv2.NORM_MINMAX)
    
    # Dilate to connect nearby corners and create edge-like features
    kernel = np.ones((3,3), np.uint8)
    harris_dilated = cv2.dilate(harris_response, kernel)
    
    # Convert to tensor
    feature_tensor = torch.tensor(harris_dilated, dtype=torch.float32).unsqueeze(0)
    
    return feature_tensor

def combine_image_with_features(image_tensor, feature_tensor):
    """
    Combine the original image with Harris corner features
    """
    # Normalize tensors
    image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min())
    feature_tensor = (feature_tensor - feature_tensor.min()) / (feature_tensor.max() - feature_tensor.min())
    
    # Combine with appropriate weight for satellite imagery
    weight = 0.3
    combined = image_tensor + (weight * feature_tensor)
    combined = combined / (combined.max() + 1e-8)
    
    return combined

def harris_guided_super_resolution(lr_tensor):
    """
    Perform super-resolution guided by Harris corner features
    """
    # Detect Harris corners and convert to edge-like features
    feature_tensor = detect_harris_corners(lr_tensor[0])
    feature_tensor = feature_tensor.to(device)
    
    # Combine image with corner features
    combined_tensor = combine_image_with_features(lr_tensor[0], feature_tensor)

    # Perform super-resolution
    with torch.no_grad():
        inputs = {'pixel_values': combined_tensor.unsqueeze(0).to(device)}
        outputs = model(**inputs)
        sr_image = outputs.reconstruction
    
    # Apply subtle blur to reduce any artifacts
    sr_image = blur(sr_image)

    return sr_image, feature_tensor

def visualize_results(lr_image, sr_image, feature_tensor):
    """
    Visualize the original, Harris features, and super-resolved images
    """
    plt.figure(figsize=(15, 5))
    
    # Original image
    plt.subplot(1, 3, 1)
    lr_img = lr_image.permute(1, 2, 0).cpu().numpy()
    plt.imshow(lr_img)
    plt.title("Original Satellite Image")
    plt.axis('off')
    
    # Harris corner features
    plt.subplot(1, 3, 2)
    plt.imshow(feature_tensor.cpu().squeeze(), cmap='hot')
    plt.title("Harris Corner Features")
    plt.axis('off')
    
    # Super-resolved image
    plt.subplot(1, 3, 3)
    sr_img = sr_image[0].permute(1, 2, 0).cpu().detach().numpy()
    plt.imshow(sr_img)
    plt.title("Super-Resolved Result")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Process images
for batch_idx, (lr_tensor, hr_tensor, img_path) in enumerate(dataloader):
    lr_tensor = lr_tensor.to(device)
    
    # Apply Harris-guided super-resolution
    sr_image, feature_tensor = harris_guided_super_resolution(lr_tensor)
    
    # Visualize results
    visualize_results(lr_tensor[0], sr_image, feature_tensor)
    
    break  # Process one batch for demonstration

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
from skimage.metrics import peak_signal_noise_ratio, structural_similarity, mean_squared_error
import torchvision.transforms as T

# Define Gaussian blur with a kernel size
blur = T.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0))

def detect_harris_corners(image_tensor):
    """
    Apply Harris corner detection and convert corners to edge-like features
    """
    # Convert tensor to numpy array
    image_np = image_tensor.cpu().permute(1, 2, 0).numpy()
    
    # Convert to grayscale and normalize
    gray_img = cv2.cvtColor(np.uint8(image_np * 255), cv2.COLOR_RGB2GRAY)
    
    # Apply Harris corner detection
    harris_response = cv2.cornerHarris(gray_img, blockSize=2, ksize=3, k=0.04)
    
    # Normalize the response
    harris_response = cv2.normalize(harris_response, None, 0, 1, cv2.NORM_MINMAX)
    
    # Dilate to connect nearby corners and create edge-like features
    kernel = np.ones((3,3), np.uint8)
    harris_dilated = cv2.dilate(harris_response, kernel)
    
    # Convert to tensor
    feature_tensor = torch.tensor(harris_dilated, dtype=torch.float32).unsqueeze(0)
    
    return feature_tensor

def combine_image_with_features(image_tensor, feature_tensor):
    """
    Combine the original image with Harris corner features
    """
    # Normalize tensors
    image_tensor = (image_tensor - image_tensor.min()) / (image_tensor.max() - image_tensor.min())
    feature_tensor = (feature_tensor - feature_tensor.min()) / (feature_tensor.max() - feature_tensor.min())
    
    # Combine with appropriate weight for satellite imagery
    weight = 0.3
    combined = image_tensor + (weight * feature_tensor)
    combined = combined / (combined.max() + 1e-8)
    
    return combined

def super_resolve_images(dataloader, model, device):
    model.eval()
    psnr_scores = []
    ssim_scores = []
    mse_scores = []
    
    for batch_idx, (lr_tensor, hr_tensor, img_path) in enumerate(tqdm(dataloader)):
        lr_tensor = lr_tensor.to(device)
        hr_tensor = hr_tensor.to(device)

        for i in range(hr_tensor.shape[0]):
            # Get current image
            lr_image = lr_tensor[i]
            hr_image = hr_tensor[i]

            # Detect Harris corners and combine with image
            feature_tensor = detect_harris_corners(lr_image)
            feature_tensor = feature_tensor.to(device)
            combined_tensor = combine_image_with_features(lr_image, feature_tensor)

            # Super-resolution pass
            with torch.no_grad():
                inputs = {'pixel_values': combined_tensor.unsqueeze(0)}
                outputs = model(**inputs)
                sr_image = outputs.reconstruction

            # Apply subtle blur to reduce artifacts
            sr_image = blur(sr_image)

            # Resize original HR image to match SR output size
            resized_hr_image = F.interpolate(hr_image.unsqueeze(0), 
                                           size=(sr_image.shape[2], sr_image.shape[3]), 
                                           mode='bilinear', 
                                           align_corners=False).squeeze(0)

            # Move images to CPU and convert to numpy arrays
            hr_img_np = resized_hr_image.cpu().permute(1, 2, 0).numpy()
            sr_img_np = sr_image[0].cpu().permute(1, 2, 0).detach().numpy()

            # Ensure images are in range [0, 1]
            hr_img_np = np.clip(hr_img_np, 0, 1)
            sr_img_np = np.clip(sr_img_np, 0, 1)

            # Calculate metrics
            psnr_value = peak_signal_noise_ratio(hr_img_np, sr_img_np, data_range=1.0)
            ssim_value = structural_similarity(hr_img_np, sr_img_np, channel_axis=2, data_range=1.0, win_size=7)
            mse_value = mean_squared_error(hr_img_np, sr_img_np)

            # Store metrics
            psnr_scores.append(psnr_value)
            ssim_scores.append(ssim_value)
            mse_scores.append(mse_value)

            # Visualize results
            plt.figure(figsize=(15, 5))
            
            # Original HR image
            plt.subplot(1, 3, 1)
            plt.imshow(hr_img_np)
            plt.title(f"Original HR Image\nSize: {hr_img_np.shape[:2]}")
            plt.axis('off')
            
            # Harris corner features
            plt.subplot(1, 3, 2)
            plt.imshow(feature_tensor.cpu().squeeze(), cmap='hot')
            plt.title("Harris Corner Features")
            plt.axis('off')
            
            # Super-resolved image
            plt.subplot(1, 3, 3)
            plt.imshow(sr_img_np)
            plt.title(f"Super-Resolved Image\nSize: {sr_img_np.shape[:2]}\nPSNR: {psnr_value:.2f}")
            plt.axis('off')
            
            plt.tight_layout()
            plt.show()

    # Calculate average metrics
    avg_psnr = np.mean(psnr_scores)
    avg_ssim = np.mean(ssim_scores)
    avg_mse = np.mean(mse_scores)

    print(f"Average PSNR: {avg_psnr:.4f}")
    print(f"Average SSIM: {avg_ssim:.4f}")
    print(f"Average MSE: {avg_mse:.8f}")

    return hr_img_np, sr_img_np, avg_psnr, avg_ssim, avg_mse

# Run the super-resolution process
hr_img, sr_img, avg_psnr, avg_ssim, avg_mse = super_resolve_images(dataloader, model, device)

In [None]:
avg_psnr

In [None]:
avg_ssim

In [None]:
avg_mse