# Tutorial 1: Data preprocessing
https://astronomers.skatelescope.org/ska-science-data-challenge-1/

---

### Introdction

This notebook will show how to process astronomy images, which involves corectting primary beams, and cropping out the training area for the machine learning (ML) model. It will also include how to find the sources for the training images (@@@ why only the training, yes I suppose because they already have a ground truth prepared).

This notebook will cover the following:
  1) Preprocess images (correct PB) and crop out the training area for building ML model
  2) Find sources in the PB-corrected training images

---

In [None]:
import os
import numpy as np
# from ska.sdc1.models.sdc1_image import Sdc1Image
import bdsf

import matplotlib.pyplot as plt
from astropy.utils.data import get_pkg_data_filename
from astropy.io import fits # define
from astropy.wcs import WCS # define
from astropy import units as u #define
from astropy.coordinates import SkyCoord #define
from astropy.nddata.utils import Cutout2D # define
# from MontagePy.main import * # http://montage.ipac.caltech.edu/MontageNotebooks/

---

<b><i> get the path </i></b> 


In [None]:
fits1400_1000h = get_pkg_data_filename("data/sample_images/1400mhz_1000h.fits")
fits1400_pb = get_pkg_data_filename("data/sample_images/1400mhz_pb.fits")


---
**Exercise 1:** get the path for the 2 other image frequencies with their pb fits files
<br>


In [None]:
# -- code goes here --



---

@@ Questions, 
- I read that the fits images are corrected PB images, then why do we need *fits_pb* 1400mhz_pb fits flies  ? ? ? ? ? ?


---

<b><i> Displaying file informations </i></b> 

like the shape of the fits files

In [None]:
print(fits.info(fits1400_1000h))
print()
print(fits.info(fits1400_pb))

---
**Exercise 2:** Display the info for the other image frequencies with thier pb fits files
<br>


In [None]:
# -- code goes here --



---

<b><i> Display the shape </i></b> 

In [None]:
img1400_1000h = fits.getdata(fits1400_1000h, ext=0)
img1400_1000h = img1400_1000h.reshape(4776, 5204)

print(img1400_1000h.shape)
# print(img_1000h[0][0])

<b><i> summary statistics </i></b> 

In [None]:
print('Min:', np.min(img1400_1000h))
print('Max:', np.max(img1400_1000h))
print('Mean:', np.mean(img1400_1000h))
print('Stdev:', np.std(img1400_1000h))

you can do it for the other two images, no one will stop you !!!

<b><i> Visualise the image </i></b> 

In [None]:
from matplotlib.colors import LogNorm
# https://github.com/HorizonIITM/PythonForAstronomy/blob/master/FITS%20Handling/PythonforAstronomy3.ipynb
plt.figure(figsize=(20, 10))
plt.imshow(img1400_1000h, cmap='PuBu_r', norm=LogNorm())
plt.colorbar()

@Q: is there is no why we can turn the above image like the one from the website ? 


---
**Exercise 3:** Display the 2 other image frequencies 
<br>


In [None]:
# -- code here --



---

### Pre-processing
now we will do the following:
1) Preprocess images (correct PB)
2) crop out the training area for building ML model

---

improting some packges

In [None]:
from astropy.io import fits
from astropy import units as u
from MontagePy.main import mGetHdr, mProjectQL

from source.utils.image_utils import (
    crop_to_training_area,
    get_image_centre_coord,
    get_pixel_value_at_skycoord,
    save_subimage,
)

from source.pre.sdc1_image import Sdc1Image
from path import image_path, pb_path

---

first let us define a new image from the Sdc1Image in sdc1_image.py, and also the frequencies

In [None]:
freq = 1400
new_image = Sdc1Image(freq, image_path(freq), pb_path(freq))

first let us define the method

In [None]:
def _create_pb_corr(image):
    """
    Apply PB correction to the image at image.path, using the primary beam
    file at image.pb_path.

    This uses Montage to regrid the primary beam image to the same pixel scale
    as the image to be corrected.
    """
    image._pb_corr_image = None

    # Establish input image to PB image pixel size ratios:
    with fits.open(image.pb_path) as pb_hdu:
        pb_x_pixel_deg = pb_hdu[0].header["CDELT2"]
    with fits.open(image.path) as image_hdu:
        x_size = image_hdu[0].header["NAXIS1"]
        x_pixel_deg = image_hdu[0].header["CDELT2"]

    ratio_image_pb_pix = (x_size * x_pixel_deg) / pb_x_pixel_deg
    coord_image_centre = get_image_centre_coord(image.path)

    if ratio_image_pb_pix < 2.0:
        # Image not large enough to regrid (< 2 pixels in PB image);
        # apply simple correction
        pb_value = get_pixel_value_at_skycoord(image.pb_path, coord_image_centre)
        image._apply_pb_corr(pb_value)
        return

    with fits.open(image.pb_path) as pb_hdu:
        # Create cropped PB image larger than the input image
        # TODO: May be inefficient when images get large
        size = (
            x_size * x_pixel_deg * u.degree * 2,
            x_size * x_pixel_deg * u.degree * 2,
        )

        save_subimage(
            image.pb_path,
            image._get_pb_cut_path(),
            coord_image_centre,
            size,
            overwrite=True,
        )

    # Regrid image PB cutout to same pixel scale as input image
    mGetHdr(image.path, image._get_hdr_path())

    # TODO: mProjectQL better than mProject, which outputs too-small images?
    rtn = mProjectQL(
        input_file=image._get_pb_cut_path(),
        output_file=image._get_pb_cut_rg_path(),
        template_file=image._get_hdr_path(),
    )
    if rtn["status"] == "1":
        raise ImageNotPreprocessed(
            "Unable to reproject image: {}".format(rtn["msg"])
        )

    # Correct Montage output (convert to 32-bit and fill NaNs)
    pb_array = image._postprocess_montage_out()

    # Apply PB correction and delete temporary files
    image._apply_pb_corr(pb_array)
    image._cleanup_pb()



# @@@ should we reduce this method abit ?

In [None]:
# cropping out the training area for building ML model
def _create_train(image, pad_factor=1.0):
    """
    Create the training image (crop to the frequency-dependent training area)
    """
    image._train = None
    train_path = image.path[:-5] + "_train.fits" # creating a new path for the training image 
    crop_to_training_area(image._pb_corr_image, train_path, image.freq, pad_factor)
    image._train = train_path
        
        

In [None]:
def preprocess(image):
    """
    Perform preprocessing steps:
        1) Create PB-corrected image (image.pb_corr_image)
        2) Output separate training image (image.train)
    """
    image._prep = False
    _create_pb_corr(image)
    _create_train(image)
    image._prep = True

@@@ I wander why do we need to to cut an area for the training, instead what we can do.

In [None]:
image_path(freq)

In [None]:
pb_path(freq)

In [None]:
new_image = Sdc1Image(freq, image_path(freq), pb_path(freq))
preprocess(new_image) 

### now let is visualise the ouput images:
- corrected image
- trained area image

first let us see the path for both images

In [None]:
print("primary corrected image:\n   "+new_image.pb_corr_image)
print()
print("trainig image:\n   "+new_image.train)

In [None]:
arr_path = [new_image.pb_corr_image, new_image.train]

In [None]:

for newImg in arr_path:
    plt.figure(figsize=(20, 10))
    print(fits.info(newImg))
    img = fits.getdata(newImg, ext=0)
    print(len(img.shape))
    if len(img.shape) == 4:
        img = img.reshape(img.shape[2:])
        print(img.shape)
    plt.imshow(img, cmap='PuBu_r', norm=LogNorm())
    plt.colorbar()
    plt.show()



@@@ visually I dont think I can spot the diffrince between the original, and the corrected PB image. what can we do to see that. a sugesstion would be to test both on a ml model, and then check the difference 

---

### Source finding

In [None]:
from source.utils.source_finder import SourceFinder
from path import write_df_to_disk, train_source_df_path

In [None]:

def _get_beam_from_hdu(sFinder):
    """
    Look up the beam information in the header of the SourceFinder's image
    """
    try:
        with fits.open(sFinder.image_name) as hdu:
            beam_maj = hdu[0].header["BMAJ"]
            beam_min = hdu[0].header["BMIN"]
            beam_pa = 0
            return (beam_maj, beam_min, beam_pa)
    except IndexError:
        raise SourceFinderException("Unable to automatically determine beam info")

In [None]:
def _get_rms_box_from_hdu(sFinder):
    """
    Determine an appropriate RMS box size using the header of the SourceFinder's
    image
    """
    try:
        with fits.open(sFinder.image_name) as hdu:
            beam_maj = hdu[0].header["BMAJ"]
            pix_per_beam = beam_maj / hdu[0].header["CDELT2"]
            return (30 * pix_per_beam, 8 * pix_per_beam)
    except IndexError:
        raise SourceFinderException(
            "Unable to automatically determine RMS box size"
        )

---

In [None]:
def run(sFinder, **kwargs):
    import def_run 
    """
    Run the source finder algorithm.

    Args are the same as for the bdsf.process_image method, with sensible defaults
    submitted for any not given.

    The 'beam' and 'rms_box' arg defaults are determined from the image header if
    not provided.
    """

    sFinder._run_complete = False

    # Must switch the executor's working directory to the image directory to
    # run PyBDSF, and switch back after the run is complete.
    cwd = os.getcwd()
    print()
    
    os.chdir(sFinder.image_dirname)

    # Get beam info automatically if not provided
    if not def_run.beam:
        def_run.beam = _get_beam_from_hdu(sFinder)
    if not def_run.rms_box:
        def_run.rms_box = _get_rms_box_from_hdu(sFinder)

    # Run PyBDSF
    try:
        bdsf.process_image(
            sFinder.image_name,
            adaptive_rms_box=def_run.adaptive_rms_box,
            advanced_opts=def_run.advanced_opts,
            atrous_do=def_run.atrous_do,
            psf_vary_do=def_run.psf_vary_do,
            psf_snrcut=def_run.psf_snrcut,
            psf_snrcutstack=def_run.psf_snrcutstack,
            output_opts=def_run.output_opts,
            output_all=def_run.output_all,
            opdir_overwrite=def_run.opdir_overwrite,
            beam=def_run.beam,
            blank_limit=def_run.blank_limit,
            thresh=def_run.thresh,
            thresh_isl=def_run.thresh_isl,
            thresh_pix=def_run.thresh_pix,
            psf_snrtop=def_run.psf_snrtop,
            rms_map=def_run.rms_map,
            rms_box=def_run.rms_box,
            do_cache=def_run.do_cache,
            **kwargs
        )
    except Exception as e:
        # Catch all exceptions to ensure CWD reverted
        os.chdir(cwd)
        raise e

    # Revert current working directory
    os.chdir(cwd)
    sFinder.clean_tmp()
    sFinder._run_complete = True

    return sFinder.get_source_df()



---

now let us do source finding on the training 

In [None]:
sources_training = {}
source_finder = SourceFinder(new_image.train)
print(source_finder.image_dirname)
sl_df = run(source_finder)

sources_training[new_image.freq] = sl_df

# (Optional) Write source list DataFrame to disk
write_df_to_disk(sl_df, train_source_df_path(new_image.freq))

# Remove temp files:
source_finder.reset()

In [None]:
sources_full = {}
source_finder = SourceFinder(new_image.pb_corr_image)
print(source_finder.image_dirname)
sl_df = run(source_finder)

sources_full[new_image.freq] = sl_df

# (Optional) Write source list DataFrame to disk
# write_df_to_disk(sl_df, train_source_df_path(new_image.freq))

# Remove temp files:
source_finder.reset()

In [None]:
# print(sources_training)
# print(sources_full)

---

save the sources_training & sources_full for visualisation

In [None]:
# %store  sources_training

In [None]:
# %store sources_full