# subset training data
This notebook contains functions to generate subsets from multiple raster sources

In [None]:
def subset_ASO(subset_size=128, 
               max_time_s=1,
               max_per_tile=3):
    """
    Get random samples from multiple ASO rasters with different extents and timestamps
    """
    home_path = '..'
    # set number of subsets to 0
    subset_counter = 0
    granules_sampled = []
    tiles = gpd.read_file(f'{home_path}/data/polygons/{subset_type}_RGI_grid_25km.shp')

    # loop through tiles
    # loop through ASO rasters
    # 


In [None]:
# code from other project: 
def subset_noise(orbit_list, 
                 frame_list, 
                 year_list, 
                 subsets_desired, 
                 subset_type,
                 subset_size=128, 
                 max_time_s=1,
                 max_per_tile=3):
    '''
    subset hyp3 outputs using tiles
    '''
    
    home_path = '/mnt/d/indennt'
    # set number of subsets to 0
    subset_counter = 0
    granules_sampled = []
    tiles = gpd.read_file(f'{home_path}/polygons/{subset_type}_RGI_grid_25km.shp')
    
    #with open('granules_sampled.pkl', 'rb') as f:
        #granules_sampled = pickle.load(f)
    
    # continue to run until desired subset number is reached
    #while subset_counter < subsets_desired:
    for orbit in orbit_list:
        signal_path = f'{home_path}/signal_maps/{orbit}'
        #random.shuffle(frame_list)
        for frame in frame_list:
            #random.shuffle(year_list)
            for year in year_list:
                data_path = f'{home_path}/hyp3/{orbit}/{frame}/{year}'
                granule_list = glob(f'{data_path}/*P012*/')
                
                # loop through noise maps
                #random.shuffle(granule_list)
                for granule_path in granule_list:
                    granule = os.path.basename(os.path.normpath(granule_path))
                    
                    if granule in granules_sampled:
                        print(f'granule {granule} already sampled')
                        continue 
                    
                    print(f'working on {orbit}, {frame}, {year}, {granule}')
                    granule_counter=0
                    
                    ds = hyp3_to_xarray_single(granule_path)
                    
                    signal_ds = xr.open_dataset(f'{signal_path}/{orbit}_mean_signal_masked.tif', cache=False)
                    corr_ds = xr.open_dataset(f'{signal_path}/{orbit}_mean_corr.tif', cache=False)
                    signal_ds = signal_ds.rio.clip_box(minx=ds.x.min(), miny=ds.y.min(), maxx=ds.x.max(), maxy=ds.y.max())
                    corr_ds = corr_ds.rio.clip_box(minx=ds.x.min(), miny=ds.y.min(), maxx=ds.x.max(), maxy=ds.y.max())
                    signal_ds = signal_ds.rio.reproject_match(ds.unw_phase, nodata=np.nan).squeeze()
                    corr_ds = corr_ds.rio.reproject_match(ds.unw_phase, nodata=np.nan).squeeze()

                    ds['signal'] = (('y', 'x'), signal_ds.band_data.values)
                    ds['signal_corr'] = (('y', 'x'), corr_ds.band_data.values)
        
                    # loop through tiles
                    tiles = tiles.sample(frac=1)
                    for i, tile in tiles.iterrows():
                        tile_counter = 0
                        
                        # clip to tile extent
                        try:
                            tile_ds = ds.rio.clip([tiles.iloc[i].geometry], crs=ds.rio.crs, drop=True)
                        except: #except if tile does not overlap interferogram
                            #print(f'no valid subsets in tile {i}')
                            continue
                        else: #check if valid subset exists in tile
                            if np.invert(np.isnan(tile_ds.unw_phase.values)).sum() < subset_size**2:
                                #print(f'no valid subsets in tile {i}')
                                continue
                            else:
                                timeout = time.time() + max_time_s # set time to spend on each tile
                                # try to find appropriate subsets for a while
                                while time.time() < timeout:
                                    #grab random subset within sample 
                                    subset_ds = sample_ds(tile_ds, subset_size)
                                    
                                    # test if subset elevation is above treeline
                                    if np.median(subset_ds.elevation.values) >= 3300:
                                        if (subset_ds.signal_corr > 0.85).sum() >= 100:
                                            # interpolate small gaps
                                            unw_phase_ds = subset_ds.unw_phase.interpolate_na(dim='x', use_coordinate=False)
                                            unw_phase_ds = unw_phase_ds.interpolate_na(dim='y', use_coordinate=False)
                                            
                                            # murp also has gaps to be interpolated, the rest do not
                                            murp_phase_ds = subset_ds.murp_phase.interpolate_na(dim='x', use_coordinate=False)
                                            murp_phase_ds = murp_phase_ds.interpolate_na(dim='y', use_coordinate=False)
        
                                            # check if data gaps remain in subset
                                            nan_count = (np.isnan(subset_ds.elevation.values).sum() + 
                                                         np.isnan(subset_ds.era5_phase.values).sum() +
                                                         np.isnan(murp_phase_ds.values).sum() +
                                                         np.isnan(subset_ds.signal.values).sum() +
                                                         np.isnan(unw_phase_ds.values).sum())
            
                                            if nan_count == 0:
                                                subset_counter+=1
                                                tile_counter+=1
                                                granule_counter+=1
                                                subset_name = f'tile{i}_{orbit}_{ds.granule.item()[5:13]}_{ds.granule.item()[21:29]}_sub{subset_counter}.tif'
        
                                                # calculate era5 and murp noise
                                                murp_noise = unw_phase_ds - murp_phase_ds
                                                era5_noise = unw_phase_ds - subset_ds.era5_phase
        
                                                # center signal on 0 (effective local reference point)
                                                #subset_ds['signal'] = subset_ds['signal'] - subset_ds['signal'].median(dim=['x', 'y'])

                                                # murp to correct signal maps
                                                subset_ds = MuRP(subset_ds)
                                                
                                                # save subset
                                                unw_phase_ds.rio.to_raster(f'{home_path}/{subset_type}_subsets/noise/{subset_name}')
                                                murp_noise.rio.to_raster(f'{home_path}/{subset_type}_subsets/murp/{subset_name}')
                                                era5_noise.rio.to_raster(f'{home_path}/{subset_type}_subsets/era5/{subset_name}')
                                                subset_ds.elevation.rio.to_raster(f'{home_path}/{subset_type}_subsets/dem/{subset_name}')
                                                subset_ds.signal_MuRP.rio.to_raster(f'{home_path}/{subset_type}_subsets/signal/{subset_name}')
                                                #if subset_counter >= subsets_desired:
                                                    #print('desired number of subsets reached, exiting')
                                                    #return
                                                if tile_counter >= max_per_tile:
                                                    break
                        #print(f'tile {i} subsets: {tile_counter}')
                        tile_ds.close()

                    print(f'{ds.granule.item()} subsets: {granule_counter}')
                    gc.collect()
                    granules_sampled.append(granule)
                    
                    #save list of granules sampled
                    with open('granules_sampled.pkl', 'wb') as f:
                        pickle.dump(granules_sampled, f)