In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import time

import torch
from torch.autograd import Variable

from net.hourglass_dmitry import hourglassNetwork
from net.trumpet import trumpetNetwork
from net.downsampler import Downsampler
from utils.common import create_random_input
from utils.recorder import experimentalRecord
from utils.visual import *

In [None]:
is_recorded = False
is_plot = True

num_steps = 1501
save_frequency = 250

sigma = 1./30
learningRate = lambda step: 1e-1 if (step<=500) else 1e-2
#learningRate = lambda step: 1e-3

factor = 8
net_file = 'net/architecture/dmitry5_skip543.json'
HR_image_path = 'images/superresolution/butterfly_256rgb.png'
LR_image_path = 'images/superresolution/butterfly_32rgb.png'

In [None]:
if is_recorded:
    expName = 'butterfly8X_dmitry5_skip543'
    expDiscribe = 'just as name'
    record = experimentalRecord(basePath='../DIP_result/superresolution',
                                expName=expName,
                                author='HanXuan',
                                describe=expDiscribe,
                                fileName='Experimental Record.txt')
    record.open()

In [None]:
# import image
deconstructed = jpg_to_tensor(LR_image_path)
deconstructed = Variable(deconstructed)

if is_plot:
    image_np = deconstructed.cpu().data.numpy()
    plot_image(image_np,(8,8))
if is_recorded:
    record.add_image(image=image_np,
                    imageName='origin.png',
                    message='Original deconstructed image.',
                    mode='NP')

In [None]:
# noise = Variable(torch.randn([1,32,512,512]).cuda())
noise =create_random_input(size=[1,32,256,256],xigma=0.5)

net =  hourglassNetwork(net_file,ch_in=32)
downsampler = Downsampler(n_planes=3, factor=factor, kernel_type='lanczos2', phase=0.5, preserve_size=True)

net.cuda()
downsampler.cuda()

start_time = time.time()
counter = 0
optimizer = torch.optim.Adam(net.parameters(), lr=learningRate(counter))
#dummy index to provide names to output files
save_img_ind = 0
for step in range(num_steps):
    # get the network output
    output = net(noise)
    LR_output = downsampler(output)
    optimizer.zero_grad()
    loss = torch.sum((LR_output - deconstructed)**2)
    loss.backward()
    optimizer.step()
    counter += 1
    #every save_frequency steps, save a jpg
    if step % save_frequency == 0:
        time_cost = time.time() - start_time
        loss_value = loss.cpu().data.numpy()[-1]
        # Cost time 236.978221s  At step 02000  loss is 28938.4121094
        print_message='Cost time %fs'%(time_cost) + '  At step %05d  loss is %f' %(step, loss_value)
        if is_plot:
            image_np = output.cpu().data.numpy()
            plot_image(image_np,(8,8))
        if is_recorded:
            record.add_image(image=image_np,
                             imageName=str(save_img_ind)+'.png',
                             message=print_message,
                             mode='NP')        
        print(print_message)
        save_img_ind += 1
        noise.data += sigma * torch.randn([1,32,256,256]).cuda()
        
if is_recorded:
    record.close()