In [1]:
import torch
import torch.nn as nn
from models.awnet import AWNet
from utils.dataset import LoadData
from torch.utils.data import DataLoader
from config import trainConfig
import numpy as np
import imageio
import PIL.Image as Image
import time
import os
from utils.utils import validation

In [2]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.module = AWNet(4, 3, block=[3, 3, 3, 4, 4])

    def forward(self, x):
        return self.module(x)

In [3]:
config = trainConfig(
        lr=[1e-4, 5e-5, 1e-5, 5e-6, 1e-6],
        batch_size=2,
        epoch=50,
        print_loss=False,
        pretrain=True,
        data_dir='./data',
        weights_dir='./weights'
    )
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("CUDA visible devices: " + str(torch.cuda.device_count()))
print("CUDA Device Name: " + str(torch.cuda.get_device_name(device)))

net = Model().to(device)

net.load_state_dict(
        torch.load('{}/weight_best.pkl'.format(config.weights_dir), map_location=device)["model_state"])

print('weight loaded.')

CUDA visible devices: 1
CUDA Device Name: NVIDIA GeForce RTX 2080 Ti
weight loaded.


In [4]:
test_input_dir = "data/test/huawei_raw"
test_target_dir = "data/test/canon"
img_size = (224, 224)
dlsr_scale = 2
input_channels = 4
output_channels = 3
batch_size = 8


test_input_img_paths = sorted(
    [
        os.path.join(test_input_dir, fname)
        for fname in os.listdir(test_input_dir)
        if fname.endswith(".png")
    ]
)
test_target_img_paths = sorted(
    [
        os.path.join(test_target_dir, fname)
        for fname in os.listdir(test_target_dir)
        if fname.endswith(".jpg")
    ]
)


print("Number test of samples:", len(test_input_img_paths))

for input_path, target_path in zip(test_input_img_paths[:10], test_target_img_paths[:10]):
    print(input_path, "|", target_path)

Number test of samples: 1204
data/test/huawei_raw/0.png | data/test/canon/0.jpg
data/test/huawei_raw/1.png | data/test/canon/1.jpg
data/test/huawei_raw/10.png | data/test/canon/10.jpg
data/test/huawei_raw/100.png | data/test/canon/100.jpg
data/test/huawei_raw/1000.png | data/test/canon/1000.jpg
data/test/huawei_raw/1001.png | data/test/canon/1001.jpg
data/test/huawei_raw/1002.png | data/test/canon/1002.jpg
data/test/huawei_raw/1003.png | data/test/canon/1003.jpg
data/test/huawei_raw/1004.png | data/test/canon/1004.jpg
data/test/huawei_raw/1005.png | data/test/canon/1005.jpg


In [15]:
test_gen = LoadData(
    img_size, dlsr_scale, test_input_img_paths, test_target_img_paths
)
test_loader = DataLoader(
        dataset=test_gen,
        batch_size=16,
        shuffle=False,
        num_workers=16,
        pin_memory=True,
        drop_last=False
)

In [16]:
val_psnr, val_ssim = validation(
        net, 
        test_loader, 
        device, 
        save_tag=True
    )
print('PSNR: {:.4f}, SSIM: {:.4f}'.format(
        val_psnr, val_ssim))

PSNR: 21.4198, SSIM: 0.7484
