# Generate subsets

In [1]:
import random
import os
import numpy as np
import pandas as pd
import rasterio as rio
import xarray as xr
import datetime as dt
import rioxarray
import geopandas as gpd
from glob import glob
import time
from sklearn import linear_model
from sklearn.metrics import mean_squared_error
import pickle

In [2]:
# load in single igram and other data 
def hyp3_to_xarray_single(path):
    '''
    Reads hyp3 outputs into xarray dataset from single hyp3 folder 
    '''
    # globs for data to load
    unw_phase_path = glob(f'{path}/*unw_phase.tif')[0]
    era5_path = glob(f'{path}/*ERA5.tif')[0]
    murp_path = glob(f'{path}/*MuRP.tif')[0]
    dem_path = glob(f'{path}/*dem.tif')[0]
    corr_path = glob(f'{path}/*corr.tif')[0]
    meta_path = glob(f'{path}/S1*.txt')[0]

    # list granules for coordinate
    granule = os.path.split(unw_phase_path)[-1][0:-14]

    d = {}
    with open(meta_path) as f:
        for line in f:
            (key, val) = line.split(':')
            d[key] = str.strip(val)

    # read unw_phase into data array and assign coordinates
    da = xr.open_dataset(unw_phase_path)
    da = da.assign_coords({'granule':('granule', [granule])})
    for item in d.keys():
            da = da.assign_coords({item:('granule', [d[item]])})
    
    # concatenate into dataset and rename variable
    ds = da.rename({'band_data': 'unw_phase'})

    #open coherence and dem into datasets
    era5_ds = xr.open_dataset(era5_path)
    murp_ds = xr.open_dataset(murp_path)
    dem_ds = xr.open_dataset(dem_path)
    corr_ds = xr.open_dataset(corr_path)

    # add coherence and dem to unw_phase dataset
    ds['era5_phase'] = (('band', 'y', 'x'), era5_ds.band_data.values)
    ds['murp_phase'] = (('band', 'y', 'x'), murp_ds.band_data.values)
    ds['elevation'] = (('band', 'y', 'x'), dem_ds.band_data.values)
    ds['coherence'] = (('band', 'y', 'x'), era5_ds.band_data.values)

    # remove band coordinate
    ds = ds.squeeze()

    return ds

def sample_ds(ds, subset_size=128):
    minx = 0
    miny = 0
    maxx = len(ds.x)-subset_size
    maxy = len(ds.y)-subset_size

    sub_minx = random.randint(minx, maxx)
    sub_miny = random.randint(miny, maxy)
    subset = ds.isel(x=slice(sub_minx, sub_minx+subset_size), y=slice(sub_miny, sub_miny+subset_size))
    
    return subset

# set local ref with coherence, not in use currently
def subset_ref(subset_ds, corr_thresh=0.95):
    atmo_noise = subset_ds.signal.values[np.where(subset_ds.coherence >= corr_thresh, [subset_ds.signal, np.nan])].median(dim=['x', 'y'])
    subset_ds['signal'] = subset_ds['signal'] - atmo_noise
    return subset_ds

In [3]:
def select_refs(ds, corr_thresh, n_refs):
    ref_list = []
    ref_counter = 0
    while ref_counter < n_refs:
        x, y = np.random.randint(0, len(ds.x)), np.random.randint(0, len(ds.y))
        if ds.signal_corr.isel(x=x, y=y) >= corr_thresh:
            ref_list.append([x, y])
            ref_counter +=1
    return ref_list

def sample_refs(ds, refs):
    ref_phase = []
    ref_elevation = []
    for ref in refs:
        ref_elevation.append(ds.elevation.isel(x=ref[0], y=ref[1]).item())
        ref_phase.append(ds.signal.isel(x=ref[0], y=ref[1]).item())
    return ref_phase, ref_elevation

def linear_fits(ds, ref_phase, ref_elevation):
    model = linear_model.LinearRegression() 
    inputs = np.array((ref_elevation, ref_phase)).transpose()
    inputs = inputs[~np.isnan(inputs).any(axis=1)]
    model.fit(inputs[:, 0].reshape(-1, 1), inputs[:, 1])
    fits = [model.coef_.item(), model.intercept_]
    return fits

def correct_igrams(ds, fits):
    slope, intercept = fits[0], fits[1]
    ds['signal_MuRP'] = ds.signal - (ds.elevation.values*slope+intercept)
    return ds

# single function
def MuRP(ds, corr_thresh=0.7, n_refs=100):
    '''
    Correct unwrapped phase with linear fit to multiple stable reference points
    '''
    #print('selecting reference points')
    refs = select_refs(ds, corr_thresh=corr_thresh, n_refs=n_refs)
    
    #print('sampling reference points')
    ref_values, ref_elevation = sample_refs(ds, refs)
    
    #print('calculating linear fits')
    fits = linear_fits(ds, ref_values, ref_elevation)
    
    #print('correcting interferograms')
    ds = correct_igrams(ds, fits)

    return ds

In [18]:
def subset_noise(orbit_list, 
                 frame_list, 
                 year_list, 
                 subsets_desired, 
                 subset_type,
                 subset_size=128, 
                 max_time_s=1,
                 max_per_tile=2):
    '''
    subset hyp3 outputs using tiles
    '''
    
    home_path = '/mnt/d/indennt'
    # set number of subsets to 0
    subset_counter = 0
    granules_sampled = []
    tiles = gpd.read_file(f'{home_path}/polygons/{subset_type}_RGI_grid_25km.shp')
    
    # continue to run until desired subset number is reached
    while subset_counter < subsets_desired:
        for orbit in orbit_list:
            signal_path = f'{home_path}/signal_maps/{orbit}'
            random.shuffle(frame_list)
            for frame in frame_list:
                random.shuffle(year_list)
                for year in year_list:
                    data_path = f'{home_path}/hyp3/{orbit}/{frame}/{year}'
                    granule_list = glob(f'{data_path}/*P012*/')

                    #with open('granules_sampled.pkl', 'rb') as f:
                        #granules_sampled = pickle.load(f)
                    
                    # loop through noise maps
                    random.shuffle(granule_list)
                    for granule_path in granule_list:
                        ds = hyp3_to_xarray_single(granule_path)
                        
                        #if ds.granule.item() in granules_sampled:
                            #print(f'granule {ds.granule.item()} already sampled')
                            #continue 
                        
                        print(f'working on {orbit}, {frame}, {year}, {ds.granule.item()}')
                        granule_counter=0
                        
                        signal_ds = xr.open_mfdataset(glob(f'{signal_path}/{orbit}_mean_signal_masked.tif'))
                        corr_ds = xr.open_mfdataset(glob(f'{signal_path}/{orbit}_mean_corr.tif'))
                        signal_ds = signal_ds.rio.clip_box(minx=ds.x.min(), miny=ds.y.min(), maxx=ds.x.max(), maxy=ds.y.max())
                        corr_ds = corr_ds.rio.clip_box(minx=ds.x.min(), miny=ds.y.min(), maxx=ds.x.max(), maxy=ds.y.max())
                        signal_ds = signal_ds.rio.reproject_match(ds.unw_phase, nodata=np.nan).squeeze()
                        corr_ds = corr_ds.rio.reproject_match(ds.unw_phase, nodata=np.nan).squeeze()

                        ds['signal'] = (('y', 'x'), signal_ds.band_data.values)
                        ds['signal_corr'] = (('y', 'x'), corr_ds.band_data.values)
            
                        # loop through tiles
                        tiles = tiles.sample(frac=1)
                        for i, tile in tiles.iterrows():
                            tile_counter = 0
                            
                            # clip to tile extent
                            try:
                                tile_ds = ds.rio.clip([tiles.iloc[i].geometry], crs=ds.rio.crs, drop=True)
                            except: #except if tile does not overlap interferogram
                                #print(f'no valid subsets in tile {i}')
                                continue
                            else:
                                #check if valid subset exists in tile
                                if np.invert(np.isnan(tile_ds.unw_phase.values)).sum() < subset_size**2:
                                    #print(f'no valid subsets in tile {i}')
                                    continue
                                else:
                                    timeout = time.time() + max_time_s # set time to spend on each tile
                                    # try to find appropriate subsets for a while
                                    while time.time() < timeout:
                                        # grab random subset within sample 
                                        subset_ds = sample_ds(tile_ds, subset_size)
                                        
                                        # test if subset elevation is above treeline
                                        if np.median(subset_ds.elevation.values) >= 3300:
                                            if (subset_ds.signal_corr > 0.85).sum() >= 100:
                                                # interpolate small gaps
                                                unw_phase_ds = subset_ds.unw_phase.interpolate_na(dim='x', use_coordinate=False)
                                                unw_phase_ds = unw_phase_ds.interpolate_na(dim='y', use_coordinate=False)
                                                
                                                # murp also has gaps to be interpolated, the rest do not
                                                murp_phase_ds = subset_ds.murp_phase.interpolate_na(dim='x', use_coordinate=False)
                                                murp_phase_ds = murp_phase_ds.interpolate_na(dim='y', use_coordinate=False)
            
                                                # check if data gaps remain in subset
                                                nan_count = (np.isnan(subset_ds.elevation.values).sum() + 
                                                             np.isnan(subset_ds.era5_phase.values).sum() +
                                                             np.isnan(murp_phase_ds.values).sum() +
                                                             np.isnan(subset_ds.signal.values).sum() +
                                                             np.isnan(unw_phase_ds.values).sum())
                
                                                if nan_count == 0:
                                                    subset_counter+=1
                                                    tile_counter+=1
                                                    granule_counter+=1
                                                    subset_name = f'tile{i}_{orbit}_{ds.granule.item()[5:13]}_{ds.granule.item()[21:29]}_sub{subset_counter}.tif'
            
                                                    # calculate era5 and murp noise
                                                    murp_noise = unw_phase_ds - murp_phase_ds
                                                    era5_noise = unw_phase_ds - subset_ds.era5_phase
            
                                                    # center signal on 0 (effective local reference point)
                                                    #subset_ds['signal'] = subset_ds['signal'] - subset_ds['signal'].median(dim=['x', 'y'])
    
                                                    # murp to correct signal maps
                                                    subset_ds = MuRP(subset_ds)
                                                    
                                                    # save subset
                                                    #unw_phase_ds.rio.to_raster(f'{home_path}/{subset_type}_subsets/noise/{subset_name}')
                                                    #murp_noise.rio.to_raster(f'{home_path}/{subset_type}_subsets/murp/{subset_name}')
                                                    #era5_noise.rio.to_raster(f'{home_path}/{subset_type}_subsets/era5/{subset_name}')
                                                    #subset_ds.elevation.rio.to_raster(f'{home_path}/{subset_type}_subsets/dem/{subset_name}')
                                                    #subset_ds.signal_MuRP.rio.to_raster(f'{home_path}/{subset_type}_subsets/signal/{subset_name}')
                                                    if subset_counter >= subsets_desired:
                                                        print('desired number of subsets reached, exiting')
                                                        return
                                                    if tile_counter >= max_per_tile:
                                                        break
                            #print(f'tile {i} subsets: {tile_counter}')
                        print(f'{ds.granule.item()} subsets: {granule_counter}')
                        granules_sampled.append(ds.granule.item())
                        
                        # save list of granules sampled
                        with open('granules_sampled.pkl', 'wb') as f:
                            pickle.dump(granules_sampled, f)

In [19]:
orbit_list = ['DT56', 'AT151']
year_list = ['2017', '2018', '2019', '2020', '2021', '2022']
frame_list = ['frame_1', 'frame_2', 'frame_3']

In [None]:
subset_noise(orbit_list, 
             frame_list, 
             year_list, 
             subsets_desired=40000, 
             subset_type='train')

working on DT56, frame_2, 2021, S1AA_20210914T130956_20210926T130957_VVP012_INT40_G_ueF_1D98
