In [1]:
import warnings
warnings.filterwarnings('ignore')

import shutil
import math
import time
import os, glob
import torch
from torch.utils.data import Dataset
from PIL import Image
import json
from torchvision import transforms
import torchvision
import torchvision.models as model
import random
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

os.environ['OMP_NUM_THREADS'] = '4'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
cuda = torch.cuda.is_available()

In [2]:
# from google.colab import drive
import torch
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import numpy as np
import os, os.path

# !cat /proc/meminfo
if torch.cuda.is_available():
    print("CUDA available")

seed = 666
random.seed(seed)
np.random.seed(seed=seed)
torch.manual_seed(seed)

def log_and_print(logger, message):
    logger.write(message+'\n')
    logger.flush()
    print(message)

CUDA available


In [3]:
class ImageFolder(Dataset):
    def __init__(self, root, transform=None, split='train'):
        if split == 'train' or split == 'test':
            self.samples = sorted(glob.glob(os.path.join(root, 'seg_{}/seg_{}/*/*.jpg'.format(split, split))))
        else:
            self.samples = sorted(glob.glob(os.path.join(root, 'seg_{}/seg_{}/*.jpg'.format(split, split))))
        self.transform = transform

    def __getitem__(self, index):
        img = Image.open(self.samples[index])
        gray = img.convert('L')
        rgb = img.convert('RGB')
        if self.transform:
            gray = self.transform(gray)
            rgb = self.transform(rgb)
        return gray, rgb

    def __len__(self):
        return len(self.samples)

root = './data'
transform = transforms.Compose([transforms.Resize([128, 128]), transforms.ToTensor()])
train_dataset = ImageFolder(root, transform, 'train')
val_dataset = ImageFolder(root, transform, 'test')
test_dataset = ImageFolder(root, transform, 'pred')

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

print(len(train_dataloader))
print(len(val_dataloader))
print(len(test_dataloader))

3509
3000
7301


In [4]:
import torch
import torch.nn as  nn
import torch.nn.functional as F

class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
        super(Bottleneck, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.batch_norm1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.batch_norm2 = nn.BatchNorm2d(out_channels)
        
        self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)
        self.batch_norm3 = nn.BatchNorm2d(out_channels*self.expansion)
        
        self.i_downsample = i_downsample
        self.stride = stride
        self.relu = nn.ReLU()
        
    def forward(self, x):
        identity = x.clone()
        x = self.relu(self.batch_norm1(self.conv1(x)))
        
        x = self.relu(self.batch_norm2(self.conv2(x)))
        
        x = self.conv3(x)
        x = self.batch_norm3(x)
        
        #downsample if needed
        if self.i_downsample is not None:
            identity = self.i_downsample(identity)
        #add identity
        x+=identity
        x=self.relu(x)
        
        return x

class Block(nn.Module):
    expansion = 1
    def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
        super(Block, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
        self.batch_norm1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
        self.batch_norm2 = nn.BatchNorm2d(out_channels)

        self.i_downsample = i_downsample
        self.stride = stride
        self.relu = nn.ReLU()

    def forward(self, x):
      identity = x.clone()

      x = self.relu(self.batch_norm2(self.conv1(x)))
      x = self.batch_norm2(self.conv2(x))

      if self.i_downsample is not None:
          identity = self.i_downsample(identity)
      x += identity
      x = self.relu(x)
      return x

class ResNet(nn.Module):
    def __init__(self, ResBlock, layer_list, in_ch=1, out_ch=3):
        super(ResNet, self).__init__()
        self.in_channels = 64
        
        self.conv1 = nn.Conv2d(in_ch, 64, kernel_size=7, stride=1, padding=3, bias=False)
        self.batch_norm1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        
        self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64)
        self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128)
        self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=128)
        self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=64)

        self.conv2 = nn.Conv2d(256, 64, kernel_size=3, stride=1, padding=1)
        self.batch_norm2 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.conv3 = nn.Conv2d(64, out_ch, kernel_size=1, stride=1, padding=0)
        
    def forward(self, x):
        x = self.relu(self.batch_norm1(self.conv1(x)))

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.relu(self.batch_norm2(self.conv2(x)))
        x = self.conv3(x)
        
        return x
        
    def _make_layer(self, ResBlock, blocks, planes, stride=1):
        ii_downsample = None
        layers = []
        
        if stride != 1 or self.in_channels != planes*ResBlock.expansion:
            ii_downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, planes*ResBlock.expansion, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes*ResBlock.expansion)
            )
            
        layers.append(ResBlock(self.in_channels, planes, i_downsample=ii_downsample, stride=stride))
        self.in_channels = planes*ResBlock.expansion
        
        for i in range(blocks-1):
            layers.append(ResBlock(self.in_channels, planes))
            
        return nn.Sequential(*layers)
        
def ColorResNet(in_ch, out_ch):
    return ResNet(Bottleneck, [3,4,6,3], in_ch, out_ch)

In [5]:
model = ColorResNet(in_ch=1, out_ch=3)
model = model.cuda()

In [6]:
# define optimizer, scheduler, and loss
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, min_lr=1e-6)

In [7]:
# define metrics
import cv2
import numpy as np
import math

def calculate_psnr(img1, img2, data_range):
    # img1 and img2 have range [0, 255]
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    mse = np.mean((img1 - img2)**2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10(data_range / math.sqrt(mse))

def ssim(img1, img2, data_range):
    C1 = (0.01 * data_range)**2
    C2 = (0.03 * data_range)**2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                            (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()

def calculate_ssim(img1, img2, data_range):
    '''calculate SSIM
    the same outputs as MATLAB's
    img1, img2: [0, 255]
    '''
    if not img1.shape == img2.shape:
        raise ValueError('Input images must have the same dimensions.')
    if img1.ndim == 2:
        return ssim(img1, img2, data_range)
    elif img1.ndim == 3:
        if img1.shape[2] == 3:
            ssims = []
            for i in range(3):
                ssims.append(ssim(img1, img2, data_range))
            return np.array(ssims).mean()
        elif img1.shape[2] == 1:
            return ssim(np.squeeze(img1), np.squeeze(img2), data_range)
    else:
        raise ValueError('Wrong input image dimensions.')

In [8]:
#tensorboard
tb_logger_dir = './tb_logger/'
exp_name = 'ColorResNet'
tb_logger = SummaryWriter(log_dir=tb_logger_dir + exp_name)

# checkpoint
experiment_dir = os.path.join('./experiments', exp_name)
if not os.path.exists(experiment_dir):
    os.makedirs(experiment_dir)

# logger
logger = open(os.path.join(experiment_dir, 'train.log'), 'w')

# visual
valid_dir = os.path.join(experiment_dir, 'valid')
# shutil.rmtree(valid_dir)
if not os.path.exists(valid_dir):
    os.makedirs(valid_dir)

test_dir = os.path.join(experiment_dir, 'test')
# shutil.rmtree(test_dir)
if not os.path.exists(test_dir):
    os.makedirs(test_dir)

# print the model
print(model)

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
  (batch_norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      (batch_norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (batch_norm2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
      (batch_norm3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (i_downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU()
    )
  

In [9]:
n_epochs = 20
current_step = 0
valid_loss_min = math.inf

start = time.time()
for epoch in range(n_epochs):
    log_and_print(logger, 'Training epoch: {}, lr: {:e}'.format(epoch, optimizer.param_groups[0]['lr']))
    train_loss = 0.0
    model.train()
    for i, data in enumerate(train_dataloader):
        gray, rgb = data
        gray, rgb = gray.cuda(), rgb.cuda()
        optimizer.zero_grad()
        color = model(gray)
        loss = criterion(color, rgb)
        loss.backward()
        optimizer.step()

        tb_logger.add_scalar('loss_train', loss.item(), current_step)
        current_step += 1
        if current_step % 100 == 0:
            elapsed_time = (time.time() - start) / 3600
            log_and_print(logger, 'time = {:.2f}h | epoch = {:d} | current step = {:d} | train loss = {:.4f}'.format(elapsed_time, epoch, current_step, loss.item()))
        train_loss += loss.item()

    train_loss /= len(train_dataloader)
    log_and_print(logger, 'Average train loss: {:.6f}\n'.format(train_loss))

    save_dir = os.path.join(valid_dir, '{:03d}'.format(epoch))
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    log_and_print(logger, 'Validating epoch: {}'.format(epoch))
    valid_loss = 0.0
    valid_psnr_list = []
    valid_ssim_list = []
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(val_dataloader):
            gray, rgb = data
            gray, rgb = gray.cuda(), rgb.cuda()
            color = model(gray)
            loss = criterion(color, rgb)
            
            gray_PIL = transforms.ToPILImage()(gray[0].clamp(0, 1))
            color_PIL = transforms.ToPILImage()(color[0].clamp(0, 1))
            rgb_PIL = transforms.ToPILImage()(rgb[0].clamp(0, 1))
            color = np.array(color_PIL)
            rgb = np.array(rgb_PIL)
            p = calculate_psnr(color, rgb, 255.0)
            s = calculate_ssim(color, rgb, 255.0)
            valid_psnr_list.append(p)
            valid_ssim_list.append(s)

            # save img
            if i % 500 == 0:
                gray_PIL.save(os.path.join(save_dir, '{:04d}_input.jpg'.format(i)))
                color_PIL.save(os.path.join(save_dir, '{:04d}_output_{:.3f}dB_{:.4f}.jpg'.format(i, p, s)))
                rgb_PIL.save(os.path.join(save_dir, '{:04d}_gt.jpg'.format(i)))
                elapsed_time = (time.time() - start) / 3600
                log_and_print(logger, 'time = {:.2f}h | epoch = {:d} | i = {:d} | valid loss = {:.4f} | valid psnr = {:.3f}dB | valid ssim = {:.4f}'.format(elapsed_time, epoch, i, loss.item(), p, s))
            valid_loss += loss.item()

    valid_loss /= len(val_dataloader)
    valid_psnr = sum(valid_psnr_list) / len(valid_psnr_list)
    valid_ssim = sum(valid_ssim_list) / len(valid_ssim_list)
    lr_scheduler.step(valid_loss)
    tb_logger.add_scalar('loss_valid', valid_loss, epoch)
    tb_logger.add_scalar('psnr_valid', valid_psnr, epoch)
    tb_logger.add_scalar('ssim_valid', valid_ssim, epoch)
    log_and_print(logger, 'Average valid loss: {:.6f} | PSNR: {:.3f}dB | SSIM: {:.4f}'.format(valid_loss, valid_psnr, valid_ssim))

    if valid_loss < valid_loss_min:
        valid_loss_min = valid_loss
        save_dict = {
            "epoch": epoch,
            "iter": current_step,
            "state_dict": model.state_dict(),
            "loss": valid_loss_min,
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict(),
        }
        save_path = os.path.join(experiment_dir, 'checkpoint_best_loss.pth.tar')
        torch.save(save_dict, save_path)
        log_and_print(logger, 'best checkpoint saved!')
        
    if (epoch + 1) % 5 == 0:
        save_dict = {
            "epoch": epoch,
            "iter": current_step,
            "state_dict": model.state_dict(),
            "loss": valid_loss_min,
            "optimizer": optimizer.state_dict(),
            "lr_scheduler": lr_scheduler.state_dict(),
        }
        save_path = os.path.join(experiment_dir, 'checkpoint_epoch_{:03d}.pth.tar'.format(epoch))
        torch.save(save_dict, save_path)
        log_and_print(logger, 'checkpoint saved at epoch {:03d}'.format(epoch))
        
    log_and_print(logger, '')

Training epoch: 0, lr: 1.000000e-04
time = 0.01h | epoch = 0 | current step = 100 | train loss = 0.0284
time = 0.02h | epoch = 0 | current step = 200 | train loss = 0.0167
time = 0.03h | epoch = 0 | current step = 300 | train loss = 0.0106
time = 0.04h | epoch = 0 | current step = 400 | train loss = 0.0271
time = 0.05h | epoch = 0 | current step = 500 | train loss = 0.0083
time = 0.06h | epoch = 0 | current step = 600 | train loss = 0.0131
time = 0.07h | epoch = 0 | current step = 700 | train loss = 0.0080
time = 0.08h | epoch = 0 | current step = 800 | train loss = 0.0076
time = 0.09h | epoch = 0 | current step = 900 | train loss = 0.0082
time = 0.10h | epoch = 0 | current step = 1000 | train loss = 0.0072
time = 0.11h | epoch = 0 | current step = 1100 | train loss = 0.0064
time = 0.12h | epoch = 0 | current step = 1200 | train loss = 0.0055
time = 0.14h | epoch = 0 | current step = 1300 | train loss = 0.0048
time = 0.15h | epoch = 0 | current step = 1400 | train loss = 0.0054
time = 

time = 1.13h | epoch = 2 | current step = 9700 | train loss = 0.0047
time = 1.14h | epoch = 2 | current step = 9800 | train loss = 0.0052
time = 1.15h | epoch = 2 | current step = 9900 | train loss = 0.0048
time = 1.16h | epoch = 2 | current step = 10000 | train loss = 0.0096
time = 1.17h | epoch = 2 | current step = 10100 | train loss = 0.0075
time = 1.18h | epoch = 2 | current step = 10200 | train loss = 0.0027
time = 1.20h | epoch = 2 | current step = 10300 | train loss = 0.0054
time = 1.21h | epoch = 2 | current step = 10400 | train loss = 0.0043
time = 1.22h | epoch = 2 | current step = 10500 | train loss = 0.0039
Average train loss: 0.006103

Validating epoch: 2
time = 1.22h | epoch = 2 | i = 0 | valid loss = 0.0050 | valid psnr = 22.943dB | valid ssim = 0.9450
time = 1.23h | epoch = 2 | i = 500 | valid loss = 0.0092 | valid psnr = 20.399dB | valid ssim = 0.8185
time = 1.23h | epoch = 2 | i = 1000 | valid loss = 0.0033 | valid psnr = 24.795dB | valid ssim = 0.9334
time = 1.24h | 

time = 2.17h | epoch = 5 | current step = 18100 | train loss = 0.0024
time = 2.18h | epoch = 5 | current step = 18200 | train loss = 0.0051
time = 2.19h | epoch = 5 | current step = 18300 | train loss = 0.0044
time = 2.20h | epoch = 5 | current step = 18400 | train loss = 0.0059
time = 2.21h | epoch = 5 | current step = 18500 | train loss = 0.0042
time = 2.22h | epoch = 5 | current step = 18600 | train loss = 0.0097
time = 2.24h | epoch = 5 | current step = 18700 | train loss = 0.0053
time = 2.25h | epoch = 5 | current step = 18800 | train loss = 0.0035
time = 2.26h | epoch = 5 | current step = 18900 | train loss = 0.0048
time = 2.27h | epoch = 5 | current step = 19000 | train loss = 0.0025
time = 2.28h | epoch = 5 | current step = 19100 | train loss = 0.0035
time = 2.29h | epoch = 5 | current step = 19200 | train loss = 0.0045
time = 2.30h | epoch = 5 | current step = 19300 | train loss = 0.0021
time = 2.31h | epoch = 5 | current step = 19400 | train loss = 0.0069
time = 2.32h | epoch

time = 3.29h | epoch = 7 | current step = 27600 | train loss = 0.0066
time = 3.30h | epoch = 7 | current step = 27700 | train loss = 0.0039
time = 3.31h | epoch = 7 | current step = 27800 | train loss = 0.0021
time = 3.32h | epoch = 7 | current step = 27900 | train loss = 0.0056
time = 3.33h | epoch = 7 | current step = 28000 | train loss = 0.0043
Average train loss: 0.005107

Validating epoch: 7
time = 3.34h | epoch = 7 | i = 0 | valid loss = 0.0045 | valid psnr = 23.430dB | valid ssim = 0.9543
time = 3.35h | epoch = 7 | i = 500 | valid loss = 0.0103 | valid psnr = 19.922dB | valid ssim = 0.8111
time = 3.36h | epoch = 7 | i = 1000 | valid loss = 0.0044 | valid psnr = 23.384dB | valid ssim = 0.9266
time = 3.36h | epoch = 7 | i = 1500 | valid loss = 0.0030 | valid psnr = 25.640dB | valid ssim = 0.9186
time = 3.37h | epoch = 7 | i = 2000 | valid loss = 0.0050 | valid psnr = 22.912dB | valid ssim = 0.9419
time = 3.38h | epoch = 7 | i = 2500 | valid loss = 0.0020 | valid psnr = 26.832dB | 

time = 4.44h | epoch = 10 | current step = 36000 | train loss = 0.0064
time = 4.45h | epoch = 10 | current step = 36100 | train loss = 0.0033
time = 4.46h | epoch = 10 | current step = 36200 | train loss = 0.0043
time = 4.47h | epoch = 10 | current step = 36300 | train loss = 0.0036
time = 4.48h | epoch = 10 | current step = 36400 | train loss = 0.0020
time = 4.49h | epoch = 10 | current step = 36500 | train loss = 0.0038
time = 4.50h | epoch = 10 | current step = 36600 | train loss = 0.0051
time = 4.51h | epoch = 10 | current step = 36700 | train loss = 0.0025
time = 4.52h | epoch = 10 | current step = 36800 | train loss = 0.0034
time = 4.54h | epoch = 10 | current step = 36900 | train loss = 0.0101
time = 4.55h | epoch = 10 | current step = 37000 | train loss = 0.0056
time = 4.56h | epoch = 10 | current step = 37100 | train loss = 0.0027
time = 4.57h | epoch = 10 | current step = 37200 | train loss = 0.0027
time = 4.58h | epoch = 10 | current step = 37300 | train loss = 0.0056
time =

time = 5.54h | epoch = 12 | current step = 45400 | train loss = 0.0113
time = 5.56h | epoch = 12 | current step = 45500 | train loss = 0.0032
time = 5.57h | epoch = 12 | current step = 45600 | train loss = 0.0033
Average train loss: 0.004805

Validating epoch: 12
time = 5.57h | epoch = 12 | i = 0 | valid loss = 0.0024 | valid psnr = 26.111dB | valid ssim = 0.9635
time = 5.58h | epoch = 12 | i = 500 | valid loss = 0.0086 | valid psnr = 20.710dB | valid ssim = 0.8279
time = 5.58h | epoch = 12 | i = 1000 | valid loss = 0.0035 | valid psnr = 24.435dB | valid ssim = 0.9302
time = 5.59h | epoch = 12 | i = 1500 | valid loss = 0.0018 | valid psnr = 27.370dB | valid ssim = 0.9330
time = 5.60h | epoch = 12 | i = 2000 | valid loss = 0.0021 | valid psnr = 26.769dB | valid ssim = 0.9563
time = 5.61h | epoch = 12 | i = 2500 | valid loss = 0.0017 | valid psnr = 27.660dB | valid ssim = 0.9730
Average valid loss: 0.004756 | PSNR: 25.061dB | SSIM: 0.9328

Training epoch: 13, lr: 1.000000e-04
time = 5.62

time = 6.61h | epoch = 15 | current step = 53600 | train loss = 0.0030
time = 6.62h | epoch = 15 | current step = 53700 | train loss = 0.0047
time = 6.63h | epoch = 15 | current step = 53800 | train loss = 0.0081
time = 6.64h | epoch = 15 | current step = 53900 | train loss = 0.0089
time = 6.65h | epoch = 15 | current step = 54000 | train loss = 0.0026
time = 6.67h | epoch = 15 | current step = 54100 | train loss = 0.0053
time = 6.68h | epoch = 15 | current step = 54200 | train loss = 0.0022
time = 6.69h | epoch = 15 | current step = 54300 | train loss = 0.0031
time = 6.70h | epoch = 15 | current step = 54400 | train loss = 0.0046
time = 6.71h | epoch = 15 | current step = 54500 | train loss = 0.0048
time = 6.72h | epoch = 15 | current step = 54600 | train loss = 0.0117
time = 6.73h | epoch = 15 | current step = 54700 | train loss = 0.0061
time = 6.74h | epoch = 15 | current step = 54800 | train loss = 0.0068
time = 6.75h | epoch = 15 | current step = 54900 | train loss = 0.0016
time =

time = 7.80h | epoch = 17 | current step = 63000 | train loss = 0.0048
time = 7.81h | epoch = 17 | current step = 63100 | train loss = 0.0021
Average train loss: 0.004632

Validating epoch: 17
time = 7.82h | epoch = 17 | i = 0 | valid loss = 0.0032 | valid psnr = 24.849dB | valid ssim = 0.9621
time = 7.83h | epoch = 17 | i = 500 | valid loss = 0.0094 | valid psnr = 20.336dB | valid ssim = 0.8247
time = 7.84h | epoch = 17 | i = 1000 | valid loss = 0.0038 | valid psnr = 24.046dB | valid ssim = 0.9295
time = 7.84h | epoch = 17 | i = 1500 | valid loss = 0.0023 | valid psnr = 26.638dB | valid ssim = 0.9308
time = 7.85h | epoch = 17 | i = 2000 | valid loss = 0.0026 | valid psnr = 25.791dB | valid ssim = 0.9511
time = 7.86h | epoch = 17 | i = 2500 | valid loss = 0.0014 | valid psnr = 28.532dB | valid ssim = 0.9755
Average valid loss: 0.004676 | PSNR: 25.056dB | SSIM: 0.9317

Training epoch: 18, lr: 1.000000e-04
time = 7.87h | epoch = 18 | current step = 63200 | train loss = 0.0021
time = 7.88

In [10]:
test_loss = 0.0
test_psnr_list = []
test_ssim_list = []

model = ColorResNet(in_ch=1, out_ch=3).cuda()
checkpoint = torch.load('./experiments/ColorResNet/checkpoint_best_loss.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
model.eval()
with torch.no_grad():
    for i, data in enumerate(test_dataloader):
        gray, rgb = data
        gray, rgb = gray.cuda(), rgb.cuda()
        color = model(gray)
        loss = criterion(color, rgb)

        gray_PIL = transforms.ToPILImage()(gray[0].clamp(0, 1))
        color_PIL = transforms.ToPILImage()(color[0].clamp(0, 1))
        rgb_PIL = transforms.ToPILImage()(rgb[0].clamp(0, 1))
        color = np.array(color_PIL)
        rgb = np.array(rgb_PIL)
        p = calculate_psnr(color, rgb, 255.0)
        s = calculate_ssim(color, rgb, 255.0)
        test_psnr_list.append(p)
        test_ssim_list.append(s)

        gray_PIL.save(os.path.join(test_dir, '{:04d}_input.jpg'.format(i)))
        color_PIL.save(os.path.join(test_dir, '{:04d}_output_{:.3f}dB_{:.4f}.jpg'.format(i, p, s)))
        rgb_PIL.save(os.path.join(test_dir, '{:04d}_gt.jpg'.format(i)))
        test_loss += loss.item()

test_loss /= len(test_dataloader)
test_psnr = sum(test_psnr_list) / len(test_psnr_list)
test_ssim = sum(test_ssim_list) / len(test_ssim_list)
log_and_print(logger, 'Average test loss: {:.6f} | PSNR: {:.3f}dB | SSIM: {:.4f}'.format(test_loss, test_psnr, test_ssim))

Average test loss: 0.004609 | PSNR: 25.449dB | SSIM: 0.9316
