In [1]:
import os
import random

import torch.nn as nn
import torch.cuda.amp as amp
import torch.utils.data
from tqdm import tqdm

from dataset import SRDataset
from loss import PerceptionLoss
from models import Generator, Discriminator
from utils import init_torch_seeds, convert_image

from PIL import PngImagePlugin
PngImagePlugin.MAX_TEXT_CHUNK = 100 * (1024**2)

Почти все параметры обучения идентичны тем, которые были использованы при обучении SRResNet на прошлом шаге.

In [9]:
# параметры датасета
augments = {
    'rotation': False,
    'hflip' : True
}
crop_size = 256
lr_img_type = 'imagenet-norm'
# так как мы оптимизируем не попиксельную MSE, а MSE в пространстве фичей сетки vgg19, которая на вход
# принимает нормированные изображения, то тип hr изображений будет imagenet-normed
hr_img_type = 'imagenet-norm'
train_data_name = './jsons/train_images.json'

# параметры обучения модели
save_every = 2
print_every = 2000
start_epoch = 0
iters = 1e5
batch_size = 16
lr = 2e-4
beta = 1e-3 # модификатор adversarial ошибки (см. ориг. статью https://arxiv.org/pdf/1609.04802.pdf)
manualSeed = None
workers = 4

# параметры структуры модели
srresnet_checkpoint = './weights/SRResNet_16blocks_4x.pth' # путь к весам SRResNet, обученной на предыдущем этапе
upscale_factor = 4
n_blocks = 16

# Зададим рандомный seed, чтобы была возможность воспроизвести результат
if manualSeed is None:
    manualSeed = random.randint(1, 10000)
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
init_torch_seeds(manualSeed)

Random Seed:  7283


Создаем SRDataset и dataloader

In [10]:
dataset = SRDataset(crop_size=crop_size, scaling_factor=upscale_factor,
                    lr_img_type=lr_img_type, hr_img_type=hr_img_type,
                    train_data_name=train_data_name, augments=augments)

dataloader = torch.utils.data.DataLoader(dataset, shuffle=True,
                                         batch_size=batch_size,
                                         pin_memory=True,
                                         num_workers=int(workers))

Создадим генератор и дискриминатор, функции ошибки, оптимизаторы для генератора и дискриминатора (будем использовать Adam).<br>
В качестве функции ошибки мы используем **PerceptionLoss** и **BCEWithLogitsLoss**. <br>
**BCEWithLogitsLoss** - состязательная ошибка (**adversarial loss**), которая используется при обучении GAN-ов.<br> 
**PerceptionLoss** - MSE в пространстве фичей vgg19. <br>
Финальный loss: loss = **PerceptionLoss** + beta * **BCEWithLogitsLoss**

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# создаем генератор и дискриминатор
generator = Generator(n_blocks=n_blocks, scaling_factor=upscale_factor).to(device)
generator.load_state_dict(torch.load(srresnet_checkpoint)) # инициализируем модель весами srresnet
discriminator = Discriminator().to(device)

# инициализируем loss-ы
perception_criterion = PerceptionLoss().to(device) # MSE в пространстве фичей vgg19
adversarial_criterion = nn.BCEWithLogitsLoss().to(device)

# переводим в режим обучения
generator.train()
discriminator.train()

epochs = int(iters // len(dataloader))
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.9, 0.999))
optimizer_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.9, 0.999))

scaler_g = amp.GradScaler()
scaler_d = amp.GradScaler()

cuda


Запускаем обучение!

In [12]:
for epoch in range(start_epoch, epochs):
    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
    g_avg_loss = 0.0
    d_avg_loss = 0.0
    for i, (lr_imgs, hr_imgs) in progress_bar:
        lr_imgs = lr_imgs.to(device, non_blocking=True)
        hr_imgs = hr_imgs.to(device, non_blocking=True)
        
        ### сначала обновляем генератор
        optimizer_g.zero_grad()
        
        with amp.autocast():
            # получаем fake high res изображения
            sr_imgs = generator(lr_imgs)
            # в vgg19 на вход нужно подавать отнормированные изображения
            sr_imgs = convert_image(sr_imgs, source='[-1, 1]', target='imagenet-norm')
            
            fake_labels = discriminator(sr_imgs)

            # считаем loss-ы
            perception_loss = perception_criterion(sr_imgs, hr_imgs)
            adversarial_loss = adversarial_criterion(fake_labels, torch.ones_like(fake_labels))
            perceptual_loss = perception_loss + beta * adversarial_loss

        # back propagation
        scaler_g.scale(perceptual_loss).backward()
        scaler_g.step(optimizer_g)
        scaler_g.update()
        
        ### обновляем дискриминатор
        optimizer_d.zero_grad()
        
        with amp.autocast():
            hr_labels = discriminator(hr_imgs)
            fake_labels = discriminator(sr_imgs.detach())
            
            # Binary Cross-Entropy loss
            adversarial_loss = adversarial_criterion(fake_labels, torch.zeros_like(fake_labels)) + \
                               adversarial_criterion(hr_labels, torch.ones_like(hr_labels))
        
        scaler_d.scale(adversarial_loss).backward()
        scaler_d.step(optimizer_d)
        scaler_d.update()

        d_avg_loss += adversarial_loss.item()
        g_avg_loss += perceptual_loss.item()

        progress_bar.set_description(f"[{epoch + 1}/{epochs}][{i + 1}/{len(dataloader)}] "
                                     f"Loss_D: {adversarial_loss.item():.4f} Loss_G: {perceptual_loss.item():.4f} ")

        total_iter = len(dataloader) * epoch + i
        
        if i % print_every == 0 and i != 0:
            print(f"Avg Loss_G: {(g_avg_loss/(i+1)):.4f} Avg Loss_D: {(d_avg_loss/(i+1)):.4f}")

            
    # сохраняем модели
    if (epoch+1)%save_every == 0:
        torch.save(generator.state_dict(), 
                   f"./weights/SRGAN_{n_blocks}blocks_{upscale_factor}x_epoch{(epoch+1)}.pth")
    else:
        torch.save(generator.state_dict(), f"./weights/SRGAN_{n_blocks}blocks_{upscale_factor}x.pth")

[1/19][2001/5136] Loss_D: 0.0000 Loss_G: 4.9509 :  39%|███▉      | 2001/5136 [07:45<12:01,  4.34it/s]

Avg Loss_G: 4.4721 Avg Loss_D: 0.2068


[1/19][4001/5136] Loss_D: 0.0001 Loss_G: 4.1077 :  78%|███████▊  | 4001/5136 [15:25<04:20,  4.36it/s] 

Avg Loss_G: 4.3553 Avg Loss_D: 0.1475


[1/19][5136/5136] Loss_D: 0.0052 Loss_G: 3.6789 : 100%|██████████| 5136/5136 [19:53<00:00,  4.30it/s] 
[2/19][2001/5136] Loss_D: 0.0000 Loss_G: 4.1856 :  39%|███▉      | 2001/5136 [07:41<12:00,  4.35it/s]

Avg Loss_G: 4.1216 Avg Loss_D: 0.0070


[2/19][4001/5136] Loss_D: 0.0001 Loss_G: 3.1869 :  78%|███████▊  | 4001/5136 [15:21<04:21,  4.35it/s]

Avg Loss_G: 4.1099 Avg Loss_D: 0.0082


[2/19][5136/5136] Loss_D: 0.0000 Loss_G: 3.3755 : 100%|██████████| 5136/5136 [19:45<00:00,  4.33it/s]
[3/19][2001/5136] Loss_D: 0.0000 Loss_G: 4.0556 :  39%|███▉      | 2001/5136 [07:43<12:20,  4.23it/s]

Avg Loss_G: 4.0261 Avg Loss_D: 0.0129


[3/19][4001/5136] Loss_D: 0.0000 Loss_G: 4.3282 :  78%|███████▊  | 4001/5136 [15:23<04:24,  4.28it/s]

Avg Loss_G: 4.0152 Avg Loss_D: 0.0111


[3/19][5136/5136] Loss_D: 0.0000 Loss_G: 4.3984 : 100%|██████████| 5136/5136 [19:43<00:00,  4.34it/s]
[4/19][2001/5136] Loss_D: 0.0000 Loss_G: 3.5953 :  39%|███▉      | 2001/5136 [07:39<11:59,  4.36it/s]

Avg Loss_G: 3.9688 Avg Loss_D: 0.0000


[4/19][4001/5136] Loss_D: 0.0000 Loss_G: 4.7709 :  78%|███████▊  | 4001/5136 [15:19<04:20,  4.35it/s]

Avg Loss_G: 3.9620 Avg Loss_D: 0.0035


[4/19][5136/5136] Loss_D: 0.0000 Loss_G: 4.3521 : 100%|██████████| 5136/5136 [19:40<00:00,  4.35it/s]
[5/19][2001/5136] Loss_D: 0.0000 Loss_G: 3.6637 :  39%|███▉      | 2001/5136 [07:40<12:00,  4.35it/s]

Avg Loss_G: 3.9392 Avg Loss_D: 0.0000


[5/19][4001/5136] Loss_D: 0.0000 Loss_G: 4.0640 :  78%|███████▊  | 4001/5136 [15:20<04:22,  4.33it/s]

Avg Loss_G: 3.9279 Avg Loss_D: 0.0085


[5/19][5136/5136] Loss_D: 0.0000 Loss_G: 4.8965 : 100%|██████████| 5136/5136 [19:42<00:00,  4.34it/s]
[6/19][2001/5136] Loss_D: 0.0000 Loss_G: 4.2990 :  39%|███▉      | 2001/5136 [07:41<12:05,  4.32it/s]

Avg Loss_G: 3.8976 Avg Loss_D: 0.0000


[6/19][4001/5136] Loss_D: 0.0000 Loss_G: 3.4976 :  78%|███████▊  | 4001/5136 [15:21<04:20,  4.36it/s]

Avg Loss_G: 3.8854 Avg Loss_D: 0.0000


[6/19][5136/5136] Loss_D: 0.0000 Loss_G: 2.7696 : 100%|██████████| 5136/5136 [19:41<00:00,  4.35it/s]
[7/19][2001/5136] Loss_D: 0.0000 Loss_G: 3.7446 :  39%|███▉      | 2001/5136 [07:40<11:57,  4.37it/s]

Avg Loss_G: 3.8623 Avg Loss_D: 0.0000


[7/19][4001/5136] Loss_D: 0.0000 Loss_G: 3.7017 :  78%|███████▊  | 4001/5136 [15:20<04:38,  4.08it/s] 

Avg Loss_G: 3.8494 Avg Loss_D: 0.0057


[7/19][5136/5136] Loss_D: 0.0000 Loss_G: 3.6677 : 100%|██████████| 5136/5136 [19:42<00:00,  4.34it/s]
[8/19][251/5136] Loss_D: 0.0000 Loss_G: 3.6674 :   5%|▍         | 251/5136 [00:58<19:02,  4.28it/s]


KeyboardInterrupt: 