In [1]:
import numpy as np
import xarray as xr
import rasterio as rio
import rioxarray
import matplotlib.pyplot as plt
import pandas as pd
import dask
from glob import glob
import dask.array as da
from dask.distributed import Client
import copy
import zarr

## read in data

In [2]:
data_path = '/mnt/Backups/gbrench/repos/rg_uavsar/data'

path_dic = {'swatch_00540':f'{data_path}/swatch_00540/geo',
            'swatch_08301':f'{data_path}/swatch_08301/geo',
            'swatch_09510':f'{data_path}/swatch_09510/geo',
            'swatch_17304_s1':f'{data_path}/swatch_17304_s1/geo',
            'swatch_17304_s2':f'{data_path}/swatch_17304_s2/geo',
            'swatch_18525':f'{data_path}/swatch_18525/geo',
            'swatch_26300':f'{data_path}/swatch_26300/geo',
            'swatch_27518':f'{data_path}/swatch_27518/geo'}

In [3]:
# count_map = xr.open_dataset(f'{data_path}/coverage_mosaic/uavsar-tile-0-count.tif')
# count_map = count_map.rio.reproject('EPSG:32613')

# # open all data and save to netcdfs
# ds_dic = {}
# for key, item in path_dic.items():
#     print(f'opening {key}')
#     ds = xr.open_dataset(f'{item}/{key}_20230725_20230925.nc').chunk({'x': 1000, 'y': 1000})
#     ds = ds.rio.write_crs('32613')
#     ds = ds.rio.reproject_match(count_map)
#     # save to full size netcdf
#     ds.to_netcdf(f'{data_path}/coverage_mosaic/{key}_20230725_20230925.nc')

In [4]:
# # reopen data as single dataset and save to zarr file, chunked along time
# fns = sorted(glob(f'{data_path}/coverage_mosaic/*_20230725_20230925.nc'))
# ds = xr.open_mfdataset(fns, combine='nested', concat_dim=xr.DataArray(list(path_dic.keys()), dims='swatch', name='swatch'), chunks={'x': 1000, 'y': 1000})
# ds = ds.drop_vars(['latitude', 'longitude', 'slantRangeDistance', 'displacement', 'rgi'])
# ds = ds.chunk({'swatch': -1})

# ds.to_zarr(f'{data_path}/combined/stack_20230725_20230925.zarr')

In [5]:
client = Client()  # This starts the Dask client
print(client.dashboard_link)

http://127.0.0.1:8787/status


In [6]:
# open data from zarr file
ds = xr.open_dataset(f'{data_path}/combined/stack_20230725_20230925.zarr', chunks='auto',engine='zarr')
ds = ds.rio.write_crs('EPSG:32613')

## create geometry arrays

UAVSAR is left looking! 

In [7]:
# create LOS angle
ds['losAngle'] = xr.where(ds.azimuthAngle > 0, 360 - ds.azimuthAngle, np.fabs(ds.azimuthAngle))
# fix the two odd ones temporarily
ds['losAngle'] = ds['losAngle'].where(ds.swatch != 'swatch_27518', 180 + ds.azimuthAngle)
ds['losAngle'] = ds['losAngle'].where(ds.swatch != 'swatch_26300', 180 + ds.azimuthAngle)

In [8]:
# define unit vectors for LOS
ds['n_hat'] = np.cos(np.radians(ds.losAngle))*np.sin(np.radians(ds.incidenceAngle))
ds['e_hat'] = np.sin(np.radians(ds.losAngle))*np.sin(np.radians(ds.incidenceAngle))
ds['u_hat'] = -np.cos(np.radians(ds.incidenceAngle))

In [9]:
# mask by coherence
corr_value = 0

ds['displacement_masked'] = ds.displacement_MuRP.where(ds.cor >= corr_value)
ds['n_hat'] = ds['n_hat'].where(ds.cor >= corr_value)
ds['e_hat'] = ds['e_hat'].where(ds.cor >= corr_value)
ds['u_hat'] = ds['u_hat'].where(ds.cor >= corr_value)

In [None]:
# plot for sanity check
# my_slice = ds.isel(swatch=6, x=slice(3000, 4000), y=slice(6000, 7000))
# f, ax = plt.subplots(1, 3, figsize=(10, 5))
# my_slice.n_hat.plot.imshow(ax=ax[0], cbar_kwargs={'shrink':0.5})
# ax[0].set_aspect('equal')
# my_slice.e_hat.plot.imshow(ax=ax[1], cbar_kwargs={'shrink':0.5})
# ax[1].set_aspect('equal')
# my_slice.u_hat.plot.imshow(ax=ax[2], cbar_kwargs={'shrink':0.5})
# ax[2].set_aspect('equal')
# f.tight_layout()

## calculate 3d displacement

In [10]:
def calculate_pixel_displacement(n_hat, e_hat, u_hat, displacement, coherence=None):
    # Remove NaN values
    mask = np.isfinite(n_hat) & np.isfinite(e_hat) & np.isfinite(u_hat) & np.isfinite(displacement)
    n_hat, e_hat, u_hat, displacement = [arr[mask] for arr in [n_hat, e_hat, u_hat, displacement]]
    
    if coherence is not None:
        coherence = coherence[mask]
        weights = coherence ** 2  # Square the coherence values to use as weights

    # Check if there are at least three valid flight lines
    if len(displacement) >= 3:
        # Prepare the A matrix and b vector for the least squares calculation
        A = np.column_stack([n_hat, e_hat, u_hat])
        b = displacement

        if coherence is not None:
            # Apply weights
            W = np.diag(weights)
            A = W @ A
            b = W @ b

        # Solve for the displacements using least squares
        solution, residuals, rank, s = np.linalg.lstsq(A, b, rcond=None)

        # Save the results
        y_displacement = np.array([solution[0]])
        x_displacement = np.array([solution[1]])
        z_displacement = np.array([solution[2]])
        standard_error = np.array([np.sqrt(residuals / (len(b) - rank)) if residuals.size > 0 else np.nan])
    else:
        y_displacement = np.array([np.nan])
        x_displacement = np.array([np.nan])
        z_displacement = np.array([np.nan])
        standard_error = np.array([np.nan])

    # Return numpy arrays
    return y_displacement, x_displacement, z_displacement, standard_error

def calculate_displacement(dataset):
    # Apply the calculate_pixel_displacement function to each pixel
    results = xr.apply_ufunc(
        calculate_pixel_displacement, 
        dataset.n_hat, dataset.e_hat, dataset.u_hat, dataset.displacement_masked, dataset.cor,
        input_core_dims=[['swatch'], ['swatch'], ['swatch'], ['swatch'], ['swatch']],
        output_core_dims=[[], [], [], []],  
        vectorize=True, dask='parallelized'
    )

    # extract results
    y_displacement, x_displacement, z_displacement, standard_error = results

    # Persist the results
    y_displacement, x_displacement, z_displacement, standard_error = dask.persist(y_displacement, x_displacement, z_displacement, standard_error)

    return y_displacement, x_displacement, z_displacement, standard_error

In [11]:
# calculate components 
y_displacement, x_displacement, z_displacement, standard_error = calculate_displacement(ds)

In [None]:
# total_displacement = np.sqrt(y_displacement**2 + x_displacement**2 + z_displacement**2)

In [12]:
def save_out(data_path, y_displacement, x_displacement, z_displacement, standard_error):
    y_displacement.rio.to_raster(f'{data_path}/combined/ns_disp.tif', tiled=True, windowed=True, compress="LZW")
    x_displacement.rio.to_raster(f'{data_path}/combined/ew_disp.tif', tiled=True, windowed=True, compress="LZW")
    z_displacement.rio.to_raster(f'{data_path}/combined/ud_disp.tif', tiled=True, windowed=True, compress="LZW")
    standard_error.rio.to_raster(f'{data_path}/combined/standard_error_disp.tif', tiled=True, windowed=True, compress="LZW")

save_out(data_path, y_displacement, x_displacement, z_displacement, standard_error)

In [None]:
#total_displacement.rio.to_raster(f'{data_path}/combined/total_disp.tif', tiled=True, windowed=True, compress="LZW")

In [13]:
client.close()