In [177]:
import numpy as np

# redefining np.asscalar() and np.alen() to np.item() and np.len() to avoid deprecation errors
###~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~###
def patch_asscalar(a):
    return a.item()
setattr(np, "asscalar", patch_asscalar)

def patch_alen(a):
    return a.len()
setattr(np, "alen", patch_alen)
###~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~###

import os
import time
import torch
import matplotlib.pyplot as plt
from astropy.utils.data import download_file

from PIL import Image, ImageOps, ImageMath

# Mpol utilities
from mpol.__init__ import zenodo_record
from mpol import coordinates, gridding, fourier, losses, precomposed, utils
from mpol.images import ImageCube

In [178]:
###########################################################################################
def RML_imager(visibility_file, cell_size, npix, learning_rate, n_iter):
    '''---------------------------------------------------------------------------------------
    Performs the RML imaging process using the SimpleNet model on the input visibilities
    to produce surface brightness maps of the source in the image plane. Currently uses
    only the NLL loss function, without other regularizers.

    # TODO:
    1. Implement regularizers in the loss function.
    2. Add functionality to save the final image cube as a FITS file.
    3. Look into how to image multi-channel continuum visibilities into a single channel image.
    ----------------------------------------------------------------------------------------'''

    # create a directory to store the outputs of the RML imaging process
    if not os.path.exists('RML_loop_outputs/'):
        os.makedirs('RML_loop_outputs/')

    # load the mock visibilities from the .npz file
    d = np.load(visibility_file)
    uu = d["uu"]
    vv = d["vv"]
    weight = d["weight"]
    data = d["data"]
    data_re = np.real(data)
    data_im = np.imag(data)
    nvis = len(uu)
    print(f'Loaded visibilities from {visibility_file}.')
    print(f'The dataset has {nvis} visibilities.\n')

    # check if the input cell size Nyquist samples the spatial frequency represented by the maximum u,v value
    max_uv = np.max(np.array([uu,vv]))
    max_cell_size = utils.get_maximum_cell_size(max_uv)
    if cell_size > max_cell_size:
        raise ValueError(f'The input cell size ({cell_size} arcseconds) does not Nyquist sample the spatial frequency represented by the maximum u,v value.\nThe maximum cell_size that will still Nyquist sample the spatial frequency represented by the maximum u,v value is {max_cell_size:.2f} arcseconds).\nPlease change the cell size to be less than {max_cell_size:.2f} arcseconds.\n')
    else:
        print(f'The input cell size ({cell_size} arcseconds) Nyquist samples the spatial frequency represented by the maximum u,v value.\n(The maximum cell_size that will still Nyquist sample the spatial frequency represented by the maximum u,v value is {max_cell_size:.2f} arcseconds).\n')

    # plot and save the downloaded (u,v) distribution
    fig, ax = plt.subplots(nrows=1)
    ax.scatter(uu, vv, s=1, rasterized=True, linewidths=0.0, c="k")
    ax.set_xlabel(r"$u$ [k$\lambda$]")
    ax.set_ylabel(r"$v$ [k$\lambda$]")
    ax.set_title("uv distribution")
    plt.savefig('RML_loop_outputs/uv_distribution.pdf', format='pdf', bbox_inches='tight')
    print(f'(u,v) distribution plot saved to: RML_loop_outputs/uv_distribution.pdf\n')
    plt.close()

    # instantiate the gridcoords object
    coords = coordinates.GridCoords(cell_size=cell_size, npix=npix)

    # instantiate the dirty imager object
    imager = gridding.DirtyImager(
        coords=coords,
        uu=uu,
        vv=vv,
        weight=weight,
        data_re=data_re,
        data_im=data_im,
    )

    # calculate the dirty image and the beam using Briggs weighting with robust=0.0
    img, beam = imager.get_dirty_image(weighting="briggs", robust=0.0)
    print(f"Calculated dirty beam and dirty image using Briggs weighting with robust=0.0.")

    # plot and save the calculated dirty image and dirty beam
    chan = 0
    kw = {"origin": "lower", "interpolation": "none", "extent": imager.coords.img_ext}
    fig, ax = plt.subplots(ncols=2, figsize=(6, 3))
    bmplot = ax[0].imshow(beam[chan], **kw)
    #plt.colorbar(bmplot, ax=ax[0])
    ax[0].set_title("Dirty beam")
    imgplot = ax[1].imshow(img[chan], **kw)
    #plt.colorbar(imgplot, ax=ax[1])
    ax[1].set_title("Dirty image")
    for a in ax:
        a.set_xlabel(r"$\Delta \alpha \cos \delta$ [${}^{\prime\prime}$]")
        a.set_ylabel(r"$\Delta \delta$ [${}^{\prime\prime}$]")
    plt.tight_layout()
    plt.savefig('RML_loop_outputs/dirty_beam_and_dirty_image.pdf', format='pdf', bbox_inches='tight')
    print(f'Dirty beam and dirty image plot saved to: RML_loop_outputs/dirty_beam_and_dirty_image.pdf\n')
    plt.close()

    # instantiate the data averager object
    averager = gridding.DataAverager(
        coords=coords,
        uu=uu,
        vv=vv,
        weight=weight,
        data_re=data_re,
        data_im=data_im,
        )

    # convert the gridded visibilities to a pytorch dataset
    dset = averager.to_pytorch_dataset()
    print('Gridded and converted visibilities to a pytorch dataset.')
    print(f"The dataset has {dset.nchan} channel(s).\n") # TODO: Look into how to image multi-channel continuum visibilities into a single channel image.

    # initialise SimpleNet
    rml = precomposed.SimpleNet(coords=coords, nchan=dset.nchan)
    print(f"SimpleNet newtwork initialised.")

    # Because we want to compute a clean set of gradient values in a later step, we “zero out” any gradients attached to the tensor components so that they aren’t counted twice.
    rml.zero_grad()

    # instantiate the SGD optimizer
    optimizer = torch.optim.SGD(rml.parameters(), lr=learning_rate)

    start_time = time.time()
    ###~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~###
    # initiate a list to store the loss values at each iteration
    loss_tracker = []

    # loop over the number of iterations
    print(f"Starting the optimisation loop with {n_iter} iterations...")
    for i in range(n_iter):
        rml.zero_grad()

        # STEP 1: calculate the model visibilities from the current model image
        vis = rml()

        # STEP 2: calculate the loss between the model visibilities and the data visibilities
        #NOTE: for now, the loss function is just the negative log likelihood (nll) only. Regularizers will be added later
        loss = losses.nll_gridded(vis, dset) #TODO: Implement regularizers in the loss function
        loss_tracker.append(loss.item()) # append the loss value to the loss tracker list

        # STEP 3: calculate the gradients of the loss with respect to the model parameters
        loss.backward()

        # STEP 4: add the gradient image to the base image in order to advance base parameters in the direction of the minimum loss value
        optimizer.step()
    ###~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~###
    end_time = time.time()
    print('Done.')

    # calculate the time taken for the optimisation loop to finish
    elapsed_time = end_time - start_time
    print(f"The optimisation loop finished in {elapsed_time:.2f} seconds.\n")

    # plot the loss per iteration
    fig, ax = plt.subplots(nrows=1)
    ax.plot(np.arange(n_iter), loss_tracker, marker=".", color="k", linewidth=0.5)
    ax.set_xlabel("Iteration")
    ax.set_ylabel("Loss")
    ax.set_title("Loss per iteration")
    plt.savefig('RML_loop_outputs/loss_per_iteration.pdf', format='pdf', bbox_inches='tight')
    print(f'Loss per iteration plot saved to: RML_loop_outputs/loss_per_iteration.pdf\n')
    plt.close()

    # detach the model 'sky' image from the computational graph and convert it to a numpy array
    img_cube = rml.icube.sky_cube.detach().numpy() # TODO: Add functionality to save this image cube as a FITS file

    # plot the final model image after the last iteration
    fig, ax = plt.subplots(nrows=1)
    im = ax.imshow(
        np.squeeze(img_cube),
        origin="lower",
        interpolation="none",
        extent=rml.icube.coords.img_ext,
    )
    ax.set_xlabel(r"$\Delta \alpha \cos \delta$ [${}^{\prime\prime}$]")
    ax.set_ylabel(r"$\Delta \delta$ [${}^{\prime\prime}$]")
    ax.set_title("Maximum likelihood image")
    plt.colorbar(im, label=r"Jy/$\mathrm{arcsec}^2$")
    plt.savefig('RML_loop_outputs/maximum_likelihood_image.pdf', format='pdf', bbox_inches='tight')
    print(f'Maximum likelihood image plot saved to: RML_loop_outputs/maximum_likelihood_image.pdf\n')
    plt.close()

    return img_cube
###########################################################################################

In [179]:
# input parameters
###########################################################################################################################################
visibility_file = '../data/visibilities/mock_visibilities_model_star_new.npz' # path to the .npz file containing the observed visibilities
cell_size = 0.03 # arcseconds
npix = 128 # number of pixels per image axis
learning_rate = 3.0e2 # learning rate for the optimizer
n_iter = 250 # number of iterations for the optimizer
###########################################################################################################################################

In [180]:
# function call
###########################################################################################################################################
img_cube = RML_imager(visibility_file=visibility_file, cell_size=cell_size, npix=npix, learning_rate=learning_rate, n_iter=n_iter)
###########################################################################################################################################

Loaded visibilities from ../data/visibilities/mock_visibilities_model_star_new.npz.
The dataset has 325080 visibilities.

The input cell size (0.03 arcseconds) Nyquist samples the spatial frequency represented by the maximum u,v value.
(The maximum cell_size that will still Nyquist sample the spatial frequency represented by the maximum u,v value is 0.09 arcseconds).

(u,v) distribution plot saved to: RML_loop_outputs/uv_distribution.pdf

Calculated dirty beam and dirty image using Briggs weighting with robust=0.0.
Dirty beam and dirty image plot saved to: RML_loop_outputs/dirty_beam_and_dirty_image.pdf

Gridded and converted visibilities to a pytorch dataset.
The dataset has 1 channel(s).

SimpleNet newtwork initialised.
Starting the optimisation loop with 250 iterations...
Done.
The optimisation loop finished in 0.89 seconds.

Loss per iteration plot saved to: RML_loop_outputs/loss_per_iteration.pdf

Maximum likelihood image plot saved to: RML_loop_outputs/maximum_likelihood_image.pd