# quality_assurance

> In progress, including development of diagnostics to be moved elsewhere when complete.

In [None]:
# |default_exp quality_assurance

In [None]:
# | hide
from nbdev.showdoc import show_doc

In [None]:
# |export

import time
from abc import ABC, abstractmethod

import dask
import numpy as np
import pandas as pd
import xarray as xr

from qagmire.data import (
    get_lr_l1_single_files,
    get_lr_l2_stack_files,
    read_class_spec,
    read_class_table,
    read_galaxy_table,
    read_primary_header,
)
from qagmire.utilities import parse_line_names, parse_obstemp

To write checks of the data, we first create a subclass of `Diagnostics`.

In [None]:
# |export


class Diagnostics(ABC):
    """An abstract class to be subclassed to perform specific diagnostic checks.

    A subclass should perform a set of checks, implemented in a method named `tests`.

    Calling the method `run` will combine and compute the tests, returning the results
    as a single boolean `DataArray` for further analysis.
    """

    def run(self, **kwargs) -> xr.DataArray:
        """Compute the results of the tests.

        The `kwargs` are passed to `qagmire.data.read_*` functions to obtain the data
        for the tests.
        """
        tests = self.tests(**kwargs)
        test_names = [t["name"] for t in tests]
        test_desc = [t["description"] for t in tests]
        self.test_descriptions = dict(zip(test_names, test_desc))
        test_array = [t["test"] for t in tests]
        detail = xr.concat(test_array, pd.Index(test_names, name="test"))
        start = time.perf_counter()
        detail = dask.compute(detail)[0]
        dt = time.perf_counter() - start
        print(f"Tests took {dt:.2f} s to perform.")
        return detail

    @abstractmethod
    def tests(self, **kwargs):
        """Return the tests to be performed.

        Implementations of this method must pass `kwargs` to `qagmire.data.read_*` functions
        as necessary to obtain the data for the tests.

        This method must returns a list of dictionaries with the structure:
        ```
        [
            {
                "name": "a_short_name",
                "description": "The question that the test answers",
                "test": test_dataset,
            },
            ...
        ]
        ```
        where each `test_dataset` should be a boolean `xr.DataArray` of the same shape, giving
        the results of running the test on the data defined by `kwargs`.
        """
        return [
            {
                "name": "a_short_name",
                "description": "The question that the test answers",
                "test": None,
            },
        ]

    @staticmethod
    def summary(
        detail: xr.DataArray,  # the detailed test results
        by: str
        | None = None,  # optionally sum over all element dimensions except for this one
        show_passed_tests=False,  # if `True`, then passed tests are included
        show_passed_elements=False,  # if `True`, then passed elements are included
        sort_by_total_fails=True,  # if `False`, then keep in original order
        show_failure_count=True,  # if `False`, then omit the count of failures per row
        show_only_failure_count=False,  # if `True`, then only show the count of failures
        per_test=False,  # if `True`, then transpose output, such that each row is a test
        top: int | None = None,  # optionally limit to at most `top` elements
    ) -> pd.DataFrame:
        """Return a summary of the test failures in `detail`."""
        if by is not None:
            detail = detail.sum(dim=[d for d in detail.dims if d not in ("test", by)])
        df = detail.to_dataframe(name="failed")
        df = df.unstack() if per_test else df.unstack("test")
        if (not show_passed_tests and not per_test) or (
            not show_passed_elements and per_test
        ):
            df = df.loc[:, df.any(axis="rows")]
        if (not show_passed_elements and not per_test) or (
            not show_passed_tests and per_test
        ):
            df = df.loc[df.loc[:, "failed"].any(axis="columns")]
        df.loc[:, "total fails"] = df.sum(axis="columns")
        if sort_by_total_fails:
            df = df.sort_values("total fails", ascending=False)
        if not (show_failure_count or show_only_failure_count):
            df = df.drop(columns="total fails")
        if show_only_failure_count:
            df = df.drop(columns="failed")
        if top is not None:
            df = df.iloc[:top]
        return df

    @classmethod
    def summary_per_test(
        cls,
        detail: xr.DataArray,  # the detailed test results
        by: str
        | None = None,  # optionally sum over all element dimensions except for this one
    ) -> pd.DataFrame:
        """Return a per-test summary of the test outcomes in `detail`."""
        return cls.summary(
            detail,
            by=by,
            per_test=True,
            show_passed_tests=True,
            show_only_failure_count=True,
        )

    @classmethod
    def full_summary(
        cls,
        detail: xr.DataArray,  # the detailed test results
        by: str
        | None = None,  # optionally sum over all element dimensions except for this one
    ) -> pd.DataFrame:
        """Return a full summary of the test outcomes in `detail`."""
        return cls.summary(
            detail,
            by=by,
            show_passed_tests=True,
            show_passed_elements=True,
            sort_by_total_fails=False,
            show_failure_count=False,
            top=None,
        )

In this subclass we need to implement the `tests` method.

In [None]:
# |hide
show_doc(Diagnostics.tests)

---

[source](https://github.com/bamford/qagmire/blob/main/qagmire/quality_assurance.py#L54){target="_blank" style="float:right; font-size:smaller"}

### Diagnostics.tests

>      Diagnostics.tests (**kwargs)

Return the tests to be performed.

Implementations of this method must pass `kwargs` to `qagmire.data.read_*` functions
as necessary to obtain the data for the tests.

This method must returns a list of dictionaries with the structure:
```
[
    {
        "name": "a_short_name",
        "description": "The question that the test answers",
        "test": test_dataset,
    },
    ...
]
```
where each `test_dataset` should be a boolean `xr.DataArray` of the same shape, giving
the results of running the test on the data defined by `kwargs`.

These tests are executed by calling the `run` method.

In [None]:
# |hide
show_doc(Diagnostics.run)

---

[source](https://github.com/bamford/qagmire/blob/main/qagmire/quality_assurance.py#L35){target="_blank" style="float:right; font-size:smaller"}

### Diagnostics.run

>      Diagnostics.run (**kwargs)

Compute the results of the tests.

The `kwargs` are passed to `qagmire.data.read_*` functions to obtain the data
for the tests.

The `summary` method outputs a pandas DataFrame summary of the test outcomes, by default this shows only failed tests and elements (e.g. OBs or exposures) with the most failures.

In [None]:
# |hide
show_doc(Diagnostics.summary)

---

[source](https://github.com/bamford/qagmire/blob/main/qagmire/quality_assurance.py#L83){target="_blank" style="float:right; font-size:smaller"}

### Diagnostics.summary

>      Diagnostics.summary (detail:xarray.core.dataarray.DataArray,
>                           by:str|None=None, show_passed_tests=False,
>                           show_passed_elements=False,
>                           sort_by_total_fails=True, show_failure_count=True,
>                           show_only_failure_count=False, per_test=False,
>                           top:int|None=None)

Return a summary of the test failures in `detail`.

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| detail | DataArray |  | the detailed test results |
| by | str \| None | None |  |
| show_passed_tests | bool | False | if `True`, then passed tests are included |
| show_passed_elements | bool | False | if `True`, then passed elements are included |
| sort_by_total_fails | bool | True | if `False`, then keep in original order |
| show_failure_count | bool | True | if `False`, then omit the count of failures per row |
| show_only_failure_count | bool | False | if `True`, then only show the count of failures |
| per_test | bool | False | if `True`, then transpose output, such that each row is a test |
| top | int \| None | None | optionally limit to at most `top` elements |
| **Returns** | **DataFrame** |  |  |

See the [diagnostics](diagnostics) submodule for example tests.

## To move...

To speed up running calculations on large datasets, we can run a set of workers on a single node. There are also ways to easily [leaverage multiple nodes](https://docs.dask.org/en/stable/deploying.html).

In [None]:
# from dask.distributed import Client
# client = Client(n_workers=8, threads_per_worker=1)

In [None]:
# |export


class ObsCondCheck(Diagnostics):
    """Observing conditions check.

    A reproduction of the weaveio [obs_cond_check](https://github.com/bamford/QAG/blob/master/diagnostics/obs_cond_checks.py).

    This tests for the following cases:

    * Is the sky brighter than the requirement?
    * Is the seeing worse than the requirement?

    and also some supplementary tests:

    * Are there the other than two runs with the same MJD?
    * Do runs with the same MJD have different sky brightness?
    * Do runs with the same MJD have different seeing?
    """

    def __init__(
        self,
        sky_tolerance: float = 0.0,  # the tolerance in the sky brightness in magnitudes
        seeing_tolerance: float = 0.0,  # the tolerance in the seeing in arcsec
        by_exposure=False,  # should the checks be performed per exposure, or per OB (the default)
    ):
        self.sky_tolerance = sky_tolerance
        self.seeing_tolerance = seeing_tolerance
        if by_exposure:
            self._get_and_check = self._get_and_check_by_exp
        else:
            self._get_and_check = self._get_and_check_by_ob

    @staticmethod
    def _restore_coords(coords, da):
        return [d.assign_coords(coords) for d in da]

    @classmethod
    def _get_and_check_by_exp(cls, col):
        coords = (
            col.swap_dims(filename="MJD")
            .coords.to_dataset()
            .reset_coords()
            .groupby("MJD")
            .first()
        )
        by_exp = col.groupby("MJD")
        count, first, last = cls._restore_coords(
            coords, (by_exp.count(), by_exp.first(), by_exp.last())
        )
        expected_runs = count == 2
        runs_match = first == last
        return first, expected_runs, runs_match

    @staticmethod
    def _get_and_check_by_ob(col):
        by_ob = col.groupby("OBID")
        count, first = (by_ob.count(), by_ob.first())
        expected_runs = count == 6
        runs_match = (first != col).any(axis=-1)
        return first, expected_runs, runs_match

    def tests(
        self,
        **kwargs,
    ):
        files = get_lr_l1_single_files(**kwargs)
        hdr = read_primary_header(files)

        obstemp, two_runs, obstemp_runs_match = self._get_and_check(hdr["OBSTEMP"])
        obs = parse_obstemp(obstemp)

        sky, _, sky_runs_match = self._get_and_check(hdr["SKYBRTEL"])
        sky_fail = sky < obs["sky_brightness"] - self.sky_tolerance
        seeing, _, seeing_runs_match = self._get_and_check(hdr["SEEINGB"])
        seeing_fail = seeing > obs["seeing"] + self.seeing_tolerance

        tests = [
            {
                "name": "sky_too_bright",
                "description": "Is the sky brighter than the requirement?",
                "test": ~sky_fail,
            },
            {
                "name": "seeing_too_poor",
                "description": "Is the seeing worse than the requirement?",
                "test": ~seeing_fail,
            },
            {
                "name": "wrong_run_count",
                "description": "Are there the other than six runs in each OB?",
                "test": ~two_runs,
            },
            {
                "name": "unmatched_runs_sky",
                "description": "Do runs in the same OB have different sky brightness?",
                "test": ~sky_runs_match,
            },
            {
                "name": "unmatched_runs_seeing",
                "description": "Do runs in the same OB have different seeing?",
                "test": ~seeing_runs_match,
            },
        ]
        return tests

In [None]:
detail = ObsCondCheck().run(date="201*")

Reading files: 100%|██████████| 126/126 [00:07<00:00, 16.57it/s]
Creating Dataset... took 2.31 s. Size is 0.799 Mb
Tests took 0.00 s to perform.


In [None]:
ObsCondCheck.summary_per_test(detail)

Unnamed: 0_level_0,total fails
OBID,Unnamed: 1_level_1
test,Unnamed: 1_level_2
seeing_too_poor,21
sky_too_bright,19
wrong_run_count,0
unmatched_runs_sky,0
unmatched_runs_seeing,0


In [None]:
ObsCondCheck.summary(detail)

Unnamed: 0_level_0,failed,failed,total fails
test,sky_too_bright,seeing_too_poor,Unnamed: 3_level_1
OBID,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2
3133,True,True,2
3170,True,True,2
3175,True,True,2
3189,True,True,2
3191,True,True,2
3295,True,True,2
3346,True,True,2
3372,True,True,2
3380,True,True,2
3756,True,True,2


In [None]:
ObsCondCheck.full_summary(detail)

Unnamed: 0_level_0,failed,failed,failed,failed,failed
test,sky_too_bright,seeing_too_poor,wrong_run_count,unmatched_runs_sky,unmatched_runs_seeing
OBID,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
3133,True,True,False,False,False
3170,True,True,False,False,False
3175,True,True,False,False,False
3189,True,True,False,False,False
3191,True,True,False,False,False
3217,False,True,False,False,False
3295,True,True,False,False,False
3346,True,True,False,False,False
3372,True,True,False,False,False
3380,True,True,False,False,False


In [None]:
# |export


class LineFluxCheck(Diagnostics):
    """A reproduction of the weaveio [line_flux_check](https://github.com/bamford/QAG/blob/master/diagnostics/line_flux_check.py).

    This tests for the following cases:

    * Do non-null line fluxes appear in completely null spectra?
    * Do non-null line fluxes appear in the blue chip gap?
    * Do non-null line fluxes appear in the red chip gap?
    * Do non-null line fluxes appear outside the observed wavelength range?
    * Do null line fluxes appear in an observed wavelength range?
    """

    @staticmethod
    def _line_wavelengths(
        galaxy_table: xr.Dataset,  # provides the wavelengths of all lines in the data
        class_table: xr.Dataset,  # provides the redshift of each spectrum
    ) -> xr.Dataset:  # the observed wavelength of every potential line
        """Determine the expected observed wavelengths of all potential lines."""
        line_species, line_rest_wl = parse_line_names(galaxy_table["LINE"])
        line_wl = (1 + class_table["Z"]) * line_rest_wl
        return line_wl

    @staticmethod
    def _wavelength_boundaries(
        class_spec: xr.Dataset,  # provides the rebinned spectra to check
    ) -> tuple[dict, dict]:  # the determined boundaries
        """Determine wavelength boundaries and wavelength gaps of blue and red spectra.

        Where a spectrum is entirely null, the returned gaps and boundaries will also be null.

        Returns two dictionaries, `boundaries` and `gaps`, each containing `low` and `high` entries,
        which are Datasets giving the low and high boundaries and gap edges determined for each spectrum.
        """
        gaps = {}
        boundaries = {}
        for band, low, high in (("B", 4000, 6000), ("R", 6000, 9000)):
            wl_dim = f"LAMBDA_{band}"
            wl = class_spec[wl_dim]
            null_flux = class_spec[f"FLUX_RR_{band}"].isnull()
            wl_null = wl.where(null_flux & (wl > low) & (wl < high))
            wl_not_null = wl.where(~null_flux)
            with np.errstate(invalid="ignore"):
                gaps[band] = {
                    "low": wl_null.min(dim=wl_dim),
                    "high": wl_null.max(dim=wl_dim),
                }
                boundaries[band] = {
                    "low": wl_not_null.min(dim=wl_dim),
                    "high": wl_not_null.max(dim=wl_dim),
                }
        return boundaries, gaps

    def tests(self, **kwargs):
        lr_l2_stack_files = get_lr_l2_stack_files(**kwargs)

        class_spec = read_class_spec(lr_l2_stack_files)
        galaxy_table = read_galaxy_table(lr_l2_stack_files)
        class_table = read_class_table(lr_l2_stack_files)

        line_wl = self._line_wavelengths(galaxy_table, class_table)
        boundaries, gaps = self._wavelength_boundaries(class_spec)

        measured_line_flux = galaxy_table["LINES"].sel(QTY="FLUX", drop=True)
        null_flux = measured_line_flux.isnull()

        is_in_red_gap = (line_wl > gaps["R"]["low"]) & (line_wl < gaps["R"]["high"])
        is_in_blue_gap = (line_wl > gaps["B"]["low"]) & (line_wl < gaps["B"]["high"])

        # ignore gaps in completely null spectra
        is_in_red_gap = is_in_red_gap.fillna(False)
        is_in_blue_gap = is_in_blue_gap.fillna(False)

        is_in_gap = is_in_blue_gap | is_in_red_gap

        is_off_spectrum = (
            (line_wl < boundaries["B"]["low"]) | (line_wl > boundaries["B"]["high"])
        ) & ((line_wl < boundaries["R"]["low"]) | (line_wl > boundaries["R"]["high"]))

        is_on_spectrum = ~is_in_gap & ~is_off_spectrum

        # ignore whether on/off spectrum for completely null spectra
        is_off_spectrum = is_off_spectrum.fillna(False)
        is_on_spectrum = is_in_blue_gap.fillna(False)

        null_spectrum = (
            boundaries["B"]["low"].isnull() | boundaries["R"]["low"].isnull()
        )

        tests = [
            {
                "name": "line_in_null_spectrum",
                "description": "Do non-null line fluxes appear in completely null spectra?",
                "test": ~null_flux & null_spectrum,
            },
            {
                "name": "line_in_blue_chip_gap",
                "description": "Do non-null line fluxes appear in the blue chip gap?",
                "test": ~null_flux & is_in_blue_gap,
            },
            {
                "name": "line_in_red_chip_gap",
                "description": "Do non-null line fluxes appear in the red chip gap?",
                "test": ~null_flux & is_in_red_gap,
            },
            {
                "name": "line_off_spectrum",
                "description": "Do non-null line fluxes appear outside the observed wavelength range?",
                "test": ~null_flux & is_off_spectrum,
            },
            {
                "name": "null_line_on_spectrum",
                "description": "Do null line fluxes appear in an observed wavelength range?",
                "test": null_flux & is_on_spectrum,
            },
        ]
        return tests

In [None]:
detail = LineFluxCheck().run(date="201*")

Locating and converting where necessary: 100%|██████████| 17/17 [00:00<00:00, 10015.90it/s]
Reading netCDF files... took 1.55 s. Size is 4851.652 Mb
Locating and converting where necessary: 100%|██████████| 17/17 [00:00<00:00, 4585.71it/s]
Reading netCDF files... took 2.91 s. Size is 77.962 Mb
Locating and converting where necessary: 100%|██████████| 17/17 [00:00<00:00, 5842.13it/s]
Reading netCDF files... took 3.69 s. Size is 509.241 Mb
Tests took 6.52 s to perform.


In [None]:
detail = detail.swap_dims({"filename": "OBID"}).drop_vars("filename")

In [None]:
LineFluxCheck.summary(detail, by="OBID", top=None, show_passed_tests=True)

Unnamed: 0_level_0,failed,failed,failed,failed,failed,total fails
test,line_in_null_spectrum,line_in_blue_chip_gap,line_in_red_chip_gap,line_off_spectrum,null_line_on_spectrum,Unnamed: 6_level_1
OBID,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
3900,0,162,226,2646,2,3036
3372,0,144,178,2695,4,3021
3756,0,161,194,2651,1,3007
3653,0,165,196,2558,1,2920
3295,0,141,200,2541,0,2882
3803,0,164,195,2519,3,2881
3806,0,165,202,2497,4,2868
3802,0,158,185,2508,5,2856
3217,0,37,87,1781,1,1906
3346,0,22,94,1684,1,1801


In [None]:
LineFluxCheck.summary(detail, by="LINE")

Unnamed: 0_level_0,failed,failed,failed,failed,total fails
test,line_in_blue_chip_gap,line_in_red_chip_gap,line_off_spectrum,null_line_on_spectrum,Unnamed: 5_level_1
LINE,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
[ArIII]_7135.67,0,55,4343,0,4398
[SII2]_6730.68,0,53,3583,0,3636
[SII]_6716.31,0,51,3576,0,3627
[NII]_6583.34,0,47,3361,0,3408
Ha_6562.80,0,52,3313,0,3365
[OI]_6300.20,0,73,2776,0,2849
HeI_5875.60,1,136,1837,1,1975
HeII_3203.15,160,0,1159,2,1321
[NeV]_3345.81,135,0,969,0,1104
[NeV]_3425.81,115,0,871,1,987


In [None]:
LineFluxCheck.summary(detail, by="APS_ID", top=10)

Unnamed: 0_level_0,failed,failed,failed,failed,total fails
test,line_in_blue_chip_gap,line_in_red_chip_gap,line_off_spectrum,null_line_on_spectrum,Unnamed: 5_level_1
APS_ID,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
456,0,5,79,0,84
766,1,7,74,0,82
728,3,3,74,0,80
989,2,8,69,0,79
746,1,1,72,0,74
615,5,15,53,0,73
40,1,5,67,0,73
62,2,6,64,0,72
308,4,3,65,0,72
273,2,3,67,0,72


In [None]:
# |hide
import nbdev

nbdev.nbdev_export()