## 1. Import Packages

In [26]:
import os
from utils import psnr, save_plot, save_model, save_validation_results, create_patches, make_inference_images, get_dataloaders
from model import SRCNN
import torch
import time
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
from torchvision.utils import save_image

## 2. Arguments

In [2]:
args = dict()
args['train_image_path'] = './data/train/image'
args['test_image_path'] = './data/test/image'
args['result_dir'] = './result'
args['scale_factor'] = '2x'
args['lr'] = 0.001
args['batch_size'] = 16
args['epoch'] = 100

In [3]:
os.makedirs(args['result_dir'], exist_ok=True)
os.makedirs(os.path.join(args['result_dir'], 'valid_results'), exist_ok=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Computation device: ', device)

Computation device:  cuda


## 3. Make Patches from Image dataset

In [4]:
hr_path = args['train_image_path'].replace('image', 'hr_patches')
lr_path = args['train_image_path'].replace('image', 'lr_patches')
create_patches([args['train_image_path']], hr_path, lr_path, SHOW_PATCHES=False, STRIDE=14, SIZE=32)

Creating patches for 13 images


100%|██████████| 13/13 [00:06<00:00,  2.10it/s]


## 4. Make Test Dataset

In [5]:
test_input_path, test_label_path = make_inference_images([args['test_image_path']], scale_factor_type=args['scale_factor'])

2
Scaling factor: 2x
Low resolution images save path: ./data/test/test_bicubic_rgb_2x
Original image dimensions: 640, 480
Original image dimensions: 640, 480


## 5. Construct Torch dataloader

In [6]:
train_loader, valid_loader = get_dataloaders(train_image_paths=lr_path, train_label_paths=hr_path, valid_image_path=test_input_path, valid_label_paths=test_label_path, TRAIN_BATCH_SIZE=args['batch_size'], TEST_BATCH_SIZE=1)

Training samples: 18876
Validation samples: 2


# 6. Build Model

In [7]:
model = SRCNN().to(device)
print(model)

SRCNN(
  (conv1): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
)


## 7. Optimizer, Loss

In [8]:
optimizer = optim.Adam(model.parameters(), lr=args['lr'])
criterion = nn.MSELoss()

## 8. Train

In [15]:
SAVE_VALIDATION_RESULTS = True

def train(model, dataloader):
    model.train()
    running_loss = 0.0
    running_psnr = 0.0
    for bi, data in tqdm(enumerate(dataloader), total=len(dataloader)):
        image_data = data[0].to(device)
        label = data[1].to(device)
        
        # Zero grad the optimizer.
        optimizer.zero_grad()
        outputs = model(image_data)
        loss = criterion(outputs, label)
        # Backpropagation.
        loss.backward()
        # Update the parameters.
        optimizer.step()
        # Add loss of each item (total items in a batch = batch size).
        running_loss += loss.item()
        # Calculate batch psnr (once every `batch_size` iterations).
        batch_psnr =  psnr(label, outputs)
        running_psnr += batch_psnr
    final_loss = running_loss/len(dataloader.dataset)
    final_psnr = running_psnr/len(dataloader)
    return final_loss, final_psnr

    
def validate(model, dataloader, epoch, path):
    model.eval()
    running_loss = 0.0
    running_psnr = 0.0
    with torch.no_grad():
        for bi, data in tqdm(enumerate(dataloader), total=len(dataloader)):
            image_data = data[0].to(device)
            label = data[1].to(device)
            
            outputs = model(image_data)
            loss = criterion(outputs, label)
            # Add loss of each item (total items in a batch = batch size) .
            running_loss += loss.item()
            # Calculate batch psnr (once every `batch_size` iterations).
            batch_psnr = psnr(label, outputs)
            running_psnr += batch_psnr
            # For saving the batch samples for the validation results
            # every 500 epochs.
            if SAVE_VALIDATION_RESULTS and (epoch % 5) == 0:
                save_validation_results(outputs, epoch, bi, path)

    final_loss = running_loss/len(dataloader.dataset)
    final_psnr = running_psnr/len(dataloader)
    return final_loss, final_psnr


In [16]:
train_loss, val_loss = [], []
train_psnr, val_psnr = [], []
start = time.time()
for epoch in range(args['epoch']):
    print(f"Epoch {epoch + 1} of {args['epoch']}")
    train_epoch_loss, train_epoch_psnr = train(model, train_loader)
    val_epoch_loss, val_epoch_psnr = validate(model, valid_loader, epoch+1, os.path.join(args['result_dir'], 'valid_results'))
    print(f"Train PSNR: {train_epoch_psnr:.3f}")
    print(f"Val PSNR: {val_epoch_psnr:.3f}")
    train_loss.append(train_epoch_loss)
    train_psnr.append(train_epoch_psnr)
    val_loss.append(val_epoch_loss)
    val_psnr.append(val_epoch_psnr)
    
    # Save model with all information every 100 epochs. Can be used 
    # resuming training.
    if (epoch+1) % 100 == 0:
        save_model(epoch, model, optimizer, criterion, args['result_dir'])
    # Save the PSNR and loss plots every epoch.
    save_plot(train_loss, val_loss, train_psnr, val_psnr, args['result_dir'])
end = time.time()
print(f"Finished training in: {((end-start)/60):.3f} minutes") 

Epoch 1 of 100


100%|██████████| 1180/1180 [00:09<00:00, 127.26it/s]
100%|██████████| 2/2 [00:00<00:00, 52.59it/s]


Train PSNR: 35.166
Val PSNR: 35.062
Epoch 2 of 100


 12%|█▏        | 143/1180 [00:01<00:07, 130.20it/s]


KeyboardInterrupt: 

## 9. Test

In [28]:
def inference(model, dataloader, path):
    model.eval()
    running_loss = 0.0
    running_psnr = 0.0
    with torch.no_grad():
        for bi, data in tqdm(enumerate(dataloader), total=len(dataloader)):
            image_data = data[0].to(device)
            outputs = model(image_data)
            save_path = os.path.join(path, os.listdir(test_input_path)[bi])
            save_image(outputs, save_path)
            save_validation_results(outputs, epoch, bi, path)

In [29]:
os.makedirs(os.path.join(args['result_dir'], 'final'), exist_ok=True)
inference(model, valid_loader, os.path.join(args['result_dir'], 'final'))

100%|██████████| 2/2 [00:00<00:00,  5.10it/s]
