In [None]:
import os
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime

import astropy.units as u
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from astropy.coordinates import SkyCoord
from astropy.io.votable import parse
from astroquery.casda import Casda
from astroquery.utils.tap.core import TapPlus
from IPython.display import Audio, display


def allDone():
    display(Audio(url="http://www.mario-museum.net/sons/smb2_perdu.wav", autoplay=True))


def extractor(df, name):
    """
    Lazy extractor for coordinates of named objects in dataframe.
    Parameters:
    - df (pd.DataFrame): The input dataframe containing the columns: `name`, `ra_deg_cont`, and `dec_deg_cont`.
    - name (str): The name of the object to search for in the dataframe for coordinate extraction.
    Returns:
    - ra,dec (floats): Coordinates of the named object in the dataframe
    """
    temp = df[df["name"] == name]
    ra = temp["ra_deg_cont"].values[0]
    dec = temp["dec_deg_cont"].values[0]
    return ra, dec


def convert_xml_to_pandas(xml_file_name):
    """
    Converts an XML file to a pandas DataFrame.
    Parameters:
    - xml_file_name (str): The name of the XML file to be converted.
    Returns:
    - pd.DataFrame: A pandas DataFrame containing the data from the XML file.
    """
    votable = parse(xml_file_name)
    table = votable.get_first_table()
    bill = table.to_table(use_names_over_ids=True)
    return bill.to_pandas()


def extract_sb_number(filename):
    """
    Searches a string of text for the scheduling block ID (SBID)number.
    Parameters:
    - filename (string): Text string of the inputted filename
    Returns:
    - SBID (string): Subset of the filename containing the SBID.
    """
    pattern = re.compile(r"SB(\d+)")
    match = pattern.search(filename)

    if match:
        sb_number = match.group(1)
        return sb_number
    else:
        return None


def search_within_sky_separation(df, ra, dec, max_sep):
    """
    Search a pandas DataFrame for rows with coordinates within a certain sky separation from a given location.
    Parameters:
    - df (pd.DataFrame): DataFrame with "ra" and "dec" columns representing celestial coordinates.
    - ra (float): Right ascension of the target location in degrees.
    - dec (float): Declination of the target location in degrees.
    - max_sep_arcsec (float): Maximum allowed angular separation in arcseconds.
    Returns:
    - pd.DataFrame: Subset of the input DataFrame containing rows within the specified sky separation.
    """
    target_coord = SkyCoord(ra=ra * u.deg, dec=dec * u.deg, frame="icrs")
    df_coords = SkyCoord(
        ra=np.array(df["ra_deg_cont"]) * u.deg, dec=np.array(df["dec_deg_cont"]) * u.deg, frame="icrs"
    )
    separation_arcsec = target_coord.separation(df_coords).arcsec
    result_df = df[separation_arcsec <= max_sep.value]
    return result_df


def error_quad(flux, flux_err, rms, scale=0.06):
    """
    Calculates the quadrature error.
    Parameters:
    - flux (float): Peak flux value from the fit
    - flux_err (float): Uncertainty on the peak flux from the fit
    - rms (float): Root mean square noise for the image
    - scale (float): Fractional percentage by which to scale the flux - default 6%
    Returns:
    - error (float): the quadrature error from the fitted error, rms, and flux scale uncertainty
    """
    return np.sqrt((flux_err**2) + (rms**2) + (scale * flux) ** 2)


def scatter_plot(df, name):
    """
    Scatter plot helper function to plot the light curve from a given dataframe containing observations
    multiple frequencies and of varying lengths.
    Parameters:
    - pd.Dataframe: Dataframe containing the columns: flux_peak, flux_err_quad, obs_start, freq, obs_length
    - name (string): Name of target object.
    Returns:
    - plot
    """
    flux = df["flux_peak"]
    flux_errors = df["flux_err_quad"]
    dates = df["obs_start"]
    frequencies = df["freq"]
    observation_lengths = df["obs_length"]

    if len(flux) == 0:
        print("There are no rows to plot, consider enlarging your search radius")
        return

    norm = plt.Normalize(frequencies.min(), frequencies.max())
    cmap = plt.get_cmap("viridis_r")

    sc = plt.scatter(dates, flux, c=frequencies, cmap=cmap, norm=norm, s=50)
    plt.errorbar(dates, flux, yerr=flux_errors, fmt="none", ecolor="gray", capsize=2, alpha=0.4)

    cbar = plt.colorbar(sc)
    cbar.set_label("Frequency [MHz]")

    plt.xlabel("Dates [MJD]")
    plt.ylabel("Flux Density [mJy/beam]")
    plt.title("{}".format(name))

    plt.tight_layout()
    plt.savefig("{}_Scatter_CASDA.pdf".format(name), dpi=300)
    plt.show()


def create_directory(directory_path):
    if not os.path.exists(directory_path):
        try:
            os.makedirs(directory_path)
            print(f"Directory '{directory_path}' created successfully.")
        except OSError as e:
            print(f"Error creating directory '{directory_path}': {e}")
    else:
        print(f"Directory '{directory_path}' already exists.")


# Your OPAL username
username = input("Enter your OPAL username")  # You can hardcode your username/email here
casda = Casda()
casda.login(username=username)  # This will ask you to enter your OPAL password


def _download_single_file(url, savedir, casda_obj, file_num, total_files):
    """
    Helper function to download a single file (used for parallel downloads).

    Parameters:
    - url: URL to download
    - savedir: Directory to save file
    - casda_obj: Casda instance for downloading
    - file_num: Current file number (for progress tracking)
    - total_files: Total number of files

    Returns:
    - Tuple of (success, filename, error_message)
    """
    try:
        # Download single file
        filelist = casda_obj.download_files([url], savedir=savedir)
        filename = filelist[0] if filelist else "unknown"
        print(f"  [{file_num}/{total_files}] Downloaded: {os.path.basename(filename)}")
        return (True, filename, None)
    except Exception as e:
        print(f"  [{file_num}/{total_files}] ERROR downloading {url}: {e}")
        return (False, url, str(e))


def _parallel_download_files(url_list, savedir, casda_obj, max_workers=4):
    """
    Download multiple files in parallel using ThreadPoolExecutor.
    Parameters:
    - url_list: List of URLs to download
    - savedir: Directory to save files
    - casda_obj: Casda instance for downloading
    - max_workers: Number of parallel download threads
    Returns:
    - List of successfully downloaded files
    """
    downloaded_files = []
    failed_downloads = []
    total_files = len(url_list)

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all download tasks
        future_to_url = {
            executor.submit(_download_single_file, url, savedir, casda_obj, i + 1, total_files): url
            for i, url in enumerate(url_list)
        }
        # Process completed downloads as they finish
        for future in as_completed(future_to_url):
            url = future_to_url[future]
            try:
                success, filename, error = future.result()
                if success:
                    downloaded_files.append(filename)
                else:
                    failed_downloads.append((url, error))
            except Exception as e:
                print(f"  Unexpected error processing {url}: {e}")
                failed_downloads.append((url, str(e)))

    print(f"\nDownload summary: {len(downloaded_files)}/{total_files} successful")
    if failed_downloads:
        print(f"Failed downloads ({len(failed_downloads)}):")
        for url, error in failed_downloads:
            print(f"  - {url}: {error}")

    return downloaded_files


def casda_download(name, ra, dec, optimized=True, parallel_downloads=True, max_workers=4):
    """
    Searches the CASDA archive for catalogues covering the given coordinates.
    Downloads the catalogue .xml files for each scheduling block ID (SBID) into a new directory,
    created with the given target name.

    OPTIMIZED VERSION: Uses chunked batch staging and optional parallel downloads for massive speedup.

    Parameters:
    - name (string): Name of named object to search for database
    - ra (float): Right ascension of the target location in degrees.
    - dec (float): Declination of the target location in degrees.
    - optimized (bool): Use optimized batch staging (default True). Set to False for original behavior.
    - parallel_downloads (bool): Use parallel downloads (default True). Only works with optimized=True.
    - max_workers (int): Number of parallel download threads (default 4). Increase for faster downloads.
    Returns:
    - df_list (list): Appended dataframes containing all catalogue rows from each SBID
    """

    tap = TapPlus(url="https://casda.csiro.au/casda_vo_tools/tap")

    # This is the search function, here I'm searching only
    # for continuum catalogues. You can change this to search
    # for other data products instead
    job = tap.launch_job_async(
        ("SELECT TOP 50000 * FROM ivoa.obscore where(dataproduct_subtype = 'catalogue.continuum.component')")
    )
    r = job.get_results()
    data = r[(r["quality_level"] == "GOOD") | (r["quality_level"] == "UNCERTAIN")]

    # You have to do this step unless you have permission
    # for embargoed data associated with you OPAL account login
    public_data = Casda.filter_out_unreleased(data)

    pubdat = public_data.to_pandas()
    print("Number of rows: ", len(pubdat.index))

    print("Source coordinates:")
    print(ra, dec)
    example_source_coords = SkyCoord(ra * u.deg, dec * u.deg)
    example_radius = 10 * u.deg
    public_coords = SkyCoord(np.array(public_data["s_ra"]) * u.deg, np.array(public_data["s_dec"]) * u.deg)
    seps = example_source_coords.separation(public_coords)
    matches = np.where(seps < example_radius)[0]

    matching_files = np.array(pubdat.iloc[matches]["filename"])
    matching_files = [f for f in matching_files if "checksum" not in f]
    print(f"Found {len(matching_files)} matching files to download")

    casda_filepath = (
        f"/Downloads/{name}/"  # creates a named folder for the target source in your downloads directory
    )
    create_directory(casda_filepath)

    # Batch stage files in chunks to avoid URL length limits
    if optimized and len(matching_files) > 0:
        print("Staging files (chunked batch mode - optimized)...")
        url_list = []
        chunk_size = 50  # Stage 50 files at a time (adjust if needed)
        # Split matching files into chunks
        matching_files_list = list(matching_files)
        for i in range(0, len(matching_files_list), chunk_size):
            chunk = matching_files_list[i : i + chunk_size]
            print(
                f"Staging chunk {i//chunk_size + 1}/{(len(matching_files_list) + chunk_size - 1)//chunk_size} ({len(chunk)} files)..."
            )
            chunk_data = public_data[np.isin(public_data["filename"], chunk)]

            try:
                urls = casda.stage_data(chunk_data)
                for url in urls:
                    if url not in url_list:  # Add unique URLs only
                        url_list.append(url)
            except Exception as e:
                print(f"Error staging chunk: {e}")
                print("Falling back to individual staging for this chunk...")
                # Fallback to individual staging for this chunk
                for mfile in chunk:
                    try:
                        pdata = public_data[public_data["filename"] == mfile]
                        url = casda.stage_data(pdata)
                        if url not in url_list:
                            url_list += url
                    except Exception as e2:
                        print(f"Error staging {mfile}: {e2}")

        print(f"Staged {len(url_list)} total files")
    else:
        # If your resources aren't suited to the optimisation you can stage files one by one (slower)
        print("Staging files (individual mode - original)...")
        url_list = []
        for mfile in matching_files:
            if "checksum" not in mfile:
                pdata = public_data[public_data["filename"] == mfile]
                url = casda.stage_data(pdata)
                if url not in url_list:
                    url_list += url

    if optimized and parallel_downloads and len(url_list) > 1:
        print(f"Downloading {len(url_list)} files in parallel (max {max_workers} workers)...")
        _parallel_download_files(url_list, casda_filepath, casda, max_workers)
    else:
        print("Downloading files (sequential)...")
        filelist = casda.download_files(url_list, savedir=casda_filepath)
    print("Downloads completed.")
    print("Generating dataframes.")
    obs_lookup = {}
    for _, row in pubdat.iterrows():
        obs_id = row["obs_id"]
        if obs_id not in obs_lookup:
            obs_lookup[obs_id] = {
                "t_min": row["t_min"],
                "t_exptime": row["t_exptime"],
                "obs_collection": row["obs_collection"],
            }
    xml_file_path = casda_filepath
    directory = os.fsencode(xml_file_path)
    sbid_list = []
    df_list = []

    for file in os.listdir(directory):
        filename = os.fsdecode(file)
        if filename.endswith(".components.xml"):
            print(filename)
            sbid = extract_sb_number(filename)
            sbid_list.append("SB" + sbid)
            df = convert_xml_to_pandas(xml_file_path + filename)

            obs_info_found = False
            for prefix in ["ASKAP-", "Beta-"]:
                obs_id = f"{prefix}{sbid}"
                if obs_id in obs_lookup:
                    obs_info = obs_lookup[obs_id]
                    df["obs_start"] = obs_info["t_min"]
                    df["obs_length"] = obs_info["t_exptime"]
                    df["project"] = obs_info["obs_collection"]
                    obs_info_found = True
                    break

            # If lookup dictionary method fails fallback to filtering pubdat each time
            if not obs_info_found:
                obs_info = pubdat[
                    (pubdat["obs_id"] == "ASKAP-" + sbid) | (pubdat["obs_id"] == "Beta-" + sbid)
                ]
                if len(obs_info) > 0:
                    df["obs_start"] = obs_info["t_min"].iloc[0]
                    df["obs_length"] = obs_info["t_exptime"].iloc[0]
                    df["project"] = obs_info["obs_collection"].iloc[0]

            df_list.append(df)
    return df_list


def data_filter(name, ra, dec, df_list, sep=5, optimized=True):
    """
    Filters the list of dataframes for sources within 5 arcseconds of the target coordinates.
    Calculates the quadrature error which is more conservative than the fitted error and
    accounts for flux scaling uncertainties. Builds a combined dataframe containing the
    following columns: component_id, component_name, ra_deg_cont, dec_deg_cont,freq,
    flux_peak, flux_peak_err, rms_image, obs_start,obs_length,flux_err_quad.
    Pre-computes target coordinates and uses vectorized operations.

    Parameters:
    - name (string): Name of target object.
    - ra (float): Right ascension of the target location in degrees.
    - dec (float): Declination of the target location in degrees.
    - df_list (list): Appended dataframes containing all catalogue rows from each SBID
    - sep (float): maximum separation between catalogued sources and target coordinates in arcsec - default 5
    - optimized (bool): Use optimized vectorized operations (default True). Set to False for original behavior.
    Returns:
    - df_combined_sorted (pd.DataFrame): Combined and filtered dataframe sorted by observation data
    - data_no_matches (list): Observations without matches within the maximum separation
    """
    max_sep = sep * u.arcsec
    if optimized:
        target_coord = SkyCoord(ra=ra * u.deg, dec=dec * u.deg, frame="icrs")
    data_combined = []
    data_no_matches = []
    for index, df in enumerate(df_list):
        if optimized:
            if len(df) == 0:
                data_no_matches.append(df_list[index])
                continue

            df_coords = SkyCoord(
                ra=df["ra_deg_cont"].values * u.deg, dec=df["dec_deg_cont"].values * u.deg, frame="icrs"
            )
            separations = target_coord.separation(df_coords)
            mask = separations <= max_sep

            if mask.any():
                idx = np.where(mask)[0][0]
                row = df.iloc[idx]
                data_combined.append(
                    [
                        row["component_id"],
                        row["component_name"],
                        row["ra_deg_cont"],
                        row["dec_deg_cont"],
                        row["freq"],
                        row["flux_peak"],
                        row["flux_peak_err"],
                        row["rms_image"],
                        row["obs_start"],
                        row["obs_length"],
                        row["project"],
                    ]
                )
            else:
                data_no_matches.append(df_list[index])
        else:
            # non-vectorized method (slower)
            matches = search_within_sky_separation(df, ra, dec, max_sep)
            if len(matches) > 0:
                data_combined.append(
                    [
                        matches.component_id.values[0],
                        matches.component_name.values[0],
                        matches.ra_deg_cont.values[0],
                        matches.dec_deg_cont.values[0],
                        matches.freq.values[0],
                        matches.flux_peak.values[0],
                        matches.flux_peak_err.values[0],
                        matches.rms_image.values[0],
                        matches.obs_start.values[0],
                        matches.obs_length.values[0],
                        matches.project.values[0],
                    ]
                )
            else:
                data_no_matches.append(df_list[index])

    df_combined = pd.DataFrame(
        data_combined,
        columns=[
            "component_id",
            "component_name",
            "ra_deg_cont",
            "dec_deg_cont",
            "freq",
            "flux_peak",
            "flux_peak_err",
            "rms_image",
            "obs_start",
            "obs_length",
            "project",
        ],
    )

    df_combined["flux_err_quad"] = error_quad(
        df_combined.flux_peak, df_combined.flux_peak_err, df_combined.rms_image
    )

    df_combined_sorted = df_combined.sort_values(by=["obs_start"])
    print(f"Rows for {name} have been filtered and sorted.")

    return df_combined_sorted, data_no_matches

### Single Source CASDA Extraction:

In [None]:
name = "M31"
ra, dec = 10.684708, 41.268750
start = datetime.now()
run_dt_stamp = start.strftime("%Y%m%d_%H%M")
df_list = casda_download(name, ra, dec, parallel_downloads=True, max_workers=8)
df_combined_sorted, data_no_matches = data_filter(name, ra, dec, df_list, sep=5)
scatter_plot(df_combined_sorted, name)
df_combined_sorted.to_csv(f"{name}_CASDA_matches_{run_dt_stamp}.csv")
end = datetime.now()
elapsed = end - start
print(f"{name} CASDA executed in seconds: {elapsed.total_seconds()}")
allDone()

### Multi Source CASDA Extraction:

In [None]:
### Assumes the inputted sources csv has the columns: name, ra_deg_cont, dec_deg_cont
df = pd.read_csv("sources.csv")
names = df["name"].to_list()

for name in names:
    start = datetime.now()
    run_dt_stamp = start.strftime("%Y%m%d_%H%M")
    ra, dec = extractor(df, name)
    df_list = casda_download(name, ra, dec, parallel_downloads=True, max_workers=4)
    df_combined_sorted, data_no_matches = data_filter(name, ra, dec, df_list, sep=5)
    scatter_plot(df_combined_sorted, name)
    now = datetime.now()
    run_dt_stamp = now.strftime("%Y%m%d_%H%M")
    df_combined_sorted.to_csv(f"{name}_CASDA_matches_{run_dt_stamp}.csv")
    end = datetime.now()
    elapsed = end - start
    print(f"{name} CASDA executed in seconds: {elapsed.total_seconds()}")

allDone()