In [1]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm


class SuperResolutionDataset(Dataset):
    def __init__(self, hr_dir, scale = 2, transform = None, patch_size = 41, stride = 21):
        self.hr_dir = hr_dir
        self.scale = scale
        self.transform = transform
        self.patch_size = patch_size
        self.stride = stride
        self.hr_img = [os.path.join(hr_dir, img) for img in os.listdir(hr_dir) ]
        self.hr_patches, self.lr_patches = self.prepare_dataset()
        
    def prepare_dataset(self):
        hr_patches = []
        lr_patches = []
        for img in self.hr_img:
            image = cv2.imread(img)
#             image = preprocess_image(image)
            image = image.astype(np.float32)
#             image = torch.from_numpy(image).permute(2, 0, 1)
#             print(image.shape)
            image = image/255.0
            patches = self.extract_patches(image, self.patch_size, self.stride)
            hr_patches.extend(patches)
            lr_patches.extend(self.low_res_patches(patches))
        return hr_patches, lr_patches

    def extract_patches(self, image, patch, stride):
        patches = []
        h, w = image.shape[:2]
        for i in range(0, h-patch+1, stride):
            for j in range(0, w-patch+1, stride):
                new_patch = image[i:i+patch, j:j+patch]
                patches.append(new_patch)
        return patches
    
    def low_res_patches(self, patches):
        lr_patches = []
        for patch in patches:
            lr_patch = cv2.resize(patch, (patch.shape[1] // self.scale, patch.shape[0] // self.scale), interpolation=cv2.INTER_CUBIC)
            lr_patch = cv2.resize(lr_patch, (patch.shape[1], patch.shape[0]), interpolation=cv2.INTER_CUBIC)
            lr_patches.append(lr_patch)
            
        return lr_patches
    
    def __len__(self):
        return len(self.hr_patches)

    def __getitem__(self, idx):
        hr_patch = self.hr_patches[idx]
        lr_patch = self.lr_patches[idx]
        
        if self.transform:
            hr_patch = self.transform(hr_patch)
            lr_patch = self.transform(lr_patch)
        
        return lr_patch, hr_patch
    
hr_folder_91 = '91/'
train_dataset = SuperResolutionDataset(hr_dir=hr_folder_91, transform=None)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

# Check dimensions
for lr_patches, hr_patches in train_loader:
    print(f"Low-Resolution Patches: {lr_patches}, High-Resolution Patches: {hr_patches}")
    cv2.imshow()
    break

Low-Resolution Patches: tensor([[[[ 1.7525e-01,  4.5395e-01,  3.8735e-01],
          [ 1.6797e-01,  4.5051e-01,  3.7313e-01],
          [ 1.5540e-01,  4.4479e-01,  3.4964e-01],
          ...,
          [ 1.0442e-01,  4.5375e-01,  3.0080e-01],
          [ 1.0478e-01,  4.6123e-01,  3.0648e-01],
          [ 1.0516e-01,  4.6589e-01,  3.1009e-01]],

         [[ 1.9150e-01,  4.4175e-01,  3.9946e-01],
          [ 1.8364e-01,  4.4112e-01,  3.8355e-01],
          [ 1.7051e-01,  4.4012e-01,  3.5733e-01],
          ...,
          [ 1.0429e-01,  4.5629e-01,  3.0489e-01],
          [ 1.0246e-01,  4.6060e-01,  3.0862e-01],
          [ 1.0136e-01,  4.6325e-01,  3.1100e-01]],

         [[ 2.1802e-01,  4.1982e-01,  4.2166e-01],
          [ 2.0931e-01,  4.2432e-01,  4.0261e-01],
          [ 1.9554e-01,  4.3188e-01,  3.7130e-01],
          ...,
          [ 1.0275e-01,  4.5893e-01,  3.1180e-01],
          [ 9.7486e-02,  4.5827e-01,  3.1252e-01],
          [ 9.4021e-02,  4.5772e-01,  3.1297e-01]],

       

error: OpenCV(4.9.0) :-1: error: (-5:Bad argument) in function 'imshow'
> Overload resolution failed:
>  - imshow() missing required argument 'winname' (pos 1)
>  - imshow() missing required argument 'winname' (pos 1)
>  - imshow() missing required argument 'winname' (pos 1)


In [None]:
pip install tqdm


In [2]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch
from tqdm import tqdm
import time

# Check if MPS (Metal Performance Shaders) is available for Apple silicon devices
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

class DRCN(nn.Module):
    def __init__(self, num_channels=3, base_filter=256, num_recursions=5):
        super(DRCN, self).__init__()
        self.num_recursions = num_recursions
        
        self.embedding_layers = nn.Sequential(
            nn.Conv2d(num_channels, base_filter, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_filter, base_filter, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        
        self.recursive_layer = nn.Conv2d(base_filter, base_filter, kernel_size=3, padding=1)
        
        self.reconstruction = nn.Conv2d(base_filter, num_channels, kernel_size=3, padding=1)
        
        # Initialize weights for weighted average
        self.weights = nn.Parameter(torch.ones(num_recursions + 1) / (num_recursions + 1))
        
    
    def forward(self, x):
        embedding_out = self.embedding_layers(x)
        out = embedding_out 
        feature_maps = [out]
        
        for _ in range(self.num_recursions):
            out = self.recursive_layer(out)
            feature_maps.append(out)
        
        # Compute weighted average of outputs
        outputs = [self.reconstruction(fm) for fm in feature_maps]
        weighted_sum = sum(w * self.reconstruction(fm) for w, fm in zip(self.weights, feature_maps))
        
        return (weighted_sum, outputs)

model = DRCN().to(device)

# Example loss function combining all supervision points
def loss_fn(outputs, target, feature_maps, alpha=0.7):
    weighted_avg = outputs
    layers_losses = [F.mse_loss(fm, target) for fm in feature_maps]
    layers_loss = sum(layers_losses) / len(layers_losses)
    l2_loss = F.mse_loss(weighted_avg, target)
    
    # Regularization term on weights
#     reg_loss = beta * torch.sum(model.weights**2)
    
    total_loss = alpha * layers_loss + (1 - alpha) * l2_loss
    return total_loss


In [3]:
for recursion in range(3, 11):
    model = DRCN(num_recursions = recursion).to(device)
    def ultimate_train(model, dataloader, num_epochs=30, alpha=0.9, patience=5):
        # Using SGD with momentum and weight decay
        optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.001)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=patience, verbose=True, min_lr=1e-12)

        model.train()
        best_loss = float('inf')
        print(f"no_of_recursions: {recursion}")
        for epoch in range(num_epochs):
            start_time = time.time()
            epoch_loss = 0.0
            with tqdm(total=len(dataloader), desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch') as pbar:
                for lr_img, hr_img in dataloader:
                    lr_img, hr_img = lr_img.to(device), hr_img.to(device)
                    lr_img = lr_img.permute(0, 3, 1, 2).float()  # (batch_size, height, width, channels) -> (batch_size, channels, height, width)
                    hr_img = hr_img.permute(0, 3, 1, 2).float()
                    optimizer.zero_grad()
                    weighted_output, output = model(lr_img)
                    loss = loss_fn(weighted_output, hr_img, output, alpha)
                    loss.backward()
                    optimizer.step()
                    epoch_loss += loss.item()
                    pbar.set_postfix(loss=epoch_loss / (pbar.n + 1))
                    pbar.update(1)
            end_time = time.time()
            elapsed_time = end_time - start_time
            avg_epoch_loss = epoch_loss / len(dataloader)
            current_lr = optimizer.param_groups[0]['lr']
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_epoch_loss}, Time: {elapsed_time:.2f}s")
            print(f"Current learning rate: {current_lr}")

            # Step the scheduler
            scheduler.step(avg_epoch_loss)

            # Save best model
            if avg_epoch_loss < best_loss:
                best_loss = avg_epoch_loss
                torch.save(model.state_dict(), f"{recursion}_rec_256_filter_best_model_sf2.pth")

    ultimate_train(model, train_loader, num_epochs=100)

    



no_of_recursions: 3


Epoch 1/100:   4%|▌             | 26/587 [00:03<01:11,  7.86batch/s, loss=0.097]


KeyboardInterrupt: 

In [None]:
# loss.backward()

for name, param in model.named_parameters():
    if param.grad is None:
        print(f"No gradient for {name}")
    else:
        print(f"Gradient for {name}: {param.grad.abs().mean().item()}")

In [7]:
import cv2
psnr = 0

for i in range(1, 6):

    low_resolution_image = cv2.imread(f"Set5/image_SRF_2/img_00{i}_SRF_2_LR.png")
    high_resolution_image = cv2.imread(f"Set5/image_SRF_2/img_00{i}_SRF_2_HR.png")
    high_resolution_image = torch.tensor(high_resolution_image)
    
    low_resol_cubic = cv2.resize(low_resolution_image, (high_resolution_image.shape[1], high_resolution_image.shape[0]),
                                 interpolation=cv2.INTER_CUBIC)

    low_resol_cubic = low_resol_cubic.astype(np.float32)
    low_resol_cubic /= 255.0
    low_resol_cubic = torch.from_numpy(low_resol_cubic)
    low_resol_cubic = low_resol_cubic.permute(2, 0, 1)
    low_resol_cubic = low_resol_cubic.to(device)

    final_image,  _ = model(low_resol_cubic)

    final_image = final_image.permute(1, 2, 0)

    final_image = final_image.cpu()
    final_image = final_image.detach().numpy()
    
    psnr += calculate_psnr(final_image, high_resolution_image/255)
    
print(psnr)



-192.611680425026


In [8]:
def calculate_psnr(original_image, compressed_image):
    # Convert images to float32
    original_image = original_image.astype(np.float32)
    compressed_image = compressed_image.cpu()
    compressed_image = compressed_image.detach().numpy()
    compressed_image = compressed_image.astype(np.float32)

    # Compute the Mean Squared Error (MSE)
    mse = np.mean((original_image - compressed_image) ** 2)

    if mse == 0:
        return float('inf') 

    # Assume 8-bit image
    max_pixel_value = 1.0

    # Compute PSNR
    psnr = 10 * np.log10((max_pixel_value ** 2) / mse)
    
    return psnr


In [9]:
for i in range(4, 8):
    model2 = DRCN(num_recursions=i, base_filter=64).to(device)
    model2.load_state_dict(torch.load(f"{i}_rec_256_filter_best_model_sf2.pth"))
    model2.eval()
    psnr = 0
    for j in range(1, 6):

        low_resolution_image = cv2.imread(f"Downloads/Set5/image_SRF_2/img_00{j}_SRF_2_LR.png")
        high_resolution_image = cv2.imread(f"Downloads/Set5/image_SRF_2/img_00{j}_SRF_2_HR.png")
        high_resolution_image = torch.tensor(high_resolution_image)
        

        low_resol_cubic = cv2.resize(low_resolution_image, (high_resolution_image.shape[1], high_resolution_image.shape[0]),
                                     interpolation=cv2.INTER_CUBIC)


        low_resol_cubic = low_resol_cubic.astype(np.float32)
        low_resol_cubic /= 255
        low_resol_cubic = torch.from_numpy(low_resol_cubic)
        low_resol_cubic = low_resol_cubic.permute(2, 0, 1)
        low_resol_cubic = low_resol_cubic.to(device)

        final_image,  _ = model2(low_resol_cubic)

        final_image = final_image.permute(1, 2, 0)

        final_image = final_image.cpu()
        final_image = final_image.detach().numpy()

        psnr += calculate_psnr(final_image, high_resolution_image)

    print(f"Psnr of model with recursions {i}:", psnr/5)


Psnr of model with recursions 4: 31.737763852785633
Psnr of model with recursions 5: 31.677008259662983
Psnr of model with recursions 6: 31.622615623541737
Psnr of model with recursions 7: 31.647760980471173
