# Figuring out MuRP

In [1]:
from glob import glob
import xarray as xr
import rasterio as rio
import rioxarray
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model
from sklearn.metrics import mean_squared_error

In [14]:
# load in single igram and other data 
def hyp3_to_xarray_single(path):
    '''
    Reads unwrapped phase, coherence, and DEM into xarray dataset from single hyp3 folder 
    '''
    # globs for data to load
    unw_phase_path = glob(f'{path}/*unw_phase.tif')[0]
    dem_path = glob(f'{path}/*dem.tif')[0]
    corr_path = glob(f'{path}/*corr.tif')[0]

    # list granules for coordinate
    granule = os.path.split(unw_phase_path)[-1][0:-14]

    # read unw_phase into data array and assign coordinates
    da = xr.open_dataset(unw_phase_path)
    da = da.assign_coords({'granule':('granule', [granule])})
    
    # concatenate into dataset and rename variable
    ds = da.rename({'band_data': 'unw_phase'})

    #open coherence and dem into datasets
    corr_ds = xr.open_dataset(corr_path)
    dem_ds = xr.open_dataset(dem_path)

    # add coherence and dem to unw_phase dataset
    ds['coherence'] = (('band', 'y', 'x'), corr_ds.band_data.values)
    ds['elevation'] = (('band', 'y', 'x'), dem_ds.band_data.values)

    # remove band coordinate
    ds = ds.squeeze()

    return ds

In [22]:
def select_refs(ds, corr_thresh=0.9, n_refs=1000):
    ref_list = []
    ref_counter = 0
    while ref_counter < n_refs:
        x, y = np.random.randint(0, len(ds.x)), np.random.randint(0, len(ds.y))
        if [x, y] not in ref_list:
            if ds.coherence.isel(x=x, y=y) >= corr_thresh:
                ref_list.append([x, y])
                ref_counter +=1
    return ref_list

def plot_points(ds, refs, output_dir):
    x, y = zip(*refs)
    
    f, ax = plt.subplots()
    ax.imshow(ds.elevation, cmap='viridis')
    ax.plot(x, y, marker='o', linestyle='', color='Orange')
    ax.set_title('reference points')
    ax.set_aspect('equal')

    plt.savefig(f'{output_dir}/ref_points.png', dpi=300)

def sample_refs(ds, refs):
    ref_phase = []
    ref_elevation = []
    for ref in refs:
        ref_elevation.append(ds.elevation.isel(x=ref[0], y=ref[1]).item())
        ref_phase.append(ds.unw_phase.isel(x=ref[0], y=ref[1]).item())
    return ref_phase, ref_elevation

def plot_first_scatter(ref_phase, ref_elevation, output_dir):
    f, ax = plt.subplots()
    ax.plot(ref_elevation, ref_phase, marker='o', linestyle='')
    ax.set_xlabel('elevation')
    ax.set_ylabel('phase')
    ax.set_title('elevation and phase of ref points in first igram')
   
    plt.savefig(f'{output_dir}/first_igram_elevation_phase.png', dpi=300)

def linear_fits(ds, ref_phase, ref_elevation):
    model = linear_model.LinearRegression() 
    inputs = np.array((ref_elevation, ref_phase)).transpose()
    inputs = inputs[~np.isnan(inputs).any(axis=1)]
    model.fit(inputs[:, 0].reshape(-1, 1), inputs[:, 1])
    fits = [model.coef_.item(), model.intercept_]
    return fits

def correct_igrams(ds, fits):
    slope, intercept = fits[0], fits[1]
    ds['unw_phase_MuRP'] = ds.unw_phase - (ds.elevation.values*slope+intercept)
    return ds

def plot_correction(ds, output_dir):
    f, ax = plt.subplots(2, 1, figsize=(10, 5))
    ds.unw_phase_MuRP.plot(ax=ax[0], vmin=-5, vmax=5, cmap='RdBu')
    ax[0].set_aspect('equal')
    ax[0].set_title('MuRP corrected mean phase')
    ds.unw_phase.plot(ax=ax[1], vmin=-5, vmax=5, cmap='RdBu')
    ax[1].set_aspect('equal')
    ax[1].set_title('uncorrected mean phase')
    plt.tight_layout()

    plt.savefig(f'{output_dir}/MuRP_correction.png', dpi=300)

# single function
def MuRP(ds, corr_thresh=0.8, n_refs=1000, figs=True, fig_dir='../figs'):
    '''
    Correct unwrapped phase with linear fit to multiple stable reference points
    '''
    print('selecting reference points')
    refs = select_refs(ds)
    
    print('sampling reference points')
    ref_values, ref_elevation = sample_refs(ds, refs)
    
    print('calculating linear fits')
    fits = linear_fits(ds, ref_values, ref_elevation)
    
    print('correcting interferograms')
    ds = correct_igrams(ds, fits)

    if figs==True:
        os.makedirs(fig_dir, exist_ok=True)
        print('saving figures')
        plot_points(ds, refs, fig_dir)
        plot_first_scatter(ref_values, ref_elevation, fig_dir)
        plot_correction(ds, '../figs')

    return ds 

In [9]:
def multi_year_MuRP(orbit, frame_list, year_list):
    for frame in frame_list:
        for year in year_list:
            hyp3_path = f'../proc/data/hyp3/{orbit}/{frame}/{year}'
            hyp3_list = os.listdir(hyp3_path)
            for granule in hyp3_list:
                print(f'working on {granule}')
                granule_path = f'{hyp3_path}/{granule}'
                ds = hyp3_to_xarray_single(granule_path)
                ds = MuRP(ds, figs=False)
                ds.unw_phase_MuRP.rio.to_raster(f'{granule_path}/{granule}_unw_phase_MuRP.tif')

In [23]:
multi_year_MuRP('DT56', ['frame_3'], ['2017'])

working on S1AA_20170607T130955_20170619T130955_VVP012_INT40_G_ueF_3D12
selecting reference points
sampling reference points
calculating linear fits
correcting interferograms
working on S1AA_20170607T130955_20170701T130956_VVP024_INT40_G_ueF_4ED3
selecting reference points
sampling reference points
calculating linear fits
correcting interferograms
working on S1AA_20170607T130955_20170713T130956_VVP036_INT40_G_ueF_E6F8
selecting reference points
sampling reference points
calculating linear fits
correcting interferograms
working on S1AA_20170607T130955_20170725T130957_VVP048_INT40_G_ueF_2D1D
selecting reference points
sampling reference points
calculating linear fits
correcting interferograms
working on S1AA_20170607T130955_20170806T130958_VVP060_INT40_G_ueF_AE9B
selecting reference points
sampling reference points
calculating linear fits
correcting interferograms
working on S1AA_20170619T130955_20170701T130956_VVP012_INT40_G_ueF_AAC0
selecting reference points
sampling reference points
