DMSHN model definition

In [44]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        #Conv1
        self.layer1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1)
            )
        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1)
            )
        #Conv2
        self.layer5 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.layer6 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1)
            )
        self.layer7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1)
            )
        #Conv3
        self.layer9 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.layer10 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1)
            )
        self.layer11 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1)
            )
        
    def forward(self, x):
        #Conv1
        x = self.layer1(x)
        x = self.layer2(x) + x
        x = self.layer3(x) + x
        #Conv2
        x = self.layer5(x)
        x = self.layer6(x) + x
        x = self.layer7(x) + x
        #Conv3
        x = self.layer9(x)    
        x = self.layer10(x) + x
        x = self.layer11(x) + x 
        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()        
        # Deconv3
        self.layer13 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1)
            )
        self.layer14 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1)
            )
        self.layer16 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        #Deconv2
        self.layer17 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1)
            )
        self.layer18 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1)
            )
        self.layer20 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1)
        #Deconv1
        self.layer21 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1)
            )
        self.layer22 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1)
            )
        self.layer24 = nn.Conv2d(32, 3, kernel_size=3, padding=1)
        
    def forward(self,x):        
        #Deconv3
        x = self.layer13(x) + x
        x = self.layer14(x) + x
        x = self.layer16(x)                
        #Deconv2
        x = self.layer17(x) + x
        x = self.layer18(x) + x
        x = self.layer20(x)
        #Deconv1
        x = self.layer21(x) + x
        x = self.layer22(x) + x
        x = self.layer24(x)
        return x

class DMSHN(nn.Module):
    def __init__(self):
        super(DMSHN, self).__init__()
        self.encoder_lv1 = Encoder()
        self.encoder_lv2 = Encoder()
        self.encoder_lv3 = Encoder()

        self.decoder_lv1 = Decoder()
        self.decoder_lv2 = Decoder()
        self.decoder_lv3 = Decoder()

    def forward(self,images_lv1):
        H = images_lv1.size(2)
        W = images_lv1.size(3)

        images_lv2 = F.interpolate(images_lv1, scale_factor = 0.5, mode = 'bilinear')
        images_lv3 = F.interpolate(images_lv2, scale_factor = 0.5, mode = 'bilinear')

        feature_lv3 = self.encoder_lv3(images_lv3)
        residual_lv3 = self.decoder_lv3(feature_lv3)
        out_lv3 = images_lv3 + residual_lv3 

        residual_lv3 = F.interpolate(residual_lv3, scale_factor=2, mode= 'bilinear')
        feature_lv3 = F.interpolate(feature_lv3, scale_factor=2, mode= 'bilinear')
        feature_lv2 = self.encoder_lv2(images_lv2 + residual_lv3)
        residual_lv2 = self.decoder_lv2(feature_lv2 + feature_lv3)
        out_lv2 = images_lv2 + residual_lv2

        residual_lv2 = F.interpolate(residual_lv2, scale_factor=2, mode= 'bilinear')
        feature_lv2 = F.interpolate(feature_lv2, scale_factor=2, mode= 'bilinear')
        feature_lv1 = self.encoder_lv1(images_lv1 + residual_lv2)
        bokeh_image = self.decoder_lv1(feature_lv1 + feature_lv2)


        return bokeh_image


Stacked DMSHN model definition

In [45]:
class stacked_DMSHN(nn.Module):
    def __init__(self):
        super(stacked_DMSHN,self).__init__()
        self.net1 = DMSHN()
        self.net2 = DMSHN()

    def forward(self,x):
        out1 = self.net1(x)
        out2 = self.net2(out1)

        return out2

Run Stacked DMSHN on sample images

In [59]:
from __future__ import absolute_import, division, print_function
import cv2

import os
import sys
import glob
import argparse
import numpy as np
import PIL.Image as pil
import matplotlib as mpl
import matplotlib.cm as cm

import torch
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchvision.utils import save_image

from tqdm import tqdm

import math
import numbers
import sys



device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")


feed_width = 1536
feed_height = 1024


bokehnet = stacked_DMSHN().to(device)
bokehnet = nn.DataParallel(bokehnet)

bokehnet.load_state_dict(torch.load('checkpoints/SDMSHN/sdmshn.pth',map_location=device))


os.makedirs('sample_outputs',exist_ok= True)

src_dir  = 'sample_inputs/'
listfiles = os.listdir(src_dir)

with torch.no_grad():
    for file in listfiles : 

        image_path = src_dir + file

        # Load image and preprocess
        input_image = pil.open(image_path).convert('RGB')
        original_width, original_height = input_image.size

        input_image = input_image.resize((feed_width, feed_height), pil.LANCZOS)
        input_image = transforms.ToTensor()(input_image).unsqueeze(0)

        # PREDICTION
        input_image = input_image.to(device)

        bok_pred = bokehnet(input_image)

        bok_pred = F.interpolate(bok_pred,(original_height,original_width),mode = 'bilinear')
        
        
        save_image(bok_pred,'./sample_outputs/'+ file )

print ("Done!")

  input_image = input_image.resize((feed_width, feed_height), pil.LANCZOS)


Done!


In [57]:
import cv2
import argparse
import os

from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim

from tqdm import tqdm
import sys

# sys.path.insert(1,'PerceptualSimilarity')
# import models
from util import util

## Initializing the model
# model = models.PerceptualLoss(model='net-lin',net='alex',use_gpu=opt.use_gpu)

# crawl directories
files = os.listdir('sample_inputs')

# total_dist = 0
total_psnr = 0
total_ssim = 0

for file in tqdm(files):
	file1 = file[:4] + '.jpg'
	if(os.path.exists(os.path.join("sample_inputs",file1))):
		# Load images
		img0 = util.im2tensor(util.load_image(os.path.join("sample_inputs",file))) # RGB image from [-1,1]
		img1 = util.im2tensor(util.load_image(os.path.join("sample_outputs",file1)))

		# if(opt.use_gpu):
		# 	img0 = img0.cuda()
		# 	img1 = img1.cuda()

		# Compute distance
		# dist01 = model.forward(img0,img1)
		# total_dist += dist01.item()

		I0 = cv2.imread(os.path.join("sample_inputs",file))
		I1 = cv2.imread(os.path.join("sample_outputs",file1))
		total_psnr += compare_psnr(I0,I1)
		total_ssim += compare_ssim(I0,I1,multichannel=True)

  total_ssim += compare_ssim(I0,I1,multichannel=True)
100%|██████████| 4/4 [00:03<00:00,  1.24it/s]


In [58]:
# print ('Avg LPIPS: ', total_dist/len(files))
print ('Avg PSNR: ', total_psnr/len(files))
print ('Avg SSIM: ', total_ssim/len(files))


Avg PSNR:  24.75821945388755
Avg SSIM:  0.7442593146718459
