In [None]:
import pickle
import os

FOLDER = os.getcwd() + "/7cells_blurry"
ASSUMED_NOISE_LEVEL = 1e7

with open(FOLDER+'/ori.pkl', 'rb') as file:   
    ori = pickle.load(file)
with open(FOLDER+'/mask.pkl', 'rb') as file:   
    mask = pickle.load(file)
with open(FOLDER+'/out.pkl', 'rb') as file:   
    out = pickle.load(file)
with open(FOLDER+'/psf.pkl', 'rb') as file:   
    psf = pickle.load(file)


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
plt.rcParams['figure.figsize'] = [15, 5]
import numpy as np
from scipy import fftpack
from scipy.ndimage import gaussian_filter
from utils import GetSourcePts, InverseMatrix, ApproxPSFBesselOptimise, GetAvgIntensityWithMask
import cv2

In [None]:
ori = ori/max(ori.flatten())
mask = mask/max(mask.flatten())
out = out/max(out.flatten())

f, axarr = plt.subplots(1,3)
axarr[0].imshow(ori)
axarr[1].imshow(mask)
axarr[2].imshow(out)
plt.show()

In [None]:
#Apply gaussian filter to out

outBlur = gaussian_filter(out, sigma=1)
plt.imshow(outBlur)

#Resize mask to match out
maskRescale = cv2.resize(mask, outBlur.shape[::-1], interpolation = cv2.INTER_NEAREST)

#IMPT: All mask values are set to one
maskUnit = maskRescale.copy()

for i in range(len(maskUnit)):
    for j in range(len(maskUnit[0])):
        if (maskUnit[i][j] != 0):
            maskUnit[i][j] = 1

plt.imshow(maskUnit)

In [None]:
# Compute the 2d FFT of the input image
oriFreq = fftpack.fft2(ori)
maskFreq = fftpack.fft2(maskUnit)
outFreq = fftpack.fft2(outBlur)

# Shift the zero-frequency component to the center of the spectrum.
oriFreq = fftpack.fftshift(oriFreq)
maskFreq = fftpack.fftshift(maskFreq)
outFreq = fftpack.fftshift(outFreq)

f, axarr = plt.subplots(1,3)
axarr[0].imshow(np.abs(oriFreq))
axarr[1].imshow(np.abs(maskFreq))
axarr[2].imshow(np.abs(outFreq))
plt.show()

In [None]:
#Transform mask.shape into odd values

psfShape = [0,0]
psfShape[0] = mask.shape[0] + 1 if mask.shape[0] % 2 == 0 else mask.shape[0]
psfShape[1] = mask.shape[1] + 1 if mask.shape[1] % 2 == 0 else mask.shape[1]


In [None]:
#Only run validation step if output shape is the same as input shape

if (out.shape == ori.shape):
    psfFreqIdeal = outFreq * np.conj(oriFreq) / (np.abs(oriFreq)**2 + ASSUMED_NOISE_LEVEL/np.max(np.abs(oriFreq)**2))
    psfIdeal = fftpack.ifft2(psfFreqIdeal,shape=psfShape)
    psfIdeal = fftpack.fftshift(psfIdeal)

    f, axarr = plt.subplots(1,2)
    axarr[0].imshow(np.abs(psfFreqIdeal))
    axarr[1].imshow(np.abs(psfIdeal))
    plt.show()
    

    (maskWidth, maskHeight) = mask.shape
    maskSize = maskWidth*maskHeight

    #Normalise x axis of original psf to 1
    original_psf_y = psf[len(psf)//2]
    original_psf_y = original_psf_y[len(original_psf_y)//2:]
    original_psf_x = np.linspace(0,1,len(original_psf_y))*len(original_psf_y)/maskSize
    plt.plot(original_psf_x,original_psf_y)

    new_psf_y = np.abs(psfIdeal)
    new_psf_y = new_psf_y/max(new_psf_y.flatten())
    new_psf_y = new_psf_y[len(new_psf_y)//2]
    new_psf_y = new_psf_y[len(new_psf_y)//2:]
    new_psf_x = np.linspace(0,1,len(new_psf_y))*len(new_psf_y)/maskSize

    plt.plot(new_psf_x,new_psf_y)
    plt.show()

In [None]:

#We try to find psf with mask instead

psfFreqMask= outFreq * np.conj(maskFreq) / (np.abs(maskFreq)**2 + ASSUMED_NOISE_LEVEL/np.max(np.abs(maskFreq)**2))
psfMask = fftpack.ifft2(psfFreqMask, shape=psfShape)
psfMask = fftpack.fftshift(psfMask)

f, axarr = plt.subplots(1,3)
axarr[0].imshow(np.abs(psfFreqMask))
axarr[1].imshow(np.abs(psfMask))

#Blur psf with gaussian
#psfSmooth = gaussian_filter(np.abs(psfMask), sigma=1)
psfModel = ApproxPSFBesselOptimise(np.abs(psfMask), cutoff = 0.5)

axarr[2].imshow(np.abs(psfModel))


In [None]:
#Only run validation step if output shape is the same as input shape

if (out.shape == ori.shape):
    (maskWidth, maskHeight) = mask.shape
    maskSize = maskWidth*maskHeight

    #Normalise x axis of original psf to 1
    original_psf_y = psf[len(psf)//2]
    original_psf_y = original_psf_y[len(original_psf_y)//2:]
    original_psf_x = np.linspace(0,1,len(original_psf_y))*len(original_psf_y)/maskSize
    plt.plot(original_psf_x,original_psf_y)

    new_psf_y = np.abs(psfModel)
    new_psf_y = new_psf_y/max(new_psf_y.flatten())
    new_psf_y = new_psf_y[len(new_psf_y)//2]
    new_psf_y = new_psf_y[len(new_psf_y)//2:]
    new_psf_x = np.linspace(0,1,len(new_psf_y))*len(new_psf_y)/maskSize
    plt.plot(new_psf_x,new_psf_y)
    plt.show()

In [None]:
sourcePts = GetSourcePts(mask)
sourcePtsRescale = GetSourcePts(maskRescale)
recoveredImage = InverseMatrix(outBlur, maskRescale, sourcePtsRescale, adjPts=0, learningRate = 1, psf=psfModel)

In [None]:
f, axarr = plt.subplots(1,3)
axarr[0].imshow(out)
axarr[1].imshow(mask)
axarr[2].imshow(recoveredImage)
plt.show()

In [None]:
#Get average intensity in each mask

average_original_mask, average_original_intensity = GetAvgIntensityWithMask(ori,mask,sourcePts)
average_output_mask, average_output_intensity = GetAvgIntensityWithMask(out,maskRescale,sourcePtsRescale)
average_custom_mask, average_custom_intensity = GetAvgIntensityWithMask(recoveredImage,maskRescale,sourcePtsRescale)

assert(average_custom_intensity.all()==average_original_intensity.all())
assert(average_output_intensity.all()==average_original_intensity.all())

plt.plot(average_original_intensity,average_output_intensity,'bo')
plt.plot(average_original_intensity,average_custom_intensity,'ro')
plt.show()

#Calculate correlation coefficient
print("Default: ", np.corrcoef(average_original_intensity,average_output_intensity))
print("Custom: ", np.corrcoef(average_original_intensity,average_custom_intensity))
