In [1]:
import os
from copy import deepcopy
from math import sqrt

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm

from models.srcnn import SRCNN

In [2]:
SCALE_FACTOR = 4
CROP_SIZE = 32
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm'  # include image suffixes

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Train

In [4]:
train_model = SRCNN(in_channels=1).to(device)
optimizer = optim.Adam(
    [
        {"params": train_model.conv1.parameters(), 'lr': 0.0001},
        {"params": train_model.conv2.parameters(), 'lr': 0.0001},
        {"params": train_model.conv3.parameters(), 'lr': 0.00001},
    ],
    lr=0.00001,
)
loss_fn = nn.MSELoss()

In [5]:
class LoadDataset(Dataset):
    def __init__(self, path, scale_factor=SCALE_FACTOR, crop_size=CROP_SIZE):
        super().__init__()
        scale_resize = int(sqrt(scale_factor))
        crop_size_ = crop_size - (crop_size % scale_resize) # Valid crop size

        self.imgs_path_list = [os.path.join(path, x) for x in os.listdir(path) if x.split('.')[-1].lower() in IMG_FORMATS]

        self.image_transform = transforms.Compose([
            transforms.CenterCrop(crop_size_),  # cropping the image
            transforms.Resize(crop_size_ // scale_resize, interpolation=Image.BICUBIC),  # subsampling the image (half size)
            transforms.Resize(crop_size),
            transforms.ToTensor(),
        ])
        self.label_transform = transforms.Compose([
            transforms.CenterCrop(crop_size_), # keep label's original quality
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.imgs_path_list)

    def __getitem__(self, index):
        img = self.load_image(self.imgs_path_list[index])
        label = img.copy()

        img = self.image_transform(img)
        label = self.label_transform(label)

        return img, label

    @staticmethod
    def load_image(filepath):
        img = Image.open(filepath)
        y, _, _ = img.split()

        return y

In [6]:
train_dataset = LoadDataset('data/train')
train_loader = DataLoader(dataset=train_dataset, batch_size=256, shuffle=True, num_workers=0)

In [7]:
epochs = 300
for epoch in range(epochs):
    train_model.train()
    mloss = torch.zeros(1, device=device)  # mean_loss

    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch}/{epochs}', unit='batches')
    for i, (images, targets) in pbar:
        images, targets = images.to(device), targets.to(device)
        preds = train_model(images)
        loss = loss_fn(preds, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        mloss = (mloss * i + loss) / (i + 1)
        mem = f'{torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0:.3g}G'  # GPU_mem
        pbar.set_postfix(loss=mloss.item(), GPU_mem=mem)

    ckpt = {  # checkpoint
        'epoch': epoch,
        'model': deepcopy(train_model).half(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(ckpt, 'SRCNN.pt')

Epoch 0/300: 100%|██████████| 2/2 [00:02<00:00,  1.01s/batches, GPU_mem=1.41G, loss=0.198]
Epoch 1/300: 100%|██████████| 2/2 [00:00<00:00,  3.40batches/s, GPU_mem=1.41G, loss=0.137]
Epoch 2/300: 100%|██████████| 2/2 [00:00<00:00,  3.71batches/s, GPU_mem=1.41G, loss=0.0929]
Epoch 3/300: 100%|██████████| 2/2 [00:00<00:00,  3.66batches/s, GPU_mem=1.41G, loss=0.0587]
Epoch 4/300: 100%|██████████| 2/2 [00:00<00:00,  3.58batches/s, GPU_mem=1.41G, loss=0.0372]
Epoch 5/300: 100%|██████████| 2/2 [00:00<00:00,  3.63batches/s, GPU_mem=1.41G, loss=0.0306]
Epoch 6/300: 100%|██████████| 2/2 [00:00<00:00,  3.56batches/s, GPU_mem=1.41G, loss=0.0343]
Epoch 7/300: 100%|██████████| 2/2 [00:00<00:00,  3.64batches/s, GPU_mem=1.41G, loss=0.0418]
Epoch 8/300: 100%|██████████| 2/2 [00:00<00:00,  3.50batches/s, GPU_mem=1.41G, loss=0.0452]
Epoch 9/300: 100%|██████████| 2/2 [00:00<00:00,  3.52batches/s, GPU_mem=1.41G, loss=0.0409]
Epoch 10/300: 100%|██████████| 2/2 [00:00<00:00,  3.54batches/s, GPU_mem=1.41G, lo

# Test

In [8]:
class LoadImages:
    def __init__(self, path, scale_factor=SCALE_FACTOR):
        self.scale_resize = int(sqrt(scale_factor))
        self.imgs_path_list = [os.path.join(path, x) for x in os.listdir(path) if x.split('.')[-1].lower() in IMG_FORMATS]
        self.num_files = len(self.imgs_path_list)

    def __len__(self):
        return self.num_files

    def __iter__(self):
        self.count = 0
        return self

    def __next__(self):
        if self.count == self.num_files:
            raise StopIteration
        img_path = self.imgs_path_list[self.count]
        self.count += 1

        img_name = img_path.split(os.sep)[-1]
        img0 = Image.open(img_path).convert('YCbCr')
        _img = img0.resize((
            int(img0.size[0] / self.scale_resize),
            int(img0.size[1] / self.scale_resize)
        ),
            Image.BICUBIC,  # scale the image via bicubic interpolation
        )
        img = _img.resize((img0.size[0],img0.size[1]))

        y, cb, cr = img.split()
        img_ = transforms.ToTensor()(y).view(1, -1, y.size[1], y.size[0])  # only work with the "Y" channel
        # img = transforms.Resize([int(_img.size[0] * self.scale_resize), int(_img.size[1] * self.scale_resize)])(_img)
        # img_ = transforms.ToTensor()(img_).view(1, -1, img.size[1], img.size[0])
        return img_, cb, cr, img_name

In [9]:
ckpt = torch.load('SRCNN.pt')
test_model = ckpt['model'].to(device).float()
test_model.eval()

SRCNN(
  (conv1): Conv2d(1, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (conv2): Conv2d(64, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(32, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (relu): ReLU(inplace=True)
)

In [10]:
test_dataset = LoadImages('data/test')  # test dataset: Set5

In [11]:
with torch.no_grad():
    for image, cb, cr, image_name in test_dataset:
        image = image.to(device)
        pred = test_model(image).cpu()
        pred_y = pred[0].detach().numpy() * 255.0
        pred_y = pred_y.clip(0, 255)
        pred_y = Image.fromarray(np.uint8(pred_y[0]), mode='L')

        # merge the output of our network with the upscaled Cb and Cr from before converting the result in RGB
        pred_img = Image.merge('YCbCr', [pred_y, cb, cr]).convert('RGB')

        pred_img.save(f'output/srcnn/srcnn_{image_name}')