In [None]:
"""
*Uncomment if running on colab* 
Set Runtime -> Change runtime type -> Under Hardware Accelerator select GPU in Google Colab 
"""
# !git clone https://github.com/ubc-vision/juho-usra.git
# !mv juho-usra/* ./

In [None]:
# pip install pytorch_model_summary

In [None]:
# pip install einops

In [None]:
# pip install lpips

# Import lib

In [None]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from einops import rearrange
import os
from skimage.metrics import peak_signal_noise_ratio
from pytorch_model_summary import summary
# use tensorboard with pytorch
from torch.utils.tensorboard import SummaryWriter
from models import *
# currently can't install this package 
# from siren_pytorch import SirenNet
from models.siren_pytorch import SirenNet
# import lpips

dtype = None
if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
else:
    dtype = torch.FloatTensor

# Load image

In [None]:
# Downsampler factor
factor = 4 

# set up original, HR, LR images for deep image prior 
path_to_image = 'data/sr/zebra_GT.png'
img_orig_pil = Image.open(path_to_image)
img_orig_np = np.array(img_orig_pil)
img_orig_np = img_orig_np.transpose(2,0,1)
img_orig_np = img_orig_np.astype(np.float32) / 255.
# HR 
# we usually need the dimensions to be divisible by a power of two (32 in this case)
new_size = (img_orig_pil.size[0] - img_orig_pil.size[0] % 32, img_orig_pil.size[1] - img_orig_pil.size[1] % 32)
bbox = [(img_orig_pil.size[0] - new_size[0])/2, 
        (img_orig_pil.size[1] - new_size[1])/2, 
        (img_orig_pil.size[0] + new_size[0])/2,
        (img_orig_pil.size[1] + new_size[1])/2,]
img_HR_pil = img_orig_pil.crop(bbox)
img_HR_np = np.array(img_HR_pil)
img_HR_np = img_HR_np.transpose(2,0,1)
img_HR_np = img_HR_np.astype(np.float32) / 255.
# LR
LR_size = [img_HR_pil.size[0] // factor, img_HR_pil.size[1] // factor]
img_LR_pil = img_HR_pil.resize(LR_size, Image.ANTIALIAS)
img_LR_np = np.array(img_LR_pil)
img_LR_np = img_LR_np.transpose(2,0,1)
img_LR_np = img_LR_np.astype(np.float32) / 255.

net_input_width, net_input_height = img_HR_pil.size

# show both HR and LR image
plt.figure()
plt.imshow(img_HR_pil)
plt.figure()
plt.imshow(img_LR_pil)

# Setup

In [None]:
# Setup input meshgrid 
tensors = [torch.linspace(-1, 1, steps = net_input_height), torch.linspace(-1, 1, steps = net_input_width)]
net_input = torch.stack(torch.meshgrid(*tensors), dim=-1).type(dtype)
net_input = rearrange(net_input, 'h w c -> () c h w', h = net_input_height, w = net_input_width)
Deep_Image_Prior_net_input = net_input.clone().detach().requires_grad_()
SIREN_net_input = net_input.clone().detach().requires_grad_()
input_depth = 2

# Setup Deep Image Prior 
pad = 'reflection'

deepImagePriorNet = DeepImagePriorNet (
            input_depth, 3, 
            channels_down = [128, 128, 128, 128, 128],
            channels_up = [128, 128, 128, 128, 128],
            channels_skip = [4, 4, 4, 4, 4],
            kernel_size_down = [3, 3, 3, 3, 3],
            kernel_size_up = [3, 3, 3, 3, 3],
            upsample_mode = 'bilinear',
            need_sigmoid=True, need_bias=True, pad=pad)


# Setup SIREN
sirenNet = SirenNet(
    dim_in = input_depth,              # input dimension, ex. 2d coor
    dim_hidden = 256,                  # hidden dimension
    dim_out = 3,                       # output dimension, ex. rgb value
    num_layers = 5,                    # number of layers
    w0_initial = 30.)                  # different signals may require different omega_0 in the first layer - this is a hyperparameter


# Deep Image Prior train

In [None]:
# Train for Deep Image Prior
def deepImagePriorTrain(Deep_Image_Prior_net_input):

    LR = 0.01
    num_iter = 25000
    reg_noise_std = 0.03
    Deep_Image_Prior_net_input_saved = Deep_Image_Prior_net_input.detach().clone()
    noise = Deep_Image_Prior_net_input.detach().clone()
    img_LR = torch.from_numpy(img_LR_np)[None, :].type(dtype)

    # Create optimizier
    parameters = deepImagePriorNet.parameters()
    optimizer = torch.optim.Adam(parameters, lr=LR)

    # Loss
    loss_type = None #"LPIPS"
    loss = None
    if loss_type is "LPIPS":
        # Learned Perceptual Image Patch Similarity (LPIPS) metric Loss   
        loss = lpips.LPIPS(net='alex')
    else:
        loss = nn.MSELoss().type(dtype) 

    # tensorboard log directory 
    log_dir = None
    if loss_type is "LPIPS":
        log_dir = "./logs/experiment/Deep_Image_Prior/super_resolution/LPIPS"
    else:
        log_dir = "./logs/experiment/Deep_Image_Prior/super_resolution/Adam" 

    # Create summary writer
    writer = SummaryWriter(log_dir)

    # Create log directory and save directory if it does not exist
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # Training loop
    for i in range(num_iter):

        if reg_noise_std > 0:
            Deep_Image_Prior_net_input = Deep_Image_Prior_net_input_saved + (noise.normal_() * reg_noise_std)

        # Apply the model to obtain scores (forward pass)
        out_HR = deepImagePriorNet.forward(Deep_Image_Prior_net_input)
        out_LR = nn.functional.interpolate(out_HR, scale_factor=1/factor, mode="bilinear", antialias=True)

        # Compute the loss 
        total_loss = loss(out_LR, img_LR)

        # Compute gradients    
        total_loss.backward()

        # Update parameters
        optimizer.step()
        
        # Zero the parameter gradients in the optimizer
        optimizer.zero_grad()

        # Save the results
        if i % 100 == 0:
            # Write output image to tensorboard, using keywords `image_output`
            writer.add_image("image_output", out_HR, global_step=i, dataformats='NCHW')
            if loss_type is "LPIPS":
                # Write loss to tensorboard, using keywords `loss`
                writer.add_scalar("LPIPS_loss", total_loss, global_step=i)
            else:
                # Write loss to tensorboard, using keywords `loss`
                writer.add_scalar("loss", total_loss, global_step=i)
            # Write PSNR of LR image 
            psnr_LR = peak_signal_noise_ratio(img_LR_np, out_LR.detach().cpu().numpy()[0])
            writer.add_scalar("LR_PSNR", psnr_LR, global_step=i)
            # Write PSNR of HR image
            psnr_HR = peak_signal_noise_ratio(img_HR_np, out_HR.detach().cpu().numpy()[0])
            writer.add_scalar("HR_PSNR", psnr_HR, global_step=i)
            


deepImagePriorTrain(Deep_Image_Prior_net_input)

# SIREN train

In [None]:
# Train for SIREN
def sirenTrain(SIREN_net_input):

    LR = 0.01
    num_iter = 25000
    img_LR = torch.from_numpy(img_LR_np)[None, :].type(dtype)

    # Create optimizier
    parameters = deepImagePriorNet.parameters()
    optimizer = torch.optim.Adam(parameters, lr=LR)

    # Loss
    loss_type = None #"LPIPS"
    loss = None
    if loss_type is "LPIPS":
        # Learned Perceptual Image Patch Similarity (LPIPS) metric Loss   
        loss = lpips.LPIPS(net='alex')
        if torch.cuda.is_available():
            loss.cuda()
    else:
        loss = nn.MSELoss().type(dtype)   

    # tensorboard log directory 
    log_dir = None
    if loss_type is "LPIPS":
        log_dir = "./logs/experiment/Siren/super_resolution/LPIPS"
    else:
        log_dir = "./logs/experiment/Siren/super_resolution/Adam" 

    # Create summary writer
    writer = SummaryWriter(log_dir)

    # Create log directory and save directory if it does not exist
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # Training loop
    for i in range(num_iter):

        # Apply the model to obtain scores (forward pass)
        out_HR = deepImagePriorNet.forward(SIREN_net_input)
        out_LR = nn.functional.interpolate(out_HR, scale_factor=1/factor, mode="bilinear", antialias=True)

        # Compute the loss 
        total_loss = loss(out_LR, img_LR)

        # Compute gradients    
        total_loss.backward()

        # Update parameters
        optimizer.step()
        
        # Zero the parameter gradients in the optimizer
        optimizer.zero_grad()

        # Save the results
        if i % 100 == 0:
            # Write output image to tensorboard, using keywords `image_output`
            writer.add_image("image_output", out_HR, global_step=i, dataformats='NCHW')
            if loss_type is "LPIPS":
                # Write loss to tensorboard, using keywords `loss`
                writer.add_scalar("LPIPS_loss", total_loss, global_step=i)
            else:
                # Write loss to tensorboard, using keywords `loss`
                writer.add_scalar("loss", total_loss, global_step=i)
            # Write PSNR of LR image 
            psnr_LR = peak_signal_noise_ratio(img_LR_np, out_LR.detach().cpu().numpy()[0])
            writer.add_scalar("LR_PSNR", psnr_LR, global_step=i)
            # Write PSNR of HR image
            psnr_HR = peak_signal_noise_ratio(img_HR_np, out_HR.detach().cpu().numpy()[0])
            writer.add_scalar("HR_PSNR", psnr_HR, global_step=i)
           


sirenTrain(SIREN_net_input)