In [None]:
from astropy.io import fits
from matplotlib import pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1 import ImageGrid
from IPython.display import clear_output
from IPython.html import widgets 
from IPython.display import display, clear_output
from plotly.widgets import GraphWidget
from IPython.html.widgets import FloatProgress
%matplotlib inline

### Fit along channel

In [None]:
POLY_ORDER = 30
BEAT_TERMS = 15
img_corr = "RL"

In [None]:
hdr = fits.open("/scratch/JVLA_HOLOGRAPHY/JVLA-L/ant10"+img_corr+"real.fits")
img_data_real = hdr[0].data
hdr.close()
hdr = fits.open("/scratch/JVLA_HOLOGRAPHY/JVLA-L/ant10"+img_corr+"imag.fits")
img_data_imag = hdr[0].data
hdr.close()

flattened_real_view = img_data_real.reshape((img_data_real.shape[0],
                                  img_data_real.shape[1]*img_data_real.shape[2]))
flattened_imag_view = img_data_imag.reshape((img_data_imag.shape[0],
                                  img_data_imag.shape[1]*img_data_imag.shape[2]))
pxrange = np.arange(flattened_real_view.shape[1])
chrange = np.arange(flattened_real_view.shape[0])

flattened_real_reconstructed = np.zeros(flattened_real_view.shape)
flattened_imag_reconstructed = np.zeros(flattened_imag_view.shape)
cheby_real_coef = np.zeros([flattened_real_view.shape[1],POLY_ORDER+1])
cheby_imag_coef = np.zeros([flattened_imag_view.shape[1],POLY_ORDER+1])

pbar_fit = FloatProgress(min=0, max=100)
display(pbar_fit)
for px in pxrange:
    pbar_fit.value = px / float(flattened_real_view.shape[1]) * 100.0
    cheby_real_coef[px,:] = np.polynomial.chebyshev.chebfit(chrange,flattened_real_view[:,px],POLY_ORDER)
    flattened_real_reconstructed[:,px] = np.polynomial.chebyshev.chebval(chrange,cheby_real_coef[px,:])
    cheby_imag_coef[px,:] = np.polynomial.chebyshev.chebfit(chrange,flattened_imag_view[:,px],POLY_ORDER)
    flattened_imag_reconstructed[:,px] = np.polynomial.chebyshev.chebval(chrange,cheby_imag_coef[px,:])
reconstructed_real = flattened_real_reconstructed.reshape(img_data_real.shape)
reconstructed_imag = flattened_imag_reconstructed.reshape(img_data_imag.shape)

diff_img_real = img_data_real - reconstructed_real
diff_img_imag = img_data_imag - reconstructed_imag
print "Done fitting..."

In [None]:
def compute_beat(l,m,img_data,reconstructed,n_term):
    residue = img_data[:,l,m]-reconstructed[:,l,m]
    freq_decomposition = np.fft.fft(residue)
    norm_term = float(residue.shape[0])
    threshold = np.min(np.sort(np.abs(freq_decomposition)[1:residue.shape[0]/2])[:-n_term:-1])
    beat_coef_args = np.where(np.abs(freq_decomposition) < threshold)
    freq_decomposition[beat_coef_args] = 0
    beat = np.real(np.fft.ifft(freq_decomposition))
#     beat_coef_values = freq_decomposition[beat_coef_args[0:n_term]] / norm_term
#     beat = np.zeros(residue.shape)
#     for cid in range(n_term-1):
#         beat += np.real(beat_coef_values[cid] * np.exp(2.0j*np.pi*
#                                                        beat_coef_args[cid]/float(residue.shape[0]) *
#                                                        np.arange(residue.shape[0])))
    return beat

pbar_fit = FloatProgress(min=0, max=100)
display(pbar_fit)
reconstructed_real_with_beating = np.zeros(reconstructed_real.shape)
reconstructed_imag_with_beating = np.zeros(reconstructed_imag.shape)
for l in range(img_data_real.shape[1]):
    pbar_fit.value = l / float(img_data_real.shape[1]) * 100.0
    for m in range(img_data_real.shape[2]):
        reconstructed_real_with_beating[:,l,m] = reconstructed_real[:,l,m] + compute_beat(l,m,img_data_real,reconstructed_real,n_term=BEAT_TERMS)
        reconstructed_imag_with_beating[:,l,m] = reconstructed_imag[:,l,m] + compute_beat(l,m,img_data_imag,reconstructed_imag,n_term=BEAT_TERMS)
diff_img_real_with_beating = img_data_real - reconstructed_real_with_beating
diff_img_imag_with_beating = img_data_imag - reconstructed_imag_with_beating
print "Done adding beating residual..."

In [None]:
channel_slider = widgets.IntSlider()
channel_slider.min= 0
channel_slider.max= flattened_real_view.shape[0]-1
channel_slider.value = 0
channel_slider.description = 'Channel'
F = plt.figure(1,(15,15))
grid = ImageGrid(F, 111,  # similar to subplot(111)
        nrows_ncols=(2, 4),
        direction="row",
        axes_pad=0.5,
        add_all=True,
        label_mode="1",
        share_all=True,
        cbar_location="right",
        cbar_mode="each",
        cbar_size="3%")
im0 = grid[0].imshow(img_data_real[channel_slider.value], interpolation="nearest", cmap="cubehelix")
im1 = grid[1].imshow(reconstructed_real[channel_slider.value], interpolation="nearest", cmap="cubehelix")
im2 = grid[2].imshow(diff_img_real[channel_slider.value], interpolation="nearest", cmap="cubehelix")
im3 = grid[3].imshow(diff_img_real_with_beating[channel_slider.value], interpolation="nearest", cmap="cubehelix")
im4 = grid[4].imshow(img_data_imag[channel_slider.value], interpolation="nearest", cmap="cubehelix")
im5 = grid[5].imshow(reconstructed_imag[channel_slider.value], interpolation="nearest", cmap="cubehelix")
im6 = grid[6].imshow(diff_img_imag[channel_slider.value], interpolation="nearest", cmap="cubehelix")
im7 = grid[7].imshow(diff_img_imag_with_beating[channel_slider.value], interpolation="nearest", cmap="cubehelix")
grid[0].set_title("Beam RE")
grid[1].set_title("Reconstructed RE")
grid[2].set_title("RE difference")
grid[3].set_title("RE difference (with beating)")
grid[4].set_title("Beam IM")
grid[5].set_title("Reconstructed IM")
grid[6].set_title("IM difference")
grid[7].set_title("IM difference (with beating)")
grid[0].cax.colorbar(im0)
grid[1].cax.colorbar(im1)
grid[2].cax.colorbar(im2)
grid[3].cax.colorbar(im3)
grid[4].cax.colorbar(im4)
grid[5].cax.colorbar(im5)
grid[6].cax.colorbar(im6)
grid[7].cax.colorbar(im7)

def stats(ch):
    print "RMS Real (w/o beating): %f" % np.sqrt(np.mean(diff_img_real[ch]**2))
    print "RMS Real (with beating): %f" % np.sqrt(np.mean(diff_img_real_with_beating[ch]**2))
    print "RMS Imag (w/o beating): %f" % np.sqrt(np.mean(diff_img_imag[ch]**2))
    print "RMS Imag (with beating): %f" % np.sqrt(np.mean(diff_img_imag_with_beating[ch]**2))
stats(0)    
def animate():
    clear_output(wait=True)
    stats(channel_slider.value)
    im0.set_data(img_data_real[channel_slider.value])
    im1.set_data(reconstructed_real[channel_slider.value])
    im2.set_data(diff_img_real[channel_slider.value])
    im3.set_data(diff_img_real_with_beating[channel_slider.value])
    im4.set_data(img_data_imag[channel_slider.value])
    im5.set_data(reconstructed_imag[channel_slider.value])
    im6.set_data(diff_img_imag[channel_slider.value])
    im7.set_data(diff_img_imag_with_beating[channel_slider.value])
    grid[0].cax.colorbar(im0)
    grid[1].cax.colorbar(im1)
    grid[2].cax.colorbar(im2)
    grid[3].cax.colorbar(im3)
    grid[4].cax.colorbar(im4)
    grid[5].cax.colorbar(im5)
    grid[6].cax.colorbar(im6)
    grid[7].cax.colorbar(im7)
    plt.draw()
    display(F)
channel_slider.on_trait_change(animate, 'value')
display(channel_slider)

In [None]:
def plot_beam_px(l,m):
    plt.figure(1,(15,15))
    plt.plot(img_data_real[:,l,m])
    plt.plot(reconstructed_real[:,l,m])
    plt.plot(reconstructed_real_with_beating[:,l,m])
    plt.show()
plot_beam_px(55,55)