In [None]:
import imageio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import torch as pt
from torch.autograd import Variable
from torch import optim
from fsim import FSIM, FSIMc
from PIL import Image, ImageDraw, ImageFont
import os
%matplotlib inline

In [None]:
def read_convert_pt_image(image_path):
    '''
    Function to read an image from the file specified in image_path and 
    convert it to pytorch tensor
    '''
    image = Image.open(image_path).convert('RGB')

    image = pt.from_numpy(np.asarray(image))
    
    image = image.permute(2,0,1)
    return image

def save_image_score(img_torch, img_name, score, bw = False ):
    '''
    Function to save the image to the output folder with FSIM score imprinted
    '''
    img_mask = img_torch.squeeze(0).permute(1,2,0).data.numpy()
    
    img_mask[img_mask>255.] = 255.
    img_mask[img_mask<0] = 0.

    img = Image.fromarray(np.uint8(img_mask))

    d = ImageDraw.Draw(img)
    font =  ImageFont.truetype("./misc/arial.ttf", 40)
    d.text((10,10), 'FSIM='+str(round(score*1000)/1000), font=font, fill=(255,255,255))

    img.save(img_name+'.png')

In [None]:
# The metric expects images to be in the range from 0 to 255.

# Path to reference image
img1_path ='./misc/mandril_color.tif'
# Is it black and white?
bw = False
# Size of the batch for training
batch_size = 1
# Do we regenrate the image from noise (True), or clean up the noise from the image (False)
noise = True
# Save image
save_image = False

if save_image and not (os.path.isdir('output')):
    os.mkdir('output')

# Read reference and distorted images
img1 = read_convert_pt_image(img1_path)
img1 = img1.unsqueeze(0).type(pt.FloatTensor)
if noise:
    img2 = pt.clamp(pt.rand(img1.size())*255.0,0,255.0)
else:
    img2 = pt.clamp(img1+200*pt.rand(img1.size()),0,255.0)
    

# Create fake batch (for testing)
img1b = pt.cat(batch_size*[img1],0)
img2b = pt.cat(batch_size*[img2],0)
# Convert images to variables to support gradients
img1b = Variable( img1b, requires_grad = False)
img2b = Variable( img2b, requires_grad = True)

if pt.cuda.is_available():
    img1b = img1b.cuda()
    img2b = img2b.cuda()

# Create FSIM loss
FSIM_loss = FSIMc()

# Tie optimizer to the distorted batch
optimizer = optim.Adam([img2b], lr=0.1)

# Check if the gradient propagates
for ii in range(0,1000):
    optimizer.zero_grad()

    loss = -FSIM_loss(img1b,img2b)    
    print(loss)
    loss = pt.sum(loss)
    loss.backward()
    optimizer.step()
    
    if ii%20 ==0 and save_image:
        save_image_score(img2b,'./output/optimized_image_'+str(ii),loss.item()*-1.0)

