In [None]:
import numpy as np
import datetime

from tools.utils import create_polydict_ptsarr, xy2ll, directory_spider
from tools.landsat_utils import easy_raster, normalize_band as normalize

from tools.flexural_fitting import deflection
from tools.icesat2_utils import advect_IS2

import matplotlib.pyplot as plt
import matplotlib as mpl
from cmcrameri import cm as cmc
from tools.plot_tools import plot_latlon_lines, shifted_colormap
from matplotlib import gridspec

In [None]:
import plotSettings
plotSettings.paper(1.4)
COLS = plotSettings.COLS

In [None]:
%matplotlib tk

In [None]:
# general data directory
data_dir = '/'

In [None]:
# location of context images
ls89_files = 'data/LandSat8-9_Color/LC08_L1GT_127111_20220210_20220222_02_T2'

# bounding boxes for masking images 
bounding_box_file = 'GIS/BoundingBoxes2.shp'
bbs_dict, _ = create_polydict_ptsarr(os.path.join(data_dir, bounding_box_file))
bb = bbs_dict[0]

# grab locations of MEaSUREs ITS_LIVE data for is2 track advection
vel_fill_file = os.path.join(data_dir, 'data/MEaSUREs_ITS_LIVE/ANT_G0120_0000.nc')
vel_files = {y:os.path.join(data_dir, f'data/MEaSUREs_ITS_LIVE/ANT_G0240_{y}.nc') for y in range(2013,2018 + 1)}
vel_files[2019] = os.path.join(data_dir, 'data/MEaSUREs_ITS_LIVE/ANT_G0240_2018.nc')

# locations of IS2 data
masked_IS2 = directory_spider(os.path.join(data_dir, 'data/ICESat-2'), file_pattern='.pkl')

# useful function for extracting dates from IS2 files
extract_date = re.compile(r'ATL06_MASKED_(\d{4})(\d{2})(\d{2})(\d{6})_(\d{8})_(\d{3})_(\d{2}).pkl')

In [None]:
TRKs = ['0470', '0401']
gts = ['gt1', 'gt1']

start_year = [2018, 2018]
end_year = [2022, 2022]
all_dates = []

start_files = []
end_files = []

start_dates = []
end_dates = []

dt_mins = []
dt_maxs = []

for i, (gt, TRK) in enumerate(zip(gts, TRKs)):
    years = [YY for YY in range(start_year[i], end_year[i] + 1)]
    is2_files = [file for file in masked_IS2 if f'/{TRK}/' in file if '._' not in file]
    end_file = is2_files[0]
    end_date = datetime.datetime(1900,1,1)

    start_date = datetime.datetime(2100,1,1)
    all_date = []
    for year in years:
        files = [f for f in is2_files if f'/{year}/' in f]
        for f in files:
            YY, MM, DD, _, _, _, _ = extract_date.findall(f).pop()
            date = datetime.datetime(int(YY), int(MM), int(DD))
            if (TRK == '0401' and date == datetime.datetime(2019,1,23)):
                pass
            else:
                if date > end_date:
                    end_date = date
                    end_file = f
                else:
                    end_date = end_date
                if date < start_date:
                    start_date = date
                    start_file = f
                else:
                    start_date = start_date

                print(date, start_date, end_date)
    dt_min = np.inf  # days
    dt_max = -np.inf # days
    
    for year in years:
        files = [f for f in is2_files if f'/{year}/' in f]
        for f in files:
            YY, MM, DD, _, _, _, _ = extract_date.findall(f).pop()
            date = datetime.datetime(int(YY), int(MM), int(DD))
            all_date.append(date)
            if (TRK == '0401' and date == datetime.datetime(2019,1,23)):
                pass
            else:
                dt = (end_date - date).days
                if dt > 0:
                    dt_min = dt if dt < dt_min else dt_min
                dt_max = dt if dt > dt_max else dt_max
    all_dates.append(all_date)            
    
    start_year[i] = start_date.year
    end_year[i] = end_date.year
    
    start_dates.append(start_date)
    end_dates.append(end_date)
    
    start_files.append(start_file)
    end_files.append(end_file)
    
    dt_mins.append(dt_min)
    dt_maxs.append(dt_maxs)

    # now we have to clean specific is2 track dates
    _tmp = []
    for date in all_dates[i]:
        if TRK == '0401':
            if date in [datetime.datetime(2019,1,23), datetime.datetime(2019,7,24)]:
                pass
            else:
                _tmp.append(date)
        if TRK == '0470':
            if date in [datetime.datetime(2020, 1, 26)]:
                pass
            else:
                _tmp.append(date)
    _tmp.sort()
    all_dates[i] = [t for t in _tmp]

In [None]:
SAVE = False
fig_save_dir = 'figs'
fig_type = '.svg'
save_name = 'figure8'

##########################################################################################
# plot settings
##########################################################################################
figsize = [17.515, 5.248]
NROWS = 4
NCOLS = 6

cmap = cmc.batlow_r
norm = mpl.colors.Normalize(vmin=0, vmax=1.09)
CM = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)

xticks = [-72.05 - 0.025, -72.05, -72.025, -72.0, -72 + 0.025]
xticklabels = ['{0}°'.format(f'{v:0.2f}')  for i,v in enumerate(xticks)]

##########################################################################################
# create figure
##########################################################################################
fig = plt.figure(num=4, clear=1, figsize=figsize)
gs = gridspec.GridSpec(NROWS, NCOLS)

imax = fig.add_subplot(gs[:NROWS,:NCOLS//3])

i,j = slice(2,4), slice(4,6)
trk = np.array([
       [fig.add_subplot(gs[:2, i]), fig.add_subplot(gs[:2, j])],
       [fig.add_subplot(gs[2:, i]), fig.add_subplot(gs[2:, j])]
      ])

##########################################################################################
# context image
##########################################################################################
left, bottom, right, top = bb.bounds

ll = xy2ll(left, bottom), xy2ll(left, top), xy2ll(right, top), xy2ll(right, bottom)
ll = np.array(ll)
lon_range = [ll[:,0].min(), ll[:,0].max()]
lat_range = [ll[:,1].min(), ll[:,1].max()]

rast = easy_raster(os.path.join(data_dir, ls89_file),
                   mask=bb, always_TOA=True)
print('..pan sharpening img')
ps = rast.pan_sharpen(method='brovey', TOA_correction=True)
R, G, B = ps['B4'], ps['B3'], ps['B2']
MIN = (0,0,0)
MAX = [np.quantile(b, 0.98) for b in [R,G,B]]
shape = R.shape
RGB = np.empty((shape[0], shape[1], 3),dtype=int)
RGB[:,:,0] = np.clip(255*normalize(R, MIN[0], MAX[0]), 0, 255)
RGB[:,:,1] = np.clip(255*normalize(G, MIN[1], MAX[1]), 0, 255)
RGB[:,:,2] = np.clip(255*normalize(B, MIN[2], MAX[2]), 0, 255)

imax.imshow(RGB, extent=(left, right, bottom, top), aspect='equal', rasterized=True)
imax.set(xticklabels=[], yticklabels=[],
         xticks=[], yticks=[],
         xlim=(left, right), ylim=(bottom, top),
        )

imax = plot_latlon_lines(imax, lat_range=lat_range, lon_range=lon_range, 
                         lat_step=0.05, lon_step=0.15,
                         extents=(left, right, bottom, top)
                        )
##########################################################################################
# plot is2 tracks and elevations
##########################################################################################
letters = [('b','c'), ('d','e')]
for i,TRK in enumerate(TRKs):
    is2_files = [file for file in masked_IS2 if f'/{TRK}/' in file if '._' not in file]
    for year in np.array(years):
        files = [f for f in is2_files if f'/{year}/' in f]
        for file in files:
            YY,MM,DD,_,_,_,_ = extract_date.findall(file).pop()
            date = datetime.datetime(int(YY), int(MM), int(DD))
            if date in all_dates[i]:
                dt = (end_dates[i] - date).days
                color_index = 1 - dt/dt_max
                col = CM.cmap(color_index)
                ans = advect_IS2(file, end_files[i], vel_files, vel_fill_file, date, dt, gt=gts[i])
                x1l, y1l, x2l, y2l, h1l_lagrangian, h2l, x1r, y1r, x2r, y2r, h1r_lagrangian, h2r = ans
                
                lonl, latl = xy2ll(x2l, y2l)
                lonr, latr = xy2ll(x2r, y2r)
                
                trk[i,0].plot(latl, h1l_lagrangian, c=col, zorder=255*color_index)                                
                trk[i,0].text(0.03, 0.1, f'({letters[i][0]}) TRK {TRK} {gts[i].upper()}L',
                              va='center', ha='left', fontsize=21,
                              color=COLS[i],
                              transform=trk[i,0].transAxes
                             )

                trk[i,1].plot(latr, h1r_lagrangian, c=col, zorder=255*color_index)
                trk[i,1].text(0.03, 0.15, f'({letters[i][1]}) TRK {TRK} {gts[i].upper()}R',
                              va='center', ha='left', fontsize=21,
                              color=COLS[i],
                              transform=trk[i,1].transAxes
                              )
            
                # date legend
                mtxt = f'{date.month}'.zfill(2)
                dtxt = f'{date.day}'.zfill(2)
                trk[i,1].text(1.1, color_index, f'{date.year}-{mtxt}-{dtxt}',
                             fontsize=18, color=col,
                             va='center', ha='left', transform=trk[i,1].transAxes)
                start_file,date0 = f, date1

    imax.plot(x2l, y2l, c=COLS[i], ls=':', label=f'({letters[i][0]}) TRK {TRK} {gts[i].upper()}L', zorder=255)
    imax.plot(x2r, y2r, c=COLS[i], ls='--', label=f'({letters[i][1]}) TRK {TRK} {gts[i].upper()}R', zorder=255)

imax.legend(loc='lower right', fontsize=18)

##########################################################################################
# 
##########################################################################################
elev_ylim = (np.min([a.get_ylim()[0] for a in trk.flatten()]), 
             np.max([a.get_ylim()[1] for a in trk.flatten()])
            )

[a.set(ylim=elev_ylim, xticks=xticks, xlim=(latl.min(), latl.max())) for a in trk.flatten()]
[a.set_xticklabels([]) for a in trk[0,:]]
[a.set_xticklabels(xticklabels) for a in trk[1,:]]

[[ax.axvline(v, c='k', ls='--', lw=0.3) for v in [-72.05, -72.0]] for ax in trk.flatten()]

[ax.set_xlabel('Latitude') for ax in trk[1,:]]
[ax.set_yticklabels([]) for ax in [trk[0, 1], trk[1,1]]]
[ax.set_ylabel('Surface elevation (m)') for ax in trk[:,0]]
             
gs.update(left=0.000, right=0.99, top=0.95, bottom=0.1, wspace=0.05, hspace=0.1)

if SAVE:
    plt.savefig(os.path.join(save_dir, save_name + fig_type), dpi=600)

In [None]:
plt.text?