In [2]:
import numpy as np
import pandas as pd
from scipy import stats
import pprint
import warnings
from pandas.errors import ParserWarning
pd.set_option('display.max_columns', 8)
pd.set_option('display.width', 1000)

import re
from pathlib import Path
import warnings
from typing import List, Dict, Tuple, Any, Union, Callable
from astropy.table import Table, MaskedColumn

from astroquery import mast
from astroquery.mast import Observations
from astropy.time import Time, TimeJD, TimeDelta
from astroquery.mast.missions import MastMissions
hst_mission = MastMissions(mission='hst')

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
%matplotlib inline
plt.rcParams['figure.figsize'] = [7, 7]
plt.rcParams['figure.dpi'] = 300

from tqdm import tqdm

# Autoreload extension
if 'autoreload' not in get_ipython().extension_manager.loaded:
    %load_ext autoreload

%autoreload 2

In [3]:
# local setup
WDIR = Path().cwd().parent

EMISSION_DATABASE_PATH = WDIR / "data/emissionspec.csv"
TRANSMISSION_DATABASE_PATH = WDIR / "data/transitspec.csv"

In [34]:
def alias_this(s: str):
    alias_dict = {
        "Wide Field Camera 3": "WFC3",
        "Space Telescope Imaging Spectrograph": "STIS",
        "Hubble Space Telescope satellite": "HST",
        "Hubble Space Telescope": "HST",
        # Add more aliases as necessary
    }
    if s in alias_dict:
        return alias_dict[s]
    else:
        return s

def extract_abbreviation(string, force_abbrv=False):
    match = re.search(r'\(([A-Z]+)\)', string)
    if match:
        abbrv = match.group(1)
        abbrv = re.sub(r'[\[\]()]', '', abbrv)
        return abbrv
    elif force_abbrv:
        words = string.split()
        abbrv = ""
        for word in words:
            abbrv += word[0].upper()
        abbrv = re.sub(r'[\[\]()]', '', abbrv)
        return abbrv
    else:
        return string

def append_next_string(first_list: List[str], second_list: List[str]) -> List[str]:
    """
    This function takes in a list of strings and appends the next string in another
    list of strings that is not already in the first list. It returns the first list with the appended string.

    Args:
    - first_list: A list of strings.
    - second_list: Another list of strings.

    Returns:
    - A list of strings that contains all the strings in the first list and the next string in the
      second list that is not already in the first list.
    """
    for string in second_list:
        if string not in first_list:
            first_list.append(string)
            break

    return first_list

default_warning_filters = [
    {"category":DeprecationWarning, "module":''},
    {"category":FutureWarning, "module":''},
    {"category":RuntimeWarning, "module":''},
    {"category":Warning, "module":'numpy'},
    {"category":Warning, "module":'scipy'},
    {"category":Warning, "module":'pandas'},
    {"category":Warning, "module":'astropy'},
    {"category":Warning, "module":'astroquery'},
]

def run_function_with_warnings_filtered(func: Callable[..., Any], *args: Tuple[Any], filters: (List[Dict[str, Any]], None)=None, verbose: bool=True, **kwargs: Any) -> Any:
    """
    Executes the input function while filtering warnings as specified by the filters list.

    Args:
        func (Callable[..., Any]): The function to be executed.
        filters (List[Dict[str, Any]]): A list of dictionaries containing the warning category and module to be filtered.
        *args (Tuple[Any]): Positional arguments to be passed to the input function.
        verbose (bool, optional): If True, prints the warnings after the function execution. Defaults to True.
        **kwargs (Any): Keyword arguments to be passed to the input function.

    Returns:
        Any: The result of the executed input function.
    """
    if filters is None:
        filters = []

    # Create a filter that includes warnings from all relevant modules
    all_warnings = []
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        filters = default_warning_filters + filters
        for f in filters:
            warnings.filterwarnings("ignore", **f)

        # Call the input function with the provided *args and **kwargs
        result = func(*args, **kwargs)

        all_warnings.extend(w)

    # print out the resulting warnings
    if (verbose and len(all_warnings) > 0):
        print("The following warnings were raised:")
        warns, idx, counts = np.unique([repr(w.message) for w in all_warnings], return_index=True, return_counts=True, )
        warns = np.array(all_warnings, dtype=object)[idx]

        for w, c in zip(warns, counts):
            message_str = re.sub(r'^.*?(Warning\()', r'\1', str(w.message).rsplit('Warning(', 1)[-1], count=1)
            print(f'Line: {w.lineno}, count: {c}: \n'
                  f'{pprint.pformat(message_str)}')
    elif verbose:
        print("No warnings were raised.")

    # Return the result of the function call
    return result

In [5]:
df_transmission = pd.read_csv(TRANSMISSION_DATABASE_PATH,
                              header=26, index_col=0,
                              dtype={"plntranmid": str, },
                              )

# fix time
plntranmid = df_transmission["plntranmid"].dropna(axis="index")
plntranmid_jd_time = pd.DataFrame(data=Time(plntranmid.to_numpy().astype(float), format="jd"),
                                  index=plntranmid.index, columns=["plntranmid_jd_time"])
plntranmid_jd_time["plntranmid_mjd_time"] = [t.mjd for t in plntranmid_jd_time["plntranmid_jd_time"]]
df_transmission = df_transmission.join(plntranmid_jd_time, how="left")

In [6]:
def add_source_key(data_dict: Dict) -> Dict:
    """
    Adds source information to the input data dictionary by extracting the author(s) and year from the reflink.

    Args:
        data_dict (Dict): A dictionary containing data related to a planet observation.

    Returns:
        Dict: The modified data dictionary with the added 'source', 'author', and 'year' keys.
    """
    # Extract the plntranreflink value
    reflink = data_dict.get('reflink', '')

    # Search for the 'authoryear' pattern using a regular expression
    pattern = re.compile(r'>(.+? et al\. \d{4})</a>|>(.+? \d{4})</a>')
    match = pattern.search(reflink)

    author_pattern = re.compile(r'\+\d{4}$')
    year_pattern = re.compile(r'\d{4}')

    if match:
        authoryear = match.group(1) if match.group(1) else match.group(2)
        if 'et al.' in authoryear:
            authoryear = authoryear.replace(' et al. ', '+')

        # Add the 'source' key with the extracted authoryear value
        data_dict['source'] = authoryear.strip()
        data_dict['author'] = re.sub(author_pattern, '',  data_dict['source'])
        year_match = re.search(year_pattern, data_dict['source'])
        data_dict['year'] = int(year_match.group(0)) if match else None
    else:
        data_dict['source'] = None
        data_dict['author'] = None
        data_dict['year'] = None

    data_dict["data"]['prelim_group_id'] = data_dict['prelim_group_id']
    data_dict["data"]['source'] = data_dict['source']
    data_dict["data"]['author'] = data_dict['author']
    data_dict["data"]['year'] = data_dict['year']

    return data_dict

def extract_data(df: pd.DataFrame) -> List[Dict]:
    # Group the DataFrame by unique combinations of planet_name, instrument, facility, and reflink
    grouped = df.groupby(['plntname', 'facility', 'instrument', 'plntranreflink'])

    result = []

    for i, ((planet_name, facility, instrument, reflink), group) in enumerate(grouped):
        # Convert each group to a dictionary
        entry = {
            'plntname': planet_name,
            'facility': facility,
            'instrument': instrument,
            'reflink': reflink,
            "prelim_group_id": i,
            'data': group
        }

        # Add the entry to the result list
        entry = add_source_key(entry)

        result.append(entry)

    return result

def concat_specs(dicts_list):
    dfs = [d['data'] for d in dicts_list]
    return pd.concat(dfs, axis=1)

split_transmission = extract_data(df_transmission)
spec_transmission = concat_specs(split_transmission)

In [7]:
def unique_planets_and_counts(dict_list: List[Dict]) -> (List[str], List[float]):
    planet_counts = {}

    for item in dict_list:
        planet_name = item['plntname']
        if planet_name in planet_counts:
            planet_counts[planet_name] += 1
        else:
            planet_counts[planet_name] = 1

    unique_planet_names = list(planet_counts.keys())
    counts = list(planet_counts.values())

    return unique_planet_names, counts

# TODO: write regex that checks for number of star systems: all non-special characters must be equal except for the letter at the last position.

def summarize_spectra(spectra_dicts: List[Dict[str, str]], verbose: bool = True) -> Dict[str, int]:
    """
    Summarize the input list of dictionaries by counting occurrences of specific elements.

    Parameters
    ----------
    spectra_dicts : List[Dict[str, str]]
        A list of dictionaries containing specific elements.
    verbose : bool, optional
        If True, print the summary, default is True.

    Returns
    -------
    Dict[str, int]
        A dictionary with the counts of specific elements in the input list.
    """
    # Initialize the output dictionary
    summary = {}

    # Find the number of spectra in the input
    summary['spectra'] = len(spectra_dicts)

    # Define columns for counting occurrences
    count_columns = ['plntname', 'facility', 'instrument']

    # Count occurrences for each element in count_columns
    for column in count_columns:
        summary[column] = {}
        for d in spectra_dicts:
            key = d[column]
            summary[column][key] = summary[column].get(key, 0) + 1

    # Print the summary if verbose is True
    if verbose:
        print("Summary:")
        for key, value in summary.items():
            print(f"  {key}:")
            if isinstance(value, dict):
                inverted = {}
                for k, v in value.items():
                    if v not in inverted:
                        inverted[v] = [k]
                    else:
                        inverted[v].append(k)

                for count, items in inverted.items():
                    if key in ['facility', 'instrument']:
                        print(f"    {count} x {', '.join([extract_abbreviation(item, force_abbrv=True) if len(item) > 10  else extract_abbreviation(item, force_abbrv=False) for item in items])}")
                    else:
                        print(f"    {count} x {', '.join(items)}")
            else:
                print(f"    {value}")

    return summary

all_planets, counts = unique_planets_and_counts(split_transmission)

print(f"Number of unique planets with transmission spectra: {len(all_planets)}.\n"
      f"\tMedian count: {np.median(counts)}, mean count: {np.median(counts):.1f} +/- {np.std(counts):.1f}, max count: {np.max(counts)} ({all_planets[np.argmax(counts)]}).")

summary = summarize_spectra(split_transmission)

Number of unique planets with transmission spectra: 103.
	Median count: 2.0, mean count: 2.0 +/- 3.8, max count: 21 (GJ 1214 b).
Summary:
  spectra:
    346
  plntname:
    4 x 55 Cnc e, GJ 1132 b, HAT-P-3 b, K2-18 b, WASP-31 b, WASP-46 b, WASP-57 b, WASP-67 b, WASP-74 b, WASP-79 b
    2 x CoRoT-1 b, GJ 436 b, HAT-P-18 b, HAT-P-23 b, K2-26 b, K2-3 b, K2-3 c, K2-3 d, KELT-9 b, KOI-12 b, KOI-13 b, KOI-94 d, Kepler-10 c, Kepler-102 d, Kepler-102 e, Kepler-104 d, Kepler-11 e, Kepler-125 b, Kepler-126 d, Kepler-127 d, Kepler-138 c, Kepler-14 b, Kepler-158 c, Kepler-18 c, Kepler-18 d, Kepler-19 b, Kepler-20 c, Kepler-20 d, Kepler-205 c, Kepler-22 b, Kepler-236 c, Kepler-249 d, Kepler-25 b, Kepler-25 c, Kepler-26 c, Kepler-32 d, Kepler-37 d, Kepler-410 A b, Kepler-49 b, Kepler-49 c, Kepler-61 b, Kepler-62 e, Kepler-68 b, Kepler-93 b, Kepler-94 b, WASP-107 b, WASP-121 b, WASP-21 b, WASP-52 b, XO-1 b
    21 x GJ 1214 b, GJ 3470 b
    7 x HAT-P-1 b, HAT-P-32 b, WASP-12 b, WASP-17 b, WASP-43 b, W

In [8]:
def get_count_elements(column: MaskedColumn):
    """Return the most frequent element in the given Astropy MaskedColumn."""
    if isinstance(column, MaskedColumn):
        column = column.filled(np.nan)

    mode, counts = stats.mode(column, nan_policy="omit")
    return mode.flatten(), counts.flatten()

def check_time_window(table: Table, value: float) -> (bool, np.ndarray):
    """Return an array indicating if the float is between two column values for each row."""
    col1, col2 = table[table.colnames[0]], table[table.colnames[1]]

    col1, col2 = Time(np.array(col1), format='isot', scale='utc'), Time(np.array(col2), format='isot', scale='utc')

    if value == np.nan:
        return False, []
    else:
        value = Time(value, format="mjd")
        delta1, delta2 = col1 - value, col2 - value
        return True, np.unique([np.argwhere(delta1.isclose(TimeDelta(0., format="sec"), atol=TimeDelta(600, format="sec"))), np.argwhere(delta2.isclose(TimeDelta(0., format="sec"), atol=TimeDelta(600, format="sec")))])

def check_wavelength_window(table: Table, value: float, scale: float) -> (bool, np.ndarray):
    """Return an array indicating if the float is between two column values for each row."""
    col1, col2 = table[table.colnames[0]], table[table.colnames[1]]

    col1, col2 = np.array(col1, dtype=float), np.array(col2, dtype=float)

    if value == np.nan:
        return False, []
    else:
        delta1, delta2 = np.abs(col1 - value * scale), np.abs(col2 - value * scale)
        return True, np.unique([delta1.argmin(), delta2.argmin()])

def check_central_wavelength(col: MaskedColumn, value: float, scale: float) -> (bool, np.ndarray):
    """Return an array indicating if the float is close to the column values for each row."""

    col = np.array(col, dtype=float)

    if value == np.nan:
        return False, []
    else:
        delta1 = np.abs(col - value * scale)
        return True, np.argwhere(np.isclose(delta1, 0., atol=10., rtol=0.)).flatten()

In [37]:
def assign_spec_element_to_obs(obs: pd.Series, results: pd.DataFrame, planet_name: str, instrument: str) -> pd.Series:
    """
    Assigns the spectral element and aperture to the given observation based on the given results.

    Parameters
    ----------
    obs : pd.Series
        A pandas Series containing the observation data.
    results : pd.DataFrame
        A pandas DataFrame containing the results data.
    planet_name : str
        The name of the planet for the observation.
    instrument : str
        The name of the instrument used for the observation.

    Returns
    -------
    pd.Series
        A pandas Series containing the assigned spectral element and aperture.

    Raises
    ------
    ValueError
        If a ValueError is encountered while processing the time windows.
    """



    # Check if the observation is within the time window criteria
    try:
        time_match, t_window_crit = check_time_window(results["sci_start_time", "sci_stop_time"], obs["plntranmid_mjd_time"])
    except ValueError:
        t_window_crit = np.arange(len(results))

    try:
        assert (np.all(results['sci_aper_1234'] == results['sci_aper_1234'][0])
                and np.all(results['sci_spec_1234'] == results['sci_spec_1234'][0]))
        out =  pd.Series({
                    'spectral_element': results['sci_spec_1234'][0],
                    "aperture": results['sci_aper_1234'][0],
                    "spec_match_frac": "1/1",
                    "aper_match_frac": "1/1",
                })
    except AssertionError:
        try:
            assert (np.all(results['sci_aper_1234'][t_window_crit] == results['sci_aper_1234'][t_window_crit][0])
                    and np.all(results['sci_spec_1234'][t_window_crit] == results['sci_spec_1234'][t_window_crit][0]))

            out =  pd.Series({
                'spectral_element': results['sci_spec_1234'][t_window_crit][0],
                "aperture": results['sci_aper_1234'][t_window_crit][0],
                "spec_match_frac": "1/1",
                "aper_match_frac": "1/1",
            })
        except (IndexError,  AssertionError):
            # Get the most frequent spectral element and aperture
            if len(t_window_crit) > 1:
                spec_mode, spec_count = get_count_elements(results['sci_spec_1234'][t_window_crit])
                aper_mode, aper_count = get_count_elements(results['sci_aper_1234'][t_window_crit])
            else:
                spec_mode, spec_count = get_count_elements(results['sci_spec_1234'])
                aper_mode, aper_count = get_count_elements(results['sci_aper_1234'])

            spec_match_frac = f"{np.sum(spec_count[0])}/{np.sum(spec_count)}"
            aper_match_frac = f"{np.sum(aper_count[0])}/{np.sum(aper_count)}"

            if not (len(spec_mode)==1 and len(aper_mode)==1):
            # Issue a warning about the inconsistency
                warnings.warn(f"{planet_name} - {instrument}:\n"
                              f"Could not find matching spectral element and aperture."
                              f"Spectral_element: {spec_mode[0]} ({spec_count[0] / np.sum(spec_count) * 100:.3f} % match, total: {np.sum(spec_count[0])}/{np.sum(spec_count)})\n"
                              f"Aperture:         {aper_mode[0]} ({aper_count[0] / np.sum(aper_count) * 100:.3f} % match, total: {np.sum(aper_count[0])}/{np.sum(aper_count)})")

            # Return the most frequent spectral element and aperture
            out =  pd.Series({
                'spectral_element': spec_mode[0],
                "aperture": aper_mode[0],
                "spec_match_frac": spec_match_frac,
                "aper_match_frac": aper_match_frac,
            })
    finally:
        if not 'out' in locals():
            warnings.warn(f"{planet_name} - {instrument} -  no valid match found in results after criteria relaxation.")
            out = pd.Series({
                'spectral_element': "NO VALID MATCH FAILURE",
                "aperture": "NO VALID MATCH FAILURE",
                "spec_match_frac": "0/0",
                "aper_match_frac": "0/0",
            })


    return out

def get_mast_obs_data(spec: Dict[str, Any], pbar: Any=None, relax: (List[str], None)=None) -> Dict[str, Any]:
    """
    Retrieves observation data from the MAST database for the given input specification and adds the data to the specification.

    Parameters
    ----------
    spec : Dict[str, Any]
        A dictionary containing the input specification including planet name, facility, instrument, data, source, author, and year.

    Returns
    -------
    Dict[str, Any]
        The input specification with the retrieved observation data added.

    Raises
    ------
    IndexError
        If no matches are found in the MAST database for the given target and instrument.
    """

    if relax is None:
        relax = []

    # Extract information from the input specification
    planet_name = spec["plntname"]
    facility = spec["facility"]
    instrument = spec['instrument']
    data = spec["data"]

    # TODO: remove: - Remove unnecessary keys from the data
    drop_keys = ["spectral_element", "aperture"]
    for key in drop_keys:
        try:
            spec["data"].drop(key, axis="columns", inplace=True)
        except KeyError:
            pass

    # Create target name and instrument aliases
    target_name = re.sub(r'[a-z]$', '', planet_name).strip()
    sci_targname = target_name + "*," + (target_name.replace(' ', '-'))
    instrument = alias_this(instrument) + "*"

    relaxations = [
        "RELAX_sci_targname",
        "RELAX_target",
        "FAIL",
    ]
    if "FAIL" in relax:
        print(f"FAILED: {target_name} - {sci_targname} - {instrument.lower()} - {relax}")
        warnings.warn (f"No matches in {target_name} - {sci_targname} - {instrument.lower()}")
        spec["data"]["spectral_element"] = "NO MATCH FAILURE"
        spec["data"]["aperture"] = "NO MATCH FAILURE"
        return spec
    if "RELAX_sci_targname" in relax:
        sci_targname = "*"
    if "RELAX_target" in relax:
        # TODO: not implemented yet
        target_name = target_name

    # Query the MAST database for observation data
    try:
        results = hst_mission.query_criteria(
            target=target_name,
            sci_targname=sci_targname,
            select_cols=[
                'sci_targname',
                'sci_instrume', 'sci_instrument_config', 'sci_aper_1234', 'sci_spec_1234',
                'sci_central_wavelength', 'sci_spectrum_end', 'sci_spectrum_start', 'sci_bandwidth', 'sci_spectral_res',
                'sci_preview_name', 'sci_refnum', 'sci_aec', 'sci_status', 'sci_data_set_name', 'sci_pi_last_name',
                'sci_actual_duration', 'sci_start_time', 'sci_stop_time',
            ],
            sci_instrume=instrument.lower(),
            sci_obs_type='all',
            sci_aec='S',
        )

        # Handle cases where no matches are found
        if len(results) == 0:
            if isinstance(pbar, tqdm):
                pbar.set_postfix_str(f"Relaxing search for {target_name} - {sci_targname} - {instrument}")
            relax = append_next_string(relax, relaxations)
            spec = get_mast_obs_data(spec, relax=relax)
        else:
            # Assign spectral element and aperture to the retrieved data
            new = data.apply(assign_spec_element_to_obs, axis="columns", args=(results, planet_name, instrument))
            data = pd.concat([data, new], axis=1)

            spec["data"] = data

    # Handle cases where no matches are found
    except NotImplementedError:
        print(f"FAILED: {target_name} - {sci_targname} - {instrument.lower()} - {relax}")
        warnings.warn(f"No matches in {target_name} - {sci_targname} - {instrument.lower()}")
        spec["data"]["spectral_element"] = "OUTER FAILURE"
        spec["data"]["aperture"] = "OUTER FAILURE"

    return spec

In [38]:
def process_data(fac: List[str], ins: List[str], split_transmission: List[Dict]) -> pd.DataFrame:
    """
    Processes data by filtering the split_transmission list based on the ins list and retrieving
    MAST observation data for each filtered element.

    Args:
        ins (List[str]): A list of instrument names to filter the split_transmission list.
        split_transmission (List[Dict]): A list of dictionaries containing transmission data for each instrument.

    Returns:
        pd.DataFrame: A DataFrame containing the concatenated MAST observation data for the filtered instruments.
    """
    new_split_transmission = []

    # Iterate through the filtered list of split_transmission elements based on the ins list
    spec_iterable = tqdm([spec for spec in split_transmission if (spec["instrument"] in ins and spec["facility"] in fac)])
    for spec in spec_iterable:
        spec_iterable.set_postfix_str("")
        # Append the MAST observation data for the current instrument specification
        new_split_transmission.append(get_mast_obs_data(spec, pbar=spec_iterable))

    # Concatenate the MAST observation data for all filtered instruments
    df_out_with_spectral_elements = pd.concat([spec["data"] for spec in new_split_transmission], axis=0)

    return df_out_with_spectral_elements
 
ins = ["Wide Field Camera 3", "Space Telescope Imaging Spectrograph", "WFC3", "WFC", "STIS"]
fac = ["HST", "Hubble Space Telescope", "Hubble Space Telescope satellite"]

output_df = run_function_with_warnings_filtered(process_data, fac, ins, split_transmission)
output_df.to_csv(WDIR / "data/transitspec_with_spectral_elements.csv")


100%|██████████| 65/65 [00:50<00:00,  1.29it/s]                                                                 




