In [None]:
"""
*Uncomment if running on colab* 
Set Runtime -> Change runtime type -> Under Hardware Accelerator select GPU in Google Colab 
"""
# !git clone https://github.com/ubc-vision/juho-usra.git
# !mv juho-usra/* ./

# Import lib

In [None]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from einops import rearrange
import os
from skimage.metrics import peak_signal_noise_ratio
from pytorch_model_summary import summary
# use tensorboard with pytorch
from torch.utils.tensorboard import SummaryWriter
from models import *
# currently can't install this package 
# from siren_pytorch import SirenNet
from models.siren_pytorch import SirenNet

dtype = None
if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
else:
    dtype = torch.FloatTensor

# Load image

In [None]:
# Downsampler factor
factor = 4 

path_to_image = 'data/sr/zebra_GT.png'
image = Image.open(path_to_image)
image_width, image_height = image.size
# HR 
img_HR_pil = image.copy()
img_HR_np = np.array(img_HR_pil)
img_HR_np = img_HR_np.transpose(2,0,1)
img_HR_np = img_HR_np.astype(np.float32) / 255.
# LR
LR_size = [img_HR_pil.size[0] // factor, img_HR_pil.size[1] // factor]
img_LR_pil = img_HR_pil.resize(LR_size, Image.ANTIALIAS)
img_LR_np = np.array(img_LR_pil)
img_LR_np = img_LR_np.transpose(2,0,1)
img_LR_np = img_LR_np.astype(np.float32) / 255.

# show both HR and LR image
plt.figure()
plt.imshow(img_HR_pil)
plt.figure()
plt.imshow(img_LR_pil)

# Setup

In [None]:
# Setup input meshgrid 
tensors = [torch.linspace(-1, 1, steps = image_height), torch.linspace(-1, 1, steps = image_width)]
net_input = torch.stack(torch.meshgrid(*tensors), dim=-1)
net_input = rearrange(net_input, 'h w c -> () c h w', h = image_height, w = image_width)
Deep_Image_Prior_net_input =net_input.clone().detach().requires_grad_()
SIREN_net_input = net_input.clone().detach().requires_grad_()
input_depth = 2

# Setup Deep Image Prior 
pad = 'reflection'

deepImagePriorNet = DeepImagePriorNet (
            input_depth, 3, 
            channels_down = [128, 128, 128, 128, 128],
            channels_up = [128, 128, 128, 128, 128],
            channels_skip = [4, 4, 4, 4, 4],
            kernel_size_down = [3, 3, 3, 3, 3],
            kernel_size_up = [3, 3, 3, 3, 3],
            upsample_mode = 'bilinear',
            need_sigmoid=True, need_bias=True, pad=pad)


# Setup SIREN
sirenNet = SirenNet(
    dim_in = input_depth,              # input dimension, ex. 2d coor
    dim_hidden = 256,                  # hidden dimension
    dim_out = 3,                       # output dimension, ex. rgb value
    num_layers = 5,                    # number of layers
    w0_initial = 30.)                  # different signals may require different omega_0 in the first layer - this is a hyperparameter


# Deep Image Prior train

In [None]:
# Train for Deep Image Prior
def deepImagePriorTrain(Deep_Image_Prior_net_input):

    LR = 0.01
    tv_weight = 0.0
    num_iter = 2000
    reg_noise_std = 0.03
    Deep_Image_Prior_net_input_saved = Deep_Image_Prior_net_input.detach().clone()
    noise = Deep_Image_Prior_net_input.detach().clone()
    img_LR_var = torch.from_numpy(img_LR_np)[None, :].type(dtype)
    psnr_history = []

    # Create optimizier
    parameters = deepImagePriorNet.parameters()
    optimizer = torch.optim.Adam(parameters, lr=LR)

    # Loss
    loss = nn.MSELoss().type(dtype)    

    # tensorboard log directory 
    log_dir = "./logs/experiment/Deep_Image_Prior/super_resolution"

    # Create summary writer
    writer = SummaryWriter(log_dir)

    # Create log directory and save directory if it does not exist
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # Training loop
    for i in range(num_iter):

        if reg_noise_std > 0:
            Deep_Image_Prior_net_input = Deep_Image_Prior_net_input_saved + (noise.normal_() * reg_noise_std)

        # Apply the model to obtain scores (forward pass)
        out_HR = deepImagePriorNet.forward(Deep_Image_Prior_net_input)
        out_LR = nn.functional.interpolate(Deep_Image_Prior_net_input, antialias=True)

        # Compute the loss 
        total_loss = loss(out_LR, img_LR_var)
        if tv_weight > 0:
            # Calculates TV loss
            dh = torch.pow(out_HR[:,:,:,1:] - out_HR[:,:,:,:-1], 2)
            dw = torch.pow(out_HR[:,:,1:,:] - out_HR[:,:,:-1,:], 2)
            tv_loss = torch.sum(torch.pow(dh[:, :, :-1] + dw[:, :, :, :-1], 0.5))
            total_loss += tv_weight * tv_loss
        # Compute gradients    
        total_loss.backward()
        # Update parameters
        optimizer.step()
        # Zero the parameter gradients in the optimizer
        optimizer.zero_grad()

        # Log
        psnr_LR = peak_signal_noise_ratio(img_LR_np, out_LR.detach().cpu().numpy()[0])
        psnr_HR = peak_signal_noise_ratio(img_HR_np, out_HR.detach().cpu().numpy()[0])
        print ('Iteration %05d    PSNR_LR %.3f   PSNR_HR %.3f' % (i, psnr_LR, psnr_HR), '\r', end='')
                            
        # History
        psnr_history.append([psnr_LR, psnr_HR])

        # Image plot and monitor results
        if i % 100 == 0:
            # Write output image to tensorboard, using keywords `image_output`
            writer.add_image("image_output", out_HR, global_step=i, dataformats='NCHW')
            # Write loss to tensorboard, using keywords `loss`
            writer.add_scalar("loss", total_loss, global_step=i)


deepImagePriorTrain(Deep_Image_Prior_net_input)

# SIREN train

# Result