In [1]:
import time
import torch
from torch import nn
import torch.backends.cudnn as cudnn
from models import SRResNet
from datasets import SRDataset
from utils import *

In [2]:
# Data parameters
data_folder = 'data'
crop_size = 96  # crop size of target HR images
scaling_factor = 4  # the scaling factor for the generator; the input LR images will be downsampled from the target HR images by this factor

# Model parameters
large_kernel_size = 9  # kernel size of the first and last convolutions which transform the inputs and outputs
small_kernel_size = 3  # kernel size of all convolutions in-between, i.e. those in the residual and subpixel convolutional blocks
n_channels = 64  # number of channels in-between, i.e. the input and output channels for the residual and subpixel convolutional blocks
n_blocks = 16  # number of residual blocks

# Learning parameters
checkpoint = None  # path to model checkpoint, None if none
batch_size = 16  # batch size
start_epoch = 0  # start at this epoch
iterations = 1e6  # number of training iterations
workers = 4  # number of workers for loading data in the DataLoader
print_freq = 500  # print training status once every __ batches
lr = 1e-4  # learning rate
grad_clip = None  # clip if gradients are exploding

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cudnn.benchmark = True

In [3]:
def train(train_loader, model, criterion, optimizer, epoch):
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    start = time.time()
    for i, (lr_imgs, hr_imgs) in enumerate(train_loader):
        data_time.update(time.time() - start)

        lr_imgs = lr_imgs.to(device)  # (batch_size (N), 3, 24, 24), imagenet-normed
        hr_imgs = hr_imgs.to(device)  # (batch_size (N), 3, 96, 96), in [-1, 1]

        sr_imgs = model(lr_imgs)  # (N, 3, 96, 96), in [-1, 1]
        loss = criterion(sr_imgs, hr_imgs)  # scalar

        optimizer.zero_grad()
        loss.backward()

        if grad_clip is not None:
            clip_gradient(optimizer, grad_clip)


        optimizer.step()
        losses.update(loss.item(), lr_imgs.size(0))

        batch_time.update(time.time() - start)
        start = time.time()

        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]----'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})----'
                  'Data Time {data_time.val:.3f} ({data_time.avg:.3f})----'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(epoch, i, len(train_loader),
                                                                    batch_time=batch_time,
                                                                    data_time=data_time, loss=losses))
    del lr_imgs, hr_imgs, sr_imgs


In [4]:
def main():
    global start_epoch, epoch, checkpoint

    if checkpoint is None:
        model = SRResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size,
                         n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor)
        optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    model = model.to(device)
    criterion = nn.MSELoss().to(device)

    train_dataset = SRDataset(data_folder, split='train', crop_size=crop_size, scaling_factor=scaling_factor,
                              lr_img_type='imagenet-norm', hr_img_type='[-1, 1]')
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
    epochs = int(iterations // len(train_loader) + 1)

    for epoch in range(start_epoch, epochs):
        train(train_loader=train_loader, model=model, criterion=criterion, optimizer=optimizer, epoch=epoch)
        torch.save({'epoch': epoch, 'model': model, 'optimizer': optimizer}, 'checkpoint_srresnet.pth.tar')

if __name__ == '__main__':
    main()

FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/batu/.pyenv/versions/3.7.0/envs/vispro/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/batu/.pyenv/versions/3.7.0/envs/vispro/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/batu/.pyenv/versions/3.7.0/envs/vispro/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/batu/dev/vision/project/res/srresnet/datasets.py", line 58, in __getitem__
    img = Image.open(self.images[i], mode='r')
  File "/home/batu/.pyenv/versions/3.7.0/envs/vispro/lib/python3.7/site-packages/PIL/Image.py", line 2843, in open
    fp = builtins.open(filename, "rb")
FileNotFoundError: [Errno 2] No such file or directory: '0046.png'


In [16]:
import os
images = os.listdir(os.path.join(data_folder, 'train/hr'))

'data/train/hr/'

In [19]:
images

['data/train/hr/0027.png',
 'data/train/hr/0004.png',
 'data/train/hr/0012.png',
 'data/train/hr/0033.png',
 'data/train/hr/0039.png',
 'data/train/hr/0026.png',
 'data/train/hr/0046.png',
 'data/train/hr/0010.png',
 'data/train/hr/0009.png',
 'data/train/hr/0044.png',
 'data/train/hr/0005.png',
 'data/train/hr/0008.png',
 'data/train/hr/0020.png',
 'data/train/hr/0001.png',
 'data/train/hr/0029.png',
 'data/train/hr/0035.png',
 'data/train/hr/0032.png',
 'data/train/hr/0048.png',
 'data/train/hr/0019.png',
 'data/train/hr/0043.png',
 'data/train/hr/0047.png',
 'data/train/hr/0013.png',
 'data/train/hr/0007.png',
 'data/train/hr/0037.png',
 'data/train/hr/0049.png',
 'data/train/hr/0025.png',
 'data/train/hr/0036.png',
 'data/train/hr/0003.png',
 'data/train/hr/0045.png',
 'data/train/hr/0017.png',
 'data/train/hr/0031.png',
 'data/train/hr/0038.png',
 'data/train/hr/0030.png',
 'data/train/hr/0023.png',
 'data/train/hr/0006.png',
 'data/train/hr/0042.png',
 'data/train/hr/0018.png',
 