# CUSTOM CNN FOR SR OF REMOTE SENSING IMAGES

## UCMERCED

## Importing Libraries

We'll import the necessary libraries for this project.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
from sklearn.model_selection import train_test_split
import torch.nn.functional as F

# Splitting the UCMERCED Dataset

To effectively work with the UCMERCED dataset, it's essential to split it into training and testing (or validation) sets. This process ensures that you can train your model on one subset and evaluate its performance on another. Follow these steps to achieve this:

## Step 1: Define a Split Function

First, you need to create a function that splits the dataset into training and testing subsets. This can be efficiently done using `sklearn.model_selection.train_test_split`. Here's a sample implementation:

In [None]:
def split_dataset(indices, test_size=0.2, random_state=None, use_subset=False):
    """
    Splits dataset indices into training and testing subsets.

    Parameters:
    - indices (list or array): List of dataset indices.
    - test_size (float): Proportion of the dataset to include in the test split.
    - random_state (int, optional): Seed for the random number generator.

    Returns:
    - train_indices (array): Indices for the training set.
    - test_indices (array): Indices for the testing set.
    """
    train_indices, test_indices = train_test_split(
        indices, test_size=test_size, random_state=random_state
    )
    if use_subset:
        train_indices = train_indices[:100]
        test_indices = test_indices[:100]
    return train_indices, test_indices


class UCMERCEDDataset(Dataset):
    def __init__(self, image_dir, indices=None, transform=None, scale_factor=2):
        """
        Initializes the dataset.

        Parameters:
        - image_dir (str): Directory with images.
        - indices (list or array, optional): Indices to select a subset of the data.
        - transform (callable, optional): A function/transform to apply to the images.
        - scale_factor (int, optional): Factor by which to scale down the high-resolution images.
        """
        self.image_dir = image_dir
        self.transform = transform
        self.scale_factor = scale_factor
        self.indices = indices

        # Get all image files
        self.image_files = []
        for f in os.listdir(image_dir):
            for g in os.listdir(os.path.join(image_dir,f)):
                if g.endswith('.tif'): # change if other dataset with other extesion is used
                    self.image_files.append(os.path.join(f,g))
        
        # If indices are provided, filter the image files
        if self.indices is not None:
            self.image_files = [self.image_files[i] for i in self.indices]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_files[idx])
        hr_image = Image.open(img_name).convert('RGB')

        # Apply transforms if provided
        if self.transform:
            hr_image = self.transform(hr_image)

        # Create low-resolution image tensor by resizing
        lr_image = F.interpolate(
            hr_image.unsqueeze(0),
            scale_factor=1 / self.scale_factor,
            mode='bicubic',
            align_corners=False
        ).squeeze(0)  # Remove batch dimension

        return lr_image, hr_image
    


train_transform = transforms.Compose([
    transforms.Resize(256),              # Resize to 256x256
    transforms.RandomCrop(224),         # Randomly crop and resize to 64x64
    transforms.RandomHorizontalFlip(),  # Random horizontal flip
    transforms.RandomVerticalFlip(),    # Optional: Random vertical flip
    transforms.ToTensor(),              # Convert image to Tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
test_transform = transforms.Compose([
    transforms.Resize(256),              # Resize to 256x256
    transforms.CenterCrop(224),         # Randomly crop and resize to 64x64
    transforms.ToTensor(),              # Convert image to Tensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


image_dir = './UCMerced_LandUse/Images'
dataset = UCMERCEDDataset(image_dir=image_dir)
train_indices, test_indices = split_dataset(list(range(len(dataset))), use_subset=False)

# Instantiate dataset and dataloader
train_dataset = UCMERCEDDataset(image_dir=image_dir, indices=train_indices, transform=train_transform, scale_factor=2)
test_dataset  = UCMERCEDDataset(image_dir=image_dir, indices=test_indices, transform=test_transform, scale_factor=2)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=True, num_workers=4)

print("Size of train set", len(train_loader.dataset))
print("Size of test set",  len(test_loader.dataset))

## Defining the Model

We will define a simple convolutional neural network model for super-resolution based on the SRCNN (Super-Resolution Convolutional Neural Network) architecture.

In [None]:
class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
        
        # Transposed convolution layer to upscale the image
        self.deconv = nn.ConvTranspose2d(3, 3, kernel_size=4, stride=2, padding=1, output_padding=0)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.deconv(x)  # Transposed convolution to upscale the image
        return x

model = SRCNN()
print(model)

## Defining the Loss Function and Optimizer

We'll use Mean Squared Error (MSE) as our loss function and Adam optimizer to update the model weights.

In [None]:
criterion = nn.MSELoss()
# criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

## Training the Model

We'll train the model using the UCMERCED dataset.

In [None]:
def show_images(hr_images, lr_images, sr_images):
   
    fig, axs = plt.subplots(3, 5, figsize=(15, 9), gridspec_kw={'height_ratios': [1, 0.5, 1]})
    for i in range(5):
        axs[0, i].imshow(hr_images[i])
        axs[0, i].set_title('High Resolution')
        axs[0, i].axis('on')
        
        axs[1, i].imshow(lr_images[i])
        axs[1, i].set_title('Low Resolution')
        axs[1, i].axis('on')
        
        axs[2, i].imshow(sr_images[i])
        axs[2, i].set_title('Super-Resolved')
        axs[2, i].axis('on')
    plt.show()


def denormalize(tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    mean = torch.tensor(mean).view(1, 3, 1, 1)
    std = torch.tensor(std).view(1, 3, 1, 1)
    return torch.clip(tensor * std + mean, 0, 1)


def draw_images(test_loader, device):
    lr_images, hr_images = next(iter(test_loader))
    lr_images, hr_images = lr_images.to(device), hr_images.to(device)
    sr_images = model(lr_images)

    denormalized_hr_images = denormalize(hr_images.cpu(), mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    denormalized_lr_images = denormalize(lr_images.cpu(), mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    denormalized_sr_images = denormalize(sr_images.cpu(), mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    hr_images = [transforms.ToPILImage()(img) for img in denormalized_hr_images]
    lr_images = [transforms.ToPILImage()(img) for img in denormalized_lr_images]
    sr_images = [transforms.ToPILImage()(img) for img in denormalized_sr_images]

    show_images(hr_images, lr_images, sr_images)

In [None]:
num_epochs = 50
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    running_loss = 0.0
    for lr_images, hr_images in train_loader:
        lr_images, hr_images = lr_images.to(device), hr_images.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs = model(lr_images)
        loss = criterion(outputs, hr_images)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * lr_images.size(0)

    if epoch % 2:
        draw_images(test_loader, device)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')

print('Training completed')