# Load packages

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
# System libraries
import os
import sys
import h5py
import configparser
import json

#Standard libraries
import numpy as np
import pandas as pd
import geopandas as gpd
import xarray as xr
import rasterio as rs
import rioxarray as rioxr

#For geometries
import shapely
from shapely import box, LineString, MultiLineString, Point, Polygon, LinearRing
from shapely.geometry.polygon import orient

#For REMA
from rasterio import plot
from rasterio.mask import mask
from rasterio.features import rasterize

#Datetime
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from dateutil.relativedelta import relativedelta
import time

#For plotting, ticking, and line collection
from matplotlib import cm 
import matplotlib
import matplotlib.ticker as ticker
import matplotlib.pylab as plt
import matplotlib.gridspec as gridspec
from matplotlib.collections import LineCollection
import matplotlib.colors as mcolors
from matplotlib.colors import LightSource, LinearSegmentedColormap
from matplotlib.lines import Line2D
import cmcrameri.cm as cmc
import contextily as cx
import earthpy.spatial as es
# for legend
from matplotlib.patches import Rectangle
from matplotlib.legend_handler import HandlerTuple

#Personal and application specific utilities
from utils.nsidc import download_is2
#from utils.S2 import plotS2cloudfree, add_inset, convert_time_to_string
from utils.utilities import is2dt2str
import pyTMD

#For error handling
import shutil
import traceback

#For raster
from rasterio.transform import from_origin

# not in use 
from ipyleaflet import Map, basemaps, Polyline, GeoData, LayersControl
from rasterio import warp
from rasterio.crs import CRS

# Specify the region, load and parse single shape file

In [None]:
ini = 'config/C-Cp_all_lumos.ini'

######## Load variables ###########
# Create a ConfigParser object
config = configparser.ConfigParser()

# Read the configuration file
config.read(ini)

#os and pyproj paths
gdal_data = config.get('os', 'gdal_data')
proj_lib = config.get('os', 'proj_lib')
proj_data = config.get('os', 'proj_data')

#path params
basin = config.get('data', 'basin')
region = config.get('data', 'region')
shape = f'shapes/{basin}_{region}.shp'
output_dir = config.get('data', 'output_dir')
rema_path = config.get('data', 'rema_path')
try: plot_dir = config.get('data', 'plot_dir')
except: plot_dir='plots'

#access params
uid = config.get('access', 'uid')
pwd = config.get('access', 'pwd')
email = config.get('access', 'email')

#Print results
os.environ["GDAL_DATA"] = gdal_data # need to specify to make gdal work
os.environ["PROJ_LIB"] = proj_lib # need to specify to make pyproj work
os.environ["PROJ_DATA"] = proj_data # need to specify to make pyproj work

print(f"basin: {basin}")
print(f"region: {region}")
print(f"output_dir: {output_dir}")

print(f"uid: {uid}")
print(f"pwd: {'*'*len(pwd)}")
print(f"email: {email}")

# Make Shapes

In [None]:
#Make shapes
####
crs_antarctica = 'EPSG:3031'
crs_latlon = 'EPSG:4326'
short_name = 'ATL11'
# Read shapefile into gdf for everything
gdf = gpd.read_file(shape).set_crs(crs_latlon, allow_override=True).to_crs(crs_antarctica)

# Separate by entry type
gdf_fl = gdf[gdf.Id_text=='Ice shelf']
gdf_pp = gdf[(gdf.Id_text=='Ice rise or connected island')]
gdf_ext = gpd.GeoDataFrame(geometry=[gdf_fl.apply(lambda p: Polygon(p.geometry.exterior.coords), axis=1).unary_union.union(gdf_pp.unary_union)],
    crs=crs_antarctica).explode(ignore_index=True)
gdf_gr = gdf[gdf.Id==1]
gdf_ext_all = gpd.GeoDataFrame(geometry=[gdf_ext.unary_union.union(gdf_gr.unary_union)], crs=crs_antarctica)

# gdf for subsetting
#gdf = gdf[~(gdf.Id==1)]
gdf = gpd.GeoDataFrame(geometry=[gdf.unary_union], crs=crs_antarctica)

# plot the geometries
fig1, ax1 = plt.subplots(figsize=[10, 8])
gdf.convex_hull.plot(ax=ax1, color='None', edgecolor='black')
gdf_ext.plot(ax=ax1, color='yellow', edgecolor='black')
gdf_fl.plot(ax=ax1, color='red', edgecolor='black')
gdf_pp.plot(ax=ax1, color='lightblue', edgecolor='black')

# Define a bunch of functions

In [10]:
def download_data():
    uid,pw,eml = getedcreds()
    download_is2(short_name='ATL11', uid=uid, pwd=pw, email=eml, output_dir=output_dir, shape=shape, shape_subset=shape)
    print('saved files to %s' % output_dir)
    return output_dir

def get_file_info():
    search_for = '%s_' % short_name
    search_in = output_dir + '/'
    filelist = [search_in+f for f in os.listdir(search_in) \
                if os.path.isfile(os.path.join(search_in, f)) & (search_for in f) & ('.h5' in f)]
    filelist.sort()
    print('There are %i files.' % len(filelist))
    
    dirdict = dict([(x,'ascending') for x in [1,2,3,12,13,14]] + \
                   [(x,'descending') for x in [5,6,7,8,9,10]] + \
                   [(x,'turning') for x in [4,11]])
    df_files = pd.DataFrame({'filename': filelist})
    df_files['granule_id'] = df_files.apply(lambda x: x.filename[x.filename.rfind(search_for):], axis=1)
    df_files['tides_filename'] = df_files.apply(lambda x: f'{output_dir}/tides/ATL11_CATS2008-v2023_TIDES_{x.granule_id[6:]}', axis=1)
    df_files['track'] = df_files.apply(lambda x: int(x.granule_id[6:10]), axis=1)
    df_files['region'] = df_files.apply(lambda x: int(x.granule_id[10:12]), axis=1)
    df_files['direction'] = df_files.apply(lambda x: dirdict[x.region], axis=1)
    df_files['cycles'] = df_files.apply(lambda x: '%s-%s' % (x.granule_id[13:15],x.granule_id[15:17]), axis=1)
    df_files['version'] = df_files.apply(lambda x: int(x.granule_id[18:21]), axis=1)
    df_files['release'] = df_files.apply(lambda x: int(x.granule_id[22:24]), axis=1)
    return df_files

def getedcreds():
    # change your credentials here, do not push them to github! 
    uid = uid
    pwd = pwd
    email = email

    # to print a message if they haven't been changed
    if uid == '<your_nasa_earthdata_user_id>':
        print('\n WARNING: YOU NEED TO SET UP YOUR NASA EARTHDATA CREDENTIALS TO DOWNLOAD ICESAT-2 DATA!\n')
        print('  update the info in ed/edcreds.py :\n')
        print("  def getedcreds():")
        print("    # change your credentials here, do not push them to github!")
        print("    uid = '<your_nasa_earthdata_user_id>'")
        print("    pwd = '<your_nasa_earthdata_password>'")
        print("    email = '<your_nasa_earthdata_account_email>'")
        return None
    else:
        return uid, pwd, email

def is2dt2str(lake_mean_delta_time):
    lake_mean_delta_time = np.mean(lake_mean_delta_time)
    if np.isnan(lake_mean_delta_time) | (lake_mean_delta_time == np.inf):
        return np.nan
    else:
        ATLAS_SDP_epoch_datetime = datetime(2018, 1, 1, tzinfo=timezone.utc)
        ATLAS_SDP_epoch_timestamp = datetime.timestamp(ATLAS_SDP_epoch_datetime)
        lake_mean_timestamp = ATLAS_SDP_epoch_timestamp + lake_mean_delta_time
        lake_mean_datetime = datetime.fromtimestamp(lake_mean_timestamp, tz=timezone.utc)
        time_format_out = '%Y-%m-%dT%H:%M:%SZ'
        is2time = datetime.strftime(lake_mean_datetime, time_format_out)
        return is2time

def set_axis_color(ax, axcolor):
    ax.spines['bottom'].set_color(axcolor)
    ax.spines['top'].set_color(axcolor) 
    ax.spines['right'].set_color(axcolor)
    ax.spines['left'].set_color(axcolor)
    ax.tick_params(axis='x', colors=axcolor)
    ax.tick_params(axis='y', colors=axcolor)
    ax.yaxis.label.set_color(axcolor)
    ax.xaxis.label.set_color(axcolor)
    ax.title.set_color(axcolor)

def get_ground_tracks(datadict):
    crs_latlon = 'EPSG:4326'
    gts = []
    for k in datadict.keys():
        ds = datadict[k]
        gdf_gt = gpd.GeoDataFrame(geometry=gpd.points_from_xy(ds.longitude, ds.latitude), crs=crs_latlon)
        #for 3d geometry
        #gdf_gt = gpd.GeoDataFrame(geometry=gpd.points_from_xy(ds.longitude, ds.latitude, ds.h_ano.sel(track=ds.track.data[0], pt=k).mean(dim='cycle_number')), crs=crs_latlon)
        gdf_gt['pt'] = k
        gts.append(gdf_gt)
    gdf_gts = gpd.GeoDataFrame(geometry=pd.concat(gts).groupby(['pt'])[['geometry']].apply(lambda x: LineString(x.geometry.tolist()))
        ).reset_index().set_crs(crs_latlon)
    colordict = {'col0': 'darkblue', 'col1': 'rebeccapurple', 'col2': 'palevioletred', 'col3': 'thistle'}
    gdf_gts['plotcolor'] = gdf_gts.apply(lambda x: colordict['col%s' % (int(x.pt[2])-1)], axis=1)
    gdf_gts['track'] = ds.track.data[0]
    
    return gdf_gts

def read_atl11(filename, track, verbose=False):
    if verbose: print(f'reading track: {track}')
    with h5py.File(filename, 'r') as f:
        datadict = {}
        pts = [x for x in f.keys() if 'pt' in x]
        for pt in pts:
            try:
                vars_data = ['delta_time', 'h_corr', 'h_corr_sigma', 'h_corr_sigma_systematic', 'quality_summary']
                vars_coords = ['cycle_number', 'latitude', 'longitude','ref_pt']
                ds = xr.Dataset({**{v: (['x', 'cycle_number', 'track', 'pt'], f[pt][v][()][:, :, np.newaxis, np.newaxis]) for v in vars_data},
                    'geoid': (['x', 'track', 'pt'], f[pt]['ref_surf/geoid_h'][()][:, np.newaxis, np.newaxis])},
                    coords={'cycle_number': f[pt]['cycle_number'][()],
                    **{v : ('x', f[pt][v][()]) for v in vars_coords[1:]}})
                ds.coords['x'], ds['track'], ds['pt'] = np.arange(len(ds.x)), [track], [pt]
                ds = ds.assign_coords(x_atc=('x', np.arange(len(f[pt]['latitude'][()])) * 60))
                h_arr = np.array(ds.h_corr-ds.geoid)# go to numpy for 2-d boolean indexing
                h_arr[ds.quality_summary>0] = np.nan
                h_arr[(h_arr>2.5e3)+(h_arr<-50)] = np.nan
                ds['h_corr'] = (ds.h_corr.dims, h_arr)
                ds['h_corr'] = ds.h_corr+ds.geoid
                datadict[pt] = ds
            #except KeyError as e:
            #    print(f"KeyError: The key {e} was not found in the data source.")
            #except ValueError as e:
            #    print(f"ValueError: {e}")
            #except Exception as e:
            #    print(f"An unexpected error occurred: {e}")
            except: continue
    return datadict

def read_atl11_tides(filename, track):
    with h5py.File(filename, 'r') as f:
        tidedict = {}
        pts = [x for x in f.keys() if 'pt' in x]
        for pt in pts:
            vars_data = ['delta_time', 'cycle_stats/tide_ocean']
            vars_coords = ['cycle_number', 'latitude', 'longitude', 'ref_pt']
            ds = xr.Dataset({v: (['x', 'cycle_number', 'track', 'pt'], f[pt][v][()][:, :, np.newaxis, np.newaxis]) for v in vars_data}, 
                    coords={'cycle_number': f[pt]['cycle_number'][()],
                    **{v : ('x', f[pt][v][()]) for v in vars_coords[1:]}})
            ds.coords['x'], ds['track'], ds['pt'] = np.arange(len(ds.x)), [track], [pt]
            ds = ds.assign_coords(x_atc=('x', np.arange(len(f[pt]['latitude'][()])) * 60))
            ds = ds.rename({'cycle_stats/tide_ocean': 'tide_cats'})
            tide_cats = np.array(ds.tide_cats) # go to numpy for 2-d boolean indexing
            tide_cats[tide_cats>1e5]=np.nan
            ds['tide_cats'] = (ds.tide_cats.dims, tide_cats)
            #ds['tide_cats'] = ds.tide_cats.interpolate_na(dim='x', method='linear').interpolate_na(dim='x', method='nearest', fill_value='extrapolate')
            ds['tide_cats'] = ds.tide_cats.interpolate_na(dim='x', method='linear').fillna(0.0)
            tidedict[pt] = ds
    return tidedict

def get_data(track, verbose=False):
    # get the data
    filename = df_files[df_files.track == track].filename.iloc[0]
    tides_filename = df_files[df_files.track == track].tides_filename.iloc[0]
    datadict = read_atl11(filename, track, verbose)
    tidedict = read_atl11_tides(tides_filename, track)
    datadict = {pt: xr.merge([datadict[pt], tidedict[pt]], join='inner', compat='override') for pt in datadict}
    datadict = {pt: datadict[pt].assign(h_abs=(('x', 'cycle_number', 'track', 'pt'), 
        (datadict[pt]['h_corr']-datadict[pt]['geoid']-datadict[pt]['tide_cats']).data)) for pt in datadict}
    datadict = {pt: datadict[pt].assign(h_ano=(('x', 'cycle_number', 'track', 'pt'),
        (datadict[pt].h_abs-datadict[pt].h_abs.median(dim='cycle_number')).data)) for pt in datadict}
    gdf_gts = get_ground_tracks(datadict).to_crs(crs_antarctica)
    return datadict, gdf_gts

'''
def clip_data(datadict, gdf_gts, mask):
    #clip through rectangle is faster
    datadict_clipped = {}
    gdf_gts_clipped_list = []
    for pt in datadict:
        ds=datadict[pt]
        #select pt
        gdf_this = gdf_gts[gdf_gts.pt==pt]
        #convert linestring to points
        gdf_this_pts = gpd.GeoDataFrame(geometry=gdf_gts[gdf_gts.pt==pt].get_coordinates(ignore_index=True).apply(lambda l: Point(l), axis=1), 
            crs=crs_antarctica)
        gdf_clipped = gdf_this_pts.clip(mask.clip(gdf_this.bounds.values[0]))
        try: gdf_gts_clipped_list.append(gdf_this.clip(mask))
        except: continue
        datadict_clipped[pt] = ds.sel(x=(gdf_clipped.index))
    return datadict_clipped, pd.concat(gdf_gts_clipped_list, ignore_index=True)
'''

def clip_data(datadict, gdf_gts, mask):
    #clip through rectangle is faster
    datadict_clipped = {}
    gdf_gts_clipped_list = []
    for pt in datadict:
        ds=datadict[pt]
        #select pt
        gt_this = gdf_gts[gdf_gts.pt==pt]
        #convert linestring to points
        gt_this_pts = gpd.GeoDataFrame(geometry=gt_this.get_coordinates(ignore_index=True).apply(lambda l: Point(l), axis=1), 
            crs=crs_antarctica)
        gdf_clipped_index = gt_this_pts.clip(mask.clip(gt_this.bounds.values[0])).index
        try: gdf_gts_clipped_list.append(gt_this.clip(mask))
        except: continue
        datadict_clipped[pt] = ds.sel(x=(gdf_clipped_index))
    return datadict_clipped, pd.concat(gdf_gts_clipped_list, ignore_index=True)

def get_ds_dict(tracklist, mask=None):
    ds_list, gdf_gts_clipped_list = [], []
    for p in tracklist: 
        t, c = p[0], p[1]
        if mask is not None: datadict, gdf_gts_clipped = clip_data(datadict, gdf_gts_clipped, mask)
        ds_add = xr.concat([datadict[pt] for pt in datadict], dim='pt')
        ds_add['x'] = np.arange(len(ds_add.x))
        ds_list.append(ds_add.sortby('x'))
        gdf_gts_clipped_list.append(gdf_gts_clipped)
    print('generating dict and gdf')
    return {ds.track.values[0]: ds for ds in ds_list}, pd.concat(gdf_gts_clipped_list, ignore_index=True)

def combine_ds_dict(ds_dict):
    ds_list = []
    for t in ds_dict:
        ds = ds_dict[t]
        if len(ds.x)!=0: 
            ds['x'] = np.arange(len(ds.x))
            ds_list.append(ds)
    return xr.concat(ds_list, dim='track')

def get_stats(tracklist, mask=None):
    ds_list = []
    gdf_gts_list = []
    for t in tracklist:
        datadict, gdf_gts_clipped = get_data(t)
        # subset here
        if mask is not None: datadict, gdf_gts_clipped = clip_data(datadict, gdf_gts_clipped, mask)
        gdf_gts_list.append(gdf_gts_clipped)
        for pt in datadict:
            try:
                ds = datadict[pt]
                h_abs = ds.h_corr - ds.geoid - ds.tide_cats
                coords_dict = {'cycle_number': ds.cycle_number.data,
                 'track': ds.track.data,
                 'pt': ds.pt.data}
                stat_dict = {'h_min': h_abs.min(dim='x'), 
                 'h_max': h_abs.max(dim='x'), 
                 'h_sum': h_abs.sum(dim='x'),
                 'h_mean': h_abs.mean(dim='x'), 
                 'h_med': h_abs.median(dim='x'), 
                 'h_ano': ds.h_ano.median(dim='x'),
                 'h_std': h_abs.std(dim='x', skipna=True, ddof=1),
                 'h_var': h_abs.var(dim='x', skipna=True, ddof=1), 
                 't_count': h_abs.count(dim='x'), 
                 'pct_nan': h_abs.count(dim='x')/h_abs.sizes['x'], 
                 't_dist': h_abs.count(dim='x')*0+h_abs.x_atc.max()/1000,
                 'tide_min': ds.tide_cats.min(dim='x'),
                 'tide_max': ds.tide_cats.max(dim='x'),
                 'tide_mean': ds.tide_cats.mean(dim='x'),
                 'tide_sum': ds.tide_cats.sum(dim='x')}
                dss = xr.Dataset({v: (['cycle_number', 'track', 'pt'], stat_dict[v].data) for v in stat_dict}, coords={v: coords_dict[v] for v in coords_dict})
                ds_list.append(dss)
            except: 
                continue
            #except KeyError as e:
            #    print(f'failed for track {t}, {pt}')
            #    print(f"KeyError: The key {e} was not found in the data source.")
            #except ValueError as e:
            #    print(f'failed for track {t}, {pt}')
            #    print(f"ValueError: {e}")
            #except Exception as e:
            #    print(f'failed for track {t}, {pt}')
            #    print(f"An unexpected error occurred: {e}")
    return xr.combine_by_coords(data_objects=ds_list), pd.concat(gdf_gts_list, ignore_index=True)
        
def reproject_raster(src, target_crs):
    '''
    Change crs of imported data
    '''
    
    src_crs = src.crs
    src_transform = src.transform
    src_width = src.width
    src_height = src.height

    # Define the target CRS
    target_crs = target_crs

    # Reproject the raster data to the target CRS
    reprojected_data, dst_transform = rs.warp.reproject(
        source=rs.band(src, 1),
        src_transform=src_transform,
        src_crs=src_crs,
        dst_crs=target_crs,
        resampling=rs.enums.Resampling.nearest)
    
    return reprojected_data, dst_transform, target_crs

def retry(num_attempts=1, sleep_time=5):
    for i in range(num_attempts): 
        try: 
            download_data()
            print(f'success on attempt {i}')
            return
        except Exception as e:
            if i==num_attempts-1:
                print(f"The following error occurred: {e}")
                traceback.print_exc()
            time.sleep(sleep_time)
    return


# Process the data

In [None]:
%%time
# comment out after downloading 
#retry(20, 20)
#download_data()

df_files = get_file_info()
df_files;

In [106]:
#track=26
track=727
#track=1077
track=392 #problem child
#track=376 # erases all the data, but maybe it should
track=1376 #c-cp
datadict, gdf_gts = get_data(track)
datadict, gdf_gts = clip_data(datadict, gdf_gts, gdf)
pt = 'pt2'
ds = datadict[pt]

In [None]:
%%time

# Get_stats
tracklist=[33, 26, 155, 194, 392, 727, 1077]
dss_short_fl, gdf_gts_short_all = get_stats(tracklist, gdf_ext)
dss_short_pp, gdf_gts_short_all_pp = get_stats(tracklist, gdf_pp)

In [None]:
%%time

# Get_ds_dict
tracklist=[33, 26, 155, 194, 392, 727, 1077]
ds_dict_short, gdf_gts_short_all = get_ds_dict(tracklist, gdf)
ds_dict_pp, gdf_gts_short_all_pp = get_ds_dict(tracklist, gdf_pp)

In [None]:
%%time
# Get_stats, a little bit slow
#tracklist=df_files.track
#dss, gdf_gts_all = get_stats(tracklist, gdf_ext)
#dss_pp, gdf_gts_all_pp = get_stats(tracklist, gdf_pp)

In [None]:
%%time

# Get_ds_dict, pretty fast
tracklist=df_files.track
ds_dict, gdf_gts_all = get_ds_dict(tracklist, gdf_ext_all)
ds_dict_gr, gdf_gts_gr = get_ds_dict(tracklist, gdf_gr)
ds_dict_fl, gdf_gts_fl = get_ds_dict(tracklist, gdf_fl)
ds_dict_pp, gdf_gts_pp = get_ds_dict(tracklist, gdf_pp)

# Combine for full dataset
ds_all = combine_ds_dict(ds_dict)
ds_gr = combine_ds_dict(ds_dict_gr)
ds_fl = combine_ds_dict(ds_dict_fl)
ds_pp = combine_ds_dict(ds_dict_pp)

In [13]:
# save
#netcdf
processed_out_dir = '/Volumes/nox2/Chance/processed_data/'
ds_all.to_netcdf(f'{processed_out_dir}/{basin}_ds_all.nc')
ds_gr.to_netcdf(f'{processed_out_dir}/{basin}_ds_gr.nc')
ds_fl.to_netcdf(f'{processed_out_dir}/{basin}_ds_fl.nc')
ds_pp.to_netcdf(f'{processed_out_dir}/{basin}_ds_pp.nc')

gdf_gts_all.to_file(f'{processed_out_dir}/{basin}_gdf_gts_all.shp')
gdf_gts_gr.to_file(f'{processed_out_dir}/{basin}_gdf_gts_gr.shp')
gdf_gts_fl.to_file(f'{processed_out_dir}/{basin}_gdf_gts_fl.shp')
gdf_gts_pp.to_file(f'{processed_out_dir}/{basin}_gdf_gts_pp.shp')

# Load Processed Data

In [None]:
processed_dir='/Volumes/nox2/Chance/processed_data/'
print('Reading netCDFs into XR datasets...', end='', flush=True)
ds_all = xr.open_dataset(f'{processed_dir}/{basin}_ds_all.nc')
ds_gr = xr.open_dataset(f'{processed_dir}/{basin}_ds_gr.nc')
ds_fl = xr.open_dataset(f'{processed_dir}/{basin}_ds_fl.nc')
ds_pp = xr.open_dataset(f'{processed_dir}/{basin}_ds_pp.nc')
print('DONE')

print('Reading ESRI shapefiles into geodataframes...', end='', flush=True)
gdf_gts_all = gpd.read_file(f'{processed_dir}/{basin}_gdf_gts_all.shp')
gdf_gts_gr = gpd.read_file(f'{processed_dir}/{basin}_gdf_gts_gr.shp')
gdf_gts_fl = gpd.read_file(f'{processed_dir}/{basin}_gdf_gts_fl.shp')
gdf_gts_pp = gpd.read_file(f'{processed_dir}/{basin}_gdf_gts_pp.shp')
print('DONE')

# Plot individuals

In [14]:
#Calculate value to plot
h_ano_lin = ds_all.h_ano.polyfit('cycle_number', deg=1, skipna=True);

In [605]:
## make plot function
def make_plot(color_by='', dpi=600, rolling=None, vlims=[-5, 5], save=False):
    # figure setup
    imagery_aspect = 1.4
    major_font_size = 12
    minor_font_size = 10
    line_w = 1.0
    
    # make figure and axes
    fig = plt.figure(figsize=[13,8], dpi=dpi)
    gs = fig.add_gridspec(3, 20)
    axs = [fig.add_subplot(gs[:, :9])]
    for i in range(3):
        axs.append(fig.add_subplot(gs[i, 10:20]))
    boxprops = dict(boxstyle='round', facecolor='white', alpha=0.5, edgecolor='none', pad=0.2)
    
    # plot the basemap and ground track
    # We do this using the package `contextily`, which provides basemaps for plotting in matplotlib. 
    # Here we use ESRI's WordImagery basemap. 
    ax = axs[0]

    with rs.open(filename) as src:
        dem = src.read()
        tr = src.transform
        bbox = box(*gdf_ext.total_bounds)
        src_masked, src_masked_tr = mask(src, shapes=[gdf_gr.geometry[0].intersection(bbox)], 
            crop=True, filled=False)
        src_hs = np.ma.masked_array(es.hillshade(src_masked[0].filled(np.nan)), mask=src_masked.mask[0])
        plot.show(src_hs, ax=ax, transform=src_masked_tr, cmap='Greys_r', 
            aspect='equal', vmin=-150, vmax=100)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_axis_off()
    
    #add ground tracks
    lc_list=[]
    if color_by is not None:
        for i in range(len(gdf_gts_all)):
            row = gdf_gts_all[i:i+1]
            coords = row.get_coordinates()
            cmap = cmc.vik_r
            t = row.track.iloc[0]
            pt = row.pt.iloc[0]
            # Prepare segments for LineCollection
            # Prepare segments for LineCollection
            dc_dx = np.gradient(coords.x)
            dc_dx_idx = np.abs(dc_dx)>1e2
            coords.loc[dc_dx_idx, 'x'] = np.nan
            points = np.array([coords.x, coords.y]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            # Create a LineCollection
            lc = LineCollection(segments, cmap=cmap, norm=plt.Normalize(vlims[0], vlims[1]))
            h_plot = h_ano_lin.sel(degree=1).sel(track=t, pt=pt).polyfit_coefficients
            if rolling is not None: 
                h_plot = np.array(pd.Series(h_plot).rolling(window=rolling, center=True, min_periods=1, 
                    win_type='gaussian').mean(std=rolling/3))
            lc.set_array(h_plot)
            lc.set_linewidth(line_w)
            lc_list.append(lc)
            ax.add_collection(lc)
    elif color_by is None:
        gdf_gts_all.plot(ax=ax, color=gdf_gts_all.plotcolor, linewidth=line_w)
        gdf_gts_pp.plot(ax=ax, color='lightblue', linewidth=line_w)
    
    hdls = []
   # gdf_ext_all.apply(lambda p: p.buffer(3e3)).plot(ax=ax, color='None', edgecolor='black', label='grounded ice', linewidth=1, zorder=500)
    gdf_pp.plot(ax=ax, color='None', edgecolor='black', label='grounded ice', linewidth=0.2, zorder=501)
    gdf_ext.apply(lambda p: p.buffer(1e3)).plot(ax=ax, color='None', edgecolor='royalblue', label='floating ice', linewidth=0.4, zorder=502)
    
    
    ax.axis('off')
    ax.text(0.5, 0.96, f'ICESat-2 ATL11 Height Change\n Basin {basin} 2019-2024', transform=ax.transAxes, 
            ha='center', va='center', fontsize=major_font_size, bbox=boxprops, zorder=502)
        
    
    ax=axs[1]
    hdls = []
    
    # mean values
    
    ano_cycle_fl = ds_fl.h_ano.median(dim=['x', 'track', 'pt'])
    ano_cycle_pp = ds_pp.h_ano.median(dim=['x', 'track', 'pt'])
    ax.axhline(y=0.0, linestyle='--', color='black')
    hdl = ax.plot(ano_cycle_fl.cycle_number, (ano_cycle_fl - ano_cycle_fl.isel(cycle_number=0)), color='red', label='floating ice')
    hdls.append(hdl)
    hdl = ax.plot(ano_cycle_pp.cycle_number, (ano_cycle_pp - ano_cycle_pp.isel(cycle_number=0)), '-', color='lightblue', label='pinning points')
    hdls.append(hdl)
    # starting at 2 is 2019
    ax.set_xlim([2, 24])
    ax.set_ylim([-0.9, 0.9])
    ax.set_xticks(ticks=np.arange(2, 24), minor=True)
    ax.set_xticks(ticks=np.arange(2, 24, 4), labels=[])
    ax.set_yticks(ticks=[-0.5, 0.5])
    ax.tick_params(axis='x', which='major', length=5, width=2)
    ax.tick_params(axis='x', which='minor', length=3, width=1)
    ax.tick_params(which='both', direction='in', zorder=1, labelsize=minor_font_size)
    ax.set_ylabel('height (m)', fontsize=minor_font_size, labelpad=-0.3)
    ax.text(0.5, 0.96, 'Median Relative Height', transform=ax.transAxes, ha='center', va='top', 
        fontsize=major_font_size, bbox=boxprops)
    ax.legend(loc='lower left', fontsize=minor_font_size)
    
    ax=axs[2]
    
    #anomaly (data) count
    ano_count_fl = ds_fl.h_ano.count(dim=['x', 'track', 'pt'])
    ano_count_pp = ds_pp.h_ano.count(dim=['x', 'track', 'pt'])
    
    hdls = []
    ax.axhline(y=1.0, linestyle='--', color='black')
    hdl = ax.plot(ano_cycle_fl.cycle_number, ano_count_fl/ano_count_fl.median(), color='red', label=f'mean = {int(ano_count_fl.mean().data)} values')
    hdls.append(hdl)
    hdl = ax.plot(ano_cycle_pp.cycle_number, ano_count_pp/ano_count_pp.median(), '-', color='lightblue', label=f'mean = {int(ano_count_pp.mean().data)} values')
    hdls.append(hdl)
    ax.set_xlim([2, 24])
    ax.set_ylim([-0.0, 1.75])
    ax.set_yticks([0.5, 1.0, 1.5])
    ax.set_ylabel('count fraction', fontsize=minor_font_size)
    ax.set_xticks(ticks=np.arange(2, 24), minor=True)
    ax.set_xticks(ticks=np.arange(2, 24, 4), labels=[f'{int(c)}' for c in np.arange(2019, 2025)])
    ax.tick_params(axis='x', which='major', length=5, width=2, labelsize=minor_font_size)
    ax.tick_params(axis='x', which='minor', length=3, width=1)
    ax.tick_params(which='both', direction='in', zorder=1, labelsize=minor_font_size)
    ax.text(0.5, 0.96, 'Relative Elevation Count', transform=ax.transAxes, ha='center', va='top', fontsize=major_font_size, bbox=boxprops)
    ax.legend(loc='lower left', fontsize=minor_font_size)
    
    # remove extra axes if no data
    axs[3].axis('off')
    
    # Create a ScalarMappable with the same colormap and normalization
    norm = mcolors.Normalize(vmin=vlims[0], vmax=vlims[1])
    sm = cm.ScalarMappable(cmap=cmc.vik_r, norm=norm)
    sm.set_array([])  # Only needed for older versions of matplotlib
    
    
    
    # Add the inset axis
    pos = axs[0].get_position()
    cax_pos = [pos.x0+pos.width*0.63, pos.y0+0.13, pos.width*0.25, pos.height*0.02]
    cax = fig.add_axes(cax_pos)
    #fig.patches.append(Rectangle((cax_pos[0]-, cax_pos[1]), cax_pos[2]*1.25, cax_pos[3]*2,
    #    transform=fig.transFigure, color='white', zorder=1))
    
    # Add a colorbar to the plot, with a specific location
    cbar = plt.colorbar(sm, cax=cax, orientation='horizontal', fraction=0.1, pad=0.0)
    cbar.ax.tick_params(labelsize=minor_font_size) 
    cbar.set_label('height change \n(m yr$^{-1}$)', fontsize=minor_font_size)
    cbar.outline.set_edgecolor('black')
    cbar.outline.set_linewidth(0.5)
    set_axis_color(cax, 'black')
    
    # Customize the colorbar ticks if needed
    cbar.set_ticks([-1, 0, 1])
    #cbar.set_ticklabels(['0', '0.2', '0.4', '0.6', '0.8', '1'])
    
    plotname = f'/Users/ccroberts/Desktop/{basin}_dhdt.png'
    if rolling is not None: plotname = f'/Users/ccroberts/Desktop/{basin}_dhdt_avg{rolling}.png'
    if save: fig.savefig(plotname, dpi=dpi, bbox_inches='tight', transparent=False)
    
    plt.close(fig)

    return fig

In [None]:
%%time
fig = make_plot(color_by='', dpi=400, vlims=[-1, 1], rolling=None, save=False)
display(fig)

In [626]:
## make plot function
def make_map(color_by='', dpi=600, rolling=None, vlims=[-5, 5], save=False, transparent=True):
    # figure setup
    major_font_size = 12
    minor_font_size = 10
    line_w = 0.5
    
    # make figure and axes
    fig, ax = plt.subplots(figsize=[13,8], dpi=dpi)
    
    # plot the basemap and ground tracks

    with rs.open(rema_path) as src:
        dem = src.read()
        tr = src.transform
        bbox = box(*gdf_ext_all.total_bounds)
        src_masked, src_masked_tr = mask(src, shapes=[gdf_ext_all.geometry[0]], 
            crop=True, filled=False)
        src_hs = np.ma.masked_array(es.hillshade(src_masked[0].filled(np.nan)), mask=src_masked.mask[0])
        plot.show(src_hs, ax=ax, transform=src_masked_tr, cmap='Greys_r', 
            aspect='equal', vmin=-150, vmax=100)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_axis_off()
    
    #add ground tracks
    lc_list=[]
    if color_by is not None:
        for i in range(len(gdf_gts_all)):
            row = gdf_gts_all[i:i+1]
            coords = row.get_coordinates()
            cmap = cmc.vik_r
            t = row.track.iloc[0]
            pt = row.pt.iloc[0]
            # Prepare segments for LineCollection
            # Prepare segments for LineCollection
            dc_dx = np.gradient(coords.x)
            dc_dx_idx = np.abs(dc_dx)>1e2
            coords.loc[dc_dx_idx, 'x'] = np.nan
            points = np.array([coords.x, coords.y]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            # Create a LineCollection
            lc = LineCollection(segments, cmap=cmap, norm=plt.Normalize(vlims[0], vlims[1]))
            h_plot = h_ano_lin.sel(degree=1).sel(track=t, pt=pt).polyfit_coefficients
            if rolling is not None: 
                h_plot = np.array(pd.Series(h_plot).rolling(window=rolling, center=True, min_periods=1, 
                    win_type='gaussian').mean(std=rolling/3))
            lc.set_array(h_plot)
            lc.set_linewidth(line_w)
            lc_list.append(lc)
            ax.add_collection(lc)
    elif color_by is None:
        gdf_gts_all.plot(ax=ax, color=gdf_gts_all.plotcolor, linewidth=line_w)
        gdf_gts_pp.plot(ax=ax, color='lightblue', linewidth=line_w)
    
    hdls = []
    gdf_ext_all.apply(lambda p: p.buffer(4e3)).plot(ax=ax, color='None', edgecolor='white', label='grounded ice', linewidth=1, zorder=500)
    gdf_pp.plot(ax=ax, color='None', edgecolor='black', label='grounded ice', linewidth=0.2, zorder=501)
    gdf_ext.apply(lambda p: p.buffer(1e3)).plot(ax=ax, color='None', edgecolor='royalblue', label='floating ice', linewidth=0.4, zorder=502)
    
    
    ax.axis('off')
    #ax.text(0.5, 0.96, f'ICESat-2 ATL11 Height Change\n Basin {basin} 2019-2024', transform=ax.transAxes, 
    #        ha='center', va='center', fontsize=major_font_size, bbox=boxprops, zorder=502)
    # Create a ScalarMappable with the same colormap and normalization
    norm = mcolors.Normalize(vmin=vlims[0], vmax=vlims[1])
    sm = cm.ScalarMappable(cmap=cmc.vik_r, norm=norm)
    sm.set_array([])  # Only needed for older versions of matplotlib
    
    # Add the inset axis
    pos = ax.get_position()
    cax_pos = [pos.x0+pos.width*0.7, pos.y0+0.13, pos.width*0.25, pos.height*0.01]
    cax = fig.add_axes(cax_pos)
    #fig.patches.append(Rectangle((cax_pos[0]-, cax_pos[1]), cax_pos[2]*1.25, cax_pos[3]*2,
    #    transform=fig.transFigure, color='white', zorder=1))
    
    # Add a colorbar to the plot, with a specific location
    cbar = plt.colorbar(sm, cax=cax, orientation='horizontal', fraction=0.1, pad=0.0)
    cbar.ax.tick_params(labelsize=minor_font_size) 
    cbar.set_label('height change \n(m yr$^{-1}$)', fontsize=minor_font_size)
    cbar.outline.set_edgecolor('black')
    cbar.outline.set_linewidth(0.5)
    set_axis_color(cax, 'black')
    
    # Customize the colorbar ticks if needed
    cbar.set_ticks([-1, 0, 1])
    #cbar.set_ticklabels(['0', '0.2', '0.4', '0.6', '0.8', '1'])
    
    plotname = f'{plot_dir}/{basin}_dhdt_map.png'
    if rolling is not None: plotname = f'{plot_dir}/{basin}_dhdt_avg{rolling}_map.png'
    if save: fig.savefig(plotname, dpi=dpi, bbox_inches='tight', transparent=transparent)
    
    plt.close(fig)

    return fig

In [None]:
%%time
fig = make_map(color_by='', dpi=400, vlims=[-1, 1], rolling=None, save=False, transparent=True)
display(fig)

In [708]:
def make_ensemble_plots():
    imagery_aspect = 1.4
    major_font_size = 16
    minor_font_size = 16
    line_w = 0.5
    figsize=[8, 4]
    legendsize=[5, 1.5]
    
    fig_list = []
    fig, ax = plt.subplots(figsize=figsize)
    boxprops = dict(boxstyle='round', facecolor='white', alpha=0.5, edgecolor='none', pad=0.2)
    
    hdls=[]
    ano_cycle_fl = ds_fl.h_ano.median(dim=['x', 'track', 'pt'])
    ano_cycle_gr = ds_gr.h_ano.median(dim=['x', 'track', 'pt'])
    ano_cycle_pp = ds_pp.h_ano.median(dim=['x', 'track', 'pt'])
    ax.axhline(y=0.0, linestyle='--', color='black')
    hdl, = ax.plot(ano_cycle_fl.cycle_number, (ano_cycle_fl - ano_cycle_fl.isel(cycle_number=0)), color='red', label='floating ice')
    hdls.append(hdl)
    hdl, = ax.plot(ano_cycle_pp.cycle_number, (ano_cycle_pp - ano_cycle_pp.isel(cycle_number=0)), '--', color='red', label='pinning points')
    hdls.append(hdl)
    hdl, = ax.plot(ano_cycle_gr.cycle_number, (ano_cycle_gr - ano_cycle_gr.isel(cycle_number=0)), '-', color='lightblue', label='upstream grounded ice')
    hdls.append(hdl)
    # starting at 2 is 2019
    ax.set_xlim([2, 24])
    ax.set_ylim([-0.9, 0.9])
    ax.set_xticks(ticks=np.arange(2, 24), minor=True)
    ax.set_xticks(ticks=np.arange(2, 24, 4), labels=[f'{int(c)}' for c in np.arange(2019, 2025)])
    ax.set_yticks(ticks=[-0.5, 0.5])
    ax.tick_params(axis='x', which='major', length=5, width=2)
    ax.tick_params(axis='x', which='minor', length=3, width=1)
    ax.tick_params(which='both', direction='in', zorder=1, labelsize=minor_font_size)
    ax.set_ylabel('height (m)', fontsize=minor_font_size, labelpad=-9)
    ax.text(0.5, 0.96, 'Median Relative Height', transform=ax.transAxes, ha='center', va='top', 
    fontsize=major_font_size, bbox=boxprops)
    #ax.legend(loc='lower left', fontsize=minor_font_size)
    fig_list.append(fig)
    
    handles, labels = ax.get_legend_handles_labels()
    fig, ax = plt.subplots(figsize=legendsize)
    ax.axis('off')
    legend = ax.legend(handles, labels, loc='center', fontsize=minor_font_size)
    ax.add_artist(legend)
    fig_list.append(fig)
    
    fig, ax = plt.subplots(figsize=figsize)
    
    #anomaly (data) count
    ano_count_fl = ds_fl.h_ano.count(dim=['x', 'track', 'pt'])
    ano_count_gr = ds_gr.h_ano.count(dim=['x', 'track', 'pt'])
    ano_count_pp = ds_pp.h_ano.count(dim=['x', 'track', 'pt'])
    
    hdls = []
    ax.axhline(y=1.0, linestyle='--', color='black')
    hdl = ax.plot(ano_cycle_fl.cycle_number, ano_count_fl/ano_count_fl.median(), color='red', label=f'mean = {int(ano_count_fl.mean().data)} values')
    hdls.append(hdl)
    hdl = ax.plot(ano_cycle_pp.cycle_number, ano_count_pp/ano_count_pp.median(), '--', color='red', label=f'mean = {int(ano_count_pp.mean().data)} values')
    hdls.append(hdl)
    hdl = ax.plot(ano_cycle_gr.cycle_number, ano_count_gr/ano_count_gr.median(), '-', color='lightblue', label=f'mean = {int(ano_count_gr.mean().data)} values')
    hdls.append(hdl)
    ax.set_xlim([2, 24])
    ax.set_ylim([-0.0, 1.75])
    ax.set_yticks([0.5, 1.0, 1.5])
    ax.set_ylabel('count fraction', fontsize=minor_font_size)
    ax.set_xticks(ticks=np.arange(2, 24), minor=True)
    ax.set_xticks(ticks=np.arange(2, 24, 4), labels=[f'{int(c)}' for c in np.arange(2019, 2025)])
    ax.tick_params(axis='x', which='major', length=5, width=2, labelsize=minor_font_size)
    ax.tick_params(axis='x', which='minor', length=3, width=1)
    ax.tick_params(which='both', direction='in', zorder=1, labelsize=minor_font_size)
    ax.text(0.5, 0.96, 'Relative Elevation Count', transform=ax.transAxes, ha='center', va='top', fontsize=major_font_size, bbox=boxprops)
    #ax.legend(loc='lower left', fontsize=minor_font_size)
    fig_list.append(fig)
    
    handles, labels = ax.get_legend_handles_labels()
    fig, ax = plt.subplots(figsize=legendsize)
    ax.axis('off')
    legend = ax.legend(handles, labels, loc='center', fontsize=minor_font_size)
    ax.add_artist(legend)
    fig_list.append(fig)
    
    fig, ax = plt.subplots(figsize=figsize)
    bin_edges = np.arange(-10, 260, 10)
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) #for step fn
    h_fl = ds_fl.h_abs.median(dim='cycle_number')
    h_pp = ds_pp.h_abs.median(dim='cycle_number')
    h_gr = ds_gr.h_abs.median(dim='cycle_number')
    ax.hist(np.ndarray.flatten(h_gr.data), bin_edges, 
    color='lightblue', edgecolor='black', alpha=1,
    label=f'{basin} floating ice', density=True);
    ax.hist(np.ndarray.flatten(h_fl.data), bin_edges, 
    color='darkred', edgecolor='black', alpha=1,
    label=f'{basin} floating ice', density=True);
    ax.hist(np.ndarray.flatten(h_pp.data), bin_edges, 
    color='white', edgecolor='darkred', alpha=0.6, hatch=None,
    label=f'{basin} pinning points', density=True);
    counts, _ = np.histogram(np.ndarray.flatten(h_pp.data), bins=bin_edges, density=True)
    ax.step(bin_centers, counts, where='mid', linestyle='-', color='white')
    ax.step(bin_centers, counts, where='mid', linestyle='--', color='darkred')
    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    ax.set_ylabel('count density', fontsize=minor_font_size, labelpad=-9)
    ax.tick_params(axis='y', colors='white')
    ax.set_ylim([0, 0.035])
    ax.set_xlabel('height (m)', fontsize=minor_font_size)
    ax.tick_params(which='both', direction='in', zorder=1, labelsize=minor_font_size)
    ax.text(0.5, 0.96, f'Median elevation distribution', transform=ax.transAxes, ha='center', va='top', fontsize=major_font_size, bbox=boxprops)
    ax.set_xlim([-10, 250])
    fig_list.append(fig)

    fig, ax = plt.subplots(figsize=[8, 3])
    m = h_ano_lin.sel(degree=1).polyfit_coefficients
    ax.hist(np.ndarray.flatten(m.data), 200, color='bisque', edgecolor='gray');
    ax.axvline(x=m.mean(), label=f'mean={m.mean().round(3).data} m', color='purple')
    ax.axvline(x=m.median(), label=f'median={m.median().round(3).data} m', color='cornflowerblue')
    ax.set_ylabel('count')
    ax.set_xlabel('height change (m/yr)')
    ax.set_title(f'{basin} height change 2018-2024')
    ax.legend()
    ax.set_xlim([-3, 3])
    fig.savefig(f'/Users/ccroberts/Desktop/{basin}_dhdt_hist.png', dpi=200, bbox_inches='tight')
    
    if True==False:
        fig_list[0].savefig(f'{plot_dir}/{basin}_median_anom.png', dpi=300, bbox_inches='tight')
        fig_list[1].savefig(f'{plot_dir}/{basin}_median_anom_legend.png', dpi=300, bbox_inches='tight')
        fig_list[2].savefig(f'{plot_dir}/{basin}_median_count.png', dpi=300, bbox_inches='tight')
        fig_list[3].savefig(f'{plot_dir}/{basin}_median_count_legend.png', dpi=300, bbox_inches='tight')
        fig_list[4].savefig(f'{plot_dir}/{basin}_h_abs_hist.png', dpi=300, bbox_inches='tight')

    return

In [None]:
make_ensemble_plots()

# Plot with contexily

In [7]:
#Calculate value to plot
h_ano_lin = ds_all.h_ano.polyfit('cycle_number', deg=1, skipna=True);

In [11]:
## make plot function
def make_cx_plot(color_by='', dpi=600, rolling=None, vlims=[-5, 5], imagery_resolution_adjust=1, save=False):
    # figure setup
    imagery_aspect = 1.4
    major_font_size = 12
    minor_font_size = 10
    line_w = 1.0
    
    # make figure and axes
    fig = plt.figure(figsize=[15,8], dpi=dpi)
    gs = fig.add_gridspec(3, 20)
    axs = [fig.add_subplot(gs[:, :9])]
    for i in range(3):
        axs.append(fig.add_subplot(gs[i, 10:20]))
    boxprops = dict(boxstyle='round', facecolor='white', alpha=0.5, edgecolor='none', pad=0.2)
    
    # plot the basemap and ground track
    # We do this using the package `contextily`, which provides basemaps for plotting in matplotlib. 
    # Here we use ESRI's WordImagery basemap. 
    ax = axs[0]
    
    buffer = 0.001 * np.max([gdf_ext.total_bounds[i+2] - gdf_ext.total_bounds[i] for i in [0,1]])
    bbox = np.array(box(*gdf_ext.total_bounds).buffer(buffer).bounds)
    xrng = bbox[2] - bbox[0]
    yrng = xrng*imagery_aspect
    ymid = np.mean(bbox[[1,3]])
    ax.set_xlim(bbox[[0,2]])
    ax.set_ylim([ymid - yrng/2, ymid + yrng/2])
    cx.add_basemap(ax=ax, crs=crs_antarctica, source=cx.providers.Esri.WorldImagery, 
        zoom_adjust=imagery_resolution_adjust, attribution=' ', attribution_size=3)
    #, attribution='imagery ©ESRI (WorldImagery)', attribution_size=4)
    txt = ax.texts[-1]
    txt.set_position([0.98,0.01])
    txt.set_ha('right')
    txt.set_va('bottom')
    
    #add ground tracks
    lc_list=[]
    if color_by is not None:
        for i in range(len(gdf_gts_all)):
            row = gdf_gts_all[i:i+1]
            coords = row.get_coordinates()
            cmap = cmc.vik_r
            t = row.track.iloc[0]
            pt = row.pt.iloc[0]
            # Prepare segments for LineCollection
            # Prepare segments for LineCollection
            dc_dx = np.gradient(coords.x)
            dc_dx_idx = np.abs(dc_dx)>1e2
            coords.loc[dc_dx_idx, 'x'] = np.nan
            points = np.array([coords.x, coords.y]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            # Create a LineCollection
            lc = LineCollection(segments, cmap=cmap, norm=plt.Normalize(vlims[0], vlims[1]))
            h_plot = h_ano_lin.sel(degree=1).sel(track=t, pt=pt).polyfit_coefficients
            if rolling is not None: 
                h_plot = np.array(pd.Series(h_plot).rolling(window=rolling, center=True, min_periods=1, 
                    win_type='gaussian').mean(std=rolling/3))
            lc.set_array(h_plot)
            lc.set_linewidth(line_w)
            lc_list.append(lc)
            ax.add_collection(lc)
    elif color_by is None:
        gdf_gts_all.plot(ax=ax, color=gdf_gts_all.plotcolor, linewidth=line_w)
        gdf_gts_pp.plot(ax=ax, color='lightblue', linewidth=line_w)
    
    hdls = []
    gdf_pp.plot(ax=ax, color='None', edgecolor='black', label='grounded ice', linewidth=0.2, zorder=500)
    gdf_ext.plot(ax=ax, color='None', edgecolor='royalblue', label='floating ice', linewidth=0.4, zorder=501)
    
    
    ax.axis('off')
    ax.text(0.5, 0.96, f'ICESat-2 ATL11 Height Change\n Basin {basin} 2019-2024', transform=ax.transAxes, 
            ha='center', va='center', fontsize=major_font_size, bbox=boxprops, zorder=502)
        
    
    ax=axs[1]
    hdls = []
    
    # mean values
    
    ano_cycle_fl = ds_fl.h_ano.median(dim=['x', 'track', 'pt'])
    ano_cycle_pp = ds_pp.h_ano.median(dim=['x', 'track', 'pt'])
    ax.axhline(y=0.0, linestyle='--', color='black')
    hdl = ax.plot(ano_cycle_fl.cycle_number, (ano_cycle_fl - ano_cycle_fl.isel(cycle_number=0)), color='red', label='floating ice')
    hdls.append(hdl)
    hdl = ax.plot(ano_cycle_pp.cycle_number, (ano_cycle_pp - ano_cycle_pp.isel(cycle_number=0)), '-', color='lightblue', label='pinning points')
    hdls.append(hdl)
    # starting at 2 is 2019
    ax.set_xlim([2, 24])
    ax.set_ylim([-0.9, 0.9])
    ax.set_xticks(ticks=np.arange(2, 24), minor=True)
    ax.set_xticks(ticks=np.arange(2, 24, 4), labels=[])
    ax.set_yticks(ticks=[-0.5, 0.5])
    ax.tick_params(axis='x', which='major', length=5, width=2)
    ax.tick_params(axis='x', which='minor', length=3, width=1)
    ax.tick_params(which='both', direction='in', zorder=1, labelsize=minor_font_size)
    ax.set_ylabel('height (m)', fontsize=minor_font_size, labelpad=-0.3)
    ax.text(0.5, 0.96, 'Median Relative Height', transform=ax.transAxes, ha='center', va='top', 
        fontsize=major_font_size, bbox=boxprops)
    ax.legend(loc='lower left', fontsize=minor_font_size)
    
    ax=axs[2]
    
    #anomaly (data) count
    ano_count_fl = ds_fl.h_ano.count(dim=['x', 'track', 'pt'])
    ano_count_pp = ds_pp.h_ano.count(dim=['x', 'track', 'pt'])
    
    hdls = []
    ax.axhline(y=1.0, linestyle='--', color='black')
    hdl = ax.plot(ano_cycle_fl.cycle_number, ano_count_fl/ano_count_fl.median(), color='red', label=f'mean = {int(ano_count_fl.mean().data)} values')
    hdls.append(hdl)
    hdl = ax.plot(ano_cycle_pp.cycle_number, ano_count_pp/ano_count_pp.median(), '-', color='lightblue', label=f'mean = {int(ano_count_pp.mean().data)} values')
    hdls.append(hdl)
    ax.set_xlim([2, 24])
    ax.set_ylim([-0.0, 1.75])
    ax.set_yticks([0.5, 1.0, 1.5])
    ax.set_ylabel('count fraction', fontsize=minor_font_size)
    ax.set_xticks(ticks=np.arange(2, 24), minor=True)
    ax.set_xticks(ticks=np.arange(2, 24, 4), labels=[f'{int(c)}' for c in np.arange(2019, 2025)])
    ax.tick_params(axis='x', which='major', length=5, width=2, labelsize=minor_font_size)
    ax.tick_params(axis='x', which='minor', length=3, width=1)
    ax.tick_params(which='both', direction='in', zorder=1, labelsize=minor_font_size)
    ax.text(0.5, 0.96, 'Relative Elevation Count', transform=ax.transAxes, ha='center', va='top', fontsize=major_font_size, bbox=boxprops)
    ax.legend(loc='lower left', fontsize=minor_font_size)
    
    # remove extra axes if no data
    axs[3].axis('off')
    
    # Create a ScalarMappable with the same colormap and normalization
    norm = mcolors.Normalize(vmin=vlims[0], vmax=vlims[1])
    sm = cm.ScalarMappable(cmap=cmc.vik_r, norm=norm)
    sm.set_array([])  # Only needed for older versions of matplotlib
    
    
    
    # Add the inset axis
    pos = axs[0].get_position()
    cax_pos = [pos.x0+pos.width*0.63, pos.y0+0.13, pos.width*0.25, pos.height*0.02]
    cax = fig.add_axes(cax_pos)
    #fig.patches.append(Rectangle((cax_pos[0]-, cax_pos[1]), cax_pos[2]*1.25, cax_pos[3]*2,
    #    transform=fig.transFigure, color='white', zorder=1))
    
    # Add a colorbar to the plot, with a specific location
    cbar = plt.colorbar(sm, cax=cax, orientation='horizontal', fraction=0.1, pad=0.0)
    cbar.ax.tick_params(labelsize=minor_font_size) 
    cbar.set_label('height change \n(m yr$^{-1}$)', fontsize=minor_font_size)
    cbar.outline.set_edgecolor('white')
    cbar.outline.set_linewidth(0.5)
    set_axis_color(cax, 'white')
    
    # Customize the colorbar ticks if needed
    cbar.set_ticks([-1, 0, 1])
    #cbar.set_ticklabels(['0', '0.2', '0.4', '0.6', '0.8', '1'])
    
    plotname = f'/Users/ccroberts/Desktop/{basin}_dhdt.png'
    if rolling is not None: plotname = f'/Users/ccroberts/Desktop/{basin}_dhdt_avg{rolling}.png'
    if save: fig.savefig(plotname, dpi=dpi, bbox_inches='tight')
    
    plt.close(fig)

    return fig

In [None]:
##############

# Grid the data (incomplete)

In [238]:
#Calculate value to plot
m = h_ano_lin.sel(degree=1).polyfit_coefficients
lat = ds_all.latitude
lon = ds_all.longitude

In [383]:
ds_m = xr.Dataset(data_vars={'m': (['lat'], m.stack(z=['track', 'pt', 'x']).values)}, coords={'lat': lat.stack(z=['track', 'pt', 'x']).values,
    'lon': (['lat'], lon.stack(z=['track', 'pt', 'x']).values)})
m_df = gpd.GeoDataFrame({'m': m.stack(z=['track', 'pt', 'x']), 
    'lon': lon.stack(z=['track', 'pt', 'x']),
    'lat': lat.stack(z=['track', 'pt', 'x']),
    'geometry': gpd.points_from_xy(lon.stack(z=['track', 'pt', 'x']), 
    lat.stack(z=['track', 'pt', 'x']))}, 
    crs=crs_latlon).dropna(how='any', ignore_index=True).to_crs(crs_antarctica)

In [None]:
# Define metadata for the raster
metadata = {
    'driver': 'GTiff',
    'count': 1,
    'dtype': 'float32',
    'width': ds_m.sizes['lat'],
    'height': ds_m.sizes['lat'],
    'crs': 'EPSG:4326',
    'transform': from_origin(ds_m.lat.min(), ds_m.lon.max(), 1, 1)  # Adjust transform based on your data
}

# Write the xarray.DataArray to a raster file
with r.open('Cp-D_h_ano_lin.tif', 'w', **metadata) as dst:
    dst.write(ds_m.values, 1)

# REMA Testing

In [158]:
def get_REMA(filename, crs):
    src = rs.open(filename)
    dem = src.read()
    tr = src.transform
    src, tr, crs = reproject_raster(src, crs)
    return src, tr, crs

In [7]:
# load DEM
filename = 'data/REMA/rema_mosaic_1km_v2.0/rema_mosaic_1km_v2.0_dem.tif'
#src, tr, crs = get_REMA(filename, crs_antarctica)
#src = src[0, :, :]

In [348]:
src = rs.open(filename)
dem = src.read()
tr = src.transform

In [None]:
with rs.open(filename) as src:
    dem = src.read()
    tr = src.transform
    siogz_path = 'shapes/scripps_antarctic_polygons_CR.shp'
    gdf_siogz = gpd.read_file(siogz_path).set_crs(crs_antarctica, allow_override=True)
    
    bbox = shapely.to_geojson(box(*gdf_ext.total_bounds))
    bbox = box(*gdf_ext.total_bounds)
    
    
    fig, ax = plt.subplots(figsize=[4, 8])
    src_masked, src_masked_tr = mask(src, shapes=[gdf_gr.geometry[0].intersection(bbox)], 
        crop=True, filled=False)
    src_hs = np.ma.masked_array(es.hillshade(src_masked[0].filled(np.nan)), mask=src_masked.mask[0])
    plot.show(src_hs, ax=ax, transform=src_masked_tr, cmap='Greys_r', 
        aspect='equal', vmin=-150, vmax=100)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_axis_off()
    plt.savefig('/Users/ccroberts/Desktop/rema_2.0.png', dpi=400, bbox_inches='tight', transparent=True)

# Make individual plots

In [None]:
fig, ax = plt.subplots(figsize=[8, 3])
h_fl = ds_fl.h_abs.median(dim='cycle_number')
h_pp = ds_pp.h_abs.median(dim='cycle_number')
bin_range = 10
bins_fl = int(np.round((h_fl.max()-h_fl.min())/bin_range))
bins_pp = int(np.round((h_pp.max()-h_pp.min())/bin_range))
ax.hist(np.ndarray.flatten(h_fl.data), bins_fl, 
    color='darkred', edgecolor='gray', alpha=0.7,
    label=f'{basin} floating ice', density=True);
ax.hist(np.ndarray.flatten(h_pp.data), bins_pp, 
    color='lightblue', edgecolor='gray', alpha=0.8,
    label=f'{basin} pinning points', density=True);
#ax.axvline(x=m.mean(), label=f'mean={m.mean().round(3).data} m', color='purple')
#ax.axvline(x=m.median(), label=f'median={m.median().round(3).data} m', color='cornflowerblue')
ax.set_ylabel('count density')
ax.set_yticks([])
ax.set_xlabel('height (m)')
ax.set_title(f'{basin} median elevation distribution')
ax.legend()
ax.set_xlim([-5, 255])
#fig.savefig(f'/Users/ccroberts/Desktop/{basin}_h_hist.png', dpi=200, bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=[8, 3])
m = h_ano_lin.sel(degree=1).polyfit_coefficients
ax.hist(np.ndarray.flatten(m.data), 200, color='bisque', edgecolor='gray');
ax.axvline(x=m.mean(), label=f'mean={m.mean().round(3).data} m', color='purple')
ax.axvline(x=m.median(), label=f'median={m.median().round(3).data} m', color='cornflowerblue')
ax.set_ylabel('count')
ax.set_xlabel('height change (m/yr)')
ax.set_title(f'{basin} height change 2018-2024')
ax.legend()
ax.set_xlim([-3, 3])
fig.savefig(f'/Users/ccroberts/Desktop/{basin}_dhdt_hist.png', dpi=200, bbox_inches='tight')

In [None]:
plt.figure(figsize=[8, 3])
plt.plot(ds_fl.cycle_number.data, ds_fl.h_abs.median(dim=['track', 'pt', 'x']), color='red', label=f'{basin} floating ice')
plt.plot(ds_pp.cycle_number.data, ds_pp.h_abs.median(dim=['track', 'pt', 'x']), color='lightblue', label=f'{basin} pinning points')
# starting at 2 is 2019
#plt.xlim([2, 23])
plt.xticks(ticks=np.arange(2, 23), labels=[f'{int(2018+((c+2)/4))}' if ((c+2)%4)==0 else '' for c in np.arange(2, 23)])
plt.tick_params(direction='in')
plt.xlabel('cycle')
plt.ylabel('median height (m)')
plt.title(f'Basin {basin} median elevation')
plt.legend()

In [None]:
# Cycle plots

In [None]:
# compare anomaly methods
# Grid up the data
# plot the grid to see change

In [None]:
ds_all.count(dim=['track', 'pt']).max()

In [None]:
ds_all.isel(cycle_number=4).count(dim=['track', 'pt'])

In [None]:
ds_all.h_ano.isel(cycle_number=1).count(dim=['track', 'pt'])

In [None]:
fig, ax = plt.subplots(figsize=[10, 4])
#plotdict = {f'ano_{t}_{pt}': dss.h_med.sel(track=t, pt=pt) for t in dss.track.values for pt in dss.pt.values}
#for p in plotdict: ax.plot(plotdict[p].cycle_number, plotdict[p], '.', color='black', alpha=0.1)
#ax.axhline(y=0.0, linestyle='--', color='black')
ax.plot(dss.cycle_number, dss.h_med.max(dim=['track', 'pt']), color='lightcoral')
ax.plot(dss.cycle_number, dss.h_med.median(dim=['track', 'pt']), color='red', label='median')
ax.plot(dss.cycle_number, dss.h_med.mean(dim=['track', 'pt']), color='blue', label='mean')
ax.plot(dss.cycle_number, dss.h_med.min(dim=['track', 'pt']), color='lightcoral')
#ax.plot(dss_pp.cycle_number, dss_pp.h_med.median(dim=['track', 'pt']), '-', color='royalblue', label='Cp-D pinning points')
# starting at 2 is 2019
ax.set_xlim([2, 23])
#ax.set_ylim([-4.5, 4.5])
#ax.set_ylim([-0.9, 0.9])

ax.set_xticks(ticks=np.arange(2, 23), minor=True)
ax.set_xticks(ticks=np.arange(2, 23, 4), labels=[f'{int(c)}' for c in np.arange(2019, 2025)])
ax.tick_params(axis='x', which='major', length=5, width=2)
ax.tick_params(axis='x', which='minor', length=3, width=1)
ax.tick_params(which='both', direction='out', zorder=1)

ax.set_xlabel('cycle')
ax.set_ylabel('elevation (m)')
ax.set_title('Basin Cp-D: floating ice elevations')
ax.legend()

In [None]:
fig, ax = plt.subplots(figsize=[10, 4])
#plotdict = {f'ano_{t}_{pt}': dss_pp.h_med.sel(track=t, pt=pt) for t in dss_pp.track.values for pt in dss_pp.pt.values}
#for p in plotdict: ax.plot(plotdict[p].cycle_number, plotdict[p], color='lightgray')
#ax.axhline(y=0.0, linestyle='--', color='black')
ax.plot(dss_pp.cycle_number, dss_pp.h_med.max(dim=['track', 'pt']), color='lightcoral')
ax.plot(dss_pp.cycle_number, dss_pp.h_med.median(dim=['track', 'pt']), color='red', label='median')
ax.plot(dss_pp.cycle_number, dss_pp.h_med.mean(dim=['track', 'pt']), color='blue', label='mean')
ax.plot(dss_pp.cycle_number, dss_pp.h_med.min(dim=['track', 'pt']), color='lightcoral')
#ax.plot(dss_pp.cycle_number, dss_pp.h_med.median(dim=['track', 'pt']), '-', color='royalblue', label='Cp-D pinning points')
# starting at 2 is 2019
ax.set_xlim([2, 23])
#ax.set_ylim([-4.5, 4.5])
#ax.set_ylim([-0.9, 0.9])

ax.set_xticks(ticks=np.arange(2, 23), minor=True)
ax.set_xticks(ticks=np.arange(2, 23, 4), labels=[f'{int(c)}' for c in np.arange(2019, 2025)])
ax.tick_params(axis='x', which='major', length=5, width=2)
ax.tick_params(axis='x', which='minor', length=3, width=1)
ax.tick_params(which='both', direction='out', zorder=1)

ax.set_xlabel('cycle')
ax.set_ylabel('elevation (m)')
ax.set_title('Basin Cp-D: pinning point elevations')
ax.legend()

In [37]:
ano_cycle = dss.h_ano.median(dim=['track', 'pt'])
ano_cycle_std = dss.h_ano.std(dim=['track', 'pt'])
ano_cycle_pp = dss_pp.h_ano.median(dim=['track', 'pt'])

In [None]:
fig, ax = plt.subplots(figsize=[10, 4])
#plotdict = {f'ano_{t}_{pt}': dss.h_med.sel(track=t, pt=pt) for t in dss.track.values for pt in dss.pt.values}
#for p in plotdict: ax.plot(plotdict[p].cycle_number, (plotdict[p] - plotdict[p].sel(cycle_number=3)), color='lightgray')
ax.axhline(y=0.0, linestyle='--', color='black')
ax.plot(ano_cycle.cycle_number, (ano_cycle - ano_cycle.isel(cycle_number=0)), color='red', label='Cp-D floating ice')
ax.plot(ano_cycle_pp.cycle_number, (ano_cycle_pp - ano_cycle_pp.isel(cycle_number=0)), '-', color='royalblue', label='Cp-D pinning points')
# starting at 2 is 2019
ax.set_xlim([2, 23])
#ax.set_ylim([-4.5, 4.5])
#ax.set_ylim([-0.9, 0.9])

ax.set_xticks(ticks=np.arange(2, 23), minor=True)
ax.set_xticks(ticks=np.arange(2, 23, 4), labels=[f'{int(c)}' for c in np.arange(2019, 2025)])
ax.tick_params(axis='x', which='major', length=5, width=2)
ax.tick_params(axis='x', which='minor', length=3, width=1)
ax.tick_params(which='both', direction='out', zorder=1)

ax.set_xlabel('cycle')
ax.set_ylabel('height change (m)')
ax.set_title('Basin Cp-D')
ax.legend()

In [None]:
plt.figure(figsize=[8, 3])
plt.plot(mean_cycle.cycle_number, dss.h_med.median(dim=['track', 'pt']), color='red', label='Cp-D floating ice')
plt.plot(mean_cycle_pp.cycle_number, dss_pp.h_med.median(dim=['track', 'pt']), '--', color='red', label='Cp-D pinning points')
# starting at 2 is 2019
#plt.xlim([2, 23])
plt.xticks(ticks=np.arange(2, 23), labels=[f'{int(2018+((c+2)/4))}' if ((c+2)%4)==0 else '' for c in np.arange(2, 23)])
plt.tick_params(direction='in')
plt.xlabel('cycle')
plt.ylabel('median height (m)')
plt.title('Basin Cp-D ')
plt.legend()

In [None]:
##### old stuff

In [428]:
# Data availability

In [None]:
#727
this = gdf_gts_all[197:198]
coords = this.to_crs(crs_latlon).get_coordinates()

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=[11, 7])
ax1.plot(ds.latitude, ds.longitude-0.004, '.', label='clip_data direct', color='purple')
ax1.plot(ds_dict_short[727].latitude.sel(pt=pt), ds_dict_short[727].longitude.sel(pt=pt)-0.002, '.', label='ds_dict_short', color='blue')
ax1.plot(ds_dict[727].latitude.sel(pt=pt), ds_dict[727].longitude.sel(pt=pt), '.', label='ds_dict', color='orange')
ax1.scatter(coords.y, coords.x+0.002, color='red', label='gdf_gts_all', s=1.5)
ax1.set_title(f'Track data comparison track {track}, {pt}, cycle 13')
ax1.legend()
ax1.set_xlabel('latitude')
ax1.set_ylabel('longitude')
ax1.text(-67.11, 117.79, 
    '*n.b.*, \'x\' is simply a placeholder for latitude and longtiude.\n We want it to line up with available (or NaN) elevation values',
    color='k')
ax2.plot(ds.x, ds.h_ano.sel(cycle_number=13, track=track, pt=pt)-2, '.', label='clip_data direct', color='purple')
ax2.plot(ds_dict_short[727].x, ds_dict_short[727].h_ano.sel(cycle_number=13, pt=pt, track=track), '.', label='ds_dict_short', color='blue')
ax2.plot(ds_dict[727].x, ds_dict[727].h_ano.sel(cycle_number=13, pt=pt, track=track)+2, '.', label='ds_dict', color='orange')
#ax2.set_title(f'Height anomaly comparison track {track}, {pt}, cycle 13')
ax2.set_xlabel(f'Pandas index \'x\'')
ax2.set_ylabel('elevation anomaly (m)')
#plt.savefig('/Users/ccroberts/Desktop/track_comparison_727.png', dpi=200, bbox_inches='tight')

In [1006]:
#Tides and stuff

In [None]:
fig, ax = plt.subplots(figsize=[8, 4])
axt = ax.twinx()
#plt.axvspan(pp[0]/1000, pp[1]/1000, color='lightgray')
axt.plot(ds.x_atc/1000, ds.h_abs.sel(cycle_number=3, track=track, pt=pt), color='lightgrey', label='ATL11')
ax.plot(ds.x_atc/1000, ds.tide_cats.sel(cycle_number=3, track=track, pt=pt))
ax.plot(ds.x_atc/1000, ds.tide_cats.interpolate_na(dim='x', method='linear', fill_value='extrapolate').sel(cycle_number=3, track=track, pt=pt), label='linear')
ax.plot(ds.x_atc/1000, ds.tide_cats.interpolate_na(dim='x', method='linear').interpolate_na(dim='x', method='nearest', fill_value='extrapolate').sel(cycle_number=3, track=track, pt=pt), label='linear + nearest')
ax.plot(ds.x_atc/1000, ds.tide_cats.sel(cycle_number=3, track=track, pt=pt), color='black', label='pyTMD output')
ax.set_title(f'Tides interpolation, track {track}, {pt}, cycle {3}')
ax.set_xlabel('x_atc (m)')
ax.set_ylabel('elevation (m)')
axt.set_ylabel('absolute elevation (m)')
axt.tick_params(colors='grey')
axt.yaxis.label.set_color('grey')
ax.legend()
axt.legend(loc=4)
fig.savefig('/Users/ccroberts/Desktop/tides_interpolation.png', dpi=200, bbox_inches='tight')

In [None]:
# An xarray combine_by_coords test

def combine_by_coords_problem(name, join="outer"):

    da0 = [['a10', 'a20', 'a30'],['b10', 'b20', 'b30'], ['c10', 'c20', 'c30']]
    da1 = [['c40', 'c50', 'c60'], ['d40', 'd50', 'd60']] 
    ds0 = xr.Dataset({'data': (['x1', name], da0)}, coords={"x1": ['a', 'b', 'c'], name: [10, 20, 30]})
    ds1 = xr.Dataset({'data': (['x1', name], da1)}, coords={"x1": ['c', 'd'], name: [40, 50, 60]})

    return xr.combine_by_coords([ds0, ds1], join=join)

#combine_by_coords_problem("x0") # concatenates 1, 2, 3, 4, 5, 6
#combine_by_coords_problem("x2") # concatenates 10, 20, 30, 40, 50, 60

out = combine_by_coords_problem("x2", join='outer')
out.sel(x1='d', x2=50)

In [None]:
# a line collection test

cmap = cmc.vik
this = gdf_gts_all[197:198]  #track 727
these = [gdf_gts_all[197:198], gdf_gts_all[196:197]]
#these = [gdf_gts_all[196:197]]

# Plotting
fig, (ax, ax1) = plt.subplots(1, 2, figsize=[10, 4])
lc_list = []
for this in these:
    # Prepare segments for LineCollection
    points = np.array([this.get_coordinates().x, this.get_coordinates().y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    # Create a LineCollection
    lc = LineCollection(segments, cmap=cmap, label=f'{this.pt.iloc[0]}', norm=plt.Normalize(-2, 2))
    lc.set_array(ds_dict[this.track.iloc[0]].h_ano.mean(dim='cycle_number').sel(track=this.track.iloc[0], pt=this.pt.iloc[0]))
    #lc.set_array(ds_dict[these[1].track.iloc[0]].h_ano.mean(dim='cycle_number').sel(track=these[1].track.iloc[0], pt=these[1].pt.iloc[0]))
    lc.set_linewidth(4)
    lc_list.append(lc)
    
    ax.add_collection(lc)
    ax.autoscale()

    ax1.plot(ds_dict[this.track.iloc[0]].h_ano.mean(dim='cycle_number').sel(track=this.track.iloc[0], pt=this.pt.iloc[0]), 
             label=f'{this.pt.iloc[0]}')
ax.set_xlabel('x')
ax.set_ylabel('y')

# Adding colorbar
cb = fig.colorbar(lc, ax=ax, label='line color')

# Creating a custom legend handle
#cmap_colors = cmap(np.linspace(0, 1, 256))
#cmap_gradient = [patches.Patch(facecolor=c, edgecolor=c, label=cmap.name) for c in cmap_colors]
#ax.legend()#(handles=[cmap_gradient], labels=['color gradient line'], handler_map={list: HandlerTuple(ndivide=None, pad=0)})
ax.set_ylim([-1.182e6, -1.171e6])
ax.set_xlim([2.2230e6, 2.2430e6])
ax1.legend()



In [None]:
# useful but not in use
def find_intersect_indices(mask, gt_this, gt_this_pts):
    intersect = mask.clip(gt_this.bounds.values[0]).explode(ignore_index=True).exterior.intersection(gt_this.geometry.iloc[0])
    intersect = intersect[~intersect.is_empty].explode(ignore_index=True)
    return [gt_this_pts.geometry.distance(i).idxmin() for i in intersect]