In [1]:
# Import required packages
import os
import numpy as np
import pandas as pd
import rasterio as rio
import matplotlib.pyplot as plt
import xarray as xr
import datetime as dt
import rioxarray
import seaborn as sns
import geopandas as gpd
from glob import glob
import gc
import math
import datashader as dsh
from datashader import transfer_functions as tf 
from colorcet import palette
from datashader.mpl_ext import dsshow, alpha_colormap
from PIL import Image, ImageFont, ImageDraw

In [2]:
# functions to load in mintpy data

# function to rewrite coordinates from metadata
def coord_range(ds):
    latrange = np.linspace(float(ds.attrs['Y_FIRST']),
                           ((float(ds.attrs['Y_STEP'])*float(ds.attrs['LENGTH']))+float(ds.attrs['Y_FIRST'])),
                           int(ds.attrs['LENGTH']))
    lonrange = np.linspace(float(ds.attrs['X_FIRST']),
                           ((float(ds.attrs['X_STEP'])*float(ds.attrs['WIDTH']))+float(ds.attrs['X_FIRST'])),
                           int(ds.attrs['WIDTH']))
    return latrange, lonrange

# function to read time series into xarray
def mintpyTS_to_xarray(fn, crs):
    ds = xr.open_dataset(fn, cache=False)
    ds = ds.rename_dims({'phony_dim_1':'time',
                         'phony_dim_2':'y',
                         'phony_dim_3':'x'})
    ds = ds.rename({'timeseries': 'displacement'})
    latrange, lonrange = coord_range(ds)
    ds = ds.assign_coords({'time': ('time', pd.to_datetime(ds.date)),
                           'y': ('y', latrange),
                           'x': ('x', lonrange)})
    ds = ds.drop(['bperp', 'date'])
    ds = ds.rio.write_crs(crs)
    
    return ds

def mintpy2d_to_xarray(fn, crs):
    ds = xr.open_dataset(fn)
    ds = ds.rename_dims({'phony_dim_0':'y',
                         'phony_dim_1':'x'
                        })
    latrange, lonrange = coord_range(ds)
    ds = ds.assign_coords({'y': ('y', latrange),
                           'x': ('x', lonrange)})
    ds = ds.rio.write_crs(crs)
    
    return ds

def mintpyInputs_to_xarray(fn, crs):
    ds = xr.open_dataset(fn)
    ds = ds.rename_dims({'phony_dim_0':'reference_time',
                         'phony_dim_1':'secondary_time',
                         'phony_dim_2':'y',
                         'phony_dim_3':'x'})
    latrange, lonrange = coord_range(ds)
    ds = ds.assign_coords({'reference_time': ('reference_time', pd.to_datetime([span[0] for span in ds.date.values])),
                           'secondary_time': ('secondary_time', pd.to_datetime([span[1] for span in ds.date.values])),
                           'y': ('y', latrange),
                           'x': ('x', lonrange)})
    ds = ds.drop('date')
    ds = ds.rio.write_crs(crs)
    
    return ds

In [6]:
def open_summer_mintpy(orbit, year_list, suffix=''):
    ds_list = []
    home_path = '/mnt/d/indennt/mintpy_app'
    for year in year_list:
        mintpy_path = f'{home_path}/{orbit}/mintpy_{year}{suffix}'
        ts_ds = mintpyTS_to_xarray(f'{mintpy_path}/timeseries_ramp_demErr.h5', 32612)
        ts_ds = ts_ds.assign_coords({'year':('year', [int(year)])})
        ds_list.append(ts_ds)

    return ds_list

def tile_images(images):
    # Tile the three images side by side
    widths, heights = zip(*(i.size for i in images))
    width = max(widths)*6
    height = max(heights)*2
    output = Image.new("RGB", (width, height))
    x_offset = 0
    y_offset = 0
    for i, im in enumerate(images):
        output.paste(im, (x_offset, y_offset))
        x_offset += im.size[0]

        if i == 5:
            x_offset = 0
            y_offset += im.size[1]

    return output

def draw_text(img, text, size=30, face="Regular", color='black'):
    """ Helper to draw text labels using PIL. """
    d = ImageDraw.Draw(img)
    font = ImageFont.truetype('arial.ttf', 50)
    d.text((0, 0), text, font=font, fill=color)
    return img

def plot_image_list(ds, correction):
    cmap = palette['coolwarm']
    img_list = []
    
    for i, time in enumerate(ds.time):
        da = np.flip(ds.displacement[i], 0)
        ys, xs = int(da.shape[0]*0.05), int(da.shape[1]*0.05)
        agg = dsh.Canvas(plot_width=xs, plot_height=ys).raster(da, downsample_method='mean')
        img = tf.shade(agg, cmap=cmap, span=[-0.01,0.01], how='linear').to_pil()
        img_list.append(draw_text(img, time.dt.strftime('%Y-%m-%d').item()))
        
    output = tile_images(img_list)
    output.save(f'../../figs/{ds.year.item()}_{correction}.png')

In [4]:
orbit = 'AT137'
year_list = ['2017', '2018', '2019', '2020', '2021']

ds_list_CNN = open_summer_mintpy(orbit, year_list)
ds_list_uncorrected = open_summer_mintpy(orbit, year_list, suffix='_uncorrected')
ds_list_murp = open_summer_mintpy(orbit, year_list, suffix='_MuRP')
ds_list_era5 = open_summer_mintpy(orbit, year_list, suffix='_ERA5')

In [8]:
for ds in ds_list_uncorrected:
    plot_image_list(ds, 'uncorrected')

In [7]:
for ds in ds_list_CNN:
    plot_image_list(ds, 'CNN')

In [None]:
# for ds in ds_list_murp:
#     plot_image_list(ds, 'murp')

In [None]:
# for ds in ds_list_era5:
#     plot_image_list(ds, 'era5')