In [1]:
import os
from copy import deepcopy
from math import sqrt

import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm

from models.srcnn import SRCNN

In [2]:
SCALE_FACTOR = 4
CROP_SIZE = 32
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm'  # include image suffixes

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
train_model = SRCNN(in_channels=3, scale_factor=SCALE_FACTOR).to(device)
optimizer = optim.Adam(
    [
        {"params": train_model.conv1.parameters(), 'lr': 0.0001},
        {"params": train_model.conv2.parameters(), 'lr': 0.0001},
        {"params": train_model.conv3.parameters(), 'lr': 0.00001},
        {"params": train_model.upsample.parameters(), 'lr':0.00001},
    ],
    lr=0.00001,
)
loss_fn = nn.MSELoss()

In [5]:
class LoadDataset(Dataset):
    def __init__(self, path, scale_factor=SCALE_FACTOR, crop_size=CROP_SIZE):
        super().__init__()
        scale_resize = int(sqrt(scale_factor))
        crop_size_ = crop_size - (crop_size % scale_resize) # Valid crop size

        self.imgs_path_list = [os.path.join(path, x) for x in os.listdir(path) if x.split('.')[-1].lower() in IMG_FORMATS]

        self.image_transform = transforms.Compose([
            transforms.CenterCrop(crop_size_),  # cropping the image
            transforms.Resize(crop_size_ // scale_resize, interpolation=Image.BICUBIC),  # subsampling the image (half size)
            transforms.ToTensor(),
        ])
        self.label_transform = transforms.Compose([
            transforms.CenterCrop(crop_size_), # keep label's original quality
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.imgs_path_list)

    def __getitem__(self, index):
        img = Image.open(self.imgs_path_list[index])
        label = img.copy()

        img = self.image_transform(img)
        label = self.label_transform(label)

        return img, label

In [6]:
train_dataset = LoadDataset('data/train')
train_loader = DataLoader(dataset=train_dataset, batch_size=256, shuffle=True, num_workers=0)

In [7]:
epochs = 300
for epoch in range(epochs):
    train_model.train()
    mloss = torch.zeros(1, device=device)  # mean_loss

    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch}/{epochs}', unit='batches')
    for i, (images, targets) in pbar:
        images, targets = images.to(device), targets.to(device)
        preds = train_model(images)
        loss = loss_fn(preds, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        mloss = (mloss * i + loss) / (i + 1)
        mem = f'{torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0:.3g}G'  # GPU_mem
        pbar.set_postfix(loss=mloss.item(), GPU_mem=mem)

    ckpt = {  # checkpoint
        'epoch': epoch,
        'model': deepcopy(train_model).half(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(ckpt, 'srcnn.pt')

Epoch 0/300: 100%|██████████| 2/2 [00:02<00:00,  1.11s/batches, GPU_mem=0.174G, loss=0.261]
Epoch 1/300: 100%|██████████| 2/2 [00:00<00:00,  3.53batches/s, GPU_mem=0.174G, loss=0.257]
Epoch 2/300: 100%|██████████| 2/2 [00:00<00:00,  3.42batches/s, GPU_mem=0.174G, loss=0.243]
Epoch 3/300: 100%|██████████| 2/2 [00:00<00:00,  3.76batches/s, GPU_mem=0.174G, loss=0.241]
Epoch 4/300: 100%|██████████| 2/2 [00:00<00:00,  2.70batches/s, GPU_mem=0.174G, loss=0.229]
Epoch 5/300: 100%|██████████| 2/2 [00:00<00:00,  2.40batches/s, GPU_mem=0.174G, loss=0.229]
Epoch 6/300: 100%|██████████| 2/2 [00:00<00:00,  2.35batches/s, GPU_mem=0.174G, loss=0.211]
Epoch 7/300: 100%|██████████| 2/2 [00:00<00:00,  2.11batches/s, GPU_mem=0.174G, loss=0.199]
Epoch 8/300: 100%|██████████| 2/2 [00:00<00:00,  2.17batches/s, GPU_mem=0.174G, loss=0.193]
Epoch 9/300: 100%|██████████| 2/2 [00:00<00:00,  2.09batches/s, GPU_mem=0.174G, loss=0.186]
Epoch 10/300: 100%|██████████| 2/2 [00:00<00:00,  2.03batches/s, GPU_mem=0.174G,

In [8]:
class LoadImages:
    def __init__(self, path, scale_factor=SCALE_FACTOR):
        self.scale_resize = int(sqrt(scale_factor))
        self.imgs_path_list = [os.path.join(path, x) for x in os.listdir(path) if x.split('.')[-1].lower() in IMG_FORMATS]
        self.num_files = len(self.imgs_path_list)

    def __len__(self):
        return self.num_files

    def __iter__(self):
        self.count = 0
        return self

    def __next__(self):
        if self.count == self.num_files:
            raise StopIteration
        img_path = self.imgs_path_list[self.count]
        self.count += 1

        img_name = img_path.split(os.sep)[-1]
        img0 = Image.open(img_path)
        _img = img0.resize((
            int(img0.size[0] / self.scale_resize),
            int(img0.size[1] / self.scale_resize)
        ),
            Image.BICUBIC,  # scale the image via bicubic interpolation
        )

        img_ = transforms.ToTensor()(_img).view(1, -1, _img.size[1], _img.size[0])

        return img_, img_name

In [9]:
ckpt = torch.load('srcnn.pt')
test_model = ckpt['model'].to(device).float()
test_model.eval()
test_loader = LoadImages('data/test')  # test dataset: Set5

In [10]:
with torch.no_grad():
    for image, image_name in test_loader:
        image = image.to(device)
        pred = test_model(image).cpu()
        pred_img = transforms.ToPILImage()(pred.squeeze(0))

        pred_img.save(f'output/srcnn/srcnn_{image_name}')