In [None]:
import numpy as np
import scipy
from scipy import stats
from scipy.ndimage.filters import convolve
try:
    from astropy.io import fits
except ImportError:
    import pyfits as fits
import matplotlib.pyplot as plt
import pdb

def pad_psf(psf, spectrum):
    out_psf = np.zeros( spectrum.shape )
    start = len(spectrum)/2 - len(psf)/2
    end = start + len(psf)
    out_psf[start:end] = psf

    return out_psf

def rl_fft(raw_image, psf, niter, k=1, con_var=None):
    calc_chisq = lambda a, b, c, d: np.sum((a - b)**2 / (a + c)**2 / (d-1))
    
    conversion =  raw_image.mean() / 10
    raw_image /= conversion
    
    lucy = np.ones(raw_image.shape)
    ratio = k * np.ones(raw_image.shape)
    fft_psf = np.fft.fft(psf)
    
    con_var = sample_noise(raw_image)
    print ("using: ", con_var)

    norm = np.fft.ifft(np.fft.fft(ratio) * np.conj(fft_psf))
    fft_conv = fft_psf * np.fft.fft(lucy)
    lucy_conv = np.fft.ifft(fft_conv)

    chisq = calc_chisq(lucy_conv, raw_image, con_var, raw_image.size)
    print ("initial Chisq: {}".format(chisq))

    for iteration in range(niter):
        ratio = k * (raw_image + con_var) / (lucy_conv + con_var)
        fft_srat = np.fft.fft(ratio) * np.conj(fft_psf)

        lucy *= np.fft.ifft(fft_srat) / norm
        print (lucy.max(), lucy.mean(), lucy.min())
        fft_conv = fft_psf * np.fft.fft(lucy)
        lucy_conv = np.fft.ifft(fft_conv)
        size = lucy.size
        chisq = calc_chisq(lucy_conv, raw_image, con_var, raw_image.size)
        print ("Iteration {} Chisq: {}".format(iteration, chisq))
    lucy = lucy[range(size/2,size)+range(0,size/2)]
    return lucy * conversion