In [1]:
import copy
import itertools
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import pywt
from IPython.display import clear_output
from pynufft import NUFFT
from scipy import signal as sci_signal
from scipy.constants import c

from csromer.base import Dataset
from csromer.dictionaries.discrete import DiscreteWavelet
from csromer.dictionaries.undecimated import UndecimatedWavelet
from csromer.io import Reader, Writer
from csromer.objectivefunction import L1, TSV, TV, Chi2, OFunction
from csromer.optimization import FISTA, GradientBasedMethod
from csromer.reconstruction import Parameter
from csromer.simulation import FaradayThickSource, FaradayThinSource
from csromer.transformers import NDFT1D, NUFFT1D, Gridding
from csromer.utils import Gaussian, complex_to_real, real_to_complex

%matplotlib inline

#np.random.seed(666)

In [2]:
def gini_coefficient(w):
    # Order vector
    w_ordered = np.sort(np.abs(w), kind="stable")
    l1_norm = np.sum(np.abs(w_ordered))
    M = len(w_ordered)
    m = np.arange(0, M)
    const = (M - m + 1.5) / M
    if l1_norm == 0.0:
        coeff = np.nan
    else:
        coeff = np.sum((w_ordered / l1_norm) * const)
    return 1.0 - 2.0 * coeff

In [3]:
def chi2_calc(residuals):
    if residuals.dtype == np.complex64 or residuals.dtype == np.complex128:
        data = residuals.real**2 + residuals.imag**2
    else:
        data = residuals**2
    return np.sum(data)

In [4]:
def aicbic(residuals, x):
    rss = chi2_calc(residuals)
    if x.dtype == np.complex64 or x.dtype == np.complex128:
        df = np.count_nonzero(x.real) + np.count_nonzero(x.imag)
    else:
        df = np.count_nonzero(x)
    l = 2 * len(residuals)
    return l * np.log(rss / l) + 2 * df, l * np.log(rss / l) + df * np.log(l)

In [5]:
def list_to2darray(x: list = None, cols: int = None, dtype=None):
    b = list(map(list, zip(*[iter(x)] * cols)))
    #b = [x[cols*i : cols*(i+1)] for i in range(rows)]
    if dtype is None:
        return np.array(b)
    else:
        return np.array(b, dtype=dtype)

In [6]:
class statistics:

    def __init__(self, m, n, z):
        self.sum = np.zeros((m, n), dtype=np.float32)
        self.sum2 = np.zeros((m, n), dtype=np.float32)
        self.n = z * np.ones((m, n), dtype=np.int32)

    def cumul(self, x):
        x_values = np.where(x != np.nan, x, 0.0)
        subtract = np.where(x == np.nan, -1, 0)
        self.sum += x_values
        self.sum2 += x_values * x_values
        self.n += subtract

    def mean(self):
        return np.where(self.n > 0, self.sum / self.n, np.nan)

    def std(self):
        return np.where(
            self.n > 0, np.sqrt(self.sum2 / self.n - self.sum * self.sum / self.n / self.n), np.nan
        )

In [7]:
class Test:

    def __init__(
        self,
        nu_min=None,
        nu_max=None,
        nchannels=None,
        noise_frac=None,
        remove_frac=None,
        use_gridding=False,
        ftransform="nufft",
        use_wavelet=None,
        source_1=None,
        source_2=None,
        scenario=1,
        append_signal=False,
    ):
        self.nu_min = nu_min
        self.nu_max = nu_max
        self.nchannels = nchannels
        self.noise_frac = noise_frac
        self.remove_frac = remove_frac
        self.use_gridding = use_gridding
        self.use_wavelet = use_wavelet
        self.ftransform = ftransform
        self.scenario = scenario
        self.append_signal = append_signal
        self.nu = np.linspace(start=nu_min, stop=nu_max, num=nchannels)
        self.source_1 = copy.deepcopy(source_1)
        self.source_2 = copy.deepcopy(source_2)

        if self.source_1 is not None:
            self.source_1.nu = self.nu
            self.source_1.simulate()

        if self.source_2 is not None:
            self.source_2.nu = self.nu
            self.source_2.simulate()

        if scenario == 1:
            self.source = self.source_1
        elif scenario == 2:
            self.source = self.source_2
        elif scenario == 3:
            self.source = self.source_1 + self.source_2
        else:
            raise ValueError("This scenario does not exist")

        if remove_frac:
            self.source.remove_channels(remove_frac, np.random.RandomState(int(time.time())))

        if scenario == 1:
            self.avg_signal = np.abs(self.source_1.s_nu)
        elif scenario == 2:
            self.avg_signal = np.abs(self.source_2.s_nu)
        else:
            self.avg_signal = (np.abs(self.source_1.s_nu) + np.abs(self.source_2.s_nu)) / 2.0

        if noise_frac:
            self.source.apply_noise(2.2804e-03 + 2.38714e-03j)

        if use_gridding:
            gridding = Gridding(self.source)
            self.source = gridding.run()

    def apply_noise(self):
        if self.noise_frac:
            self.source.apply_noise(
                2.2804e-03 + 2.38714e-03j, np.random.RandomState(int(time.time()))
            )
            #print("Calculating l2_0...")
            #self.source.l2_ref = self.source.calculate_l2ref()
            #print("l2_0: {0:0.3f} rad/m^2".format(self.source.l2_ref))

    def run(self, lambda_tv: float = None, lambda_tsv: float = None):
        self.parameter = Parameter()
        self.parameter.calculate_cellsize(dataset=self.source)

        dft = NDFT1D(dataset=self.source, parameter=self.parameter)

        self.F_dirty = dft.backward(self.source.data)

        if self.use_wavelet:
            #self.wavelet = DiscreteWavelet(
            #    wavelet_name=self.use_wavelet,
            #    mode="periodization",
            #    append_signal=self.append_signal
            #)
            self.wavelet = UndecimatedWavelet(
                wavelet_name=self.use_wavelet,
                mode="periodization",
                append_signal=self.append_signal
            )
            self.lambda_l1 = np.sqrt(len(self.source.data) + 2. * np.sqrt(len(self.source.data))
                                     ) * 2.0 * np.sqrt(2) * np.mean(self.source.sigma)
        else:
            self.lambda_l1 = np.sqrt(len(self.source.data) + 2. * np.sqrt(len(self.source.data))
                                     ) * np.sqrt(2) * np.mean(self.source.sigma)

        if lambda_tv is None:
            lambda_tv = 0.0

        if lambda_tsv is None:
            lambda_tsv = 0.0

        if self.ftransform == "nufft":
            nufft = NUFFT1D(dataset=self.source, parameter=self.parameter, solve=True)
            if self.use_wavelet:
                chi2 = Chi2(dft_obj=nufft, wavelet=self.wavelet)
            else:
                chi2 = Chi2(dft_obj=nufft)
        else:
            if self.use_wavelet:
                chi2 = Chi2(dft_obj=dft, wavelet=self.wavelet)
            else:
                chi2 = Chi2(dft_obj=dft)

        l1 = L1(reg=self.lambda_l1)
        tsv = TSV(reg=lambda_tsv)
        tv = TV(reg=lambda_tv)
        F_func = [chi2, l1, tsv]
        f_func = [chi2]
        g_func = [l1, tsv]

        F_obj = OFunction(F_func)
        f_obj = OFunction(f_func)
        g_obj = OFunction(g_func)

        self.parameter.data = self.F_dirty

        self.parameter.complex_data_to_real()

        if self.use_wavelet:
            self.parameter.data = self.wavelet.decompose(self.parameter.data)

        opt = FISTA(
            guess_param=self.parameter,
            F_obj=F_obj,
            fx=chi2,
            gx=g_obj,
            noise=2. * self.source.theo_noise,
            verbose=False
        )
        self.obj, self.X = opt.run()

        if self.use_wavelet is not None:
            self.coeffs = copy.deepcopy(self.X.data)
            k = np.count_nonzero(self.coeffs)
            self.sparsity = k / len(self.coeffs)
            self.gini = gini_coefficient(self.coeffs)
            self.X.data = self.wavelet.reconstruct(self.X.data)
            self.aic, self.bic = aicbic(self.source.residual, self.coeffs)

        else:
            k = np.count_nonzero(self.X.data)
            self.sparsity = k / len(self.X.data)
            self.gini = gini_coefficient(self.X.data)
            self.aic, self.bic = aicbic(self.source.residual, self.X.data)

        self.X.real_data_to_complex()

        self.X_residual = dft.backward(self.source.residual)

        self.X_restored = self.X.convolve() + self.X_residual

        self.res_noise = 0.5 * (np.std(self.X_residual.real) + np.std(self.X_residual.imag))
        self.rmse = np.sqrt(
            np.sum(self.source.residual.real**2 + self.source.residual.imag**2) /
            (2 * len(self.source.residual))
        )
        meaningful_signal = np.where(np.abs(self.parameter.phi) < self.parameter.max_faraday_depth)
        self.signal = np.mean(np.abs(self.X_restored[meaningful_signal]))
        self.peak_signal = np.max(np.abs(self.X_restored))

        self.snr = self.signal / self.res_noise
        self.psnr = self.peak_signal / self.res_noise

        print("Signal-to-noise ratio: {0}".format(self.snr))
        print("Peak Signal-to-noise ratio: {0}".format(self.psnr))
        print("Standard deviation: ({0})*10**-5".format(self.res_noise * 10**5))
        print("Root Mean Squared Error: ({0})*10**-5".format(self.rmse * 10**5))
        print("\n")
        """
        
        self.fig, self.ax = plt.subplots(nrows=2, ncols=4, sharey='row', figsize=(18, 5))

        # Data
        self.ax[0,0].plot(self.source.lambda2, self.source.data.real, 'k.', label=r"Stokes $Q$")
        self.ax[0,0].plot(self.source.lambda2, self.source.data.imag, 'c.', label=r"Stokes $U$")
        self.ax[0,0].plot(self.source.lambda2, np.abs(self.source.data), 'g.', label=r"$|P|$")
        self.ax[0,0].set_xlabel(r'$\lambda^2$[m$^{2}$]')
        self.ax[0,0].set_ylabel(r'Jy/beam')
        self.ax[0,0].title.set_text("Data")

        self.ax[1,0].plot(self.parameter.phi, self.F_dirty.real, 'c--', label=r"Stokes $Q$")
        self.ax[1,0].plot(self.parameter.phi, self.F_dirty.imag, 'c:', label=r"Stokes $U$")
        self.ax[1,0].plot(self.parameter.phi, np.abs(self.F_dirty), 'k-', label=r"|P|")
        self.ax[1,0].set_xlabel(r'$\phi$[rad m$^{-2}$]')
        self.ax[1,0].set_ylabel(r'Jy/beam m$^2$ rad$^{-1}$ rmtf$^{-1}$')
        self.ax[1,0].set_xlim([-1000,1000])

        # Model
        self.ax[0,1].plot(self.source.lambda2, self.source.model_data.real, 'k.', label=r"Stokes $Q$")
        self.ax[0,1].plot(self.source.lambda2, self.source.model_data.imag, 'c.', label=r"Stokes $U$")
        self.ax[0,1].plot(self.source.lambda2, np.abs(self.source.model_data), 'g.', label=r"$|P|$")
        self.ax[0,1].set_xlabel(r'$\lambda^2$[m$^{2}$]')
        self.ax[0,1].set_ylabel(r'Jy/beam')
        self.ax[0,1].title.set_text("Model")

        self.ax[1,1].get_shared_y_axes().remove(self.ax[1,1])
        self.ax[1,1].clear()
        self.ax[1,1].plot(self.parameter.phi, self.X.data.real, 'c--', label=r"Stokes $Q$")
        self.ax[1,1].plot(self.parameter.phi, self.X.data.imag, 'c:', label=r"Stokes $U$")
        self.ax[1,1].plot(self.parameter.phi, np.abs(self.X.data), 'k-', label=r"$|P|$")
        self.ax[1,1].set_xlabel(r'$\phi$[rad m$^{-2}$]')
        self.ax[1,1].set_ylabel(r'Jy/beam m$^2$ rad$^{-1}$ pix$^{-1}$')
        self.ax[1,1].set_xlim([-1000,1000])
        
        # Residual

        self.ax[0,2].plot(self.source.lambda2, self.source.residual.real, 'k.', label=r"Stokes $Q$")
        self.ax[0,2].plot(self.source.lambda2, self.source.residual.imag, 'c.', label=r"Stokes $U$")
        self.ax[0,2].plot(self.source.lambda2, np.abs(self.source.residual), 'g.', label=r"$|P|$")
        self.ax[0,2].set_xlabel(r'$\lambda^2$[m$^{2}$]')
        self.ax[0,2].set_ylabel(r'Jy/beam')
        self.ax[0,2].title.set_text("Residual")

        self.ax[1,2].plot(self.parameter.phi, self.X_residual.real, 'c--', label=r"Stokes $Q$")
        self.ax[1,2].plot(self.parameter.phi, self.X_residual.imag, 'c:', label=r"Stokes $U$")
        self.ax[1,2].plot(self.parameter.phi, np.abs(self.X_residual), 'k-', label=r"$|P|$")
        self.ax[1,2].set_xlabel(r'$\phi$[rad m$^{-2}$]')
        self.ax[1,2].set_ylabel(r'Jy/beam m$^2$ rad$^{-1}$ rmtf$^{-1}$')
        self.ax[1,2].set_xlim([-1000,1000])
        
        if self.use_wavelet:
            self.ax[0,3].get_shared_y_axes().remove(self.ax[0,3])
            self.ax[0,3].clear()
            self.ax[0,3].plot(self.coeffs)
            self.ax[0,3].title.set_text("Coefficients")

        self.ax[1,3].plot(self.parameter.phi, self.X_restored.real, 'c--', label=r"Stokes $Q$")
        self.ax[1,3].plot(self.parameter.phi, self.X_restored.imag, 'c:', label=r"Stokes $U$")
        self.ax[1,3].plot(self.parameter.phi, np.abs(self.X_restored), 'k-', label=r"$|P|$")
        self.ax[1,3].set_xlim([-1000,1000])
        self.ax[1,3].set_xlabel(r'$\phi$[rad m$^{-2}$]')
        self.ax[1,3].set_ylabel(r'Jy/beam m$^2$ rad$^{-1}$ rmtf$^{-1}$')
        self.ax[1,3].title.set_text("Restored")
        
        self.fig.tight_layout()
        """

In [8]:
def run_test(
    source_1,
    source_2,
    nsigma,
    remove_frac,
    nu_min=1.008e9,
    nu_max=2.031e9,
    nchannels=1000,
    scenario=1,
    use_wavelet=None,
    append_signal=False
):
    nosigma_objs = []
    for remv_frac in remove_frac:
        nosigma_objs.append(
            Test(
                nu_min=nu_min,
                nu_max=nu_max,
                nchannels=nchannels,
                noise_frac=0.0,
                remove_frac=remv_frac,
                scenario=scenario,
                source_1=source_1,
                source_2=source_2,
                use_wavelet=use_wavelet,
                append_signal=append_signal
            )
        )

    test_objs = []
    for nsig in nsigma:
        for i in range(0, len(remove_frac)):
            copy_object = copy.deepcopy(nosigma_objs[i])
            copy_object.noise_frac = nsig
            copy_object.apply_noise()
            test_objs.append(copy_object)

    del nosigma_objs

    nid = len(nsigma) * len(remove_frac)
    for _id in range(0, nid):
        test_objs[_id].run()

    return test_objs

In [9]:
# JVLA 1.008 - 2.031 GHz 546 channels
# MeerKAT 0.9 GHz-1.420 GHz 546
# eMERLIN 1.230 - 1.740 GHz 4096
def run_tests(
    source_1,
    source_2,
    nsigma,
    remove_frac,
    nsamples,
    nu_min=1.008e9,
    nu_max=2.031e9,
    nchannels=1000,
    scenario=1,
    use_wavelet=None,
    append_signal=True
):
    m = len(nsigma)
    n = len(remove_frac)
    w_len = len(use_wavelet)
    psnrs = statistics(m, n, nsamples * w_len)
    rmses = statistics(m, n, nsamples * w_len)
    #noises = statistics(m, n, nsamples*w_len)
    sparsities = statistics(m, n, nsamples * w_len)
    #ginies = statistics(m, n, nsamples*w_len)
    aics = statistics(m, n, nsamples * w_len)
    bics = statistics(m, n, nsamples * w_len)
    for j in range(0, w_len):
        for i in range(0, nsamples):
            test = run_test(
                source_1,
                source_2,
                nsigma,
                remove_frac,
                nu_min=nu_min,
                nu_max=nu_max,
                nchannels=nchannels,
                scenario=scenario,
                use_wavelet=use_wavelet[j],
                append_signal=append_signal
            )
            psnrs.cumul(list_to2darray([x.psnr for x in test], n, dtype=np.float32))
            #noises.cumul(list_to2darray([x.res_noise for x in test], n, dtype=np.float32))
            sparsities.cumul(
                list_to2darray([x.sparsity * 100.0 for x in test], n, dtype=np.float32)
            )
            #ginies.cumul(list_to2darray([x.gini for x in test], n, dtype=np.float32))
            rmses.cumul(list_to2darray([x.rmse for x in test], n, dtype=np.float32))
            aics.cumul(list_to2darray([x.aic for x in test], n, dtype=np.float32))
            bics.cumul(list_to2darray([x.bic for x in test], n, dtype=np.float32))
            for t in test:
                del t
            test = []

    print("Shape: ", psnrs.sum)

    psnr_mean, psnr_std = psnrs.mean(), psnrs.std()
    rmse_mean, rmse_std = rmses.mean(), rmses.std()
    #noise_mean, noise_std = noises.mean(), noises.std()
    sparsity_mean, sparsity_std = sparsities.mean(), sparsities.std()
    #gini_mean, gini_std = ginies.mean(), ginies.std()
    aic_mean, aic_std = aics.mean(), aics.std()
    bic_mean, bic_std = bics.mean(), bics.std()
    return psnr_mean, psnr_std, rmse_mean, rmse_std, aic_mean, aic_std, bic_mean, bic_std, sparsity_mean, sparsity_std

In [10]:
source_1 = FaradayThinSource(s_nu=0.0555036, phi_gal=-200, spectral_idx=0.7)
source_2 = FaradayThickSource(s_nu=0.06669743, phi_fg=40.0, phi_center=200, spectral_idx=0.7)

In [11]:
nsigma = [0.2]
remove_frac = [0.3]
scenarios = [1, 2, 3]
families = ["haar", "coif", "db", "dmey", "sym"]
#use_wavelet= pywt.wavelist(kind="discrete").remove("db1")
#use_wavelet = ["db1", "coif1", "coif2"]
#use_wavelet = pywt.wavelist("coif", kind="discrete")
#use_wavelet.remove("haar")
#nwavelets = len(use_wavelet)
#use_wavelet=None

In [12]:
names = ["PSNR", "RMSE", "AIC", "BIC"]
samples = 50
scenario_means = np.empty((len(scenarios), len(names), len(families)), dtype=np.float32)
scenario_stds = np.empty((len(scenarios), len(names), len(families)), dtype=np.float32)
append_signal = False
for z in range(0, len(families)):
    fam = families[z]
    use_wavelet = pywt.wavelist(fam, kind="discrete")
    print("Wavelets in this family: ", use_wavelet)
    for i in range(len(scenarios)):
        psnr_mean, psnr_std, rmse_mean, rmse_std, aic_mean, aic_std, bic_mean, bic_std, sparsity_mean, sparsity_std = run_tests(
            source_1,
            source_2,
            nsigma,
            remove_frac,
            samples,
            scenario=scenarios[i],
            use_wavelet=use_wavelet,
            append_signal=append_signal,
            nu_min=0.9e9,
            nu_max=1.62e9
        )
        scenario_means[i, 0, z] = psnr_mean
        scenario_means[i, 1, z] = rmse_mean
        scenario_means[i, 2, z] = aic_mean
        scenario_means[i, 3, z] = bic_mean

        scenario_stds[i, 0, z] = psnr_std
        scenario_stds[i, 1, z] = rmse_std
        scenario_stds[i, 2, z] = aic_std
        scenario_stds[i, 3, z] = bic_std

Wavelets in this family:  ['haar']
FWHM of the main peak of the RMTF: 49.639 rad/m^2
Maximum recovered width structure: 90.921 rad/m^2
Maximum Faraday Depth to which one has more than 50% sensitivity: 17348.916
Signal-to-noise ratio: 1.2983753344084936
Peak Signal-to-noise ratio: 28.15838406150629
Standard deviation: (763.1830871105194)*10**-5
Root Mean Squared Error: (20608.613373325006)*10**-5


FWHM of the main peak of the RMTF: 52.001 rad/m^2
Maximum recovered width structure: 91.736 rad/m^2
Maximum Faraday Depth to which one has more than 50% sensitivity: 18174.208
Signal-to-noise ratio: 1.3125424948132831
Peak Signal-to-noise ratio: 38.860599767311506
Standard deviation: (824.4350552558899)*10**-5
Root Mean Squared Error: (21629.465490583443)*10**-5


FWHM of the main peak of the RMTF: 49.933 rad/m^2
Maximum recovered width structure: 90.272 rad/m^2
Maximum Faraday Depth to which one has more than 50% sensitivity: 17451.507
Signal-to-noise ratio: 1.2984163060796916
Peak Signal-to

Signal-to-noise ratio: 1.3064529566273342
Peak Signal-to-noise ratio: 31.84065495168319
Standard deviation: (797.7122440934181)*10**-5
Root Mean Squared Error: (20872.631949742998)*10**-5


FWHM of the main peak of the RMTF: 52.876 rad/m^2
Maximum recovered width structure: 90.840 rad/m^2
Maximum Faraday Depth to which one has more than 50% sensitivity: 18480.296
Signal-to-noise ratio: 1.329006605656294
Peak Signal-to-noise ratio: 38.992213704965394
Standard deviation: (757.8910794109106)*10**-5
Root Mean Squared Error: (21175.079665049794)*10**-5


FWHM of the main peak of the RMTF: 47.529 rad/m^2
Maximum recovered width structure: 88.179 rad/m^2
Maximum Faraday Depth to which one has more than 50% sensitivity: 16611.505
Signal-to-noise ratio: 1.2989804220633874
Peak Signal-to-noise ratio: 39.02705684971512
Standard deviation: (839.660968631506)*10**-5
Root Mean Squared Error: (21596.316593095336)*10**-5


FWHM of the main peak of the RMTF: 45.262 rad/m^2
Maximum recovered width struc

In [13]:
np.save("uwavelet_meerkat_means_nolambda0.npy", scenario_means)
np.save("uwavelet_meerkat_stds_nolambda0.npy", scenario_stds)