In [2]:
!pip install -r requirements.txt

Collecting pandas==1.5.3 (from -r requirements.txt (line 1))
  Downloading pandas-1.5.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting numpy==1.23.5 (from -r requirements.txt (line 2))
  Downloading numpy-1.23.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.3 kB)
Collecting opendatasets==0.1.22 (from -r requirements.txt (line 3))
  Downloading opendatasets-0.1.22-py3-none-any.whl.metadata (9.2 kB)
Collecting Pillow==9.4.0 (from -r requirements.txt (line 4))
  Downloading Pillow-9.4.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (9.3 kB)
Collecting torch==2.0.0 (from -r requirements.txt (line 5))
  Downloading torch-2.0.0-cp310-cp310-manylinux1_x86_64.whl.metadata (24 kB)
Collecting torchvision==0.15.1 (from -r requirements.txt (line 7))
  Downloading torchvision-0.15.1-cp310-cp310-manylinux1_x86_64.whl.metadata (11 kB)
Collecting tqdm==4.65.0 (from -r requirements.txt (line 8))
  Downloading tqdm-4.65.0-py3-none-any.whl

## Preparing Data

In [None]:
import random
import glob
import pickle
import os

In [None]:
all_images_list = glob.glob(f"/content/drive/MyDrive/images_001/images/*.png", recursive=True)
len(all_images_list)

# Shuffle the data if the list is not empty
if all_images_list:
    random.shuffle(all_images_list)
else:
    print("No images found. Please check the file path or directory structure.")

# Print the first 10 image paths if available
print(all_images_list[:10])

['/content/drive/MyDrive/images_001/images/00000877_037.png', '/content/drive/MyDrive/images_001/images/00000193_013.png', '/content/drive/MyDrive/images_001/images/00000632_012.png', '/content/drive/MyDrive/images_001/images/00001219_000.png', '/content/drive/MyDrive/images_001/images/00000508_000.png', '/content/drive/MyDrive/images_001/images/00000963_008.png', '/content/drive/MyDrive/images_001/images/00000368_003.png', '/content/drive/MyDrive/images_001/images/00000248_001.png', '/content/drive/MyDrive/images_001/images/00000591_004.png', '/content/drive/MyDrive/images_001/images/00001104_018.png']


In [None]:
## Spliting the data into train and test
train_images = all_images_list[:4500]
test_images = all_images_list[4500:]

## Training

In [None]:
import os
import math
import pandas as pd
import torch
import torchvision
from prepare_data import TrainDataset, ValDataset
from torch.utils.data import DataLoader
from model_architecture import Generator, Discriminator
from custom_loss import Generator_Loss
from model_metrics import ssim
from tqdm import tqdm
import argparse
import torchvision.transforms as transforms
torch.backends.cudnn.benchmark = True
torch.cuda.manual_seed_all(42)

In [None]:
## Initialize parameters
CROP_SIZE = 100
UPSCALE_FACTOR = 4
NUM_EPOCHS = 8
BATCH_SIZE = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
## Initialize the generator
netG = Generator(upscale_factor=4).to(DEVICE)
print("# generator parameters:", sum(param.numel() for param in netG.parameters()))

# generator parameters: 201209


In [None]:
## Initialize the discriminator
netD = Discriminator().to(DEVICE)
print("# discriminator parameters:", sum(param.numel() for param in netD.parameters()))

# discriminator parameters: 19413279


In [None]:
## Initialize the loss function
generator_criterion = Generator_Loss().to(DEVICE)

In [None]:
## Initialize the optimizer
optimizerG = torch.optim.AdamW(netG.parameters(), lr=1e-3)
optimizerD = torch.optim.AdamW(netD.parameters(), lr=1e-3)

## Initialize the dictionary to store the results
results = {
    "d_loss": [],
    "g_loss": [],
    "d_score": [],
    "g_score": [],
    "psnr": [],
    "ssim": [],
}

In [None]:
## Load the train dataset
print("[INFO] Loading Train dataset")
train_set = TrainDataset(train_images)

## Load the validation dataset
print("[INFO] Loading Val dataset")
val_set = ValDataset(test_images)

[INFO] Loading Train dataset
[INFO] Loading Val dataset


In [None]:
## Create the train data loader
print("Creating Train data loader")
train_loader = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True,)

## Create the validation data loader
print("Creating Val data loader")
val_loader = DataLoader(dataset=val_set, batch_size=1, shuffle=False)

Creating Train data loader
Creating Val data loader


In [None]:
for epoch in range(1, NUM_EPOCHS + 1):

    train_bar = tqdm(train_loader, total=len(train_loader))

    running_results = {"batch_sizes": 0,
                       "d_loss": 0, "g_loss": 0,
                       "d_score": 0, "g_score": 0,
                    }

    netG.train()
    netD.train()

    ## Iterate over the batch of images
    for lr_img, hr_img in train_bar:

        batch_size = lr_img.size(0)
        running_results["batch_sizes"] += batch_size

        hr_img = hr_img.to(DEVICE) # high resolution image
        lr_img = lr_img.to(DEVICE) # low resolution image
        with torch.no_grad():
            sr_img = netG(lr_img) # super resolution image

        ## Set the gradients of Discriminator to zero
        netD.zero_grad()

        ## Formward propagate the HR image and SR image through the discriminator
        real_out = netD(hr_img).mean()
        fake_out = netD(sr_img).mean()

        ## Calculate the discriminator loss
        d_loss = 1 - real_out + fake_out

        ## Backpropagate the loss
        d_loss.backward(retain_graph=True)

        ## Update the weights
        optimizerD.step()

        ## Forward propagate the SR image through the discriminator
        with torch.no_grad():
            fake_out = netD(sr_img).mean()

        ## Set the gradients of the Generator to zero
        netG.zero_grad()

        ## Forward propagate the LR image through the generator to get the SR image
        sr_img = netG(lr_img)

        ## Calculate the generator loss
        g_loss = generator_criterion(fake_out, sr_img, hr_img)

        ## Backpropagate the loss
        g_loss.backward()

        ## Update the weights
        optimizerG.step()

        running_results["g_loss"] += g_loss.item() * batch_size
        running_results["d_loss"] += d_loss.item() * batch_size
        running_results["d_score"] += real_out.item() * batch_size
        running_results["g_score"] += fake_out.item() * batch_size

        train_bar.set_description(
            desc="[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f"
            % (
                epoch,
                NUM_EPOCHS,
                running_results["d_loss"] / running_results["batch_sizes"],
                running_results["g_loss"] / running_results["batch_sizes"],
                running_results["d_score"] / running_results["batch_sizes"],
                running_results["g_score"] / running_results["batch_sizes"],
            )
        )

    torch.cuda.empty_cache()

    ## Set the Generator to evaluation mode
    netG.eval()

    ## Run the validation loop
    with torch.no_grad():

        ## Progress bar for validation loop
        val_bar = tqdm(val_loader, total=len(val_loader))

        valing_results = {
            "mse": 0,
            "ssims": 0,
            "psnr": 0,
            "ssim": 0,
            "batch_sizes": 0,
        }

        val_images = []

        ## Iterate over the batch of images
        for val_lr, val_hr in val_bar:

            ## Get the current batch size
            batch_size = val_lr.size(0)
            valing_results["batch_sizes"] += batch_size

            lr = val_lr
            hr = val_hr
            if torch.cuda.is_available():
                lr = lr.cuda()
                hr = hr.cuda()

            ## Forward propagate the LR image through the generator to get the SR image
            sr = netG(lr)

            ## Calculate All the metrics
            ## Calculate and store the MSE
            batch_mse = ((sr - hr) ** 2).data.mean()
            valing_results["mse"] += batch_mse * batch_size

            ## Calculate and store the SSIMs
            batch_ssim = ssim(sr, hr).item()
            valing_results["ssims"] += batch_ssim * batch_size

            ## Calculate and store the PSNR
            valing_results["psnr"] = 10 * math.log10( (hr.max() ** 2) / (valing_results["mse"] / valing_results["batch_sizes"]))

            ## Calculate and store the SSIM
            valing_results["ssim"] = (valing_results["ssims"] / valing_results["batch_sizes"])

            ## Update the progress bar and print the results
            val_bar.set_description(
                desc="[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f"
                % (valing_results["psnr"], valing_results["ssim"])
            )

    netG.train()
    netD.train()

    ## Save the Generator model
    torch.save({"model": netG.state_dict()},
        f"netG_{UPSCALE_FACTOR}x_epoch{epoch}.pth.tar")

    ## Save the Discriminator model
    torch.save({"model": netD.state_dict()},
        f"netD_{UPSCALE_FACTOR}x_epoch{epoch}.pth.tar")

    ## Store the losses and scores for the current epoch
    results["d_loss"].append(running_results["d_loss"] / running_results["batch_sizes"])
    results["g_loss"].append(running_results["g_loss"] / running_results["batch_sizes"])
    results["d_score"].append(running_results["d_score"] / running_results["batch_sizes"])
    results["g_score"].append(running_results["g_score"] / running_results["batch_sizes"])
    results["psnr"].append(valing_results["psnr"])
    results["ssim"].append(valing_results["ssim"])

    print(results)

[1/8] Loss_D: 0.9969 Loss_G: 0.0041 D(x): 0.9992 D(G(z)): 0.9960: 100%|██████████| 2250/2250 [15:22<00:00,  2.44it/s]
[converting LR images to SR images] PSNR: 34.5777 dB SSIM: 0.9340: 100%|██████████| 761/761 [00:46<00:00, 16.30it/s]


{'d_loss': [0.9969477570586734], 'g_loss': [0.004061847448400739], 'd_score': [0.9992202426989873], 'g_score': [0.9960130303783549], 'psnr': [34.57768020777527], 'ssim': [0.9340427720750365]}


[2/8] Loss_D: 1.0000 Loss_G: 0.0015 D(x): 1.0000 D(G(z)): 1.0000: 100%|██████████| 2250/2250 [15:16<00:00,  2.45it/s]
[converting LR images to SR images] PSNR: 36.4357 dB SSIM: 0.9450: 100%|██████████| 761/761 [00:44<00:00, 16.93it/s]


{'d_loss': [0.9969477570586734, 1.0], 'g_loss': [0.004061847448400739, 0.0015366882274910393], 'd_score': [0.9992202426989873, 1.0], 'g_score': [0.9960130303783549, 1.0], 'psnr': [34.57768020777527, 36.43571701741517], 'ssim': [0.9340427720750365, 0.9449760649113398]}


[3/8] Loss_D: 1.0000 Loss_G: 0.0014 D(x): 1.0000 D(G(z)): 1.0000: 100%|██████████| 2250/2250 [15:15<00:00,  2.46it/s]
[converting LR images to SR images] PSNR: 38.7566 dB SSIM: 0.9503: 100%|██████████| 761/761 [00:44<00:00, 16.92it/s]


{'d_loss': [0.9969477570586734, 1.0, 1.0], 'g_loss': [0.004061847448400739, 0.0015366882274910393, 0.0013853977413899783], 'd_score': [0.9992202426989873, 1.0, 1.0], 'g_score': [0.9960130303783549, 1.0, 1.0], 'psnr': [34.57768020777527, 36.43571701741517, 38.75657330482375], 'ssim': [0.9340427720750365, 0.9449760649113398, 0.9503182663084173]}


[4/8] Loss_D: 1.0000 Loss_G: 0.0013 D(x): 1.0000 D(G(z)): 1.0000: 100%|██████████| 2250/2250 [15:19<00:00,  2.45it/s]
[converting LR images to SR images] PSNR: 39.2482 dB SSIM: 0.9560: 100%|██████████| 761/761 [00:44<00:00, 16.94it/s]


{'d_loss': [0.9969477570586734, 1.0, 1.0, 1.0], 'g_loss': [0.004061847448400739, 0.0015366882274910393, 0.0013853977413899783, 0.0013097206223497374], 'd_score': [0.9992202426989873, 1.0, 1.0, 1.0], 'g_score': [0.9960130303783549, 1.0, 1.0, 1.0], 'psnr': [34.57768020777527, 36.43571701741517, 38.75657330482375, 39.2481752884387], 'ssim': [0.9340427720750365, 0.9449760649113398, 0.9503182663084173, 0.9559727827290198]}


[5/8] Loss_D: 1.0000 Loss_G: 0.0013 D(x): 1.0000 D(G(z)): 1.0000: 100%|██████████| 2250/2250 [15:17<00:00,  2.45it/s]
[converting LR images to SR images] PSNR: 39.5467 dB SSIM: 0.9571: 100%|██████████| 761/761 [00:44<00:00, 16.93it/s]


{'d_loss': [0.9969477570586734, 1.0, 1.0, 1.0, 1.0], 'g_loss': [0.004061847448400739, 0.0015366882274910393, 0.0013853977413899783, 0.0013097206223497374, 0.0012697992790231688], 'd_score': [0.9992202426989873, 1.0, 1.0, 1.0, 1.0], 'g_score': [0.9960130303783549, 1.0, 1.0, 1.0, 1.0], 'psnr': [34.57768020777527, 36.43571701741517, 38.75657330482375, 39.2481752884387, 39.54670890577556], 'ssim': [0.9340427720750365, 0.9449760649113398, 0.9503182663084173, 0.9559727827290198, 0.9571178509277364]}


[6/8] Loss_D: 1.0000 Loss_G: 0.0012 D(x): 1.0000 D(G(z)): 1.0000: 100%|██████████| 2250/2250 [15:22<00:00,  2.44it/s]
[converting LR images to SR images] PSNR: 40.3222 dB SSIM: 0.9580: 100%|██████████| 761/761 [00:44<00:00, 16.95it/s]


{'d_loss': [0.9969477570586734, 1.0, 1.0, 1.0, 1.0, 1.0], 'g_loss': [0.004061847448400739, 0.0015366882274910393, 0.0013853977413899783, 0.0013097206223497374, 0.0012697992790231688, 0.001243539672681234], 'd_score': [0.9992202426989873, 1.0, 1.0, 1.0, 1.0, 1.0], 'g_score': [0.9960130303783549, 1.0, 1.0, 1.0, 1.0, 1.0], 'psnr': [34.57768020777527, 36.43571701741517, 38.75657330482375, 39.2481752884387, 39.54670890577556, 40.32220822584583], 'ssim': [0.9340427720750365, 0.9449760649113398, 0.9503182663084173, 0.9559727827290198, 0.9571178509277364, 0.9580364776346906]}


[7/8] Loss_D: 1.0000 Loss_G: 0.0012 D(x): 1.0000 D(G(z)): 1.0000: 100%|██████████| 2250/2250 [15:17<00:00,  2.45it/s]
[converting LR images to SR images] PSNR: 40.1413 dB SSIM: 0.9579: 100%|██████████| 761/761 [00:44<00:00, 16.95it/s]


{'d_loss': [0.9969477570586734, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'g_loss': [0.004061847448400739, 0.0015366882274910393, 0.0013853977413899783, 0.0013097206223497374, 0.0012697992790231688, 0.001243539672681234, 0.0012359368288372126], 'd_score': [0.9992202426989873, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'g_score': [0.9960130303783549, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'psnr': [34.57768020777527, 36.43571701741517, 38.75657330482375, 39.2481752884387, 39.54670890577556, 40.32220822584583, 40.14125488547145], 'ssim': [0.9340427720750365, 0.9449760649113398, 0.9503182663084173, 0.9559727827290198, 0.9571178509277364, 0.9580364776346906, 0.9579243424217585]}


[8/8] Loss_D: 1.0000 Loss_G: 0.0012 D(x): 1.0000 D(G(z)): 1.0000: 100%|██████████| 2250/2250 [15:20<00:00,  2.44it/s]
[converting LR images to SR images] PSNR: 39.1469 dB SSIM: 0.9577: 100%|██████████| 761/761 [00:45<00:00, 16.85it/s]


{'d_loss': [0.9969477570586734, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'g_loss': [0.004061847448400739, 0.0015366882274910393, 0.0013853977413899783, 0.0013097206223497374, 0.0012697992790231688, 0.001243539672681234, 0.0012359368288372126, 0.0012128635379227086], 'd_score': [0.9992202426989873, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'g_score': [0.9960130303783549, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'psnr': [34.57768020777527, 36.43571701741517, 38.75657330482375, 39.2481752884387, 39.54670890577556, 40.32220822584583, 40.14125488547145, 39.14688253257503], 'ssim': [0.9340427720750365, 0.9449760649113398, 0.9503182663084173, 0.9559727827290198, 0.9571178509277364, 0.9580364776346906, 0.9579243424217585, 0.9577492126184131]}


##Testing

In [3]:
import pandas as pd
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torchvision.transforms.functional import to_tensor
import sys

## Add the scripts folder to the path
sys.path.insert(0, '../scripts/')
from model_architecture import Generator

## Set the seed for reproducibility
torch.backends.cudnn.benchmark = True
torch.cuda.manual_seed_all(42)

In [4]:
## Set the device
DEVICE = "cuda"

## Load the model
model = Generator(upscale_factor=4).to(DEVICE)

## Load the model weights state dict
state_dict = torch.load('/content/netG_4x_epoch8.pth.tar', map_location=torch.device(DEVICE))

## Load the model from state dict
model.load_state_dict(state_dict["model"], )

## Set the model to evaluation mode
model.eval()

Generator(
  (initial): ConvBlock(
    (cnn): SeperableConv2d(
      (depthwise): Conv2d(3, 3, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), groups=3)
      (pointwise): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
    )
    (bn): Identity()
    (act): PReLU(num_parameters=64)
  )
  (residual): Sequential(
    (0): ResidualBlock(
      (block1): ConvBlock(
        (cnn): SeperableConv2d(
          (depthwise): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
          (pointwise): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): PReLU(num_parameters=64)
      )
      (block2): ConvBlock(
        (cnn): SeperableConv2d(
          (depthwise): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
          (pointwise): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=Fa

In [5]:
# Load an image
hr_image = Image.open('/content/sample_hr_input.png').convert('RGB')

## Create the LR image transformer by downsampling the HR image and applying bicubic interpolation
lr_scale = transforms.Resize((256,256), interpolation=Image.BICUBIC)

## Create the restored HR image tranformer (simple classical method) by upsampling the LR image and applying bicubic interpolation
hr_scale = transforms.Resize((1024,1024), interpolation=Image.BICUBIC)

## Create the LR Image from the original HR Image using the LR Image transformer
lr_image = lr_scale(hr_image)
lr_image.save("/content/sample_lr_input.png")

## Create the restored HR Image from the LR Image using the classical method of restored HR Image transforms
hr_restore_img = hr_scale(lr_image)

## Convert the LR Image to a tensor
lr_image = to_tensor(lr_image)

# Move the image and model to GPU if available
if torch.cuda.is_available():
    lr_image = lr_image.cuda()

## Add a batch dimension to the image
lr_image = lr_image.unsqueeze(0)

lr_image.shape

# Perform model inference
with torch.no_grad():
    output = model(lr_image)

In [6]:
## Remove the batch dimension
out = output.squeeze(0)

## Transforms for displaying the images
display_transform = transforms.Compose([
    transforms.ToPILImage(),
])

## Transform the output image
out = display_transform(out)

## Save the output image
out.save("/content/sample_sr_output.png")