In [None]:
import numpy as np
import math
import time
import os
import glob
import re
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
from datetime import datetime

from astropy.io.votable import parse
from astroquery.casda import Casda
from astroquery.utils.tap.core import TapPlus
from astroquery.utils.tap.core import Tap
from astropy.io.votable import parse, parse_single_table

import astropy.coordinates as coord
from astropy.coordinates import SkyCoord
import astropy.units as u


from IPython.display import Audio, display


def allDone():
    """
    Audio notifier for task completition with default sound file.
    """

    display(Audio(url="http://www.mario-museum.net/sons/smb2_perdu.wav", autoplay=True))


def extractor(df, name):
    """
    Lazy extractor for coordinates of named objects in dataframe.
    Parameters:
    - pd.Dataframe: Dataframe containing the columns: name, ra_deg_cont, dec_deg_cont
    - name (string): Name of named object to search for in dataframe for coordinate extraction
    Returns:
    - ra,dec (floats): Coordinates of the named object in the dataframe
    """

    temp = df[df["name"] == name]
    ra = temp["ra_deg_cont"].values[0]
    dec = temp["dec_deg_cont"].values[0]
    return ra, dec


def convert_xml_to_pandas(xml_file_name):
    """
    Coverts xml file to pandas dataframe.
    Parameters:
    - xml file
    Returns:
    - pd.DataFrame
    """
    votable = parse(xml_file_name)
    table = votable.get_first_table()
    bill = table.to_table(use_names_over_ids=True)
    return bill.to_pandas()


def extract_sb_number(filename):
    """
    Searches a string of text for the scheduling block ID (SBID)number.
    Parameters:
    - filename (string): Text string of the inputted filename
    Returns:
    - SBID (string): Subset of the filename containing the SBID.
    """
    pattern = re.compile(r"SB(\d+)")
    match = pattern.search(filename)

    if match:
        sb_number = match.group(1)
        return sb_number
    else:
        return None


def search_within_sky_separation(df, ra, dec, max_sep):
    """
    Search a pandas DataFrame for rows with coordinates within a certain sky separation from a given location.
    Parameters:
    - df (pd.DataFrame): DataFrame with "ra" and "dec" columns representing celestial coordinates.
    - ra (float): Right ascension of the target location in degrees.
    - dec (float): Declination of the target location in degrees.
    - max_sep_arcsec (float): Maximum allowed angular separation in arcseconds.
    Returns:
    - pd.DataFrame: Subset of the input DataFrame containing rows within the specified sky separation.
    """
    # Create a SkyCoord object for the target location
    target_coord = SkyCoord(ra=ra * u.deg, dec=dec * u.deg, frame="icrs")
    # Create SkyCoord objects for the coordinates in the DataFrame
    df_coords = SkyCoord(
        ra=np.array(df["ra_deg_cont"]) * u.deg, dec=np.array(df["dec_deg_cont"]) * u.deg, frame="icrs"
    )
    # Calculate angular separation in arcseconds
    separation_arcsec = target_coord.separation(df_coords).arcsec
    # Filter DataFrame based on the separation
    result_df = df[separation_arcsec <= max_sep.value]
    return result_df


def error_quad(flux, flux_err, rms, scale=0.06):
    """
    Calculates the quadrature error.
    Parameters:
    - flux (float): Peak flux value from the fit
    - flux_err (float): Uncertainty on the peak flux from the fit
    - rms (float): Root mean square noise for the image
    - scale (float): Fractional percentage by which to scale the flux - default 6%
    Returns:
    - error (float): the quadrature error from the fitted error, rms, and flux scale uncertainty
    """
    return np.sqrt((flux_err**2) + (rms**2) + (scale * flux) ** 2)


def scatter_plot(df, name):
    """
    Scatter plot helper function to plot the light curve from a given dataframe containing observations
    multiple frequencies and of varying lengths.
    Parameters:
    - pd.Dataframe: Dataframe containing the columns: flux_peak, flux_err_quad, obs_start, freq, obs_length
    - name (string): Name of target object.
    Returns:
    - plot
    """
    flux = df["flux_peak"]
    flux_errors = df["flux_err_quad"]
    dates = df["obs_start"]
    frequencies = df["freq"]
    observation_lengths = df["obs_length"]

    if len(flux) == 0:
        print("There are no rows to plot, consider enlarging your search radius")
        return

    norm = plt.Normalize(frequencies.min(), frequencies.max())
    cmap = plt.get_cmap("viridis_r")

    sc = plt.scatter(
        dates, flux, c=frequencies, cmap=cmap, norm=norm, s=50
    )  # *observation_lengths/observation_lengths.min(), edgecolor='black')
    plt.errorbar(dates, flux, yerr=flux_errors, fmt="none", ecolor="gray", capsize=2, alpha=0.4)

    cbar = plt.colorbar(sc)
    cbar.set_label("Frequency [MHz]")

    plt.xlabel("Dates [MJD]")
    plt.ylabel("Flux Density [mJy/beam]")
    plt.title("{}".format(name))

    plt.tight_layout()
    plt.savefig("{}_Scatter_CASDA.pdf".format(name), dpi=300)
    plt.show()


def create_directory(directory_path):
    # Check if the directory already exists
    if not os.path.exists(directory_path):
        try:
            # Create the directory
            os.makedirs(directory_path)
            print(f"Directory '{directory_path}' created successfully.")
        except OSError as e:
            print(f"Error creating directory '{directory_path}': {e}")
    else:
        print(f"Directory '{directory_path}' already exists.")


def casda_download(name, ra, dec):
    """
    Searches the CASDA archive for catalogues covering the given coordinates.
    Downloads the catalogue .xml files for each scheduling block ID (SBID) into a new directory,
    created with the given target name.
    Parameters:
    - name (string): Name of named object to search for database
    - ra (float): Right ascension of the target location in degrees.
    - dec (float): Declination of the target location in degrees.
    Returns:
    - df_list (list): Appended dataframes containing all catalogue rows from each SBID
    """
    # Your OPAL username
    username = input("Enter your OPAL username")  # You can hardcode your username/email here
    casda = Casda()
    casda.login(username=username)  # This will ask you to enter your OPAL password

    # Set up the TAP url
    tap = TapPlus(url="https://casda.csiro.au/casda_vo_tools/tap")

    # This is the search function, here I'm searching only
    # for continuum catalogues. You can change this to search
    # for other data products instead
    # job = tap.launch_job_async(("SELECT TOP 50000 * FROM ivoa.obscore where(dataproduct_subtype = 'cont.restored.t0')"))
    job = tap.launch_job_async(
        ("SELECT TOP 50000 * FROM ivoa.obscore where(dataproduct_subtype = 'catalogue.continuum.component')")
    )
    r = job.get_results()

    # Here I keep only good or uncertain data
    data = r[(r["quality_level"] == "GOOD") | (r["quality_level"] == "UNCERTAIN")]

    # You have to do this step unless you have permission
    # for embargoed data associated with you OPAL
    # account login
    public_data = Casda.filter_out_unreleased(data)

    # Get the centre coords of all of the observations
    # in the table
    public_coords = SkyCoord(np.array(public_data["s_ra"]) * u.deg, np.array(public_data["s_dec"]) * u.deg)
    # Get the file names
    public_files = np.array(public_data["filename"])

    # This is the filepath where this script
    # will save the products you download

    casda_filepath = "Downloads/{}/".format(
        name
    )  # creates a named folder for the target source in your downloads directory
    create_directory(casda_filepath)
    # local
    # casda_filepath = '/Downloads'
    # I like pandas better
    pubdat = public_data.to_pandas()
    print("Number of rows: ", len(pubdat.index))

    print("Source coordinates:")  # print("Please enter the source coordinates")
    ra = ra  # float(input("RA [deg]: "))
    dec = dec  # float(input("Dec [deg]: "))
    print(ra, dec)

    # The coordinates of the source you're interested in
    example_source_coords = SkyCoord(ra * u.deg, dec * u.deg)
    # This sets how far away from the centre of the image
    # that you're searching
    example_radius = 10 * u.deg  # float(input("Search radius [deg]: "))* u.deg

    # Find the rows in the public data table
    # that are within example_radius of your
    # source
    seps = example_source_coords.separation(public_coords)
    matches = np.where(seps < example_radius)[0]

    matching_files = np.array(pubdat.iloc[matches]["filename"])
    url_list = []

    # This part stages the files you want to download
    # so it sometimes takes a minute
    for mfile in matching_files:
        # I usually don't need the checksum files
        if "checksum" not in mfile:
            pdata = public_data[public_data["filename"] == mfile]
            url = casda.stage_data(pdata)
            if url not in url_list:
                url_list += url

    filelist = casda.download_files(url_list, savedir=casda_filepath)

    print("Downloads completed.")

    print("Generating dataframes.")

    xml_file_path = casda_filepath  #'CASDA_downloads/{}'.format(name)
    directory = os.fsencode(xml_file_path)

    sbid_list = []
    df_list = []

    for file in os.listdir(directory):
        filename = os.fsdecode(file)
        if filename.endswith(".components.xml"):
            print(filename)

            sbid = extract_sb_number(filename)
            sbid_list.append("SB" + sbid)
            df = convert_xml_to_pandas(xml_file_path + filename)
            # obs_info = public_data[public_data['obs_id'] == 'ASKAP-'+sbid] #updated to include prefix
            obs_info = pubdat[
                (pubdat["obs_id"] == "ASKAP-" + sbid) | (pubdat["obs_id"] == "Beta-" + sbid)
            ]  # updated to include prefix

            df["obs_start"] = obs_info["t_min"].iloc[0]
            df["obs_length"] = obs_info["t_exptime"].iloc[0]
            df["project"] = obs_info["obs_collection"].iloc[0]
            df_list.append(df)
            continue
        else:
            continue
    return df_list


def data_filter(name, ra, dec, df_list, sep=5):
    """
    Filters the list of dataframes for sources within 5 arcseconds of the target coordinates.
    Calculates the quadrature error which is more conservative than the fitted error and
    accounts for flux scaling uncertainties. Builds a combined dataframe containing the
    following columns: component_id, component_name, ra_deg_cont, dec_deg_cont,freq,
    flux_peak, flux_peak_err, rms_image, obs_start,obs_length,flux_err_quad.
    Parameters:
    - name (string): Name of target object.
    - ra (float): Right ascension of the target location in degrees.
    - dec (float): Declination of the target location in degrees.
    - df_list (list): Appended dataframes containing all catalogue rows from each SBID
    - sep (float): maximum separation between catalogued sources and target coordinates in arcsec - default 5
    Returns:
    - df_combined_sorted (pd.DataFrame): Combined and filtered dataframe sorted by observation data
    - data_no_matches (list): Observations without matches within the maximum separation
    """
    ra = ra
    dec = dec
    max_sep = sep * u.arcsec
    combined_df = []

    data_combined = []
    data_no_matches = []
    for index, df in enumerate(df_list):
        matches = search_within_sky_separation(df, ra, dec, max_sep)
        if len(matches) > 0:
            data_combined.append(
                [
                    matches.component_id.values[0],
                    matches.component_name.values[0],
                    matches.ra_deg_cont.values[0],
                    matches.dec_deg_cont.values[0],
                    matches.freq.values[0],
                    matches.flux_peak.values[0],
                    matches.flux_peak_err.values[0],
                    matches.rms_image.values[0],
                    matches.obs_start.values[0],
                    matches.obs_length.values[0],
                    matches.project.values[0],
                ]
            )
        else:
            data_no_matches.append(df_list[index])

    df_combined = pd.DataFrame(
        data_combined,
        columns=[
            "component_id",
            "component_name",
            "ra_deg_cont",
            "dec_deg_cont",
            "freq",
            "flux_peak",
            "flux_peak_err",
            "rms_image",
            "obs_start",
            "obs_length",
            "project",
        ],
    )

    df_combined["flux_err_quad"] = error_quad(
        df_combined.flux_peak, df_combined.flux_peak_err, df_combined.rms_image
    )
    df_combined_sorted = df_combined.sort_values(by=["obs_start"])
    print("Rows for {} have been filtered and sorted.".format(name))
    return df_combined_sorted, data_no_matches

### Single Source CASDA Extraction:

In [None]:
name = "M31"
ra, dec = 10.684708, 41.268750
df_list = casda_download(name, ra, dec)
df_combined_sorted, data_no_matches = data_filter(name, ra, dec, df_list, sep=5)
scatter_plot(df_combined_sorted, name)
now = datetime.now()
run_dt_stamp = now.strftime("%Y%m%d_%H%M")
df_combined_sorted.to_csv("{}_CASDA_matches_{}.csv".format(name, run_dt_stamp))
allDone()

### Multi Source CASDA Extraction:

In [None]:
### Assumes the inputted sources csv has the columns: name, ra_deg_cont, dec_deg_cont

df = pd.read_csv("sources.csv")
names = df["name"].to_list()

for name in names:
    ra, dec = extractor(df, name)
    df_list = casda_download(name, ra, dec)
    f_combined_sorted, data_no_matches = data_filter(name, ra, dec, df_list, sep=5)
    scatter_plot(df_combined_sorted, name)
    now = datetime.now()
    run_dt_stamp = now.strftime("%Y%m%d_%H%M")
    df_combined_sorted.to_csv("{}_CASDA_matches_{}.csv".format(name, run_dt_stamp))

allDone()