# Mean Shift Segmentation Batch Runner

Author: [Jerry Clayton](https://github.com/jerry-clayton)

For: [Ni-Meister Lab](http://www.geography.hunter.cuny.edu/~wenge/)

Adapted from [Ian Grant's Script](https://github.com/i-c-grant/ni-meister-gedi-biomass/blob/main/run_on_maap.py)

#### This Jupyter notebook handles the batch processing of the modified Mean Shift tree segmentation workflow developed by the author and Dr. Ni-Meister on the NASA MAAP platform. 

#### The [AMS3D](https://www.sciencedirect.com/science/article/abs/pii/S0034425716302292) was first proposed by Ferraz et. al, and this implementation depends on Dr. Nikolai Knapp's [MeanShiftR](https://github.com/niknap/MeanShiftR/tree/master/R) package

#### This workflow is executed in four parts: Split, Segment, Reconcile, and Merge

In [1]:
import datetime
import logging
import os
import shutil
import time
import glob
import tarfile

import warnings
from pathlib import Path
from typing import Dict, List

import click
import geopandas as gpd
import pandas as pd
from tqdm import tqdm
from geopandas import GeoDataFrame
from maap.maap import MAAP
from maap.Result import Granule

maap = MAAP(maap_host='api.maap-project.org')

In [2]:
def build_file_url(filename):
    url_first_part = "s3://maap-ops-workspace/jclayton0/sq_km_tiles_norm/" 
    url = f'{url_first_part}{filename}'
    return url

def build_test_file_url(filename):
    url_first_part = "s3://maap-ops-workspace/jclayton0/test-input-sm-tiles/" 
    url = f'{url_first_part}{filename}'
    return url

def get_split_kwargs(fileurl):
     job_kwargs = {
            "identifier": "Mean Shift Split LAS",
            "algo_id": "MS-Step-1-Split",
            "version": "main",
            "username": "jclayton0",
            "queue": "maap-dps-worker-64gb",
            "LAS": fileurl,
            "Subplot_width": 25,
            "Buffer_width": 10
    }
    
     return job_kwargs

def get_segment_kwargs(fileurl):
    job_kwargs = {
            "identifier": "Mean Shift Segment LAS",
            "algo_id": "MS-Step-2-Segment-v2",
            "version": "main",
            "username": "jclayton0",
            "queue": "maap-dps-worker-64gb",
            "Point Cloud RDS": fileurl,
            "Subplot_widthFrac_cores": 0.9
    }
    
    return job_kwargs

def get_reconcile_kwargs(tarball_url, las_url):
    job_kwargs = {
            "identifier": "Mean Shift Reconcile LAS",
            "algo_id": "MS-Step-3-Reconcile-v3",
            "version": "main",
            "username": "jclayton0",
            "queue": "maap-dps-worker-64gb",
            "tarball": tarball_url,
            "original_las": las_url
    }
    
    return job_kwargs

def get_merge_kwargs(fileurl):
    job_kwargs = {
            "identifier": "Mean Shift Merge Trees",
            "algo_id": "MS-Step-4-Merge",
            "version": "main",
            "username": "jclayton0",
            "queue": "maap-dps-worker-64gb",
            "segmented_las": fileurl
    }
    
    return job_kwargs



In [3]:
def local_url_to_s3(url):
    second = str.split(url,"my-private-bucket")[1]
    full = f"s3://maap-ops-workspace/jclayton0{second}"
    return full

In [4]:
def job_status_for(job_id: str) -> str:
    return maap.getJobStatus(job_id)

def job_result_for(job_id: str) -> str:
    return maap.getJobResult(job_id)[0]

def to_job_output_dir(job_result_url: str, username: str) -> str:
    return (f"/projects/my-private-bucket/"
            f"{job_result_url.split(f'/{username}/')[1]}")

def log_and_print(message: str):
    logging.info(message)
    click.echo(message)

def update_job_states(job_states: Dict[str, str],
                      final_states: List[str],
                      batch_size: int,
                      delay: int) -> Dict[str, str]:
    """Update the job states dictionary in place.

    Updating occurs in batches, with a delay in seconds between batches.

    Return the number of jobs updated to final states.
    """
    batch_count = 0
    n_updated_to_final = 0
    for job_id, state in job_states.items():
        if state not in final_states:
            new_state: str = job_status_for(job_id)
            job_states[job_id] = new_state
            if new_state in final_states:
                n_updated_to_final += 1
            batch_count += 1
        # Sleep after each batch to avoid overwhelming the API
        if batch_count == batch_size:
            time.sleep(delay)
            batch_count = 0

    return n_updated_to_final



In [5]:
maap.getQueues().json()

{'code': 200,
 'message': 'success',
 'queues': ['maap-dps-sandbox',
  'maap-dps-worker-64gb',
  'maap-dps-cuny-worker-512gb',
  'maap-dps-worker-32vcpu-64gb',
  'maap-dps-worker-32gb',
  'maap-dps-worker-8gb',
  'maap-dps-worker-16gb']}

In [6]:
## Source Dir for large tiles, relative path
long_list = glob.glob('../my-private-bucket/sq_km_tiles_norm/*')
test_list = glob.glob('../my-private-bucket/test-input-sm-tiles/*')

In [7]:
files = list()
for file in long_list:
    files.append(str.split(file, '/')[3])


In [8]:
files

['TEAK_large_001.las',
 'TEAK_large_002.las',
 'TEAK_large_003.las',
 'TEAK_large_004.las',
 'TEAK_large_005.las',
 'TEAK_large_006.las',
 'TEAK_large_007.las',
 'TEAK_large_008.las',
 'TEAK_large_009.las',
 'TEAK_large_010.las',
 'TEAK_large_011.las',
 'TEAK_large_012.las',
 'TEAK_large_013.las',
 'TEAK_large_014.las',
 'TEAK_large_015.las',
 'TEAK_large_016.las',
 'TEAK_large_017.las',
 'TEAK_large_018.las',
 'TEAK_large_019.las',
 'TEAK_large_020.las',
 'TEAK_large_021.las',
 'TEAK_large_022.las',
 'TEAK_large_023.las',
 'TEAK_large_024.las',
 'TEAK_large_025.las',
 'TEAK_large_026.las',
 'TEAK_large_027.las',
 'TEAK_large_028.las',
 'TEAK_large_029.las',
 'TEAK_large_030.las',
 'TEAK_large_031.las',
 'TEAK_large_032.las',
 'TEAK_large_033.las',
 'TEAK_large_034.las',
 'TEAK_large_035.las',
 'TEAK_large_036.las',
 'TEAK_large_037.las',
 'TEAK_large_038.las',
 'TEAK_large_039.las',
 'TEAK_large_040.las',
 'TEAK_large_041.las',
 'TEAK_large_042.las',
 'TEAK_large_043.las',
 'TEAK_larg

In [9]:
test_file = "s3://maap-ops-workspace/jclayton0/normalized/norm_TEAK_047_lidar_2021.las"
url = build_file_url(files[145])

## Run Split on all files and collect the outputs

In [10]:
start_time = datetime.datetime.now()

# Set up output directory
output_dir = Path(f"/projects/my-private-bucket/run_output_"
                      f"{start_time.strftime('%Y%m%d_%H%M%S')}")
os.makedirs(output_dir, exist_ok=False)

# Set up log
logging.basicConfig(filename=output_dir / "run.log",
                        level=logging.INFO,
                        format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S')

log_and_print(f"Starting new model run at MAAP at {start_time}.")



Starting new model run at MAAP at 2024-11-15 12:50:06.899425.


In [None]:
    # Submit jobs for each pair of granules
username = "jclayton0"
job_limit = 20
check_interval = 90 #seconds between updates
    if job_limit:
        n_jobs = min(len(files), job_limit)
    else:
        n_jobs = len(files)
    log_and_print(f"Submitting {n_jobs} "
                  f"jobs.")

    job_kwargs_list = []
    for file in files:
        
        #job_kwargs = get_split_kwargs(build_file_url(file))
        job_kwargs = get_split_kwargs(build_file_url(file))

        job_kwargs_list.append(job_kwargs)

    jobs = []
    for job_kwargs in job_kwargs_list[:job_limit]:
        job = maap.submitJob(**job_kwargs)
        jobs.append(job)

    print(f"Submitted {len(jobs)} jobs.")

    job_ids = [job.id for job in jobs]

    # Write job IDs to a file in case processing is interrupted
    job_ids_file = output_dir / "job_ids.txt"
    with open(job_ids_file, 'w') as f:
        for job_id in job_ids:
            f.write(f"{job_id}\n")
    log_and_print(f"Job IDs written to {job_ids_file}")

    # Give the jobs time to start
    click.echo("Waiting for jobs to start...")
    time.sleep(10)

    # Initialize job states
    final_states = ["Succeeded", "Failed", "Deleted"]

    job_states = {job_id: "" for job_id in job_ids}
    update_job_states(job_states, final_states, batch_size=50, delay=10)

    known_completed = len([state for state in job_states.values()
                           if state in final_states])

    while True:
        try:
            with tqdm(total=len(job_ids), desc="Jobs Completed", unit="job") as pbar:
                while any(state not in final_states for state in job_states.values()):

                    # Update the job states
                    n_new_completed: int = update_job_states(job_states,
                                                             final_states,
                                                             batch_size = 50,
                                                             delay = 10)

                    # Update the progress bar
                    pbar.update(n_new_completed)
                    last_updated = datetime.datetime.now()
                    known_completed += n_new_completed
                    
                    status_counts = {status: list(job_states.values()).count(status)
                                     for status in final_states + ["Accepted", "Running"]}
                    status_counts["Other"] = len(job_states) - sum(status_counts.values())
                    status_counts["Last updated"] = last_updated.strftime("%H:%M:%S")

                    pbar.set_postfix(status_counts, refresh=True)

                    if known_completed == len(job_ids):
                        break

                    time.sleep(check_interval)

        except KeyboardInterrupt:
            print("Are you sure you want to cancel the process?")
            print("Press Ctrl+C again to confirm, or wait to continue.")
            try:
                time.sleep(3)
                print("Continuing...")
            except KeyboardInterrupt:
                print("Model run aborted.")
                pending_jobs = [job_id for job_id, state in job_states.items()
                                if state not in final_states]
                click.echo(f"Cancelling {len(pending_jobs)} pending jobs.")
                for job_id in pending_jobs:
                    maap.cancelJob(job_id)
                break
        else:
            break

    # Process the results once all jobs are completed
    succeeded_job_ids = [job_id for job_id in job_ids
                         if job_status_for(job_id) == "Succeeded"]
    
    failed_job_ids = [job_id for job_id in job_ids
                      if job_status_for(job_id) == "Failed"]

    other_job_ids = [job_id for job_id in job_ids
                     if job_status_for(job_id)
                     not in ["Succeeded", "Failed"]]

    click.echo(f"Processing results for {len(succeeded_job_ids)} "
               f"succeeded jobs.")

    click.echo(f"Gathering tarball paths from succeeded jobs.")

    tar_paths = []
    for job_id in tqdm(succeeded_job_ids):
        job_result_url = job_result_for(job_id)
        time.sleep(1) # to avoid overwhelming the API
        job_output_dir = to_job_output_dir(job_result_url, username)
        # Find .tar.gz file in the output dir
        tar_file = [f for f in os.listdir(job_output_dir)
                     if f.endswith('.tar.gz')]
        if len(tar_file) > 1:
            warnings.warn(f"Multiple .tar.gz files found in "
                          f"{job_output_dir}.")
        if len(tar_file) == 0:
            warnings.warn(f"No .tar.gz files found in "
                          f"{job_output_dir}.")
        if tar_file:
            tar_paths.append(os.path.join(job_output_dir, tar_file[0]))

    # Log the succeeded and failed job IDs
    logging.info(f"{len(succeeded_job_ids)} jobs succeeded.")
    logging.info(f"Succeeded job IDs: {succeeded_job_ids}\n")
    logging.info(f"{len(failed_job_ids)} jobs failed.")
    logging.info(f"Failed job IDs: {failed_job_ids}\n")
    logging.info(f"{len(other_job_ids)} jobs in other states.")
    logging.info(f"Other job IDs: {other_job_ids}\n")

    # Copy all tarballs to the output directory
    click.echo(f"Copying {len(tar_paths)} Tarballs to {output_dir}.")
    copy_batch_count = 0
    for tar_path in tqdm(tar_paths):
        try:
            shutil.copy(tar_path, output_dir)
            copy_batch_count += 1
            if copy_batch_count == 50:
                time.sleep(60)
                copy_batch_count = 0
            else:
                time.sleep(2)
            
        except Exception as e:
            warnings.warn(f"Error copying {tar_path} to {output_dir}: {str(e)}")
            click.echo("Retrying in 10 seconds.")
            time.sleep(10)
            try:
                shutil.copy(tar_path, output_dir)
            except Exception as e:
                click.echo(f"Retry failed: {str(e)}")
                click.echo(f"Skipping {tar_path}.")
                continue

    # Compress the output directory
    # click.echo(f"Compressing output directory.")
    # shutil.make_archive(output_dir, 'zip', output_dir)
    # click.echo(f"Output directory compressed to {output_dir}.zip.")

    end_time = datetime.datetime.now()

    log_and_print(f"Model run completed at {end_time}.")
    
# if __name__ == "__main__":
#     main()

Submitting 20 jobs.
Submitted 20 jobs.
Job IDs written to /projects/my-private-bucket/run_output_20241115_125006/job_ids.txt
Waiting for jobs to start...


Jobs Completed:  35%|███▌      | 7/20 [07:36<02:49, 13.03s/job, Succeeded=7, Failed=0, Deleted=0, Accepted=0, Running=13, Other=0, Last updated=12:57:57]

In [None]:

# Process the results once all jobs are completed
succeeded_job_ids = [job_id for job_id in job_ids
                     if job_status_for(job_id) == "Succeeded"]

failed_job_ids = [job_id for job_id in job_ids
                  if job_status_for(job_id) == "Failed"]

other_job_ids = [job_id for job_id in job_ids
                 if job_status_for(job_id)
                 not in ["Succeeded", "Failed"]]

click.echo(f"Processing results for {len(succeeded_job_ids)} "
           f"succeeded jobs.")

click.echo(f"Gathering tarball paths from succeeded jobs.")

tar_paths = []
for job_id in tqdm(succeeded_job_ids):
    job_result_url = job_result_for(job_id)
    time.sleep(1) # to avoid overwhelming the API
    job_output_dir = to_job_output_dir(job_result_url, username)
    # Find .tar.gz file in the output dir
    tar_file = [f for f in os.listdir(job_output_dir)
                 if f.endswith('.tar.gz')]
    if len(tar_file) > 1:
        warnings.warn(f"Multiple .tar.gz files found in "
                      f"{job_output_dir}.")
    if len(tar_file) == 0:
        warnings.warn(f"No .tar.gz files found in "
                      f"{job_output_dir}.")
    if tar_file:
        tar_paths.append(os.path.join(job_output_dir, tar_file[0]))

# Log the succeeded and failed job IDs
logging.info(f"{len(succeeded_job_ids)} jobs succeeded.")
logging.info(f"Succeeded job IDs: {succeeded_job_ids}\n")
logging.info(f"{len(failed_job_ids)} jobs failed.")
logging.info(f"Failed job IDs: {failed_job_ids}\n")
logging.info(f"{len(other_job_ids)} jobs in other states.")
logging.info(f"Other job IDs: {other_job_ids}\n")

# Copy all tarballs to the output directory
click.echo(f"Copying {len(tar_paths)} Tarballs to {output_dir}.")
copy_batch_count = 0
for tar_path in tqdm(tar_paths):
    try:
        shutil.copy(tar_path, output_dir)
        copy_batch_count += 1
        if copy_batch_count == 50:
            time.sleep(60)
            copy_batch_count = 0
        else:
            time.sleep(2)
        
    except Exception as e:
        warnings.warn(f"Error copying {tar_path} to {output_dir}: {str(e)}")
        click.echo("Retrying in 10 seconds.")
        time.sleep(10)
        try:
            shutil.copy(tar_path, output_dir)
        except Exception as e:
            click.echo(f"Retry failed: {str(e)}")
            click.echo(f"Skipping {tar_path}.")
            continue

# Compress the output directory
# click.echo(f"Compressing output directory.")
# shutil.make_archive(output_dir, 'zip', output_dir)
# click.echo(f"Output directory compressed to {output_dir}.zip.")

end_time = datetime.datetime.now()

log_and_print(f"Model run completed at {end_time}.")


## Above, split all files in the directory by submitting a split job then gather the results into a single directory

## Below, for each tarball (original tile), 

In [45]:
jobkwargs = get_split_kwargs(url)

jobkwargs

{'identifier': 'Mean Shift Split LAS',
 'algo_id': 'MS-Step-1-Split',
 'version': 'main',
 'username': 'jclayton0',
 'queue': 'maap-dps-worker-64gb',
 'LAS': 's3://maap-ops-workspace/jclayton0/sq_km_tiles_norm/TEAK_large_146.las',
 'Subplot_width': 25,
 'Buffer_width': 10}

In [92]:
# jobkwargs = get_split_kwargs(test_file)

# jobkwargs

In [46]:
maap.submitJob(**jobkwargs)

{'job_id': '454a0bc9-87fc-476f-aae6-5066d1b77830', 'status': 'success', 'machine_type': None, 'architecture': None, 'machine_memory_size': None, 'directory_size': None, 'operating_system': None, 'job_start_time': None, 'job_end_time': None, 'job_duration_seconds': None, 'cpu_usage': None, 'cache_usage': None, 'mem_usage': None, 'max_mem_usage': None, 'swap_usage': None, 'read_io_stats': None, 'write_io_stats': None, 'sync_io_stats': None, 'async_io_stats': None, 'total_io_stats': None, 'error_details': None, 'response_code': 200, 'outputs': []}

In [7]:
split_output_dir = "../my-private-bucket/dps_output/MS-Step-1-Split-Buffer-PC-fixed/main/Mean Shift Split LAS/2024/10/29/13/37/59/605257/"

split_output_dir_s3 = "s3://maap-ops-workspace/jclayton0/dps_output/MS-Step-1-Split-Buffer-PC-fixed/main/Mean Shift Split LAS/2024/10/29/13/37/59/605257/"
tarball = glob.glob(f"{split_output_dir}/*.tar.gz")

tarball

['../my-private-bucket/dps_output/MS-Step-1-Split-Buffer-PC-fixed/main/Mean Shift Split LAS/2024/10/29/13/37/59/605257/norm_TEAK_047_lidar_2021_point_clouds.tar.gz']

In [11]:

# Path to the tarball file and extraction directory
#tarball_path = "path/to/yourfile.tar.gz"
extract_path = f"{split_output_dir}rds/"

#os.makedirs(extract_path, exist_ok=False)
# Open the tarball file
with tarfile.open(tarball[0], "r:gz") as tar:
    # Extract all contents to the specified directory
    tar.extractall(path=extract_path)


In [14]:
pcs = glob.glob(f"{extract_path}/norm*")

## Change the above to exclude find.txt; also remove the generation of find.txt from Split, if not done already

In [15]:
pcs

['../my-private-bucket/dps_output/MS-Step-1-Split-Buffer-PC-fixed/main/Mean Shift Split LAS/2024/10/29/13/37/59/605257/rds/norm_TEAK_047_lidar_2021_pc_1',
 '../my-private-bucket/dps_output/MS-Step-1-Split-Buffer-PC-fixed/main/Mean Shift Split LAS/2024/10/29/13/37/59/605257/rds/norm_TEAK_047_lidar_2021_pc_2',
 '../my-private-bucket/dps_output/MS-Step-1-Split-Buffer-PC-fixed/main/Mean Shift Split LAS/2024/10/29/13/37/59/605257/rds/norm_TEAK_047_lidar_2021_pc_3',
 '../my-private-bucket/dps_output/MS-Step-1-Split-Buffer-PC-fixed/main/Mean Shift Split LAS/2024/10/29/13/37/59/605257/rds/norm_TEAK_047_lidar_2021_pc_4',
 '../my-private-bucket/dps_output/MS-Step-1-Split-Buffer-PC-fixed/main/Mean Shift Split LAS/2024/10/29/13/37/59/605257/rds/norm_TEAK_047_lidar_2021_pc_5',
 '../my-private-bucket/dps_output/MS-Step-1-Split-Buffer-PC-fixed/main/Mean Shift Split LAS/2024/10/29/13/37/59/605257/rds/norm_TEAK_047_lidar_2021_pc_6',
 '../my-private-bucket/dps_output/MS-Step-1-Split-Buffer-PC-fixed/main

In [16]:
jobResponses = list()
for file in pcs:
    s3url = local_url_to_s3(file)
    jobargs = get_segment_kwargs(s3url)
    response = maap.submitJob(**jobargs)
    jobResponses.append(response)

In [20]:
from ast import literal_eval

In [21]:
resulturl = job_result_for(literal_eval(str(jobResponses[0]))['job_id'])

In [22]:
resulturl

'https://maap-ops-workspace.s3-website-us-west-2.amazonaws.com/jclayton0/dps_output/MS-Step-2-Segment-v2/main/Mean Shift Split LAS/2024/11/11/20/42/30/628677'

In [29]:
tile_list = list()
for job in jobResponses:
    resulturl = to_job_output_dir(job_result_for(literal_eval(str(job))['job_id']), "jclayton0")
    file = glob.glob(f"{resulturl}/seg*")[0]
    print(file)
    tile_list.append(file)

/projects/my-private-bucket/dps_output/MS-Step-2-Segment-v2/main/Mean Shift Split LAS/2024/11/11/20/42/30/628677/seg_norm_TEAK_047_lidar_2021_pc_1


IndexError: list index out of range

In [24]:
seg_tile_output_dir = f"{output_dir}/segmented_tiles"
os.makedirs(seg_tile_output_dir, exist_ok=False)
log_and_print("making segmented tile output directory")

making segmented tile output directory


## Seems that the .rds extension is not being put on these files as necessary

In [25]:
# isolate filenames to move them
filename_list = list()

for tile in tile_list:
    file = str.split(tile, '/')[-1]
    filename_list.append(file)
    

In [26]:
# move all files to the folder we need them in
for index in range(len(tile_list)):
    old_path = tile_list[index]
    new_path = f"{seg_tile_output_dir}/{filename_list[index]}"
    os.rename(old_path, new_path)



In [27]:
# tarball them to submit to job

# Directory to compress and name of the output tarball
output_tarball = f"{seg_tile_output_dir}/TEAK_000_segmented.tar.gz"

# Create the tarball
with tarfile.open(output_tarball, "w:gz") as tar:
    # Add files to the tarball
    for root, _, files in os.walk(seg_tile_output_dir):
        for file in files:
            file_path = os.path.join(root, file)
            tar.add(file_path, arcname=os.path.relpath(file_path, seg_tile_output_dir))

print(f"Compressed directory '{seg_tile_output_dir}' into '{output_tarball}'.")


Compressed directory '/projects/my-private-bucket/run_output_20241111_203526/segmented_tiles' into '/projects/my-private-bucket/run_output_20241111_203526/segmented_tiles/TEAK_000_segmented.tar.gz'.


In [24]:
#tarball_url = local_url_to_s3(output_tarball)
tarball_url = "s3://maap-ops-workspace/jclayton0/run_output_20241111_203526/segmented_tiles/TEAK_000_segmented.tar.gz"
tarball_url 

's3://maap-ops-workspace/jclayton0/run_output_20241111_203526/segmented_tiles/TEAK_000_segmented.tar.gz'

In [30]:
test_file

's3://maap-ops-workspace/jclayton0/normalized/norm_TEAK_047_lidar_2021.las'

In [47]:
reconcile_kwargs = get_reconcile_kwargs(tarball_url, test_file)
rec_response = maap.submitJob(**reconcile_kwargs)