In [None]:
#| default_exp callbacks

# Callbacks

> Callback used in handlers

In [None]:
#| exports
import copy
import fastcore.all as fc
from operator import attrgetter
from cftime import date2num
import numpy as np
import pandas as pd
from marisco.configs import cfg, cdl_cfg
from functools import partial 
from typing import List, Dict, Callable, Tuple
from pathlib import Path 

from marisco.configs import get_lut, nuc_lut_path

In [None]:
#| hide
from marisco.configs import cdl_cfg, CONFIGS_CDL
from marisco.utils import test_dfs

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Core

The `Transformer` class is designed to facilitate the application of a series of callbacks to a set of dataframes. It provides a structured way to apply transformations (i.e `Callback`) to the data, with a focus on flexibility and ease of use.

In [None]:
#| exports
class Callback(): 
    "Base class for callbacks."
    order = 0

In [None]:
#| exports
def run_cbs(cbs, obj=None):
    "Run the callbacks in the order they are specified."
    for cb in sorted(cbs, key=attrgetter('order')):
        if cb.__doc__: obj.logs.append(cb.__doc__)
        cb(obj)

In [None]:
#| exports
class Transformer():
    def __init__(self, 
                 dfs:pd.DataFrame, # Dictionary of DataFrames to transform
                 cbs:list=None, # List of callbacks to run
                 inplace:bool=False # Whether to modify the dataframes in place
                 ): 
        "Transform the dataframes according to the specified callbacks."
        fc.store_attr()
        self.dfs = dfs if inplace else {k: v.copy() for k, v in dfs.items()}
        self.logs = []
            
    def unique(self, col_name):
        "Distinct values of a specific column present in all groups."
        columns = [df.get(col_name) for df in self.dfs.values() if df.get(col_name) is not None]
        values = np.concatenate(columns) if columns else []
        return np.unique(values)
        
    def __call__(self):
        "Transform the dataframes according to the specified callbacks."
        if self.cbs: run_cbs(self.cbs, self)
        return self.dfs

Below, a few examples of how to use the `Transformer` class.
Let's define first a test callback that adds `1` to the `depth`:

In [None]:
class TestCB(Callback):
    "A test callback to add 1 to the depth."
    def __call__(self, tfm):
        for grp, df in tfm.dfs.items(): 
            df['depth'] = df['depth'].apply(lambda x: x+1)

And apply it to the following dataframes:

In [None]:
dfs = {'biota': pd.DataFrame({'id': [0, 1, 2], 'species': [0, 2, 0], 'depth': [2, 3, 4]}),
       'seawater': pd.DataFrame({'id': [0, 1, 2], 'depth': [3, 4, 5]})}

tfm = Transformer(dfs, cbs=[TestCB()])
df_test = tfm()

fc.test_eq(df_test['biota']['depth'].to_list(), [3, 4, 5])
fc.test_eq(df_test['seawater']['depth'].to_list(), [4, 5, 6])

## Geographical

This section gathers callbacks that are used to transform the geographical coordinates.

In [None]:
#| exports
class SanitizeLonLatCB(Callback):
    "Drop row when both longitude & latitude equal 0. Drop unrealistic longitude & latitude values. Convert longitude & latitude `,` separator to `.` separator."
    def __init__(self, verbose=False): fc.store_attr()
    def __call__(self, tfm):
        for grp, df in tfm.dfs.items():
            " Convert `,` separator to `.` separator"
            df['lon'] = [float(str(x).replace(',', '.')) for x in df['lon']]
            df['lat'] = [float(str(x).replace(',', '.')) for x in df['lat']]
            
            # mask zero values
            mask_zeroes = (df.lon == 0) & (df.lat == 0) 
            nZeroes = mask_zeroes.sum()
            if nZeroes and self.verbose: 
                print(f'The "{grp}" group contains {nZeroes} data points whose (lon, lat) = (0, 0)')
            
            # mask gps out of bounds, goob. 
            mask_goob = (df.lon < -180) | (df.lon > 180) | (df.lat < -90) | (df.lat > 90)
            nGoob = mask_goob.sum()
            if nGoob and self.verbose: 
                print(f'The "{grp}" group contains {nGoob} data points whose lon or lat are unrealistic. Outside -90 to 90 for latitude and -180 to 180 for longitude.')
                
            tfm.dfs[grp] = df.loc[~(mask_zeroes | mask_goob)]

In [None]:
# Check that measurements located at (0,0) get removed
dfs = {'biota': pd.DataFrame({'lon': [0, 1, 0], 'lat': [0, 2, 0]})}
tfm = Transformer(dfs, cbs=[SanitizeLonLatCB()])
tfm()['biota']

expected = [1., 2.]
fc.test_eq(tfm()['biota'].iloc[0].to_list(), expected)

In [None]:
# Check that comma decimal separator get replaced by point instead
dfs = {'biota': pd.DataFrame({'lon': ['45,2'], 'lat': ['43,1']})}
tfm = Transformer(dfs, cbs=[SanitizeLonLatCB()])
tfm()['biota']

expected = [45.2, 43.1]
fc.test_eq(tfm()['biota'].iloc[0].to_list(), expected)

In [None]:
# Check that out of bounds lon or lat get removed
dfs = {'biota': pd.DataFrame({'lon': [-190, 190, 1, 2, 1.1], 'lat': [1, 2, 91, -91, 2.2]})}
tfm = Transformer(dfs, cbs=[SanitizeLonLatCB()])
tfm()['biota']

expected = [1.1, 2.2]
fc.test_eq(tfm()['biota'].iloc[0].to_list(), expected)

## Add columns
This section gathers callbacks that are used to add required columns to the dataframes.

In [None]:
#| exports
class AddSampleTypeIdColumnCB(Callback):
    def __init__(self, 
                 cdl_cfg:Callable=cdl_cfg, # Callable to get the CDL config dictionary
                 col_name:str='samptype_id'
                 ): 
        "Add a column with the sample type id as defined in the CDL."
        fc.store_attr()
        self.lut = {v['name']: v['id'] for v in cdl_cfg()['grps'].values()}
        
    def __call__(self, tfm):
        for grp, df in tfm.dfs.items(): df[self.col_name] = self.lut[grp]

Let's test the callback:

In [None]:
dfs = {v['name']: pd.DataFrame({'col_test': [0, 1, 2]}) for v in CONFIGS_CDL['grps'].values()}

tfm = Transformer(dfs, cbs=[AddSampleTypeIdColumnCB(cdl_cfg=lambda: CONFIGS_CDL)])
dfs_test = tfm()

for v in CONFIGS_CDL['grps'].values():
    fc.test_eq(dfs_test[v['name']]['samptype_id'].unique().item(), v['id']) 

In [None]:
#| exports
class AddNuclideIdColumnCB(Callback):
    def __init__(self, 
                 col_value:str, # Column name containing the nuclide name
                 lut_fname_fn:callable=nuc_lut_path, # Function returning the lut path
                 col_name:str='nuclide_id' # Column name to store the nuclide id
                 ): 
        "Add a column with the nuclide id."
        fc.store_attr()
        self.lut = get_lut(lut_fname_fn().parent, lut_fname_fn().name, 
                           key='nc_name', value='nuclide_id', reverse=False)
        
    def __call__(self, tfm):
        for grp, df in tfm.dfs.items(): 
            df[self.col_name] = df[self.col_value].map(self.lut)

In [None]:
dfs = {v['name']: pd.DataFrame({'Nuclide': ['cs137', 'pu239_240_tot']}) for v in CONFIGS_CDL['grps'].values()}

lut_fname_fn = lambda: Path('./files/lut/dbo_nuclide.xlsx')

tfm = Transformer(dfs, cbs=[AddNuclideIdColumnCB(col_value='Nuclide', lut_fname_fn=lut_fname_fn)])
tfm()['seawater']

expected = [33, 77]
for grp in tfm.dfs.keys():
    fc.test_eq(tfm.dfs[grp]['nuclide_id'].to_list(), expected)

## Map & Standardize

In [None]:
#| exports
class LowerStripNameCB(Callback):
    def __init__(self, 
                 col_src:str, # Source column name e.g. 'Nuclide'
                 col_dst:str=None, # Destination column name
                 fn_transform:Callable=lambda x: x.lower().strip() # Transformation function
                 ):
        "Convert values to lowercase and strip any trailing spaces."
        fc.store_attr()
        self.__doc__ = f"Convert values from '{col_src}' to lowercase, strip spaces, and store in '{col_dst}'."
        if not col_dst: self.col_dst = col_src
        
    def _safe_transform(self, value):
        "Ensure value is not NA and apply transformation function."
        return value if pd.isna(value) else self.fn_transform(str(value))
            
    def __call__(self, tfm):
        for key in tfm.dfs.keys():
            tfm.dfs[key][self.col_dst] = tfm.dfs[key][self.col_src].apply(self._safe_transform)

Let's test the callback:

In [None]:
dfs = {'seawater': pd.DataFrame({'Nuclide': ['CS137', '226RA']})}

tfm = Transformer(dfs, cbs=[LowerStripNameCB(col_src='Nuclide', col_dst='NUCLIDE')])
fc.test_eq(tfm()['seawater']['NUCLIDE'].to_list(), ['cs137', '226ra'])


tfm = Transformer(dfs, cbs=[LowerStripNameCB(col_src='Nuclide')])
fc.test_eq(tfm()['seawater']['Nuclide'].to_list(), ['cs137', '226ra'])

The point is when (semi-automatic) remapping names generally:

1. we need first to guess (fuzzy matching or other) the right nuclide name.
2. Then manually check the result and eventually update the lookup table.
3.  Finally we can apply the lookup table to the dataframe.


## Change structure

Many data providers use a long format (e.g `lat, lon, radionuclide, value, unc, ...`) to store their data. When encoding as `netCDF`, it is often required to use a wide format (e.g `lat, lon, nuclide1_value, nuclide1_unc, nuclide2_value, nuclide2_unc, ...`). The class `ReshapeLongToWide` is designed to perform this transformation.

In [None]:
#| exports
class ReshapeLongToWide(Callback):
    def __init__(self, columns=['nuclide'], values=['value'], 
                 num_fill_value=-999, str_fill_value='STR FILL VALUE'):
        "Convert data from long to wide with renamed columns."
        fc.store_attr()
        self.derived_cols = self._get_derived_cols()
    
    def _get_derived_cols(self):
        "Retrieve all possible derived vars (e.g 'unc', 'dl', ...) from configs."
        return [value['name'] for value in cdl_cfg()['vars']['suffixes'].values()]

    def renamed_cols(self, cols):
        "Flatten columns name."
        return [inner if outer == "value" else f'{inner}{outer}' if inner else outer
                for outer, inner in cols]

    def _get_unique_fill_value(self, df, idx):
        "Get a unique fill value for NaN replacement."
        fill_value = self.num_fill_value
        while (df[idx] == fill_value).any().any():
            fill_value -= 1
        return fill_value

    def _fill_nan_values(self, df, idx):
        "Fill NaN values in index columns."
        num_fill_value = self._get_unique_fill_value(df, idx)
        for col in idx:
            fill_value = num_fill_value if pd.api.types.is_numeric_dtype(df[col]) else self.str_fill_value
            df[col] = df[col].fillna(fill_value)
        return df, num_fill_value

    def pivot(self, df):
        derived_coi = [col for col in self.derived_cols if col in df.columns]
        df.index.name = 'org_index'
        df = df.reset_index()
        idx = list(set(df.columns) - set(self.columns + derived_coi + self.values))
        
        df, num_fill_value = self._fill_nan_values(df, idx)

        pivot_df = df.pivot_table(index=idx,
                                  columns=self.columns,
                                  values=self.values + derived_coi,
                                  fill_value=np.nan,
                                  aggfunc=lambda x: x).reset_index()

        pivot_df[idx] = pivot_df[idx].replace({self.str_fill_value: np.nan, num_fill_value: np.nan})
        return pivot_df.set_index('org_index')

    def __call__(self, tfm):
        for grp in tfm.dfs.keys():
            tfm.dfs[grp] = self.pivot(tfm.dfs[grp])
            tfm.dfs[grp].columns = self.renamed_cols(tfm.dfs[grp].columns)

In [None]:
#| eval: false
dfs_test = fc.load_pickle('./files/pkl/dfs_reshape_test_in.pkl')
dfs_expected = fc.load_pickle('./files/pkl/dfs_reshape_test_expected.pkl')

tfm = Transformer(dfs_test, cbs=[ReshapeLongToWide()])
dfs_output = tfm()
test_dfs(dfs_output, dfs_expected)    

In [None]:
#| exports
class CompareDfsAndTfmCB(Callback):
    def __init__(self, dfs: Dict[str, pd.DataFrame]): 
        "Create a dataframe of dropped data. Data included in the `dfs` not in the `tfm`."
        fc.store_attr()
        
    def __call__(self, tfm: Transformer) -> None:
        self._initialize_tfm_attributes(tfm)
        for grp in tfm.dfs.keys():
            dropped_df = self._get_dropped_data(grp, tfm)
            tfm.dfs_dropped[grp] = dropped_df
            tfm.compare_stats[grp] = self._compute_stats(grp, tfm)

    def _initialize_tfm_attributes(self, tfm: Transformer) -> None:
        tfm.dfs_dropped = {}
        tfm.compare_stats = {}

    def _get_dropped_data(self, 
                          grp: str, # The group key
                          tfm: Transformer # The transformation object containing `dfs`
                         ) -> pd.DataFrame: # Dataframe with dropped rows
        "Get the data that is present in `dfs` but not in `tfm.dfs`."
        index_diff = self.dfs[grp].index.difference(tfm.dfs[grp].index)
        return self.dfs[grp].loc[index_diff]
    
    def _compute_stats(self, 
                       grp: str, # The group key
                       tfm: Transformer # The transformation object containing `dfs`
                      ) -> Dict[str, int]: # Dictionary with comparison statistics
        "Compute comparison statistics between `dfs` and `tfm.dfs`."
        return {
            'Number of rows in dfs': len(self.dfs[grp].index),
            'Number of rows in tfm.dfs': len(tfm.dfs[grp].index),
            'Number of dropped rows': len(tfm.dfs_dropped[grp].index),
            'Number of rows in tfm.dfs + Number of dropped rows': len(tfm.dfs[grp].index) + len(tfm.dfs_dropped[grp].index)
        }

`CompareDfsAndTfmCB` compares the original dataframes to the transformed dataframe. A dictionary of dataframes, `tfm.dfs_dropped`, is created to include the data present in the original dataset but absent from the transformed data. `tfm.compare_stats` provides a quick overview of the number of rows in both the original dataframes and the transformed dataframe. 

For instance:

In [None]:
dfs_test = {
    'seawater': pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}),
    'sediment': pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}),
}

class TestTfmCB(Callback):
    def __call__(self, tfm):
        for key in tfm.dfs.keys():
            df = tfm.dfs[key]
            drop_idxs = [0, 1] if key == 'seawater' else [0]
            df.drop(drop_idxs, inplace=True)
            
tfm = Transformer(dfs_test, cbs=[
    TestTfmCB(), 
    CompareDfsAndTfmCB(dfs_test)], inplace=False)

print(tfm())

fc.test_eq(tfm.compare_stats['seawater']['Number of dropped rows'], 2)
fc.test_eq(tfm.compare_stats['sediment']['Number of dropped rows'], 1)

{'seawater':    a  b
2  3  6, 'sediment':    a  b
1  2  5
2  3  6}


## Time

These callbacks are used to transform the time variable according to netCDF CF standards. For instance, the `EncodeTimeCB` callback is used to encode the time variable as an integer representing seconds since a reference date as specified in `configs.ipynb` `CONFIGS_CDL` dictionary.

In [None]:
#| exports
class EncodeTimeCB(Callback):
    "Encode time as `int` representing seconds since xxx"    
    def __init__(self, cfg , verbose=False): fc.store_attr()
    def __call__(self, tfm): 
        def format_time(x): 
            return date2num(x, units=self.cfg['units']['time'])
        
        for k in tfm.dfs.keys():
            # If invalid time entries.
            if tfm.dfs[k]['time'].isna().any():
                if self.verbose:
                    invalid_time_df=tfm.dfs[k][tfm.dfs[k]['time'].isna()]
                    print (f'{len(invalid_time_df.index)} of {len(tfm.dfs[k].index)} entries for `time` are invalid for {k}.')
                # Filter nan values
                tfm.dfs[k] = tfm.dfs[k][tfm.dfs[k]['time'].notna()]
            
            tfm.dfs[k]['time'] = tfm.dfs[k]['time'].apply(format_time)