In [6]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
from PIL import Image, ImageOps, ImageMath
from astropy.utils.data import download_file
from mpol.__init__ import zenodo_record
from mpol import coordinates, gridding, fourier, losses, precomposed, utils
from mpol.images import ImageCube

In [7]:
###########################################################################################
def RML_imager(visibility_file, cell_size, npix, learning_rate, n_iter, hyperparams_config, start_from_dirty_image=False, learning_rate_dim=None, n_iter_dim=None,):
    '''---------------------------------------------------------------------------------------
    Performs the RML imaging process using the SimpleNet model on the input visibilities
    to produce surface brightness maps of the source in the image plane. Currently uses
    only the NLL loss function, without other regularizers.

    # TODO:
    1. Start the training loop from the dirty image instead of a flat BaseCube. #*DONE*#
    2. Implement regularizers in the loss function. #*DONE*#
    3. Optimise regularisers to improve the quality of the final image.
    4. Use cross-validation to find the best hyperparameters for the RML imaging process.
    5. Compare the final image with that from CASA tclean.
    6. Add functionality to save the final image cube as a FITS file.
    7. Look into how to image multi-channel continuum visibilities into a single channel image.
    ----------------------------------------------------------------------------------------'''

    # create a directory to store the outputs of the RML imaging process
    if not os.path.exists('RML_loop_outputs/'):
        os.makedirs('RML_loop_outputs/')

    # load the mock visibilities from the .npz file
    d = np.load(visibility_file)
    uu = d["uu"]
    vv = d["vv"]
    weight = d["weight"]
    data = d["data"]
    data_re = np.real(data)
    data_im = np.imag(data)
    nvis = len(uu)
    print(f'Loaded visibilities from {visibility_file}.')
    print(f'The dataset has {nvis} visibilities.\n')

    # check if the input cell size Nyquist samples the spatial frequency represented by the maximum u,v value
    max_uv = np.max(np.array([uu,vv]))
    max_cell_size = utils.get_maximum_cell_size(max_uv)
    if cell_size > max_cell_size:
        raise ValueError(f'The input cell size ({cell_size} arcseconds) does not Nyquist sample the spatial frequency represented by the maximum u,v value.\nThe maximum cell_size that will still Nyquist sample the spatial frequency represented by the maximum u,v value is {max_cell_size:.2f} arcseconds).\nPlease change the cell size to be less than {max_cell_size:.2f} arcseconds.\n')
    else:
        print(f'The input cell size ({cell_size} arcseconds) Nyquist samples the spatial frequency represented by the maximum u,v value.\n(The maximum cell_size that will still Nyquist sample the spatial frequency represented by the maximum u,v value is {max_cell_size:.2f} arcseconds).\n')

    # plot and save the downloaded (u,v) distribution
    fig, ax = plt.subplots(nrows=1)
    ax.scatter(uu, vv, s=1, rasterized=True, linewidths=0.0, c="k")
    ax.set_xlabel(r"$u$ [k$\lambda$]")
    ax.set_ylabel(r"$v$ [k$\lambda$]")
    ax.set_title("uv distribution")
    plt.savefig('RML_loop_outputs/uv_distribution.pdf', format='pdf', bbox_inches='tight')
    print(f'(u,v) distribution plot saved to: RML_loop_outputs/uv_distribution.pdf\n')
    plt.close()

    # instantiate the gridcoords object
    coords = coordinates.GridCoords(cell_size=cell_size, npix=npix)

    # instantiate the dirty imager object
    imager = gridding.DirtyImager(
        coords=coords,
        uu=uu,
        vv=vv,
        weight=weight,
        data_re=data_re,
        data_im=data_im,
    )

    # calculate the dirty image and the beam using Briggs weighting with robust=0.0
    img, beam = imager.get_dirty_image(weighting="briggs", robust=0.0)
    print(f"Calculated dirty beam and dirty image using Briggs weighting with robust=0.0.")

    # plot and save the calculated dirty image and dirty beam
    chan = 0
    kw = {"origin": "lower", "interpolation": "none", "extent": imager.coords.img_ext}
    fig, ax = plt.subplots(ncols=2, figsize=(6, 3))
    bmplot = ax[0].imshow(beam[chan], **kw)
    #plt.colorbar(bmplot, ax=ax[0])
    ax[0].set_title("Dirty beam")
    imgplot = ax[1].imshow(img[chan], **kw)
    #plt.colorbar(imgplot, ax=ax[1])
    ax[1].set_title("Dirty image")
    for a in ax:
        a.set_xlabel(r"$\Delta \alpha \cos \delta$ [${}^{\prime\prime}$]")
        a.set_ylabel(r"$\Delta \delta$ [${}^{\prime\prime}$]")
    plt.tight_layout()
    plt.savefig('RML_loop_outputs/dirty_beam_and_dirty_image.pdf', format='pdf', bbox_inches='tight')
    print(f'Dirty beam and dirty image plot saved to: RML_loop_outputs/dirty_beam_and_dirty_image.pdf')
    plt.close()
    print(f'The dirty image contains {np.sum(img < 0)} negative pixels.\n')

    # instantiate the data averager object
    averager = gridding.DataAverager(
        coords=coords,
        uu=uu,
        vv=vv,
        weight=weight,
        data_re=data_re,
        data_im=data_im,
        )

    # convert the gridded visibilities to a pytorch dataset
    dset = averager.to_pytorch_dataset()
    print('Gridded and converted visibilities to a pytorch dataset.')
    print(f"The dataset has {dset.nchan} channel(s).\n") # TODO: Look into how to image multi-channel continuum visibilities into a single channel image.

    if start_from_dirty_image:
        # use the dirty image as the initial model image in BaseCube
        # create a loss function corresponding to the mean squared error (MSE) between the RML model image pixel fluxes and the dirty image pixel fluxes and then optimize this RML model
        # It calculates the loss based off of the image-plane distance between the dirty image and the state of the ImageCube in order to make the state of the ImageCube closer to the dirty image
        ###-----------------------------------------------------------------------------------------------------------------###
        print(f"Starting the optimisation loop with {n_iter_dim} iterations to optimise the initial model image (BaseCube) based on the dirty image...")
        dirty_image = torch.tensor(img.copy())  # converts the dirty image into a pytorch tensor
        rml_dim = precomposed.SimpleNet(coords=coords, nchan=dset.nchan) # initialise SimpleNet
        optimizer_dim = torch.optim.SGD(rml_dim.parameters(), lr=learning_rate_dim) # instantiate the SGD optimizer

        loss_tracker_dim = []
        for i_dim in range(n_iter_dim):
            optimizer_dim.zero_grad() # zero out any gradients attached to the tensor components so that they aren’t counted twice
            rml_dim() # calculate the model visibilities from the current model image
            sky_cube_dim = rml_dim.icube.sky_cube # get the model image from the BaseCube object
            lossfunc_dim = torch.nn.MSELoss(reduction="sum")  # the MSELoss calculates mean squared error (squared L2 norm)
            loss_dim = (lossfunc_dim(sky_cube_dim, dirty_image)) ** 0.5 # square root of the MSE is our loss value
            loss_tracker_dim.append(loss_dim.item()) # append the loss value to the loss tracker list
            loss_dim.backward() # calculate the gradients of the loss with respect to the model parameters
            optimizer_dim.step() # subtract the gradient image to the base image in order to advance base parameters in the direction of the minimum loss value

        # plot the loss per iteration
        fig, ax = plt.subplots(nrows=1)
        ax.plot(loss_tracker_dim)
        ax.set_xlabel("iteration")
        ax.set_ylabel("loss")
        ax.set_title("loss per iteration - L2 norm (MSE) between dirty image and BaseCube model image")
        plt.savefig('RML_loop_outputs/loss_per_iteration_dim.pdf', format='pdf', bbox_inches='tight')
        print(f'Loss per iteration (for optimising the BaseCube image to the dirty image) plot saved to: RML_loop_outputs/loss_per_iteration_dim.pdf')
        plt.close()

        # plot the final model image (BaseCube) after the last iteration
        img_dim = np.squeeze(rml_dim.icube.sky_cube.detach().numpy())
        fig, ax = plt.subplots(nrows=1)
        im_dim = ax.imshow(img_dim, origin="lower", interpolation="none", extent=rml_dim.icube.coords.img_ext)
        plt.colorbar(im_dim)
        plt.savefig('RML_loop_outputs/optimised_input_model_image_based_on_dirty_image.pdf', format='pdf', bbox_inches='tight')
        print(f'Loss per iteration (for optimising the initial BaseCube to the dirty image) plot saved to: RML_loop_outputs/optimised_input_model_image_based_on_dirty_image.pdf')
        plt.close()
        print(f'The optimised initial image based on the dirty image contains {np.sum(img_dim < 0)} negative pixels.')

        # save the optimised initial model image (BaseCube) to a .pt file
        torch.save(rml_dim.state_dict(), "RML_loop_outputs/dirty_image_model.pt")
        print('Optimised initial model image (BaseCube) based on the dirty image saved to: RML_loop_outputs/dirty_image_model.pt\n')
        ###-----------------------------------------------------------------------------------------------------------------###

    # initialise SimpleNet
    rml = precomposed.SimpleNet(coords=coords, nchan=dset.nchan)
    print(f"SimpleNet network initialised.")

    # choose the model image to set as the initial BaseCube
    if start_from_dirty_image:
        # load the optimised initial model image (BaseCube) from the .pt file
        rml.load_state_dict(torch.load("RML_loop_outputs/dirty_image_model.pt"))
        print('Using optimised initial model image (loaded from: RML_loop_outputs/dirty_image_model.pt) based on the dirty image as the initial BaseCube image\n')
    else:
        # use the default flat initial model image in the BaseCube
        print('Starting from the default flat initial model image (BaseCube).\n')

    # Because we want to compute a clean set of gradient values in a later step, we “zero out” any gradients attached to the tensor components so that they aren’t counted twice.
    rml.zero_grad()

    # instantiate the SGD optimizer
    optimizer = torch.optim.SGD(rml.parameters(), lr=learning_rate)
    optimizer.zero_grad()

    start_time = time.time()
    ###~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~###
    # initiate a list to store the loss values at each iteration
    loss_tracker = []

    # Run the training loop
    print(f"Starting the optimisation loop with {n_iter} iterations...")
    for i in range(n_iter):
        rml.zero_grad()

        # STEP 1: calculate the model visibilities from the current model image
        vis = rml() # calculate model visibilities
        sky_cube = rml.icube.sky_cube # get the current model 'sky' image

        # STEP 2: calculate the loss between the model visibilities and the data visibilities
        # loss = losses.nll_gridded(vis, dset) # loss function without regularizers, using only the NLL
        loss = (
            losses.nll_gridded(vis, dset)
            + hyperparams_config["lambda_sparsity"] * losses.sparsity(sky_cube)
            + hyperparams_config["lambda_TV"] * losses.TV_image(sky_cube)
            + hyperparams_config["entropy"] * losses.entropy(sky_cube, hyperparams_config["prior_intensity"])
            + hyperparams_config["TSV"] * losses.TSV(sky_cube)
        ) # loss function with regularizers #TODO: Implement regularizers properly and check how to get best hyperparameter values
        loss_tracker.append(loss.item()) # append the loss value to the loss tracker list

        # STEP 3: calculate the gradients of the loss with respect to the model parameters
        loss.backward()

        # STEP 4: subtract the gradient image to the base image in order to advance base parameters in the direction of the minimum loss value
        optimizer.step()
    ###~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~###
    end_time = time.time()
    print('Done.')

    # calculate the time taken for the optimisation loop to finish
    elapsed_time = end_time - start_time
    print(f"The optimisation loop finished in {elapsed_time:.2f} seconds.\n")

    # print the initial and final loss values
    print(f"Initial loss: {loss_tracker[0]:.2f}")
    print(f"Final loss: {loss_tracker[-1]:.2f}\n")

    # plot the loss per iteration
    fig, ax = plt.subplots(nrows=1)
    ax.plot(np.arange(n_iter), loss_tracker, marker=".", color="k", linewidth=0.5)
    ax.set_xlabel("Iteration")
    ax.set_ylabel("Loss")
    ax.set_title("Loss per iteration")
    plt.savefig('RML_loop_outputs/loss_per_iteration.pdf', format='pdf', bbox_inches='tight')
    print(f'Loss per iteration plot saved to: RML_loop_outputs/loss_per_iteration.pdf\n')
    plt.close()

    # detach the model 'sky' image from the computational graph and convert it to a numpy array
    img_cube = rml.icube.sky_cube.detach().numpy() # TODO: Add functionality to save this image cube as a FITS file

    # plot the final model image after the last iteration
    fig, ax = plt.subplots(nrows=1)
    im = ax.imshow(
        np.squeeze(img_cube),
        origin="lower",
        interpolation="none",
        extent=rml.icube.coords.img_ext,
    )
    ax.set_xlabel(r"$\Delta \alpha \cos \delta$ [${}^{\prime\prime}$]")
    ax.set_ylabel(r"$\Delta \delta$ [${}^{\prime\prime}$]")
    ax.set_title("Maximum likelihood image")
    plt.colorbar(im, label=r"Jy/$\mathrm{arcsec}^2$")
    plt.savefig('RML_loop_outputs/maximum_likelihood_image.pdf', format='pdf', bbox_inches='tight')
    print(f'Maximum likelihood image plot saved to: RML_loop_outputs/maximum_likelihood_image.pdf\n')
    plt.close()

    return img_cube
###########################################################################################

In [8]:
# input parameters
###########################################################################################################################################
visibility_file = '../data/visibilities/mock_visibilities_model_star_new.npz' # path to the .npz file containing the observed visibilities
cell_size = 0.03 # arcseconds
npix = 128 # number of pixels per image axis
learning_rate = 3 # learning rate for the optimizer
n_iter = 25 # number of iterations for the optimizer

# hyperparameters used in the function and the optimizer (set those not being used to 0)
hyperparams_config = (
    {"lambda_sparsity": 7.0e-05,
    "lambda_TV": 0.00,
    "entropy": 1e-03,
    "prior_intensity": 1.5e-07,
    "TSV": 0.00,
    "epochs": 1000,
    }
)

start_from_dirty_image = False # If True, the initial BaseCube image is set to the dirty image, else to the default flat image.
###########################################################################################################################################

In [9]:
# function call
###########################################################################################################################################
img_cube = RML_imager(visibility_file=visibility_file, cell_size=cell_size, npix=npix, learning_rate=learning_rate, n_iter=n_iter, hyperparams_config=hyperparams_config, start_from_dirty_image = True, learning_rate_dim=5, n_iter_dim=1000)
###########################################################################################################################################

Loaded visibilities from ../data/visibilities/mock_visibilities_model_star_new.npz.
The dataset has 325080 visibilities.

The input cell size (0.03 arcseconds) Nyquist samples the spatial frequency represented by the maximum u,v value.
(The maximum cell_size that will still Nyquist sample the spatial frequency represented by the maximum u,v value is 0.09 arcseconds).



(u,v) distribution plot saved to: RML_loop_outputs/uv_distribution.pdf

Calculated dirty beam and dirty image using Briggs weighting with robust=0.0.
Dirty beam and dirty image plot saved to: RML_loop_outputs/dirty_beam_and_dirty_image.pdf
The dirty image contains 11266 negative pixels.

Gridded and converted visibilities to a pytorch dataset.
The dataset has 1 channel(s).

Starting the optimisation loop with 1000 iterations to optimise the initial model image (BaseCube) based on the dirty image...
Loss per iteration (for optimising the BaseCube image to the dirty image) plot saved to: RML_loop_outputs/loss_per_iteration_dim.pdf
Loss per iteration (for optimising the initial BaseCube to the dirty image) plot saved to: RML_loop_outputs/optimised_input_model_image_based_on_dirty_image.pdf
The optimised initial image based on the dirty image contains 0 negative pixels.
Optimised initial model image (BaseCube) based on the dirty image saved to: RML_loop_outputs/dirty_image_model.pt

Simple