In [1]:
import os
import glob
import sys
import numpy as np
import datetime as dt
from multiprocessing import Pool
import tqdm
import gc
from scipy.interpolate import griddata
import numpy as np
from matplotlib.tri import Triangulation

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
def jacobian(x0, y0, x1, y1, x2, y2):
    """
    jac = jacobian(x0, y0, x1, y1, x2, y2):
    calculates jac = det(M),
    where M is the matrix
    [[x1-x0, y1-y0], [x2-x0, y2-y0]].

    This is twice the area of a triangle with vertices:
    (x0, y0), (x1, y1), (x2, y2)

    Parameters:
    x0, x1, x2 (numpy arrays or floats) - x coords of the 3 points
    y0, y1, y2 (numpy arrays or floats) - y coords of the 3 points

    Returns:
    jac (same type as inputs)
    """
    return (x1-x0)*(y2-y0)-(x2-x0)*(y1-y0)

def measure(x, y, t):
    dx = np.diff(np.hstack([x[t], x[t][:,0][None].T]))
    dy = np.diff(np.hstack([y[t], y[t][:,0][None].T]))
    edges = np.hypot(dx, dy)
    perim = edges.sum(axis=1)
    area = get_area(x, y, t)
    ap_ratio = area**0.5/ perim
    return area, edges, perim, ap_ratio

def get_area(x, y, t):
    return .5*jacobian(x[t][:,0], y[t][:,0], x[t][:,1], y[t][:,1], x[t][:,2], y[t][:,2])

class IrregularGridInterpolator(object):
    def __init__(self, x0, y0, x1, y1, triangles=None):
        '''
        Parameters:
        -----------
        x0 : np.ndarray(float)
            x-coords of source points
        y0 : np.ndarray(float)
            y-coords of source points
        x1 : np.ndarray(float)
            x-coords of destination points
        y1 : np.ndarray(float)
            y-coords of destination points
        triangles : np.ndarray(int)
            shape (num_triangles, 3)
            indices of nodes for each triangle

        Sets:
        -----
        self.inside: np.ndarray(bool)
            shape = (num_target_points,)
        self.vertices: np.ndarray(int)
            shape = (num_good_target_points, 3)
            good target points are those inside the source triangulation
        self.weights: np.ndarray(float)
            shape = (num_good_target_points, 3)
            good target points are those inside the source triangulation

        Follows this suggestion:
        https://stackoverflow.com/questions/20915502/speedup-scipy-griddata-for-multiple-interpolations-between-two-irregular-grids
        x_target[i] = \sum_{j=0}^2 weights[i, j]*x_source[vertices[i, j]]
        y_target[i] = \sum_{j=0}^2 weights[i, j]*y_source[vertices[i, j]]
        We can do (linear) interpolation by replacing x_target, x_source with z_target, z_source
        where z_source is the field to be interpolated and z_target is the interpolated field
        '''

        # define and triangulate source points
        self.src_shape = x0.shape
        self.src_points = np.array([x0.flatten(), y0.flatten()]).T
        self.tri = Triangulation(x0.flatten(), y0.flatten(), triangles=triangles)
        self.tri_finder = self.tri.get_trifinder()
        self.num_triangles = len(self.tri.triangles)
        self._set_transform()

        # define target points
        self.dst_points = np.array([x1.flatten(), y1.flatten()]).T
        self.dst_shape = x1.shape
        self.triangle_map = self.tri_finder(x1, y1)
        self.dst_mask = (self.triangle_map < 0)
        self.triangle_map[self.dst_mask] = 0
        self.inside = ~self.dst_mask.flatten()

        """
        get barycentric coords
        https://en.wikipedia.org/wiki/Barycentric_coordinate_system#Barycentric_coordinates_on_triangles
        each row of bary is (lambda_1, lambda_2) for 1 destination point
        """
        d = 2
        inds = self.triangle_map.flatten()[self.inside]
        self.vertices = np.take(self.tri.triangles, inds, axis=0)
        temp = np.take(self.transform, inds, axis=0)
        delta = self.dst_points[self.inside] - temp[:, d]
        bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)

        # set weights
        self.weights = np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True)))

    def _set_transform(self):
        """
        Used for getting the barycentric coordinates on a triangle.
        Follows:
        https://en.wikipedia.org/wiki/Barycentric_coordinate_system#Barycentric_coordinates_on_triangles

        Sets:
        -----
        self.transform : numpy.ndarray
            For the i-th triangle,
                self.transform[i] = [[a', b'], [c', d'], [x_3, y3]]
            where the first 2 rows are the inverse of the matrix T in the wikipedia link
            and (x_3, y_3) are the coordinates of the 3rd vertex of the triangle
        """
        x = self.tri.x[self.tri.triangles]
        y = self.tri.y[self.tri.triangles]
        a = x[:,0] - x[:,2]
        b = x[:,1] - x[:,2]
        c = y[:,0] - y[:,2]
        d = y[:,1] - y[:,2]
        det = a*d-b*c

        self.transform = np.zeros((self.num_triangles, 3, 2))
        self.transform[:,0,0] = d/det
        self.transform[:,0,1] = -b/det
        self.transform[:,1,0] = -c/det
        self.transform[:,1,1] = a/det
        self.transform[:,2,0] = x[:,2]
        self.transform[:,2,1] = y[:,2]

    def interp_field(self, fld, method='linear'):
        """
        Interpolate field from elements elements or nodes of source triangulation
        to destination points

        Parameters:
        -----------
        fld: np.ndarray
            field to be interpolated
        method : str
            interpolation method if interpolating from nodes
            - 'linear'  : linear interpolation
            - 'nearest' : nearest neighbour

        Returns:
        -----------
        fld_interp : np.ndarray
            field interpolated onto the destination points
        """
        if fld.shape == self.src_shape:
            return self._interp_nodes(fld, method=method)
        fld_ = fld.flatten()
        if len(fld_) == self.num_triangles:
            return self._interp_elements(fld_)
        msg = f"""Field to interpolate should have the same size as the source points
        i.e. {self.src_shape}, or be a vector with the same number of triangles
        as the source triangulation i.e. self.num_triangles"""
        raise ValueError(msg)

    def _interp_elements(self, fld):
        """
        Interpolate field from elements of source triangulation to destination points

        Parameters:
        -----------
        fld: np.ndarray
            field to be interpolated

        Returns:
        -----------
        fld_interp : np.ndarray
            field interpolated onto the destination points
        """
        fld_interp = fld[self.triangle_map]
        fld_interp[self.dst_mask] = np.nan
        return fld_interp

    def _interp_nodes(self, fld, method='linear'):
        """
        Interpolate field from nodes of source triangulation to destination points

        Parameters:
        -----------
        fld: np.ndarray
            field to be interpolated
        method : str
            interpolation method
            - 'linear'  : linear interpolation
            - 'nearest' : nearest neighbour

        Returns:
        -----------
        fld_interp : np.ndarray
            field interpolated onto the destination points
        """
        ndst = self.dst_points.shape[0]
        fld_interp = np.full((ndst,), np.nan)
        w = self.weights
        if method == 'linear':
            # sum over the weights for each node of triangle
            v = self.vertices # shape = (ngood,3)
            fld_interp[self.inside] = np.einsum(
                    'nj,nj->n', np.take(fld.flatten(), v), w)

        elif method == 'nearest':
            # find the node of the triangle with the maximum weight
            v = np.array(self.vertices) # shape = (ngood,3)
            v = v[np.arange(len(w), dtype=int), np.argmax(w, axis=1)] # shape = (ngood,)
            fld_interp[self.inside] = fld.flatten()[v]

        else:
            raise ValueError("'method' should be 'nearest' or 'linear'")

        return fld_interp.reshape(self.dst_shape)


  '''


In [3]:
import cmocean as cm
import cartopy.crs as ccrs
import cartopy.feature as cfeature

land_50m = cfeature.LAND
sid_srs = ccrs.LambertAzimuthalEqualArea(central_longitude=0, central_latitude=90)

In [4]:
force = False
lag_dir = '/data2/Anton/sia/cdr_1991_2023'
dst_dir = f'{lag_dir}/age'

mesh_init_file = 'mesh_arctic_ease_25km_max7.npz'
xc = np.load(mesh_init_file)['xc']
yc = np.load(mesh_init_file)['yc'][::-1]
mask = np.load(mesh_init_file)['mask']
xgrd, ygrd = np.meshgrid(xc, yc)

sic = np.load('/Data/sim/data/OSISAF_ice_drift_CDR_postproc/1991/ice_drift_nh_ease2-750_cdr-v1p0_24h-199101011200.nc.npz')['c']
landmask = np.isnan(sic).astype(float)
landmask[landmask == 0] = np.nan

In [5]:
age_files = sorted(glob.glob(f'{dst_dir}/*/age_????????.npz'))[11500:]
print(len(age_files), age_files[0], age_files[-1])

731 /data2/Anton/sia/cdr_1991_2023/age/2023/age_20230311.npz /data2/Anton/sia/cdr_1991_2023/age/2025/age_20250310.npz


In [6]:
def grid_age(age_file):
    age_grd_file = age_file.replace('.npz', '_grd.npz').replace('/age/', '/age_grd/')
    if os.path.exists(age_grd_file) and not force:
        return
    try:
        t = np.load(age_file)['t']
        x = np.load(age_file)['x']
        y = np.load(age_file)['y']
        a = np.load(age_file)['a']
        f = np.load(age_file)['f']
    except:
        raise ValueError(f'Cannot load from {age_file}')

    try:
        igi = IrregularGridInterpolator(x, y, xgrd, ygrd, t)
    except:
        raise ValueError(f'Cannot create IGI {age_file}')

    src_data = np.vstack([a[None], f])
    dst_data = []
    for d in src_data:
        try:
            dgrd = igi.interp_field(d)
        except:
            raise ValueError(f'Cannot interpolate {age_file}')
        dgrd[mask == 0] = np.nan
        dst_data.append(dgrd.astype(np.float32))

    save_data = dict(age=dst_data.pop(0))
    frac_num = len(dst_data)
    for i, f in enumerate(dst_data):
        save_data[f'fraction_{frac_num - i}yi'] = f
    dst_date_dir = os.path.split(age_grd_file)[0]
    os.makedirs(dst_date_dir, exist_ok=True)
    np.savez_compressed(age_grd_file, **save_data)

with Pool(4) as p:
    r = list(tqdm.tqdm(p.imap(grid_age, age_files), total=len(age_files)))

100%|██████████| 731/731 [00:41<00:00, 17.83it/s] 


In [None]:
"""
age_grd_files = sorted(glob.glob(f'{dst_dir}_grd/*/age_????????_grd.npz'))[::10]
print(len(age_grd_files))

def plot_age_grd(i_age_grd_file):
    i, age_grd_file = i_age_grd_file
    dst_png_file = f'{dst_dir}/frame_grd_{i:04}.png'
    #if os.path.exists(dst_png_file):
    #    return
    age = np.load(age_grd_file)['age']
    age[mask == 0] = np.nan
    fig = plt.figure(figsize=(10,10))
    ax = plt.axes((0,0,1,1), projection=sid_srs)
    clim = [0, 5]
    cmap = 'jet'
    alpha = 1
    extent = np.array([xc.min(), xc.max(), yc.min(), yc.max()])*1000
    imsh = ax.imshow(age, extent=extent, clim=clim, cmap=cmap, zorder=0, alpha=alpha, interpolation='nearest')
    ax.add_feature(land_50m, edgecolor='black', zorder=10)
    xlim = (-2000000, 2200000)
    ylim = (-2200000, 2700000)
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    fig.colorbar(imsh, ax=ax, shrink=0.7)
    datestr = os.path.basename(age_grd_file).split('_')[1]
    plt.text(-1900000, 2200000, f'{datestr[:4]}-{datestr[4:6]}-{datestr[6:]}', zorder=10, fontsize=14)
    plt.savefig(dst_png_file, bbox_inches='tight', pad_inches=0.1, facecolor='white')
    plt.close('all')
    plt.close()
    gc.collect()

#plot_age_grd((0, age_grd_files[0]))
with Pool(10) as p:
    r = list(tqdm.tqdm(p.imap(plot_age_grd, enumerate(age_grd_files)), total=len(age_grd_files)))
"""

In [None]:
#!ffmpeg -y -r 25 -f image2 -i /data2/Anton/lagrangian_v5/age/frame_grd_%04d.png -vcodec libx264 -crf 8 -pix_fmt yuv420p -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" /data2/Anton/lagrangian_v5/age/age_grd.mp4

In [7]:
from netCDF4 import Dataset
from geodataset.geodataset import GeoDatasetWrite
import pyproj

class SeaIceAgeDataset(GeoDatasetWrite):
    """ wrapper for netCDF4.Dataset with info about Ice Age products """
    grid_mapping_variable = 'Lambert_Azimuthal_Equal_Area'
    projection = pyproj.Proj("+proj=laea +lat_0=90 +lon_0=0 +x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs +type=crs")
    global_attributes_source = None
    global_attributes_title = None

    def get_grid_mapping_ncattrs(self):
        return dict(
            grid_mapping_name = "lambert_azimuthal_equal_area" ,
            longitude_of_projection_origin = 0. ,
            latitude_of_projection_origin = 90. ,
            false_easting = 0. ,
            false_northing = 0. ,
            semi_major_axis = 6378137. ,
            inverse_flattening = 298.257223563 ,
            proj4_string = "+proj=laea +lat_0=90 +lon_0=0 +x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs +type=crs" ,
        )

    def set_global_attributes(self, date):
        global_attributes = dict(
            title = self.global_attributes_title,
            summary = "This climate data record of sea ice age is obtained from coarse resolution ice drift and conentration OSI SAF products. The processing chain features: 1) Lagrangian advection of ice age fractions, 2) Weighted averaging of fractions.",
            topiccategory = "Oceans ClimatologyMeteorologyAtmosphere",
            keywords = "Earth Science > Cryosphere > Sea Ice > Sea Ice Motion\n, Earth Science > Oceans > Sea Ice > Sea Ice Motion\n, Earth Science > Climate Indicators > Cryospheric Indicators > Sea Ice Motion\n, Geographic Region > Northern Hemisphere\n, Vertical Location > Sea Surface\n, NERSC > Nansen Environmental and Remote Sensing Centre",
            keywords_vocabulary = "GCMD Science Keywords",
            northernmost_latitude = 90.,
            southernmost_latitude = 17.61202,
            easternmost_longitude = 180.,
            westernmost_longitude = -180.,
            geospatial_vertical_min = 0.,
            geospatial_vertical_max = 0.,
            sensor = "SSM/I,SSMIS,AMSR-E,AMSR2",
            platform = "DMSP-F<08,10,11,13,14,15>,DMSP-F<16,17,18>,Aqua,GCOM-W1",
            source = self.global_attributes_source,
            time_coverage_start = date.strftime("%Y-%m-%dT00:00:00Z"),
            time_coverage_end = (date + dt.timedelta(1)).strftime("%Y-%m-%dT00:00:00Z"),
            time_coverage_duration = "P1D",
            time_coverage_resolution = "P1D",
            project = "TARDIS - Norwegian Research Council",
            institution = "Nansen Environmental and Remote Sensing Centre",
            creator_name = "NERSC",
            creator_type = "institution",
            creator_url = "https://nersc.no",
            creator_email = "anton.korosov@nersc.no",
            license = "All intellectual property rights of the Sea Ice Age product belong to NERSC. The use of these products is granted to every user, free of charge. If users wish to use these products, NERSC\'s copyright credit must be shown by displaying the words \'Copyright NERSC\' under each of the products shown. NERSC offers no warranty and accepts no liability in respect of the Sea Ice Age products. NERSC neither commits to nor guarantees the continuity, availability, or quality or suitability for any purpose of, the Sea Ice Age product.",
            references = "Korosov, A. A., Rampal, P., Pedersen, L. T., Saldo, R., Ye, Y., Heygster, G., Lavergne, T., Aaboe, S., and Girard-Ardhuin, F.: A new tracking algorithm for sea ice age distribution estimation, The Cryosphere, 12, 2073–2085, https://doi.org/10.5194/tc-12-2073-2018, 2018.",
            date_created = "2023-06-13",
            cdm_data_type = "Grid",
            spatial_resolution = "25.0 km grid spacing",
            algorithm = "lagrangian_sea_ice_age_v2p1",
            geospatial_bounds_crs = "EPSG:6931",
            contributor_name = "Anton Korosov, Leo Edel, Laurent Bertino",
            contributor_role = "Author, Assistant, PrincipalInvestigator",
            naming_authority = "NERSC",
            Conventions = "CF-1.7 ACDD-1.3",
            standard_name_vocabulary = "CF Standard Name Table (Version 78, 21 September 2021)",
            product_name = "nersc_arctic_sea_ice_age_climate_data_record",
            product_id = "arctic25km_sea_ice_age_v2p1",
            product_version = "v2.1",
        )
        for key, value in global_attributes.items():
            self.setncattr(key, value)

    def set_variable(self, vname, data, dims, atts, dtype=np.float32):
        """
        set variable data and attributes
        Parameters:
        -----------
        vname : str
            name of new variable
        data : numpy.ndarray
            data to set in variable
        dims : list(str)
            list of dimension names for the variable
        atts : dict
            netcdf attributes to set
        dtype : type
            netcdf data type for new variable (eg np.float32 or np.double)
        """
        ncatts = {k:v for k,v in atts.items() if k != '_FillValue'}
        kw = dict(zlib=True)# use compression
        if '_FillValue' in atts:
            # needs to be a keyword for createVariable and of right data type
            kw['fill_value'] = dtype(atts['_FillValue'])
        if 'missing_value' in atts:
            # needs to be of right data type
            ncatts['missing_value'] = dtype(atts['missing_value'])
        dst_var = self.createVariable(vname, dtype, dims, **kw)
        ncatts['grid_mapping'] = self.grid_mapping_variable
        dst_var.setncatts(ncatts)
        dst_var[0] = data



In [11]:
sic_cdr_dir = '/Data/sim/data/OSISAF_ice_conc_CDR_v3p0'
age_grd_dir = f'{lag_dir}/age_grd'
dst_root_dir = f'{lag_dir}/nc'

template_file = f'{sic_cdr_dir}/1991/01/ice_conc_nh_ease2-250_cdr-v3p0_199101011200.nc'

time_atts = {}
with Dataset(template_file) as template_ds:
    xc = template_ds['xc'][:]
    yc = template_ds['yc'][:]
    lon = template_ds['lon'][:]
    lat = template_ds['lat'][:]
    time_var = template_ds['time']
    for key in time_var.ncattrs():
        time_atts[key] = time_var.getncattr(key)
    status_flag = template_ds['status_flag'][0]

status_flag = (status_flag == 1).astype(int)

age_atts = dict(
    standard_name = 'age_of_sea_ice',
    long_name = 'Weighted Average of Sea Ice Age',
    name = 'sia',
    ancillary_variables = 'status_flag',
    comment = 'The weighted average is computed over all available fractions.',
)

conc_atts = dict(
    long_name = "Concentration of $Numeral$ Year Sea Ice",
    name = "conc_$YEAR$yi",
    units = "1",
    standard_name = "$numeral$_year_sea_ice_area_fraction",
)

status_atts = dict(
    long_name = "status flag array for sea ice age",
    standard_name = "age_of_sea_ice status_flag",
    valid_min = np.byte(0),
    valid_max = np.byte(2),
    grid_mapping = "Lambert_Azimuthal_Grid",
    coordinates = "lat lon",
    flag_masks = (np.byte(0), np.byte(1), np.byte(2)),
    flag_meanings = "nominal land invalid",
    flag_descriptions = ("\n"
        "flag = 0: Nominal retrieval by the SIA algorithm\n"
        "flag = 1: Position is over land\n"
        "flag = 2: Pixel is invalid\n"),
)

numerals = [None, 'first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 'eighth']

In [12]:
force = False
age_grd_files = sorted(glob.glob(f'{age_grd_dir}/*/age_????????_grd.npz'))
source_CDR = "Global Sea Ice Drift Climate Data Record Version 1 from the EUMETSAT OSI SAF, \n Sea Ice Concentration Climate Data Record Version 3 from the EUMETSAT OSI SAF"
source_iCDR = "Daily Low Resolution Sea Ice Displacement from OSI SAF EUMETSAT (OSI-405), \n Sea Ice Concentration Interim Climate Data Record Version 3 from the EUMETSAT OSI SAF"

title_CDR = "Arctic Sea Ice Age Climate Data Record Version 2.1 from NERSC"
title_iCDR = "Arctic Sea Ice Age Interim Climate Data Record Version 2.1 from NERSC"

iCDR_start_date = dt.datetime(2021,1,1)

def export_netcdf(age_grd_file):
    age_grd_date = dt.datetime.strptime(os.path.basename(age_grd_file).split('_')[1], '%Y%m%d')
    dst_dir = age_grd_date.strftime(f'{dst_root_dir}/%Y')
    os.makedirs(dst_dir, exist_ok=True)
    dst_file = age_grd_date.strftime(f'{dst_dir}/arctic25km_sea_ice_age_v2p0_%Y%m%d.nc')
    if os.path.exists(dst_file) and not force:
        return

    time_data = np.array([(dt.datetime(age_grd_date.year, age_grd_date.month, age_grd_date.day, 12) - dt.datetime(1978,1,1)).total_seconds()], float)

    age_grd = dict(np.load(age_grd_file))
    age_grd_vars = list(age_grd.keys())
    status_flag[(status_flag == 0) * np.isnan(age_grd['age'])] = 2

    with SeaIceAgeDataset(dst_file, 'w') as ds:
        if age_grd_date < iCDR_start_date:
            ds.global_attributes_source = source_CDR
            ds.global_attributes_title = title_CDR
        else:
            ds.global_attributes_source = source_iCDR
            ds.global_attributes_title = title_iCDR

        ds.set_projection_variable()
        ds.set_xy_dims(xc*1000, yc*1000)
        #ds.set_lonlat(lon, lat)
        ds.set_global_attributes(age_grd_date)
        ds.set_time_variable(time_data, time_atts)

        ds.set_variable('sia', age_grd['age'], ('time', 'y', 'x'), age_atts, dtype=np.float32)

        for age_grd_var in age_grd_vars:
            if 'fraction' in age_grd_var:
                year = age_grd_var.split('_')[1][0]
                numeral = numerals[int(year)]
                frac_atts = {}
                for key, value in conc_atts.items():
                    frac_atts[key] = value.replace('$Numeral$', numeral.title()).replace('$YEAR$', year).replace('$numeral$', numeral)
                ds.set_variable(frac_atts['name'], age_grd[age_grd_var][None]/100., ('time', 'y', 'x'), frac_atts, dtype=np.float32)

        ds.set_variable('status_flag', status_flag, ('time', 'y', 'x'), status_atts, dtype=np.byte)

export_netcdf(age_grd_files[0])
export_netcdf(age_grd_files[-1])

with Pool(4) as p:
    r = list(tqdm.tqdm(p.imap(export_netcdf, age_grd_files, chunksize=10), total=len(age_grd_files)))


  0%|          | 0/12162 [00:00<?, ?it/s]

100%|██████████| 12162/12162 [00:10<00:00, 1130.98it/s] 
