In [28]:
# ============= imports =============
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import torchvision.transforms as transforms
from Models import Unet_encoder, Unet_decoder, Classifier
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import copy
from test import get_metrics

In [4]:
# ============= torch cuda =============
if torch.cuda.is_available():
    device = 'cuda'
    torch.cuda.empty_cache()
else:
    device = 'cpu'

In [16]:
def motion_blur_horizontal(img_path,kernel_size):
    img = cv2.imread(img_path)

    kernel_h = np.zeros((kernel_size, kernel_size))
    kernel_h[int((kernel_size - 1)/2), :] = np.ones(kernel_size)
    kernel_h /= kernel_size

    horizontal_mb = cv2.filter2D(img, -1, kernel_h)
    horizontal_mb = cv2.cvtColor(horizontal_mb, cv2.COLOR_RGB2BGR)
    return horizontal_mb
    
def motion_blur_vertical(img_path,kernel_size):
    img = cv2.imread(img_path)
    kernel_v = np.zeros((kernel_size, kernel_size))
    kernel_v[:, int((kernel_size - 1)/2)] = np.ones(kernel_size)
    kernel_v /= kernel_size
    vertical_mb = cv2.filter2D(img, -1, kernel_v)
    vertical_mb = cv2.cvtColor(vertical_mb, cv2.COLOR_RGB2BGR)
    return vertical_mb
    
def gaussian_blur(img_path,k_size):
    img = cv2.imread(img_path)
    img = cv2.GaussianBlur(img,(k_size,k_size),0)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    return img
    
def avg_blur(img_path,k_size):
    img = cv2.imread(img_path)
    img = cv2.blur(img, (k_size,k_size))
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    return img
        
def getBlurredOutput(img_path):
    p = torch.rand(1)
    k_size = 15
        
    if p<=0.25:
        return motion_blur_vertical(img_path,k_size)
    if p<=0.5:
        return motion_blur_horizontal(img_path,k_size)
    if p<=0.75:
        return avg_blur(img_path,k_size)
    else:
        return gaussian_blur(img_path,k_size)

In [19]:
def image_loader(imgPath):
    '''
        Function to return image tensor
        :params imgPath: path to input image
        :type imgPath: str
        :return: dictionary with groundTruth image and blurred image
    '''
    image = Image.open(imgPath)
    blurredImage = getBlurredOutput(imgPath)
    
    imageTransform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((256, 256))
    ])

    image = imageTransform(image)
    blurredImage = imageTransform(blurredImage)
    return {'image': image.to(device), 'inputImg': blurredImage}

In [None]:
imageDict = image_loader('PATH')

In [22]:
encoder = Unet_encoder(in_channels=3).to(device)
decoder = Unet_decoder().to(device)
classifier = Classifier().to(device)

encoder_new = copy.deepcopy(encoder).to(device)
decoder_new = decoder
classifier_new = copy.deepcopy(classifier).to(device)

In [23]:
# =========== Load state dict for both encoder, and encoder new ===========

In [29]:
encoder.eval()
decoder.eval()
classifier.eval()

with torch.no_grad():
    x, skip_connections = encoder(imageDict['inputImg'])
    pred_img = decoder(x, skip_connections)
    psnr, ssim, uqi = get_metrics(imageDict['image'], pred_img)
    print('Without Test Time Training')
    print('psnr: {:.4f}, ssim: {:.4f}, uqi: {:.4f}'.format(psnr, ssim, uqi))
    
with torch.no_grad():
    x, skip_connections = encoder_new(imageDict['inputImg'])
    pred_img = decoder_new(x, skip_connections)
    psnr, ssim, uqi = get_metrics(imageDict['image'], pred_img)
    print('With Test Time Training')
    print('psnr: {:.4f}, ssim: {:.4f}, uqi: {:.4f}'.format(psnr, ssim, uqi))