In [1]:
import time
import torch
from torch import nn
import torch.backends.cudnn as cudnn
from models import SRResNet
from datasets import SRDataset
from dataset import LowDataset
from utils import *
import matplotlib.pyplot as plt
import numpy as np

In [5]:
# Data parameters
data_folder = 'data'
crop_size = 96  # crop size of target HR images
scaling_factor = 2  # the scaling factor for the generator; the input LR images will be downsampled from the target HR images by this factor

# Model parameters
large_kernel_size = 9  # kernel size of the first and last convolutions which transform the inputs and outputs
small_kernel_size = 3  # kernel size of all convolutions in-between, i.e. those in the residual and subpixel convolutional blocks
n_channels = 64  # number of channels in-between, i.e. the input and output channels for the residual and subpixel convolutional blocks
n_blocks = 16  # number of residual blocks

# Learning parameters
checkpoint = None  # path to model checkpoint, None if none
batch_size = 16  # batch size
start_epoch = 0  # start at this epoch
iterations = 1e6  # number of training iterations
workers = 4  # number of workers for loading data in the DataLoader
print_freq = 1
lr = 1e-4
grad_clip = None

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cudnn.benchmark = True

cpu


In [6]:
def train(train_loader, model, criterion, optimizer, epoch):
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    start = time.time()
    for i, (lr_imgs, hr_imgs) in enumerate(train_loader):
        data_time.update(time.time() - start)

        lr_imgs = lr_imgs.to(device)  # (batch_size (N), 3, 24, 24), imagenet-normed
        hr_imgs = hr_imgs.to(device)  # (batch_size (N), 3, 96, 96), in [-1, 1]

        sr_imgs = model(lr_imgs)  # (N, 3, 96, 96), in [-1, 1]
        #plt.imshow(sr_imgs[0].detach().permute(1, 2, 0))
        #plt.show()
        loss = criterion(sr_imgs, hr_imgs)  # scalar

        optimizer.zero_grad()
        loss.backward()

        if grad_clip is not None:
            clip_gradient(optimizer, grad_clip)


        optimizer.step()
        losses.update(loss.item(), lr_imgs.size(0))

        batch_time.update(time.time() - start)
        start = time.time()

        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]----'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})----'
                  'Data Time {data_time.val:.3f} ({data_time.avg:.3f})----'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(epoch, i, len(train_loader),
                                                                    batch_time=batch_time,
                                                                    data_time=data_time, loss=losses))
    del lr_imgs, hr_imgs, sr_imgs


In [None]:
def main():
    global start_epoch, epoch, checkpoint

    if checkpoint is None:
        model = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size,
                         n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor)
        optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    model = model.to(device)
    criterion = nn.MSELoss().to(device)

    """train_dataset = SRDataset(data_folder, split='train', crop_size=crop_size, scaling_factor=scaling_factor,
                              lr_img_type='imagenet-norm', hr_img_type='[-1, 1]')"""
    train_dataset = LowDataset()
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    epochs = int(iterations // len(train_loader) + 1)

    for epoch in range(start_epoch, epochs):
        train(train_loader=train_loader, model=model, criterion=criterion, optimizer=optimizer, epoch=epoch)
        torch.save({'epoch': epoch, 'model': model, 'optimizer': optimizer}, 'checkpoint_srresnet.pth.tar')

if __name__ == '__main__':
    main()

Epoch: [0][0/50]----Batch Time 13.393 (13.393)----Data Time 1.641 (1.641)----Loss 0.3217 (0.3217)
