* The following code is modified based on the original torchcv project. 
* We are going to use PASACAL VOC12 as dataset.
* You could donwload VOC2012 
  train/validation: http://host.robots.ox.ac.uk/pascal/VOC/voc2012/#devkit
* test data: https://pjreddie.com/projects/pascal-voc-dataset-mirror/


In [1]:
import os
import cv2
import math
import random
import pickle
import argparse
import matplotlib.pyplot as plt
from PIL import Image
from time import time
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as T

from __future__ import print_function
from dbpn import Net as DBPN
from dbpn import get_pair_set
from ssd import SSD
from ssd import build_ssd
from ssd.layers.modules import MultiBoxLoss
from ssd.data.config import voc
from ssd.data import detection_collate
from ssd.data import VOCAnnotationTransform, VOCDetection, BaseTransform, SRDetection, DBPNLoader
from ssd.eval import test_net
from metric import *

import sys; sys.argv=['']; del sys

# clean up device
torch.cuda.empty_cache()


In [2]:
# Global arguments & settings
parser = argparse.ArgumentParser(description='PyTorch Super Resolution Detection Networks')
parser.add_argument('--upscale_factor', type=int, default=4, help="super resolution upscale factor")
parser.add_argument('--testBatchSize', type=int, default=1, help='testing batch size')
parser.add_argument('--batch_size', type=int, default=2, help='training batch size') # GPU: 9GB!
parser.add_argument('--threads', type=int, default=2, help='number of threads for data loading')
parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
parser.add_argument('--gpu_mode', type=bool, default=True)
parser.add_argument('--gpus', default=1, type=float, help='number of gpu')
#parser.add_argument('--test_dataset', type=str, default='VOC12-LR-X8-test')
#parser.add_argument('--sr_dataset', type=str, default='VOC12-SR-X8')
parser.add_argument('--train_dataset', type=str, default='VOC12-LR-x4')
parser.add_argument('--hr_dataset', type=str, default='VOC12-HR')
parser.add_argument('--anno_path', type=str, default='Annotations')
parser.add_argument('--imSetpath', type=str, default='ImageSets')
parser.add_argument('--input_dir', type=str, default='./dataset')
#parser.add_argument('--output', default='./dataset/results', help='Location to save some outputs')
parser.add_argument('--checkpt', default='./checkpoint', help='Location to save checkpoint models')
parser.add_argument('--lr', type=float, default=1e-4, help='Learning Rate. Default=0.0001')
parser.add_argument('--nEpochs', type=int, default=10, help='number of epochs to fine tune net S over target loss')

opt = parser.parse_args()

gpus_list=range(opt.gpus)
print(opt)

cuda = opt.gpu_mode
if cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")

torch.manual_seed(opt.seed)
if cuda:
    torch.cuda.manual_seed(opt.seed)
    torch.cuda.empty_cache()

Epochs = opt.nEpochs
tbz = opt.batch_size
num_classes = 21           # 20 (VOC0712) +1 for background



Namespace(anno_path='Annotations', batch_size=2, checkpt='./checkpoint', gpu_mode=True, gpus=1, hr_dataset='VOC12-HR', imSetpath='ImageSets', input_dir='./dataset', lr=0.0001, nEpochs=10, seed=123, testBatchSize=1, threads=2, train_dataset='VOC12-LR-x4', upscale_factor=4)


In [3]:
# Task validation Setup - We use VOC07 Detection validation for our SROD network

HOME = os.path.expanduser("~")
voc_root = os.path.join(HOME, "data/VOCdevkit/")

annopath = os.path.join(voc_root, 'VOC2007', 'Annotations', '%s.xml')
imgpath = os.path.join(voc_root, 'VOC2007', 'JPEGImages', '%s.jpg')
imgsetpath = os.path.join(voc_root, 'VOC2007', 'ImageSets', 'Main', '{:s}.txt')
YEAR = '2007'
devkit_path = voc_root + 'VOC' + YEAR
# VOC0712 image channel MEANS
dataset_mean = (104, 117, 123)

# Random picked detection network end-task validation set,
# N.B. it is not for SROD (DBPN+SSD) training target Loss.
set_type = '07val32-5702'
# set_typs is explained below:
"""
We have 5 validation sets to choose, size of 16, 32, 64, 128, 256 repectively,
Due to the time & computation constraint, we have these 5 downsampled val set.
They are randomly(no replacement) picked from  VOC2007/ImageSets/Main/val.txt.

Baseline mAP on Pre-trained ssd300 is measured as well:       (mAP)
These Detection baseline are measured on frozen SSD300 network using
'weights/ssd300_mAP_77.43_v2.pth', for 75x75 LR images been SR-ed by
VOC12 retrained DBPN baseline w. weight 'VOC12-LR-x4-DBPN-ep100.pth'.

~/data/VOCdevkit/VOC2007/ImageSets/Main/07val16-5648.txt      56.48%
~/data/VOCdevkit/VOC2007/ImageSets/Main/07val32-5702.txt      57.02%
~/data/VOCdevkit/VOC2007/ImageSets/Main/07val64-5826.txt      58.26%
~/data/VOCdevkit/VOC2007/ImageSets/Main/07val128-6019.txt     60.19%
~/data/VOCdevkit/VOC2007/ImageSets/Main/07val256-6030.txt     60.30%

These can be used to evaluate training in progress quantatively.
Again we do not run through whole val set for periodical check out, too slow!
"""

# Utility functions
def save_img(img, img_path):
    save_img = img.squeeze().clamp(0, 1).numpy().transpose(1,2,0)
    cv2.imwrite(img_path, cv2.cvtColor(save_img*255, cv2.COLOR_BGR2RGB),  [cv2.IMWRITE_PNG_COMPRESSION, 0])

# Dataloader for SuperVision Subnetwork Evaluation
eval_set = DBPNLoader(root='./dataset')
eval_loader = DataLoader(dataset=eval_set, batch_size=1, num_workers=1, shuffle=False)


# Network and Dataloader for Detection subnetwork to End Task Eval
# Tried to reuse 2nd half of our SRD net for detection, not lucky!
ssd300net = build_ssd('test', 300, num_classes)   # initialize SSD
ssd300net.load_state_dict(torch.load('ssd/weights/ssd300_mAP_77.43_v2.pth'))
ssd300net.eval()

print('Finished loading SSD300 task evaluation network!')
eval_det = VOCDetection(voc_root, [('2007', set_type)],
                       BaseTransform(300, dataset_mean),
                       VOCAnnotationTransform(), sr_path='./dataset/VOC07-SR-x4')
if cuda:
     ssd300net = ssd300net.cuda()
     cudnn.benchmark = True


Finished loading SSD300 task evaluation network!


  self.priors = Variable(self.priorbox.forward(), volatile=True)
  init.constant(self.weight,self.gamma)


In [4]:
print('===> Loading SRD sub Net-S (DBPN) fine-tune training datasets (VOC12)')

sd_dataset = SRDetection(root='./dataset', image_sets='trainval.txt', data_mean = dataset_mean,
                         target_transform = VOCAnnotationTransform())

sd_data_loader = DataLoader(dataset=sd_dataset, batch_size=opt.batch_size, num_workers=opt.threads,
                            shuffle=False, collate_fn=detection_collate, pin_memory=True)


===> Loading SRD sub Net-S (DBPN) fine-tune training datasets (VOC12)


In [5]:
# TODO:
# 1. come up with test/eval functions
# 2. come up with visualization image

print('===> Architecture Definition of SRD network model: Net-S (DBPN) + Net-D (SSD300)')


class DBPN2SSD(nn.Module):
    
    def __init__(self, s_model_name, d_model_name, d_frozen):
        super(DBPN2SSD, self).__init__()
        self.supervis = DBPN(num_channels=3, base_filter=64, feat=256, num_stages=7, scale_factor=4)
        if os.path.exists(s_model_name):
            self.supervis = torch.nn.DataParallel(self.supervis, device_ids=gpus_list)
            self.supervis.load_state_dict(torch.load(s_model_name, map_location=lambda storage, loc: storage))

        # self.detector = SSD(), setup ssd as 'train' mode for gradient flow
        # later at test/eval situation, we will overwrite it's mode to 'test'
        self.detector = build_ssd('train', 300, num_classes)
        if os.path.exists(d_model_name):
            self.detector.load_state_dict(torch.load(d_model_name, map_location=lambda storage, loc: storage))
        if d_frozen:
            for param in self.detector.parameters():
                param.requires_grad = False
        
    def forward(self, x):
        superx = self.supervis(x)
        # current detector: SSD300, so assume superx: 300x300!
        detect = self.detector(superx)
        return (superx, detect)


#SDnet = DBPN2SSD('dbpn/models/DBPN_x4.pth', 'ssd/weights/ssd300_mAP_77.43_v2.pth', True)
#SDnet = DBPN2SSD('checkpoint/SRx4_4000.pth', 'ssd/weights/ssd300_mAP_77.43_v2.pth', True)
SRDnet = DBPN2SSD('dbpn/models/VOC12-LR-x4-DBPN-ep100.pth', 'ssd/weights/ssd300_mAP_77.43_v2.pth', True)


===> Architecture Definition of SRD network model: Net-S (DBPN) + Net-D (SSD300)


  torch.nn.init.kaiming_normal(m.weight)
  torch.nn.init.kaiming_normal(m.weight)
  self.priors = Variable(self.priorbox.forward(), volatile=True)
  init.constant(self.weight,self.gamma)


In [6]:
def mAP_eval():
    mAP = test_net('./eval', ssd300net, cuda, eval_det, BaseTransform(ssd300net.size, dataset_mean), \
                   5, 300, thresh=0.01, eval_set=set_type)
    return mAP
    
print(mAP_eval())


Evaluating detections
0.5011552918037575


In [7]:
# Main SRD training Loop

# Hyperparameters - balancing weights of net S loss vs. net D loss
# alpha = 1.0
# beta = 1.0

Epochs = 2
# decay_steps = [5000, 10000, 20000, 40000, 80000] in iterations
# decay_steps = [5, 10, 20, 40, 80]              # in epochs
# decay_steps = [4, 8]                           # in epochs
# experimental: half LR at 3rd epoch
decay_steps = [2]                                # in epochs
gamma = 0.1
def adjust_learning_rate(optimizer, gamma=0.1, step=0):
    """Sets the learning rate to the initial LR decayed by 10 at every
        specified step
    # Adapted from PyTorch Imagenet example:
    # https://github.com/pytorch/examples/blob/master/imagenet/main.py
    """
    gamma = 0.5
    lr = opt.lr * (gamma ** (step))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

# discover 'best' dbpn model according to an eval set
best_dbpn_model = None
best_val = -1
prev_file = None


#################################################################
#
# Parameters:
#   alpha  - Loss weight from SuperV Net-S
#   beta   - Loss weight from Detect Net-D
#   pickle - Pickle file name for all data records
#
#################################################################
def train(SDnet, alpha = 1.0, beta = 1.0, pfname='a100b100.pkl'):
    best_val = -1
    prev_file = None
    
    net = SDnet
    # current lr
    lrt = opt.lr
    
    # Dictionary to pickle all training/validation data
    trainval = {}
    
    # Following list are appended continously fr. Start2End
    # In unit sampling period of each validating invocation.
    # total running epochs
    epos = []
    # total running iterations vs. per epoch
    itrs = []
    # mean SROD training (compound) losses btw. validation
    mlos = []
    # mean APs,   sampled at each validation
    maps = []
    # mean PSNRs, sampled at each validation
    mpsn = []
    # mean SSIMs, sampled at each validation
    msim = []
    # current learning rate
    clr  = []
    

    # Criterions:
    #
    # net SR loss function, change to L2 later
    criterion_sr = nn.MSELoss()

    # ssd loss function
    criterion_ssd = MultiBoxLoss(voc['num_classes'], 0.5, True, 0, True, 3, 0.5, False, use_gpu=cuda)

    if cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        #net.supervis = torch.nn.DataParallel(net.supervis)
        cudnn.benchmark = True
        net = net.cuda(gpus_list[0])
        criterion_sr = criterion_sr.cuda(gpus_list[0])
        criterion_ssd = criterion_ssd.cuda(gpus_list[0])


    # optimizer = optim.SGD(net.supervis.parameters(), lr=opt.lr, \
    #                       momentum=opt.momentum, weight_decay=opt.weight_decay)
    #
    # Use universal Adam first :)
    # we can be Specific to what part of network's parameters to optimize, here is SuperR Net
    optimizer = optim.Adam(net.supervis.parameters(), lr=opt.lr, betas=(0.9, 0.999), eps=1e-8)
    
    # set model to training mode
    net.train()
    
    #epoch = 1
    #max_iter = 20000
    max_iter  =   600    # Approx. within 1 Epoch 5600/5770
    decay_step = 0
    for epoch in range(1, Epochs + 1):
        t_0 = time()
        epoch_loss = 0
        iters_loss = 0
        for iteration, batch in enumerate(sd_data_loader, 1):

            if iteration == max_iter + 1:
                break
            if epoch in decay_steps and iteration == 1:
                decay_step += 1
                lrt = adjust_learning_rate(optimizer, gamma, decay_step)

            # input is LR image; sr_target is pseudo (300x300) original HR image, det_target is bboxes
            input, sr_target, det_target = batch[0], batch[1], batch[2]
            if cuda:
                input = Variable(input.cuda(gpus_list[0]))
                sr_target = Variable(sr_target.cuda(gpus_list[0]))
                det_target = [Variable(ann.cuda(gpus_list[0]), volatile=True) for ann in det_target]
            else:
                input = Variable(input)
                sr_target = Variable(sr_target)
                det_target = [Variable(ann, volatile=True) for ann in det_target]

            optimizer.zero_grad()
            t0 = time()

            sr_out,ssd_out = net(input)
            loss_sr = criterion_sr(sr_out, sr_target)

            # Compound Loss from net-S: loss_sr and net-D: loss_l, loss_c
            loss_l, loss_c = criterion_ssd(ssd_out, det_target)
            loss = loss_sr * alpha + beta *(loss_l + loss_c)

            epoch_loss += loss.data[0]
            iters_loss += loss.data[0]
            loss.backward()
            optimizer.step()
            t1 = time()

            if iteration != 0 and iteration % 20 == 0:
                print("===> Epoch[{}]({}/{}): Loss: {:.4f} || Time: {:.4f} sec.".format(epoch, \
                            iteration, len(sd_data_loader), loss.data[0], (t1 - t0)))

            # End Task (detection) & Imaging Quality Evaluation vs. Model Loss steps below:
            if iteration != 0 and iteration % 200 == 0:
                mItersLoss = iters_loss / 200
                iters_loss = 0
                net.eval()
                t2 = time()
                i, ssim, psnr  = 0, 0, 0
                for iter, bat in enumerate(eval_loader, 1):
                    # loader output:  lr,     hr,   fname
                    lr, hr, fname = bat[0], bat[1], bat[2]
                    # lr, hr, fname = eval_set.pull_item(i)
                    if cuda:
                        lr_v = Variable(lr.cuda(gpus_list[0]))
                    else:
                        lr_v = Variable(lr)
                    sr, _ = net(lr_v)
                    srimg = sr.cpu().data
                    hrimg = hr.data

                    # Struct Similarity Index & Peak Signal-Noise Ratio
                    ssim += metric_ssim(hrimg.numpy(), srimg.numpy())
                    psnr += metric_psnr(hrimg.numpy(), srimg.numpy())
                    i += 1
                    #print("ssim, psnr:", ssim, psnr)

                    image=srimg
                    image = (image - image.min())/(image.max() - image.min())
                    imagef = os.path.join('./dataset/VOC07-SR-x4/', '%s.jpg')
                    save_img(image, imagef % fname)

                mSSIM = ssim / i
                mPSNR = psnr / i
                mAP_acc = mAP_eval()
                t3 = time()
                print("Mean Loss = {:.4f}, Mean AP = {:.4f}, Mean SSIM = {:.4f}, Mean PSNR = {:.4f} || Time: {:.4f} sec."\
                      .format(mItersLoss, mAP_acc, mSSIM, mPSNR, (t3 - t2)))

                # Updating lists to be pickled
                epos.append(epoch)
                itrs.append(iteration + max_iter*(epoch-1))
                mlos.append(float(mItersLoss))
                maps.append(mAP_acc)
                mpsn.append(mPSNR)
                msim.append(mSSIM)
                clr.append(lrt)

                mAP_int = int(mAP_acc * 10000)     #  2318 =>  23.18 %
                mPSNR_i = int(mPSNR * 100)         #  2318 =>  23.18 dB
                if math.isnan(mSSIM):
                    mSSIM_i = 'NaN'
                else:
                    mSSIM_i = int(mSSIM * 10000)   #  2318 => 0.2318
                if mAP_int > best_val:
                    best_val = mAP_int
                    if prev_file is not None:
                        os.remove(prev_file)
                    chkfile = opt.checkpt + '/' + 'SRx4_e' + repr(epoch) + '_i' + repr(iteration) + \
                              '_mAP' + repr(best_val) + '_mPSNR' + repr(mPSNR_i) + '_mSSIM' + repr(mSSIM_i) + '.pth'
                    torch.save(net.supervis.state_dict(), chkfile)
                    prev_file = chkfile

                # back to training mode
                net.train()

            if iteration != 0 and iteration % 2000 == 0:
                print('Saving state, iter:', iteration)
                torch.save(net.supervis.state_dict(), opt.checkpt + '/' + 'SRx4_' + repr(iteration+4000) + '.pth')
        t_1 = time()
        print("===> Epoch {} Complete: Avg. Loss: {:.4f} || Time: {:.4f} sec."\
              .format(epoch, epoch_loss / (max_iter * tbz), (t_1 - t_0)))

    # pickle all the data records
    trainval['epochs'] = epos
    trainval['iters'] = itrs
    trainval['mloss'] = mlos
    trainval['m_aps'] = maps
    trainval['mpsnr'] = mpsn
    trainval['mssim'] = msim
    trainval['lrate'] = clr
    with open(pfname, 'wb') as f:
        pickle.dump(trainval, f)
        



In [8]:
# SROD Compund Loss relative ratio, Net-S vs. Net-D
# Alpha is weight Loss from Net-S (e.g. DBPN) Loss
# Beta  is weight Loss from Net-D (e.g. SSD ) Loss
AlphaBeta = [(1.00, 0.01),   (1.00, 0.10),   (1.00, 1.00),   (0.10, 1.00),   (0.01, 1.00)]
pickle_fs = ['a100b001.pkl', 'a100b010.pkl', 'a100b100.pkl', 'a010b100.pkl', 'a001b100.pkl']

a, b = AlphaBeta[0]
pf = pickle_fs[0]

train(SRDnet, a, b, pf)




===> Epoch[1](20/5770): Loss: 2775.5461 || Time: 0.3392 sec.
===> Epoch[1](40/5770): Loss: 2012.5754 || Time: 0.3478 sec.
===> Epoch[1](60/5770): Loss: 1785.6118 || Time: 0.3460 sec.
===> Epoch[1](80/5770): Loss: 1393.4235 || Time: 0.3319 sec.
===> Epoch[1](100/5770): Loss: 1373.8630 || Time: 0.3375 sec.
===> Epoch[1](120/5770): Loss: 911.0864 || Time: 0.3413 sec.
===> Epoch[1](140/5770): Loss: 1561.3210 || Time: 0.3486 sec.
===> Epoch[1](160/5770): Loss: 1358.3650 || Time: 0.3312 sec.
===> Epoch[1](180/5770): Loss: 1492.9572 || Time: 0.3441 sec.
===> Epoch[1](200/5770): Loss: 564.3323 || Time: 0.3306 sec.


  return (np.prod(mcs[0:levels - 1]**weights[0:levels - 1]) *


Evaluating detections
Mean Loss = 3403.0847, Mean AP = -0.0053, Mean SSIM = nan, Mean PSNR = 19.0337 || Time: 11.0813 sec.
===> Epoch[1](220/5770): Loss: 751.6975 || Time: 0.3542 sec.
===> Epoch[1](240/5770): Loss: 1069.4471 || Time: 0.3336 sec.
===> Epoch[1](260/5770): Loss: 952.1415 || Time: 0.3513 sec.
===> Epoch[1](280/5770): Loss: 468.7316 || Time: 0.3470 sec.
===> Epoch[1](300/5770): Loss: 998.2669 || Time: 0.3516 sec.
===> Epoch[1](320/5770): Loss: 619.3965 || Time: 0.3423 sec.
===> Epoch[1](340/5770): Loss: 628.1650 || Time: 0.3326 sec.
===> Epoch[1](360/5770): Loss: 931.2533 || Time: 0.3288 sec.
===> Epoch[1](380/5770): Loss: 734.4723 || Time: 0.3527 sec.
===> Epoch[1](400/5770): Loss: 802.0309 || Time: 0.3287 sec.
Evaluating detections
Mean Loss = 841.2629, Mean AP = 0.3313, Mean SSIM = 0.8113, Mean PSNR = 20.4841 || Time: 10.4097 sec.
===> Epoch[1](420/5770): Loss: 412.8867 || Time: 0.3468 sec.
===> Epoch[1](440/5770): Loss: 1331.4495 || Time: 0.3357 sec.
===> Epoch[1](460/5

In [None]:
torch.cuda.empty_cache()


In [None]:
# unpickle data records for visual
with open(pf, 'rb') as f:
    trainval = pickle.load(f)

### Debugging code below, no need to run :)

In [None]:
print(' YOU Donot need to run this ===> Loading some datasets')

lr_path = os.path.join(opt.input_dir, opt.train_dataset)
hr_path = os.path.join(opt.input_dir, opt.hr_dataset)

fine_train_set = get_pair_set(lr_path, hr_path)
train_data_loader = DataLoader(dataset=fine_train_set, num_workers=opt.threads, \
                               batch_size=opt.testBatchSize, shuffle=False)


In [None]:
def simple_test():
    SDnet = DBPN2SSD('dbpn/models/DBPN_x4.pth', 'ssd/weights/ssd300_mAP_77.43_v2.pth', True)
    net = SDnet
    
    if cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True
        net = net.cuda()

    else:
        torch.set_default_tensor_type('torch.FloatTensor')


    
simple_test()

In [None]:
# dtype = torch.FloatTensor
dtype = torch.cuda.FloatTensor ## UNCOMMENT THIS LINE IF YOU'RE ON A GPU!

# Real samples were all associated with Ture labels - 1
T_labels = torch.ones(logits_real.size()).type(dtype)
