In [1]:
import os
from copy import deepcopy
from math import sqrt

import torch
import torchvision
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from models.srgan import Generator, Discriminator

In [2]:
SCALE_FACTOR = 4
CROP_SIZE = 32
IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm'  # include image suffixes

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
# Total Variation Loss
class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]

In [5]:
class GeneratorLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = torchvision.models.vgg.vgg16(weights='VGG16_Weights.IMAGENET1K_V1')
        loss_net = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in loss_net.parameters():
            param.requires_grad = False
        self.loss_net = loss_net.to(device)
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()

    def forward(self, fakes, p, targets):
        # p: (G)fakes: fake_images, (D)probabilities of real_images, targets: target_images(high-resolution)
        image_loss = self.mse_loss(fakes, targets)  # Image Loss
        adversarial_loss = torch.mean(1 - p)  # Adversarial Loss
        perception_loss = self.mse_loss(self.loss_net(fakes), self.loss_net(targets))  # Perception Loss
        tv_loss = self.tv_loss(fakes)  # TV Loss

        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss

In [6]:
class DiscriminatorLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce_loss = nn.BCELoss()

    def forward(self, p, p_gt):
        # p: probability of real_image, p_gt: probability(ground truth) of real_image
        return self.bce_loss(p, torch.zeros_like(p, device=device)) + self.bce_loss(p_gt, torch.ones_like(p_gt, device=device))

In [7]:
# train_dataset
class LoadDataset(Dataset):
    def __init__(self, path, scale_factor=SCALE_FACTOR, crop_size=CROP_SIZE):
        super().__init__()
        scale_resize = int(sqrt(scale_factor))
        crop_size_ = crop_size - (crop_size % scale_resize) # Valid crop size

        self.imgs_path_list = [os.path.join(path, x) for x in os.listdir(path) if x.split('.')[-1].lower() in IMG_FORMATS]

        self.image_transform = transforms.Compose([
            transforms.CenterCrop(crop_size_),
            transforms.Resize(crop_size_ // scale_resize, interpolation=Image.BICUBIC),
            transforms.ToTensor(),
        ])

        self.label_transform = transforms.Compose([
            transforms.CenterCrop(crop_size_),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.imgs_path_list)

    def __getitem__(self, index):
        img = Image.open(self.imgs_path_list[index])
        label = img.copy()

        img = self.image_transform(img)
        label = self.label_transform(label)
        return img, label

In [8]:
# test_dataloader
class LoadImages:
    def __init__(self, path, scale_factor=SCALE_FACTOR):
        self.scale_factor = scale_factor
        self.scale_resize = int(sqrt(scale_factor))
        self.imgs_path_list = [os.path.join(path, x) for x in os.listdir(path) if x.split('.')[-1].lower() in IMG_FORMATS]
        self.num_files = len(self.imgs_path_list)

    def __len__(self):
        return self.num_files

    def __iter__(self):
        self.count = 0
        return self

    def __next__(self):
        if self.count == self.num_files:
            raise StopIteration
        img_path = self.imgs_path_list[self.count]
        self.count += 1

        img_name = img_path.split(os.sep)[-1]
        img0 = Image.open(img_path)
        img = img0.resize((
            int(img0.size[0] / self.scale_resize),
            int(img0.size[1] / self.scale_resize)
        ),
            Image.BICUBIC,  # scale the image via bicubic interpolation
        )

        # img = img0.resize((
        #     int(img0.size[0] / self.scale_factor),
        #     int(img0.size[1] / self.scale_factor)
        # ),
        #     Image.BICUBIC,  # scale the image via bicubic interpolation
        # )

        img_ = transforms.ToTensor()(img).view(1, -1, img.size[1], img.size[0])

        return img_, img_name

# Train

In [9]:
train_generator = Generator(in_channels=3, scale_factor=SCALE_FACTOR).to(device)
train_discriminator = Discriminator(in_channels=3).to(device)

optimizerG = torch.optim.Adam(train_generator.parameters(), lr=1e-4)
optimizerD = torch.optim.Adam(train_discriminator.parameters(), lr=1e-3)
lr_schedulerG = torch.optim.lr_scheduler.StepLR(optimizer=optimizerG, step_size=20, gamma=0.5)
lr_schedulerD = torch.optim.lr_scheduler.StepLR(optimizer=optimizerD, step_size=20, gamma=0.5)

gloss_fn = GeneratorLoss()
dloss_fn = DiscriminatorLoss()

train_dataset = LoadDataset('data/train')
train_loader = DataLoader(dataset=train_dataset, batch_size=256, shuffle=True, num_workers=0)

In [14]:
epochs = 300
for epoch in range(epochs):
    train_generator.train()
    train_discriminator.train()
    mloss = torch.zeros(2, device=device)  # mean_loss of generator and discriminator

    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch}/{epochs}', unit='batches')
    for i, (images, targets) in pbar:
        images, targets = images.to(device), targets.to(device)  # images: low_resolution, targets: high-resolution
        fakes = train_generator(images)  # fake_images
        ps = train_discriminator(fakes)  # probabilities of real_images
        ps_gt = train_discriminator(targets)  # probabilities(ground truth) of real_images

        gloss = gloss_fn(fakes, ps, targets)
        optimizerG.zero_grad()
        gloss.backward()
        optimizerG.step()

        dloss = dloss_fn(ps.detach(), ps_gt)
        optimizerD.zero_grad()
        dloss.backward()
        optimizerD.step()

        mloss[0] = (mloss[0] * i + gloss) / (i + 1)  # mean_loss of generator
        mloss[1] = (mloss[1] * i + dloss) / (i + 1)  # mean_loss of discriminator
        mem = f'{torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0:.3g}G'  # GPU_mem
        pbar.set_postfix(gloss=mloss[0].item(), dloss=mloss[1].item(), GPU_mem=mem)

    lr_schedulerG.step()
    lr_schedulerD.step()

    ckpt = {  # checkpoint
        'epoch': epoch,
        'generator': deepcopy(train_generator).half(),
        'discriminator': deepcopy(train_discriminator).half(),
        'optimizerG': optimizerG.state_dict(),
        'optimizerD': optimizerD.state_dict(),
    }
    torch.save(ckpt, 'srgan.pt')

Epoch 0/300: 100%|██████████| 2/2 [00:01<00:00,  1.70batches/s, GPU_mem=2.39G, dloss=87.3, gloss=0.0107]
Epoch 1/300: 100%|██████████| 2/2 [00:00<00:00,  2.58batches/s, GPU_mem=2.39G, dloss=88.5, gloss=0.0103]
Epoch 2/300: 100%|██████████| 2/2 [00:00<00:00,  2.59batches/s, GPU_mem=2.39G, dloss=87.4, gloss=0.0102]
Epoch 3/300: 100%|██████████| 2/2 [00:00<00:00,  2.52batches/s, GPU_mem=2.39G, dloss=87.1, gloss=0.0102]
Epoch 4/300: 100%|██████████| 2/2 [00:00<00:00,  2.57batches/s, GPU_mem=2.39G, dloss=88.8, gloss=0.0104]
Epoch 5/300: 100%|██████████| 2/2 [00:00<00:00,  2.54batches/s, GPU_mem=2.39G, dloss=87.8, gloss=0.0107]
Epoch 6/300: 100%|██████████| 2/2 [00:00<00:00,  2.56batches/s, GPU_mem=2.39G, dloss=86, gloss=0.0105]  
Epoch 7/300: 100%|██████████| 2/2 [00:00<00:00,  2.57batches/s, GPU_mem=2.39G, dloss=87.7, gloss=0.0106]
Epoch 8/300: 100%|██████████| 2/2 [00:00<00:00,  2.53batches/s, GPU_mem=2.39G, dloss=87.1, gloss=0.0103] 
Epoch 9/300: 100%|██████████| 2/2 [00:00<00:00,  2.54b

# Test

In [15]:
ckpt = torch.load('srgan.pt')
test_generator = ckpt['generator'].to(device).float()
test_generator.eval()
test_loader = LoadImages('data/test')  # test dataset: Set5

In [16]:
with torch.no_grad():
    for image, image_name in test_loader:
        image = image.to(device)
        pred = test_generator(image).cpu()
        pred_img = transforms.ToPILImage()(pred.squeeze(0))

        pred_img.save(f'output/srgan/srgan_{image_name}')