In [None]:
#| default_exp callbacks

# Callbacks

> Callback used in handlers

In [None]:
#| export
import copy
import fastcore.all as fc
from operator import attrgetter
from cftime import date2num
import numpy as np
import pandas as pd
from functools import partial 
from pathlib import Path 
from typing import List, Dict, Callable, Tuple, Any, Optional, Union

from marisco.configs import (
    get_lut, 
    nuc_lut_path, 
    nc_tpl_path,
    get_time_units,
    NC_GROUPS,
    SMP_TYPE_LUT,
    cfg, 
    # cdl_cfg
)

from marisco.utils import Match

In [None]:
#| hide
# from marisco.configs import cdl_cfg, CONFIGS_CDL
from marisco.utils import test_dfs

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Core

The `Transformer` class is designed to facilitate the application of a series of callbacks to a set of dataframes. It provides a structured way to apply transformations (i.e `Callback`) to the data, with a focus on flexibility and ease of use.

In [None]:
#| exports
class Callback(): 
    "Base class for callbacks."
    order = 0

In [None]:
#| exports
def run_cbs(
    cbs: List[Callback], # List of callbacks to run
    obj: Any # Object to pass to the callbacks
    ):
    "Run the callbacks in the order they are specified."
    for cb in sorted(cbs, key=attrgetter('order')):
        if cb.__doc__: obj.logs.append(cb.__doc__)
        cb(obj)

In [None]:
#| exports
class Transformer():
    "Transform the dataframe(s) according to the specified callbacks."
    def __init__(self, 
                 data: Union[Dict[str, pd.DataFrame], pd.DataFrame], # Data to be transformed
                 cbs: Optional[List[Callback]]=None, # List of callbacks to run
                 inplace: bool=False # Whether to modify the dataframe(s) in place
                 ): 
        fc.store_attr()
        self.is_single_df = isinstance(data, pd.DataFrame)
        self.df, self.dfs = self._prepare_data(data, inplace)
        self.logs = []
            
    def _prepare_data(self, data, inplace):
        if self.is_single_df:
            return (data if inplace else data.copy()), None
        else:
            return None, (data if inplace else {k: v.copy() for k, v in data.items()})
    
    def unique(self, col_name: str) -> np.ndarray:
        "Distinct values of a specific column present in all groups."
        if self.is_single_df:
            values = self.df.get(col_name, pd.Series()).dropna().values
        else:
            columns = [df.get(col_name) for df in self.dfs.values() if df.get(col_name) is not None]
            values = np.concatenate([col.dropna().values for col in columns]) if columns else []
        return np.unique(values)
        
    def __call__(self):
        "Transform the dataframe(s) according to the specified callbacks."
        if self.cbs: run_cbs(self.cbs, self)
        return self.df if self.dfs is None else self.dfs

Below, a few examples of how to use the `Transformer` class.
Let's define first a test callback that adds `1` to the `depth`:

In [None]:
class TestCB(Callback):
    "A test callback to add 1 to the depth."
    def __call__(self, tfm: Transformer):
        for grp, df in tfm.dfs.items(): 
            df['depth'] = df['depth'].apply(lambda x: x+1)

And apply it to the following dataframes:

In [None]:
dfs = {'biota': pd.DataFrame({'id': [0, 1, 2], 'species': [0, 2, 0], 'depth': [2, 3, 4]}),
       'seawater': pd.DataFrame({'id': [0, 1, 2], 'depth': [3, 4, 5]})}

tfm = Transformer(dfs, cbs=[TestCB()])
dfs_test = tfm()

fc.test_eq(dfs_test['biota']['depth'].to_list(), [3, 4, 5])
fc.test_eq(dfs_test['seawater']['depth'].to_list(), [4, 5, 6])

In [None]:
class TestCB(Callback):
    "A test callback to add 1 to the depth."
    def __call__(self, tfm: Transformer):
        tfm.df['depth'] = tfm.df['depth'].apply(lambda x: x+1)

In [None]:
df = pd.DataFrame({'id': [0, 1, 2], 'species': [0, 2, 0], 'depth': [2, 3, 4]})

tfm = Transformer(df, cbs=[TestCB()])
df_test = tfm()

fc.test_eq(df_test['depth'].to_list(), [3, 4, 5])

## Geographical

This section gathers callbacks that are used to transform the geographical coordinates.

In [None]:
#| exports
class SanitizeLonLatCB(Callback):
    "Drop rows with invalid longitude & latitude values. Convert `,` separator to `.` separator."
    def __init__(self, 
                 lon_col: str='LON', # Longitude column name
                 lat_col: str='LAT', # Latitude column name
                 verbose: bool=False # Whether to print the number of invalid longitude & latitude values
                 ):
        fc.store_attr()
        
    def __call__(self, tfm: Transformer):
        for grp, df in tfm.dfs.items():
            # Convert `,` separator to `.` separator
            df[self.lon_col] = df[self.lon_col].apply(lambda x: float(str(x).replace(',', '.')))
            df[self.lat_col] = df[self.lat_col].apply(lambda x: float(str(x).replace(',', '.')))
            
            # Mask zero values
            mask_zeroes = (df[self.lon_col] == 0) & (df[self.lat_col] == 0) 
            nZeroes = mask_zeroes.sum()
            if nZeroes and self.verbose: 
                print(f'The "{grp}" group contains {nZeroes} data points whose ({self.lon_col}, {self.lat_col}) = (0, 0)')
            
            # Mask out of bounds values
            mask_goob = (df[self.lon_col] < -180) | (df[self.lon_col] > 180) | (df[self.lat_col] < -90) | (df[self.lat_col] > 90)
            nGoob = mask_goob.sum()
            if nGoob and self.verbose: 
                print(f'The "{grp}" group contains {nGoob} data points with unrealistic {self.lon_col} or {self.lat_col} values.')
                
            tfm.dfs[grp] = df.loc[~(mask_zeroes | mask_goob)]

In [None]:
# Check that measurements located at (0,0) get removed
dfs = {'BIOTA': pd.DataFrame({'LON': [0, 1, 0], 'LAT': [0, 2, 0]})}
tfm = Transformer(dfs, cbs=[SanitizeLonLatCB()])
tfm()['BIOTA']

expected = [1., 2.]
fc.test_eq(tfm()['BIOTA'].iloc[0].to_list(), expected)

In [None]:
# Check that comma decimal separator get replaced by point instead
dfs = {'BIOTA': pd.DataFrame({'LON': ['45,2'], 'LAT': ['43,1']})}
tfm = Transformer(dfs, cbs=[SanitizeLonLatCB()])
tfm()['BIOTA']

expected = [45.2, 43.1]
fc.test_eq(tfm()['BIOTA'].iloc[0].to_list(), expected)

In [None]:
# Check that out of bounds lon or lat get removed
dfs = {'BIOTA': pd.DataFrame({'LON': [-190, 190, 1, 2, 1.1], 'LAT': [1, 2, 91, -91, 2.2]})}
tfm = Transformer(dfs, cbs=[SanitizeLonLatCB()])
tfm()['BIOTA']

expected = [1.1, 2.2]
fc.test_eq(tfm()['BIOTA'].iloc[0].to_list(), expected)

## Map & Standardize

In [None]:
#| exports
class RemapCB(Callback):
    "Generic MARIS remapping callback."
    def __init__(self, 
                 fn_lut: Callable, # Function that returns the lookup table dictionary
                 col_remap: str, # Name of the column to remap
                 col_src: str, # Name of the column with the source values
                 dest_grps: list[str]|str=NC_GROUPS.keys(), # List of destination groups
                 default_value: Any = -1 # Default value for unmatched entries
                ):
        fc.store_attr()
        self.lut = None
        
        if isinstance(dest_grps, str):
            self.dest_grps = [dest_grps]
        # Format the documentation string based on the type and content of dest_grps
        if isinstance(self.dest_grps, list):
            if len(self.dest_grps) > 1:
                grp_str = ', '.join(self.dest_grps[:-1]) + ' and ' + self.dest_grps[-1]
            else:
                grp_str = self.dest_grps[0]
        else:
            grp_str = self.dest_grps
                
        self.__doc__ = f"Remap values from '{col_src}' to '{col_remap}' for groups: {grp_str}."

    def __call__(self, tfm):
        self.lut = self.fn_lut()
        for grp in self.dest_grps:
            if grp in tfm.dfs:
                self._remap_group(tfm.dfs[grp])
            else:
                print(f"Group {grp} not found in the dataframes.")

    def _remap_group(self, df: pd.DataFrame):
        df[self.col_remap] = df[self.col_src].apply(self._remap_value)

    def _remap_value(self, value: str) -> Any:
        value = value.strip() if isinstance(value, str) else value
        match = self.lut.get(value, Match(self.default_value, None, None, None))
        if isinstance(match, Match):
            if match.matched_id == self.default_value:
                print(f"Unmatched value: {value}")
            return match.matched_id 
        else:
            return match

In [None]:
#| exports
class LowerStripNameCB(Callback):
    "Convert values to lowercase and strip any trailing spaces."
    def __init__(self, 
                 col_src: str, # Source column name e.g. 'Nuclide'
                 col_dst: str=None, # Destination column name
                 fn_transform: Callable=lambda x: x.lower().strip() # Transformation function
                 ):
        fc.store_attr()
        self.__doc__ = f"Convert '{col_src}' column values to lowercase, strip spaces, and store in '{col_dst}' column."
        if not col_dst: self.col_dst = col_src
        
    def _safe_transform(self, value):
        "Ensure value is not NA and apply transformation function."
        return value if pd.isna(value) else self.fn_transform(str(value))
            
    def __call__(self, tfm):
        for key in tfm.dfs.keys():
            tfm.dfs[key][self.col_dst] = tfm.dfs[key][self.col_src].apply(self._safe_transform)

Let's test the callback:

In [None]:
dfs = {'seawater': pd.DataFrame({'Nuclide': ['CS137', '226RA']})}

tfm = Transformer(dfs, cbs=[LowerStripNameCB(col_src='Nuclide', col_dst='NUCLIDE')])
fc.test_eq(tfm()['seawater']['NUCLIDE'].to_list(), ['cs137', '226ra'])


tfm = Transformer(dfs, cbs=[LowerStripNameCB(col_src='Nuclide')])
fc.test_eq(tfm()['seawater']['Nuclide'].to_list(), ['cs137', '226ra'])

The point is when (semi-automatic) remapping names generally:

1. we need first to guess (fuzzy matching or other) the right nuclide name.
2. Then manually check the result and eventually update the lookup table.
3.  Finally we can apply the lookup table to the dataframe.


## Change structure

In [None]:
#| exports
class AddSampleTypeIdColumnCB(Callback):
    def __init__(self, 
                 lut: dict=SMP_TYPE_LUT, # Lookup table for sample type
                 col_name: str='samptype_id' # Column name to store the sample type id
                 ): 
        "Add a column with the sample type id as defined in the CDL."
        fc.store_attr()
        
    def __call__(self, tfm):
        for grp, df in tfm.dfs.items():             
            df[self.col_name] = self.lut[grp]

Let's test the callback:

In [None]:
dfs = {smp_type: pd.DataFrame({'col_test': [0, 1, 2]}) for smp_type in SMP_TYPE_LUT.keys()};

tfm = Transformer(dfs, cbs=[AddSampleTypeIdColumnCB()])
dfs_test = tfm()

for smp_type in SMP_TYPE_LUT.keys():
    fc.test_eq(dfs_test[smp_type]['samptype_id'].unique().item(), SMP_TYPE_LUT[smp_type]) 

In [None]:
#| exports
class AddNuclideIdColumnCB(Callback):
    def __init__(self, 
                 col_value: str, # Column name containing the nuclide name
                 lut_fname_fn: Callable=nuc_lut_path, # Function returning the lut path
                 col_name: str='nuclide_id' # Column name to store the nuclide id
                 ): 
        "Add a column with the nuclide id."
        fc.store_attr()
        self.lut = get_lut(lut_fname_fn().parent, lut_fname_fn().name, 
                           key='nc_name', value='nuclide_id', reverse=False)
        
    def __call__(self, tfm: Transformer):
        for grp, df in tfm.dfs.items(): 
            df[self.col_name] = df[self.col_value].map(self.lut)

In [None]:
dfs = {smp_type: pd.DataFrame({'Nuclide': ['cs137', 'pu239_240_tot']}) for smp_type in SMP_TYPE_LUT.keys()};

lut_fname_fn = lambda: Path('./files/lut/dbo_nuclide.xlsx')

tfm = Transformer(dfs, cbs=[AddNuclideIdColumnCB(col_value='Nuclide', lut_fname_fn=lut_fname_fn)])
tfm()['SEAWATER']

expected = [33, 77]
for grp in tfm.dfs.keys():
    fc.test_eq(tfm.dfs[grp]['nuclide_id'].to_list(), expected)

In [None]:
#| exports
class SelectColumnsCB(Callback):
    "Select columns of interest."
    def __init__(self, 
                 cois: dict # Columns of interest
                 ): 
        fc.store_attr()
        
    def __call__(self, tfm):
        "Select columns of interest."
        for grp, df in tfm.dfs.items(): 
            tfm.dfs[grp] = df.loc[:, self.cois.keys()]

In [None]:
#| exports
class RenameColumnsCB(Callback):
    "Renaming variables to MARIS standard names."
    def __init__(self,
                 renaming_rules: dict # Renaming rules
                 ): 
        fc.store_attr()
        
    def __call__(self, tfm):
        for grp in tfm.dfs.keys(): 
            tfm.dfs[grp].rename(columns=self.renaming_rules, inplace=True)

In [None]:
#| exports
class RemoveAllNAValuesCB(Callback):
    "Remove rows with all NA values."
    def __init__(self, 
                 cols_to_check: Dict[str, str] # A dictionary with the sample type as key and the column name to check as value
                ):
        fc.store_attr()

    def __call__(self, tfm):
        for k in tfm.dfs.keys():
            col_to_check = self.cols_to_check[k]
            mask = tfm.dfs[k][col_to_check].isnull().all(axis=1)
            tfm.dfs[k] = tfm.dfs[k][~mask]

[TO BE REMOVED] Many data providers use a long format (e.g `lat, lon, radionuclide, value, unc, ...`) to store their data. When encoding as `netCDF`, it is often required to use a wide format (e.g `lat, lon, nuclide1_value, nuclide1_unc, nuclide2_value, nuclide2_unc, ...`). The class `ReshapeLongToWide` is designed to perform this transformation.

In [None]:
# class ReshapeLongToWide(Callback):
#     "Convert data from long to wide with renamed columns."
#     def __init__(self, 
#                  columns: List[str]=['nuclide'], # Columns to use as index
#                  values: List[str]=['value'], # Columns to use as values
#                  num_fill_value: int=-999, # Fill value for numeric columns
#                  str_fill_value='STR FILL VALUE'
#                  ):
#         fc.store_attr()
#         self.derived_cols = self._get_derived_cols()
    
#     def _get_derived_cols(self):
#         "Retrieve all possible derived vars (e.g 'unc', 'dl', ...) from configs."
#         return [value['name'] for value in cdl_cfg()['vars']['suffixes'].values()]

#     def renamed_cols(self, cols):
#         "Flatten columns name."
#         return [inner if outer == "value" else f'{inner}{outer}' if inner else outer
#                 for outer, inner in cols]

#     def _get_unique_fill_value(self, df, idx):
#         "Get a unique fill value for NaN replacement."
#         fill_value = self.num_fill_value
#         while (df[idx] == fill_value).any().any():
#             fill_value -= 1
#         return fill_value

#     def _fill_nan_values(self, df, idx):
#         "Fill NaN values in index columns."
#         num_fill_value = self._get_unique_fill_value(df, idx)
#         for col in idx:
#             fill_value = num_fill_value if pd.api.types.is_numeric_dtype(df[col]) else self.str_fill_value
#             df[col] = df[col].fillna(fill_value)
#         return df, num_fill_value

#     def pivot(self, df):
#         derived_coi = [col for col in self.derived_cols if col in df.columns]
#         # In past implementation we added an index column before pivoting 
#         # TO BE REMOVED
#         # making all rows (compound_idx) unique.
#         # df.index.name = 'org_index'
#         # df = df.reset_index()
#         idx = list(set(df.columns) - set(self.columns + derived_coi + self.values))
        
#         df, num_fill_value = self._fill_nan_values(df, idx)

#         pivot_df = df.pivot_table(index=idx,
#                                   columns=self.columns,
#                                   values=self.values + derived_coi,
                                  
                                  
#                                   aggfunc=lambda x: x
#                                   ).reset_index()

#         pivot_df[idx] = pivot_df[idx].replace({self.str_fill_value: np.nan, num_fill_value: np.nan})
#         pivot_df = self.set_index(pivot_df)
#         return pivot_df

#         def set_index(self, df):
#             "Set the index of the dataframe."
#             # TODO: Consider implementing a universal unique index
#             # by hashing the compound index columns (lat, lon, time, depth, etc.)
#             df.index.name = 'org_index'
#             return df
    
#     def __call__(self, tfm):
#         for grp in tfm.dfs.keys():
#             tfm.dfs[grp] = self.pivot(tfm.dfs[grp])
#             tfm.dfs[grp].columns = self.renamed_cols(tfm.dfs[grp].columns)

[TO BE DELETED] Example of usage:

* **Case 1**: `compound_idx` (in our case made of `lon`, `lat`, `time`, `depth`, ...) are unique

In [None]:
# #| eval: false
# dfs_test = {'seawater': pd.DataFrame({
#     'compound_idx': ['a', 'b', 'c', 'd'], 
#     'nuclide': ['cs137', 'cs137', 'pu239_240_tot', 'pu239_240_tot'], 
#     'value': [1, 2, 3, 4],
#     '_unc': [0.1, 0.2, 0.3, 0.4]})}

# tfm = Transformer(dfs_test, cbs=[ReshapeLongToWide()])
# tfm()['seawater']

* **Case 2**: compound_idx are not unique

In [None]:
# #| eval: false
# dfs_test = {'seawater': pd.DataFrame({
#     'compound_idx': ['a', 'a', 'c', 'd'], 
#     'nuclide': ['cs137', 'cs134', 'pu239_240_tot', 'pu239_240_tot'], 
#     'value': [1, 2, 3, 4],
#     '_unc': [0.1, 0.2, 0.3, 0.4]})}

# tfm = Transformer(dfs_test, cbs=[ReshapeLongToWide()])
# tfm()['seawater']

In [None]:
# #| eval: false
# dfs_test = fc.load_pickle('./files/pkl/dfs_reshape_test_in.pkl')
# dfs_expected = fc.load_pickle('./files/pkl/dfs_reshape_test_expected.pkl')

# tfm = Transformer(dfs_test, cbs=[ReshapeLongToWide()])
# dfs_output = tfm()
# test_dfs(dfs_output, dfs_expected)    

In [None]:
#| exports
class CompareDfsAndTfmCB(Callback):
    "Create a dataframe of dropped data. Data included in the `dfs` not in the `tfm`."
    def __init__(self, 
                 dfs: Dict[str, pd.DataFrame] # Original dataframes
                 ): 
        fc.store_attr()
        
    def __call__(self, tfm: Transformer) -> None:
        self._initialize_tfm_attributes(tfm)
        for grp in tfm.dfs.keys():
            dropped_df = self._get_dropped_data(grp, tfm)
            tfm.dfs_dropped[grp] = dropped_df
            tfm.compare_stats[grp] = self._compute_stats(grp, tfm)

    def _initialize_tfm_attributes(self, tfm: Transformer) -> None:
        tfm.dfs_dropped = {}
        tfm.compare_stats = {}

    def _get_dropped_data(self, 
                          grp: str, # The group key
                          tfm: Transformer # The transformation object containing `dfs`
                         ) -> pd.DataFrame: # Dataframe with dropped rows
        "Get the data that is present in `dfs` but not in `tfm.dfs`."
        index_diff = self.dfs[grp].index.difference(tfm.dfs[grp].index)
        return self.dfs[grp].loc[index_diff]
    
    def _compute_stats(self, 
                       grp: str, # The group key
                       tfm: Transformer # The transformation object containing `dfs`
                      ) -> Dict[str, int]: # Dictionary with comparison statistics
        "Compute comparison statistics between `dfs` and `tfm.dfs`."
        return {
            'Number of rows in dfs': len(self.dfs[grp].index),
            'Number of rows in tfm.dfs': len(tfm.dfs[grp].index),
            'Number of rows removed': len(tfm.dfs_dropped[grp].index),
        }

`CompareDfsAndTfmCB` compares the original dataframes to the transformed dataframe. A dictionary of dataframes, `tfm.dfs_dropped`, is created to include the data present in the original dataset but absent from the transformed data. `tfm.compare_stats` provides a quick overview of the number of rows in both the original dataframes and the transformed dataframe. 

For instance:

In [None]:
# dfs_test = {
#     'seawater': pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}),
#     'sediment': pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}),
# }

# class TestTfmCB(Callback):
#     def __call__(self, tfm):
#         for key in tfm.dfs.keys():
#             df = tfm.dfs[key]
#             drop_idxs = [0, 1] if key == 'seawater' else [0]
#             df.drop(drop_idxs, inplace=True)
            
# tfm = Transformer(dfs_test, cbs=[
#     TestTfmCB(), 
#     CompareDfsAndTfmCB(dfs_test)], inplace=False)

# print(tfm())

# fc.test_eq(tfm.compare_stats['seawater']['Number of dropped rows'], 2)
# fc.test_eq(tfm.compare_stats['sediment']['Number of dropped rows'], 1)

In [None]:
#| exports
class UniqueIndexCB(Callback):
    "Set unique index for each group."
    def __init__(self,
                 index_name='ID'):
        fc.store_attr()
        
    def __call__(self, tfm):
        for k in tfm.dfs.keys():
            # Reset the index of the DataFrame and drop the old index
            tfm.dfs[k] = tfm.dfs[k].reset_index(drop=True)
            # Reset the index again and set the name of the new index to `ìndex_name``
            tfm.dfs[k] = tfm.dfs[k].reset_index(names=[self.index_name])

## Time

These callbacks are used to transform the time variable according to netCDF CF standards. For instance, the `EncodeTimeCB` callback is used to encode the time variable as an integer representing seconds since a reference date as specified in `maris.cdl`.

In [None]:
#| exports
class EncodeTimeCB(Callback):
    "Encode time as seconds since epoch."    
    def __init__(self, 
                 col_time: str='TIME',
                 fn_units: Callable=get_time_units # Function returning the time units
                 ): 
        fc.store_attr()
        self.units = fn_units()
    
    def __call__(self, tfm): 
        for grp, df in tfm.dfs.items():
            n_missing = df[self.col_time].isna().sum()
            if n_missing:
                print(f"Warning: {n_missing} missing time value(s) in {grp}")
            
            # Remove NaN times and convert to seconds since epoch
            tfm.dfs[grp] = tfm.dfs[grp][tfm.dfs[grp][self.col_time].notna()]
            tfm.dfs[grp][self.col_time] = tfm.dfs[grp][self.col_time].apply(lambda x: date2num(x, units=self.units))

In [None]:
dfs_test = {
    'SEAWATER': pd.DataFrame({
        'TIME': [pd.Timestamp(f'2023-01-0{t}') for t in [1, 2]],
        'value': [1, 2]
        }),
    'SEDIMENT': pd.DataFrame({
        'TIME': [pd.Timestamp(f'2023-01-0{t}') for t in [3, 4]],
        'value': [3, 4]
        }),
}

units = 'seconds since 1970-01-01 00:00:00.0'
tfm = Transformer(dfs_test, cbs=[
    EncodeTimeCB(fn_units=lambda: units)
    ], inplace=False)
dfs_result = tfm()

fc.test_eq(dfs_result['SEAWATER'].TIME.dtype, 'int64')
fc.test_eq(dfs_result['SEDIMENT'].TIME.dtype, 'int64')


fc.test_eq(dfs_result['SEAWATER'].TIME, dfs_test['SEAWATER'].TIME.apply(lambda x: date2num(x, units=units)))
fc.test_eq(dfs_result['SEDIMENT'].TIME, dfs_test['SEDIMENT'].TIME.apply(lambda x: date2num(x, units=units)))