In [1]:
from srgan.model import Generator, Discriminator
from srgan.dataset import DatasetFromQuery
from srgan.utils import get_logger

from torch.utils.data import DataLoader
import torch
from torch.optim import Adam
from torch import nn

In [2]:
logger = get_logger(__name__)

In [3]:
dataset = DatasetFromQuery(query='/workdir/dataset/BSDS300/images/train/*.jpg', shrink_scale=4, max_size=96, input_upsample=False)
loader = DataLoader(dataset, batch_size=64, num_workers=6, shuffle=True)

In [4]:
from torchvision.models.vgg import vgg19

In [5]:
import numpy as np

In [6]:
from torchvision.transforms import Normalize

In [7]:
import torch.functional as F

In [8]:
class VGGLoss(nn.Module):
    def __init__(self, layer='22', device=None):
        super().__init__()
        vgg = vgg19(pretrained=True)
        
        if layer == '22':
            self.vgg_feature = vgg.features[:11]
            
        self.mean = [0.485, 0.456, 0.406]
        self.mean = torch.tensor(self.mean, device=device)[None, :, None, None]
        self.std = [0.229, 0.224, 0.225]
        self.std = torch.tensor(self.std, device=device)[None, :, None, None]            
        self.loss = nn.MSELoss()
        
    def normalize(self, tensor):
        tensor = tensor.clone()
        tensor.sub_(self.mean).div_(self.std)
        return tensor
        
    def forward(self, img_high, img_fake):
        z_high = self.vgg_feature(self.normalize(img_high))
        z_low = self.vgg_feature(self.normalize(img_fake))
        return self.loss(z_high, z_low)

## Load Model

pretrained model の読み込み

In [9]:
gen = Generator()
dis = Discriminator()

In [10]:
DEVICE = torch.device('cuda')

In [11]:
pretrained_path = '/workdir/dataset/gen_pretrain.hdf5'
gen.load_state_dict(torch.load(pretrained_path))

In [12]:
bce_criterion = nn.BCELoss()
vgg_criterion = VGGLoss(device=DEVICE)

In [13]:
gen, dis = gen.to(DEVICE), dis.to(DEVICE)
vgg_criterion.to(DEVICE)

VGGLoss(
  (vgg_feature): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (loss): MSELoss()
)

In [14]:
opt_gen = Adam(params=gen.parameters(), lr=1e-4)
opt_dis = Adam(params=dis.parameters(), lr=1e-4)

In [15]:
from collections import defaultdict

In [16]:
import pandas as pd

In [17]:
n_epochs = 20

In [None]:
for epoch in range(1, n_epochs + 1):
    logger.info(f'start epoch: {epoch}')
    watch_logs = defaultdict(list)
    log_df = pd.DataFrame()
    for i, (img_low, img_high) in enumerate(loader):
        img_low, img_high = img_low.to(DEVICE), img_high.to(DEVICE)
        img_fake = gen(img_low)

        # step1: update discriminator
        pred_high = dis(img_high)
        pred_fake = dis(img_fake.detach())
        dis_loss = bce_criterion(pred_high, torch.full_like(pred_high, 1, device=DEVICE)) + bce_criterion(pred_fake, torch.full_like(pred_fake, 0))
        opt_dis.zero_grad()
        dis_loss.backward()
        opt_dis.step()

        # step2: generator の update
        pred_fake = dis(img_fake)
        content_loss = vgg_criterion(img_fake, img_high)
        adv_loss = bce_criterion(pred_fake, torch.full_like(pred_fake, 1, device=DEVICE))
        gen_loss = content_loss + 1e-2 * adv_loss
        opt_gen.zero_grad()
        gen_loss.backward()
        opt_gen.step()

        update_data = {
            'dis_loss': dis_loss.item(),
            'gen_loss': gen_loss.item(),
            'content_loss': content_loss.item(),
            'adv_loss': adv_loss.item()
        }

        for k, v in update_data.items():
            watch_logs[k].append(v)

        if (i + 1) % 200 == 0:
            s = pd.DataFrame(watch_logs).mean()
            str_log = [f'{i}/{len(loader)}']
            for k, v in s.items():
                str_log.append(f'{k} {v:.4e}')
            logger.info('\t'.join(str_log))
            s['n_steps'] = i
            s['epoch'] = epoch
            log_df = log_df.append(s, ignore_index=True)
            watch_logs = defaultdict(list)

[2020-01-12 20:17:23,514] start epoch: 1
[2020-01-12 20:18:33,986] 199/1563	dis_loss 8.2639e-04	gen_loss 1.3214e+01	content_loss 1.2864e+01	adv_loss 3.5032e+01
[2020-01-12 20:19:44,436] 399/1563	dis_loss 2.6447e-05	gen_loss 1.3270e+01	content_loss 1.2905e+01	adv_loss 3.6420e+01
[2020-01-12 20:20:54,766] 599/1563	dis_loss 1.0314e-08	gen_loss 1.3306e+01	content_loss 1.2915e+01	adv_loss 3.9072e+01
[2020-01-12 20:22:05,101] 799/1563	dis_loss 1.0696e-08	gen_loss 1.3171e+01	content_loss 1.2788e+01	adv_loss 3.8318e+01
[2020-01-12 20:23:15,889] 999/1563	dis_loss 1.1344e-07	gen_loss 1.3102e+01	content_loss 1.2742e+01	adv_loss 3.5996e+01
[2020-01-12 20:24:27,389] 1199/1563	dis_loss 8.7718e-05	gen_loss 1.3112e+01	content_loss 1.2729e+01	adv_loss 3.8297e+01


In [None]:
torch.save(gen.state_dict(), '/workdir/dataset/generator.hdf5')
torch.save(dis.state_dict(), '/workdir/dataset/discriminator.hdf5')