In [None]:
#| default_exp serializers

# Serializers
> Various utilities to encode MARIS dataset as `NetCDF`, `csv`, ... formats.

In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export
from netCDF4 import Dataset
import pandas as pd
from typing import Dict, Callable
import pandas as pd
from fastcore.basics import patch, store_attr

In [None]:
#| export
class NetCDFEncoder:
    "MARIS NetCDF encoder."
    def __init__(self, 
                 dfs:dict[pd.DataFrame], # dict of Dataframes to encode with group name as key {'sediment': df_sed, ...}
                 src_fname:str, # File name and path to the MARIS CDL template
                 dest_fname:str, # Name of output file to produce
                 global_attrs:Dict, # Global attributes
                 enums_xtra:Dict={}, # Enumeration types to overwrite
                 verbose:bool=False, # Print currently written NetCDF group and variable names
                 ):
        store_attr()
        self.enum_types = {}

In [None]:
df = pd.DataFrame({
    'sample': [0, 1], 
    'lon': [141, 142], 
    'lat': [37.3, 38.3], 
    'time': [1234, 1235], 
    'i131': [1, 1.5],
    'i131_dl': [0, 1], 
    'i131_unit': [1, 1],
    'species_id': [107083, 241373]
    })

dfs = {'seawater': df, 'biota': df}
attrs = {'id': '123', 'title': 'Test title', 'summary': 'Summary test'}
src = './files/nc/template-test.nc'
dest = './files/nc/encoding-test.nc'
enums_xtra = {
    'species_t': {'Aristeus antennatus': 107083, 'Apostichopus': 241373}
    }

In [None]:
encoder = NetCDFEncoder(dfs, src_fname=src, dest_fname=dest, global_attrs=attrs)

In [None]:
#| export
@patch 
def copy_global_attributes(self:NetCDFEncoder):
    "Update NetCDF template global attributes as specified by `global_attrs` argument."
    self.dest.setncatts(self.src.__dict__)
    for k, v in self.global_attrs.items(): self.dest.setncattr(k, v)

In [None]:
#| export
@patch
def copy_dimensions(self:NetCDFEncoder):
    for name, dimension in self.src.dimensions.items():
        self.dest.createDimension(name, (len(dimension) if not dimension.isunlimited() else None))

In [None]:
#| export
@patch
def process_groups(self:NetCDFEncoder):
    for grp_name, df in self.dfs.items():
        self.process_group(grp_name, df)

In [None]:
#| export
@patch
def process_group(self:NetCDFEncoder, group_name, df):
    group_dest = self.dest.createGroup(group_name)
    self.copy_variables(group_name, df, group_dest)

In [None]:
#| export
@patch
def copy_variables(self:NetCDFEncoder, group_name, df, group_dest):
    for var_name, var_src in self.src.groups[group_name].variables.items():
        if var_name in df.reset_index().columns: 
            self.copy_variable(var_name, var_src, df, group_dest)

In [None]:
#| export
@patch
def copy_variable(self:NetCDFEncoder, var_name, var_src, df, group_dest):
    dtype_name = var_src.datatype.name
    enums_src = self.src.enumtypes
    if self.verbose: print(f'Group: {group_dest.name}, Variable: {var_name}')
    if dtype_name in enums_src: self.copy_enum_type(dtype_name)   
    self._create_and_copy_variable(var_name, var_src, df, group_dest, dtype_name)
    self.copy_variable_attributes(var_name, var_src, group_dest)

In [None]:
#| export
@patch
def _create_and_copy_variable(self:NetCDFEncoder, var_name, var_src, df, group_dest, dtype_name):
    variable_type = self.enum_types.get(dtype_name, var_src.datatype)
    group_dest.createVariable(var_name, variable_type, var_src.dimensions, compression='zlib', complevel=9)
    
    df_sanitized = self.cast_verbose_rf(df, var_name)
    group_dest[var_name][:] = df_sanitized.values

In [None]:
#| export
@patch
def copy_enum_type(self:NetCDFEncoder, dtype_name):
    if dtype_name not in self.enum_types:
        enum_info = self.src.enumtypes[dtype_name]
        if enum_info.name in self.enums_xtra:
            enum_info.enum_dict = self.enums_xtra[enum_info.name]
        self.enum_types[dtype_name] = self.dest.createEnumType(enum_info.dtype, 
                                                               enum_info.name, 
                                                               enum_info.enum_dict)

In [None]:
#| export
@patch
def copy_variable_attributes(self:NetCDFEncoder, var_name, var_src, group_dest):
    group_dest[var_name].setncatts(var_src.__dict__)

In [None]:
#| export
@patch
def cast_verbose_rf(self:NetCDFEncoder, 
                    df, 
                    col):
    """
    Try to cast df column to numeric type:
        - Silently coerce to nan if not possible
        - But log when it failed
    """
    n_before = sum(df.reset_index()[col].notna())
    df_after = pd.to_numeric(df.reset_index()[col],
                                    errors='coerce', downcast=None)
    n_after = sum(df_after.notna())
    if n_before != n_after: 
        print(f'Failed to convert type of {col} in {n_before - n_after} occurences')
    
    return df_after

In [None]:
#| export
@patch
def encode(self:NetCDFEncoder):
    "Encode MARIS NetCDF based on template and dataframes."
    with Dataset(self.src_fname, format='NETCDF4') as self.src, Dataset(self.dest_fname, 'w', format='NETCDF4') as self.dest:
        self.copy_global_attributes()
        self.copy_dimensions()
        self.process_groups()

In [None]:
encoder = NetCDFEncoder(dfs, src_fname=src, dest_fname=dest, global_attrs=attrs, verbose=False)
encoder.encode()

In [None]:
#### Legacy code to remove

In [None]:
# def cast_verbose(df, col):
#     """
#     Try to cast df column to numeric type:
#         - Silently coerce to nan if not possible
#         - But log when it failed
#     """
#     n_before = sum(df.reset_index()[col].notna())
#     df_after = pd.to_numeric(df.reset_index()[col],
#                                     errors='coerce', downcast=None)
#     n_after = sum(df_after.notna())
#     if n_before != n_after: 
#         print(f'Failed to convert type of {col} in {n_before - n_after} occurences')
    
#     return df_after

# def to_netcdf(
#     dfs:dict[pd.DataFrame], # dict of Dataframes to encode with group name as key {'sediment': df_sed, ...}
#     src_fname:str, # Input MARIS template NetCDF path and name
#     # fname_output:str, # Name of output file to produce
#     dest_fname:str, # Output NetCDF path and name to produce
#     global_attrs:Dict, # Global attributes
#     units_fn:Callable, # (group, variable) -> unit look up function
# ):
#     "Encode MARIS dataset (provided as Pandas DataFrame) to NetCDF file"
#     with Dataset(src_fname, format='NETCDF4') as src, Dataset(dest_fname, 'w', format='NETCDF4') as dst:
#         # copy global attributes all at once via dictionary
#         dst.setncatts(src.__dict__)
#         dst.setncatts(global_attrs) 
        
#         # copy dimensions
#         for name, dimension in src.dimensions.items():
#             dst.createDimension(
#                 name, (len(dimension) if not dimension.isunlimited() else None))

#         # copy groups
#         for grp_name, df in dfs.items():
#             # TBD: asserting group name
#             grp_dest = dst.createGroup(grp_name)
        
#             n_before = 0
#             n_after = 0
            
#             # copy all variables of interest and fill them
#             for name_var_src, var_src in src.groups[grp_name].variables.items():
#                 # Only if source variable is in destination
#                 if name_var_src in df.reset_index().columns:
#                     # x = grp_dest.createVariable(name_var_src, var_src.datatype, var_src.dimensions,
#                     grp_dest.createVariable(name_var_src, var_src.datatype, var_src.dimensions,
#                                             compression='zlib', complevel=9)
                        
#                     df_sanitized = cast_verbose(df, name_var_src)
#                     grp_dest[name_var_src][:] = df_sanitized.values
                    
#                     # copy variable attributes all at once via dictionary
#                     grp_dest[name_var_src].setncatts(src.groups[grp_name][name_var_src].__dict__)
#                     if (hasattr(src.groups[grp_name][name_var_src], 'units') and
#                         src.groups[grp_name][name_var_src].units == '_to_be_filled_in_'):
#                         grp_dest[name_var_src].units = units_fn(grp_name, name_var_src)