In [None]:
import datetime as dt
import sys
sys.path.insert(0, '/Home/antonk/py/geodataset')
import glob
from collections import defaultdict

from netCDF4 import Dataset
import numpy as np
import scipy.ndimage as nd
from skimage.util import view_as_windows
import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature


import matplotlib.pyplot as plt
%matplotlib inline
from pyresample import bilinear
from pyresample import kd_tree
from pyresample.utils.cf import load_cf_area
from pyresample.geometry import AreaDefinition, GridDefinition

from scipy.interpolate import RectBivariateSpline

import cmocean

from geodataset.tools import open_netcdf

In [None]:
def nan_filter(arr_in, size, func=np.nanmedian):
    ''' Nanmedian filter to image '''
    gap = int(size / 2)
    arr_out = np.zeros_like(arr_in) + np.nan
    view = view_as_windows(arr_in, (size, size))
    arr_out[gap:-gap,gap:-gap] = func(view, axis=(2,3))
    return arr_out

def circ_integral_r_d_lu(a, r, s):
    dadx, dady = np.zeros(a.shape) + np.nan, np.zeros(a.shape) + np.nan
    # circular integral over a triangle
    a0 = a[:-s, :-s] #upper left
    a1 = a[:-s, s:]  #upper right
    a2 = a[s:, s:]   #lower right
    r2 = np.hypot(r,r)
    #                 right          down          left/up
    dadx[:-s, :-s] = +r * (a1 - a0) +0 * (a2 - a1) -r2 * (a0 - a2)
    dady[:-s, :-s] = +0 * (a1 - a0) -r * (a2 - a1) +r2 * (a0 - a2)
    
    return dadx, dady, r*s*r*s/2

def circ_integral_r_d_l_u(a, r, s):
    dadx, dady = np.zeros(a.shape) + np.nan, np.zeros(a.shape) + np.nan
    # circular integral over a square
    a0 = a[:-s, :-s] #upper left
    a1 = a[:-s, s:]  #upper right
    a2 = a[s:, s:]   #lower right
    a3 = a[s:, :-s]   #lower left
    #                right          down           left           up
    dadx[:-s, :-s] = +r * (a1 - a0) +0 * (a2 - a1) -r * (a3 - a2) +0 * (a0 - a3)
    dady[:-s, :-s] = +0 * (a1 - a0) -r * (a2 - a1) +0 * (a3 - a2) +r * (a0 - a3)
    return dadx, dady, r*s*r*s

def get_deformation(func, u, v, r, dt=24*60*60, s=2):
    dudx, dudy, a = func(u, r, s)
    dvdx, dvdy, a = func(v, r, s)

    # scale to 1/d
    dudx, dudy, dvdx, dvdy = [dd * dt / 2. / a for dd in [dudx, dudy, dvdx, dvdy]]

    e1 = dudx + dvdy
    e2 = np.hypot(dudx-dvdy, dudy+dvdx)
    e3 = dudy - dvdx

    return e1, e2, e3
    

In [None]:
def get_averaged_sar_defor(ifile):
    # shear is wrongly named vorticity in CMEMS/DTU products
    names = ['divergence', 'vorticity']
    mask = ifile.replace('000000.nc', '*.nc')
    ifiles = glob.glob(mask)
    idata = defaultdict(list)
    for ifile in ifiles:
        ds = Dataset(ifile)
        for name in names:
            idata[name].append(ds[name][0].filled(np.nan))
    odata = []
    for name in names:
        odata.append(np.nanmean(np.dstack(idata[name]), axis=2))
    return odata


In [None]:
odir = '/Data/sim/antonk/sat_data_4cnn'

ice_conc_dir = '/Data/sim/data/OSISAF_ice_conc_amsr/'
ice_conc_format = f'{ice_conc_dir}/ice_conc_nh_polstere-100_amsr2_%Y%m%d*.nc'
ice_conc_timedelta = 0

ice_thic_dir = '/Data/sim/data/CS2_SMOS_v2.3/'
ice_thick_format = f'{ice_thic_dir}/W_XX-ESA,SMOS_CS2,NH_25KM_EASE2_%Y%m%d*_l4sit.nc'
ice_thick_timedelta = -3

ice_drift_dir = '/Data/sim/data/OSISAF_ice_drift/'
ice_dirft_format = f'{ice_drift_dir}/*/*/ice_drift_nh_polstere-625_multi-oi_%Y%m%d*.nc'
ice_drift_timedelta = -1

sar_drift_dir = '/Data/sim/data/SEAICE_GLO_SEAICE_L4_NRT_OBSERVATIONS_011_006/cmems_sat-si_glo_drift_nrt_north_d'
sar_dirft_format = f'{sar_drift_dir}/ice_drift_mosaic_polstereo_sarbased_north_*-%Y%m%d000000.nc'
sar_drift_timedelta = 0

formats = [ice_conc_format, ice_thick_format, ice_dirft_format, sar_dirft_format]
timedeltas = [ice_conc_timedelta, ice_thick_timedelta, ice_drift_timedelta, sar_drift_timedelta]

date_start = dt.datetime(2021,1,1)
date_stop = dt.datetime(2021,1,31)
date_step = 1
date0 = date_start
ifiles = []
idates = []
while date0 <= date_stop:
    date_files = []
    for fmt, delta in zip(formats, timedeltas):
        dst_date = date0 + dt.timedelta(delta)
        file_mask = dst_date.strftime(fmt)
        files = glob.glob(file_mask)
        if files:
            date_files.append(files[0])
            file_exists = True
        else:
            file_exists = False
            break
    if file_exists:
        ifiles.append(date_files)
        idates.append(date0)
    date0 += dt.timedelta(date_step)

In [None]:
seconds_in_day = 24 * 60 * 60
factor = 1000 / 2 / seconds_in_day

sid_res = 62500

n0_datas = []

n0_area_raw = load_cf_area(ifiles[0][0])[0]
n0_area = AreaDefinition('area_id', 'descr', 'proj_id', n0_area_raw.proj4_string + ' +lat_ts=70', n0_area_raw.width, n0_area_raw.height, n0_area_raw.area_extent)


dst_xmin = -3000000
dst_ymin = -3000000
dst_xmax = 3000000
dst_ymax = 3000000
dst_res = 6000
dst_width = (dst_xmax - dst_xmin) / dst_res
dst_height = (dst_ymax - dst_ymin) / dst_res
dst_extent = [dst_xmin, dst_ymin, dst_xmax, dst_ymax]
dst_area = AreaDefinition('area_id', 'descr', 'proj_id', n0_area_raw.proj4_string + ' +lat_ts=70', dst_width, dst_height, dst_extent)
print(dst_area)


In [None]:
data_pros = []
for ifile0, ifile1, ifile2, ifile3 in ifiles:
    print(ifile0)
    n0 = open_netcdf(ifile0)
    n1 = open_netcdf(ifile1)
    n2 = open_netcdf(ifile2)

    n3_area_raw = load_cf_area(ifile3)[0]
    n3_area = AreaDefinition('area_id', 'descr', 'proj_id', "+proj=stere +R=6370997 +lat_0=90 +lat_ts=70 +lon_0=0", n3_area_raw.width, n3_area_raw.height, n3_area_raw.area_extent)

    sic = n0.get_var('ice_conc')[:].filled(np.nan)

    sit = n1.get_var('analysis_sea_ice_thickness')[:].filled(np.nan)

    u = n2.get_var('dX')[:].filled(np.nan) * factor
    v = n2.get_var('dY')[:].filled(np.nan) * factor
    u = nan_filter(u, 3)
    v = nan_filter(v, 3)
    e1, e2, e3 = get_deformation(circ_integral_r_d_l_u, u, v, sid_res, s=2)
    
    e1_sar, e2_sar = get_averaged_sar_defor(ifile3)

    data_pro = [
        kd_tree.resample_gauss(n0_area, sic, dst_area, 20000, 10000, fill_value=np.nan),
        kd_tree.resample_gauss(n1.area, sit, dst_area, 60000, 30000, fill_value=np.nan),
        kd_tree.resample_gauss(n2.area, e1, dst_area, 150000, 50000, fill_value=np.nan),
        kd_tree.resample_gauss(n2.area, e2, dst_area, 150000, 50000, fill_value=np.nan),
        kd_tree.resample_gauss(n3_area, e1_sar, dst_area, 20000, 10000, fill_value=np.nan) * seconds_in_day,
        kd_tree.resample_gauss(n3_area, e2_sar, dst_area, 20000, 10000, fill_value=np.nan) * seconds_in_day,
    ]
    data_pros.append(data_pro)
        

In [None]:
clims = [
    [0, 100],
    [0, 3],
    [-0.05, 0.05],
    [0, 0.05],
    [-0.15, 0.15],
    [0, 0.15],
]
cmaps = [
    'jet',
    'jet',
    cmocean.cm.balance,
    'jet',
    cmocean.cm.balance,
    'jet',
]

dst_crs = ccrs.NorthPolarStereo(central_longitude=-45, true_scale_latitude=70)
img_extent = [dst_xmin, dst_xmax, dst_ymax, dst_ymin]
fig_xlim = [dst_xmin, dst_xmax]
fig_ylim = [dst_ymin, dst_ymax]


for idate, data_pro in zip(idates[:2], data_pros[:2]):
    print(idate)
    fig, ax  = plt.subplots(1,len(data_pro), figsize=(30,10), subplot_kw=dict(projection=dst_crs))
    for i, array in enumerate(data_pro):
        ax[i].imshow(array, clim=clims[i], cmap=cmaps[i], extent=img_extent, interpolation='nearest')

    for a in ax:
        a.add_feature(cfeature.LAND, zorder=10, edgecolor='black')
        a.set_xlim(fig_xlim)
        a.set_ylim(fig_ylim)
    plt.show()

In [None]:
for idate, data_pro in zip(idates, data_pros):
    ofilename = f'{odir}/sic_sit_def_{idate.strftime("%Y%d%m")}.npz'
    print(ofilename)
    np.savez(ofilename, sic=data_pro[0], sit=data_pro[1], divergence=data_pro[2], shear=data_pro[3], divergence_sar=data_pro[4], shear_sar=data_pro[5])
