In [None]:
import rasterio as rio
from conncomp import connectComponent
import numpy as np
from pathlib import Path
import time
import multiprocessing
# import matplotlib.pyplot as plt
# plt.rcParams['figure.figsize'] = [20, 10]
# from rasterio.plot import show

In [None]:
def get_unw_conncomp(unw_file, conncomp_file):
    with rio.open(unw_file) as unw:
        unw_profile = unw.profile
        transform = unw.transform
        if unw_profile['count'] == 1:
            amp_data = np.nan
            unw_data = unw.read(1)
        if unw_profile['count'] == 2:
            amp_data = unw.read(1)
            unw_data = unw.read(2)

    with rio.open(conncomp_file) as conncomp:
        # conncomp_profile = conncomp.profile
        conncomp_data = conncomp.read(1)
    return unw_data, conncomp_data, amp_data, unw_profile, transform

In [None]:
def write_file(unw_arr, amp_arr, profile, new_filename):
    if np.isnan(amp_arr):
        del amp_arr

    with rio.open(new_filename, 'w', **profile) as dst:
        if profile['count'] == 1:
            dst.write(unw_arr, 1)
        if profile['count'] == 2:
            dst.write(amp_arr, 1)
            dst.write(unw_arr, 2)
    return

In [None]:
# input_folder = Path('/u/aurora-r0/havazli/disp-test-cases/DC_F834_T04/unwrapped')
input_folder = Path('/u/aurora-r0/havazli/disp-test-cases/Idaho/Idaho_V11/unwrapped')
output_folder = Path(f'{input_folder.parent}/unwrapped_bridging_masked')
output_folder.mkdir(exist_ok=True)

In [None]:
unw_file_list = sorted(input_folder.glob('*.unw.tif'))
conncomp_file_list = sorted(input_folder.glob('*.unw.conncomp.tif'))

In [None]:
def bridge_iteration(unw_file, conncomp_file, output_folder):
    unw, conncomp, amp, profile, transform = get_unw_conncomp(unw_file, conncomp_file)
    unw = np.ma.masked_where(unw < -1000, unw)
    cc = connectComponent(conncomp=conncomp, metadata=profile)
    brdg_labels = cc.label()
    bridges = cc.find_mst_bridge()
    bridge_unw = cc.unwrap_conn_comp(unw)
    outfile_name = f"{output_folder}/{unw_file.stem.replace('.unw', '_brdg.unw')}{unw_file.suffix}"
    write_file(bridge_unw, amp, profile, outfile_name)

    return print(f'Wrote: {outfile_name}')

In [None]:
st = time.time()
with multiprocessing.Pool(processes=4) as pool:
    # Map the process_iteration function to each pair of unw_file and conncomp_file
    pool.starmap(bridge_iteration, [(unw, conncomp, output_folder) for unw, conncomp in zip(unw_file_list, conncomp_file_list)])
et = time.time()
elapsed_time = (et - st) / 60
print(f'Elapsed time: {elapsed_time} minutes') # Elapsed time: 68.6463327685992 minutes Idaho_V11

In [None]:
# st = time.time()
# for idx, val in enumerate(unw_file_list):
#     unw, conncomp, amp, profile, transform = get_unw_conncomp(val, conncomp_file_list[idx])
#     unw = np.ma.masked_where(unw < -1000, unw)
#     cc = connectComponent(conncomp=conncomp, metadata=profile)
#     brdg_labels = cc.label()
#     bridges = cc.find_mst_bridge()
#     bridge_unw = cc.unwrap_conn_comp(unw)
#     outfile_name = f'{output_folder}/{val.stem}_brdg_msk{val.suffix}'
#     write_file(bridge_unw, amp, profile, outfile_name)
# et = time.time()
# elapsed_time = (et - st) / 60
# print(f'Elapsed time: {elapsed_time} minutes')

In [None]:
import gc
gc.collect()