# Mean Shift Segmentation Batch Runner

Author: [Jerry Clayton](https://github.com/jerry-clayton)

For: [Ni-Meister Lab](http://www.geography.hunter.cuny.edu/~wenge/)

Adapted from [Ian Grant's Script](https://github.com/i-c-grant/ni-meister-gedi-biomass/blob/main/run_on_maap.py)

#### This Jupyter notebook handles the batch processing of the modified Mean Shift tree segmentation workflow developed by the author and Dr. Ni-Meister on the NASA MAAP platform. 

#### The [AMS3D](https://www.sciencedirect.com/science/article/abs/pii/S0034425716302292) was first proposed by Ferraz et. al, and this implementation depends on Dr. Nikolai Knapp's [MeanShiftR](https://github.com/niknap/MeanShiftR/tree/master/R) package

#### This workflow is executed in four parts: Split, Segment, Reconcile, and Merge

In [19]:
import datetime
import logging
import os
import shutil
import time
import glob
import tarfile

import warnings
from pathlib import Path
from typing import Dict, List

import click
import geopandas as gpd
import pandas as pd
from tqdm import tqdm
from geopandas import GeoDataFrame
from maap.maap import MAAP
from maap.Result import Granule

maap = MAAP(maap_host='api.maap-project.org')

In [20]:
def build_file_url(filename):
    url_first_part = "s3://maap-ops-workspace/jclayton0/sq_km_tiles_norm/" 
    url = f'{url_first_part}{filename}'
    return url

def build_test_file_url(filename):
    url_first_part = "s3://maap-ops-workspace/jclayton0/test-input-sm-tiles/" 
    url = f'{url_first_part}{filename}'
    return url

def get_split_kwargs(fileurl):
     job_kwargs = {
            "identifier": "Mean Shift Split LAS",
            "algo_id": "MS-Step-1-Split",
            "version": "main",
            "username": "jclayton0",
            "queue": "maap-dps-worker-64gb",
            "LAS": fileurl,
            "Subplot_width": 25,
            "Buffer_width": 10
    }
    
     return job_kwargs

def get_segment_kwargs(fileurl):
    job_kwargs = {
            "identifier": "Mean Shift Segment LAS",
            "algo_id": "MS-Step-2-Segment-v2",
            "version": "main",
            "username": "jclayton0",
            "queue": "maap-dps-worker-64gb",
            "Point Cloud RDS": fileurl,
            "Subplot_widthFrac_cores": 0.9
    }
    
    return job_kwargs

def get_reconcile_kwargs(tarball_url, las_url):
    job_kwargs = {
            "identifier": "Mean Shift Reconcile LAS",
            "algo_id": "MS-Step-3-Reconcile-v3",
            "version": "main",
            "username": "jclayton0",
            "queue": "maap-dps-worker-64gb",
            "tarball": tarball_url,
            "original_las": las_url
    }
    
    return job_kwargs

def get_merge_kwargs(fileurl):
    job_kwargs = {
            "identifier": "Mean Shift Merge Trees",
            "algo_id": "MS-Step-4-Merge",
            "version": "main",
            "username": "jclayton0",
            "queue": "maap-dps-worker-64gb",
            "segmented_las": fileurl
    }
    
    return job_kwargs



In [21]:
def local_url_to_s3(url):
    second = str.split(url,"my-private-bucket")[1]
    full = f"s3://maap-ops-workspace/jclayton0{second}"
    return full

def get_old_jobs_list(old_jobs_path):
    
    with open(old_jobs_path, 'r') as file:
        old_jobs = file.readlines()
    
    old_jobs = [line.strip() for line in old_jobs]
    return old_jobs

def get_succeeded_jobs_in_list(job_ids):
    
    succeeded_job_ids = [job_id for job_id in job_ids
                         if job_status_for(job_id) == "Succeeded"]
    return succeeded_job_ids

def get_failed_jobs_in_list(job_ids):
    
    failed_job_ids = [job_id for job_id in job_ids
                      if job_status_for(job_id) == "Failed"]
    return failed_job_ids

def get_other_status_jobs_in_list(job_ids):

    other_job_ids = [job_id for job_id in job_ids
                     if job_status_for(job_id)
                     not in ["Succeeded", "Failed"]]
    return other_job_ids


import json
def get_failed_job_input_LAS(failed_json_path):

    with open(failed_json_path,'r') as file:
        data = json.load(file)
        
    return data.get('params').get('job_specification').get('params')[0].get('value')



In [22]:
def job_status_for(job_id: str) -> str:
    return maap.getJobStatus(job_id)

def job_result_for(job_id: str) -> str:
    return maap.getJobResult(job_id)[0]

def to_job_output_dir(job_result_url: str, username: str) -> str:
    return (f"/projects/my-private-bucket/"
            f"{job_result_url.split(f'/{username}/')[1]}")

def to_failed_job_params(job_result_url: str) -> str:
    return (f"../triaged-jobs/"
            f"{job_result_url.split(f'/triaged_job/')[1]}/_job.json")


def log_and_print(message: str):
    logging.info(message)
    click.echo(message)

def update_job_states(job_states: Dict[str, str],
                      final_states: List[str],
                      batch_size: int,
                      delay: int) -> Dict[str, str]:
    """Update the job states dictionary in place.

    Updating occurs in batches, with a delay in seconds between batches.

    Return the number of jobs updated to final states.
    """
    batch_count = 0
    n_updated_to_final = 0
    for job_id, state in job_states.items():
        if state not in final_states:
            new_state: str = job_status_for(job_id)
            job_states[job_id] = new_state
            if new_state in final_states:
                n_updated_to_final += 1
            batch_count += 1
        # Sleep after each batch to avoid overwhelming the API
        if batch_count == batch_size:
            time.sleep(delay)
            batch_count = 0

    return n_updated_to_final



In [30]:
maap.getQueues().json()

{'message': 'Not authorized.'}

In [6]:
## Source Dir for large tiles, relative path
long_list = glob.glob('../my-private-bucket/sq_km_tiles_norm/*')
test_list = glob.glob('../my-private-bucket/test-input-sm-tiles/*')

In [7]:
files = list()
for file in long_list:
    files.append(str.split(file, '/')[3])


In [8]:
files

['TEAK_large_001.las',
 'TEAK_large_002.las',
 'TEAK_large_003.las',
 'TEAK_large_004.las',
 'TEAK_large_005.las',
 'TEAK_large_006.las',
 'TEAK_large_007.las',
 'TEAK_large_008.las',
 'TEAK_large_009.las',
 'TEAK_large_010.las',
 'TEAK_large_011.las',
 'TEAK_large_012.las',
 'TEAK_large_013.las',
 'TEAK_large_014.las',
 'TEAK_large_015.las',
 'TEAK_large_016.las',
 'TEAK_large_017.las',
 'TEAK_large_018.las',
 'TEAK_large_019.las',
 'TEAK_large_020.las',
 'TEAK_large_021.las',
 'TEAK_large_022.las',
 'TEAK_large_023.las',
 'TEAK_large_024.las',
 'TEAK_large_025.las',
 'TEAK_large_026.las',
 'TEAK_large_027.las',
 'TEAK_large_028.las',
 'TEAK_large_029.las',
 'TEAK_large_030.las',
 'TEAK_large_031.las',
 'TEAK_large_032.las',
 'TEAK_large_033.las',
 'TEAK_large_034.las',
 'TEAK_large_035.las',
 'TEAK_large_036.las',
 'TEAK_large_037.las',
 'TEAK_large_038.las',
 'TEAK_large_039.las',
 'TEAK_large_040.las',
 'TEAK_large_041.las',
 'TEAK_large_042.las',
 'TEAK_large_043.las',
 'TEAK_larg

In [9]:
test_file = "s3://maap-ops-workspace/jclayton0/normalized/norm_TEAK_047_lidar_2021.las"
url = build_file_url(files[145])

## Get JobIDs from previous submissions

In [None]:
old_jobs_path = "../my-private-bucket/run_output_20241115_131005/job_ids.txt"
old_output_dir = "../my-private-bucket/run_output_20241115_131005/"
old_jobs = get_old_jobs_list(old_jobs_path)
old_jobs

In [73]:
## Get the failed job IDs, find the submission parameter locations,
# and scrape the input file s3 url from the parameter JSON

failed_job_ids = get_failed_jobs_in_list(old_jobs)
failed_job_urls = [job_result_for(job) for job in failed_job_ids]
failed_params = [to_failed_job_params(url) for url in failed_job_urls]
failed_inputs = [get_failed_job_input_LAS(json) for json in failed_params]

failed_inputs

['s3://maap-ops-workspace/jclayton0/sq_km_tiles_norm/TEAK_large_008.las',
 's3://maap-ops-workspace/jclayton0/sq_km_tiles_norm/TEAK_large_009.las',
 's3://maap-ops-workspace/jclayton0/sq_km_tiles_norm/TEAK_large_010.las',
 's3://maap-ops-workspace/jclayton0/sq_km_tiles_norm/TEAK_large_011.las',
 's3://maap-ops-workspace/jclayton0/sq_km_tiles_norm/TEAK_large_012.las',
 's3://maap-ops-workspace/jclayton0/sq_km_tiles_norm/TEAK_large_013.las',
 's3://maap-ops-workspace/jclayton0/sq_km_tiles_norm/TEAK_large_014.las',
 's3://maap-ops-workspace/jclayton0/sq_km_tiles_norm/TEAK_large_055.las',
 's3://maap-ops-workspace/jclayton0/sq_km_tiles_norm/TEAK_large_097.las']

In [77]:
# resubmit_kwargs = [get_split_kwargs(url) for url in failed_inputs]
# resubmit_kwargs

In [None]:
failed_tiles = [tile.split('_norm/')[1] for tile in failed_inputs]
failed_tiles
#save them to a txtfile

## Run Split on all files and collect the outputs

In [10]:
start_time = datetime.datetime.now()

# Set up output directory
output_dir = Path(f"/projects/my-private-bucket/run_output_"
                      f"{start_time.strftime('%Y%m%d_%H%M%S')}")
os.makedirs(output_dir, exist_ok=False)

# Set up log
logging.basicConfig(filename=output_dir / "run.log",
                        level=logging.INFO,
                        format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S')

log_and_print(f"Starting new model run at MAAP at {start_time}.")



Starting new model run at MAAP at 2024-11-15 13:10:05.350759.


In [78]:
    # Submit jobs for each pair of granules
username = "jclayton0"
job_limit = 1000
check_interval = 90 #seconds between updates

##### REMOVE this block if doing new run
start_time = datetime.datetime.now()
log_and_print(f"Starting new model run at MAAP at {start_time}.")
output_dir = old_output_dir 
files = failed_inputs
######


    if job_limit:
        n_jobs = min(len(files), job_limit)
    else:
        n_jobs = len(files)
    log_and_print(f"Submitting {n_jobs} "
                  f"jobs.")

    job_kwargs_list = []
    for file in files:
        
        #job_kwargs = get_split_kwargs(build_file_url(file))
        job_kwargs = get_split_kwargs(build_file_url(file))

        job_kwargs_list.append(job_kwargs)

    jobs = []
    for job_kwargs in job_kwargs_list[:job_limit]:
        job = maap.submitJob(**job_kwargs)
        jobs.append(job)

    print(f"Submitted {len(jobs)} jobs.")

    job_ids = [job.id for job in jobs]

    # Write job IDs to a file in case processing is interrupted
    job_ids_file = output_dir / "job_ids.txt"
    with open(job_ids_file, 'w') as f:
        for job_id in job_ids:
            f.write(f"{job_id}\n")
    log_and_print(f"Job IDs written to {job_ids_file}")

    # Give the jobs time to start
    click.echo("Waiting for jobs to start...")
    time.sleep(10)

    # Initialize job states
    final_states = ["Succeeded", "Failed", "Deleted"]

    job_states = {job_id: "" for job_id in job_ids}
    update_job_states(job_states, final_states, batch_size=50, delay=10)

    known_completed = len([state for state in job_states.values()
                           if state in final_states])

    while True:
        try:
            with tqdm(total=len(job_ids), desc="Jobs Completed", unit="job") as pbar:
                while any(state not in final_states for state in job_states.values()):

                    # Update the job states
                    n_new_completed: int = update_job_states(job_states,
                                                             final_states,
                                                             batch_size = 50,
                                                             delay = 10)

                    # Update the progress bar
                    pbar.update(n_new_completed)
                    last_updated = datetime.datetime.now()
                    known_completed += n_new_completed
                    
                    status_counts = {status: list(job_states.values()).count(status)
                                     for status in final_states + ["Accepted", "Running"]}
                    status_counts["Other"] = len(job_states) - sum(status_counts.values())
                    status_counts["Last updated"] = last_updated.strftime("%H:%M:%S")

                    pbar.set_postfix(status_counts, refresh=True)

                    if known_completed == len(job_ids):
                        break

                    time.sleep(check_interval)

        except KeyboardInterrupt:
            print("Are you sure you want to cancel the process?")
            print("Press Ctrl+C again to confirm, or wait to continue.")
            try:
                time.sleep(3)
                print("Continuing...")
            except KeyboardInterrupt:
                print("Model run aborted.")
                pending_jobs = [job_id for job_id, state in job_states.items()
                                if state not in final_states]
                click.echo(f"Cancelling {len(pending_jobs)} pending jobs.")
                for job_id in pending_jobs:
                    maap.cancelJob(job_id)
                break
        else:
            break

    # Process the results once all jobs are completed
    succeeded_job_ids = get_succeeded_jobs_in_list(job_ids)
    
    failed_job_ids = get_failed_jobs_in_list(job_ids)

    other_job_ids = get_other_status_jobs_in_list(job_ids) 

    click.echo(f"Processing results for {len(succeeded_job_ids)} "
               f"succeeded jobs.")

    click.echo(f"Gathering tarball paths from succeeded jobs.")

    tar_paths = []
    for job_id in tqdm(succeeded_job_ids):
        job_result_url = job_result_for(job_id)
        time.sleep(1) # to avoid overwhelming the API
        job_output_dir = to_job_output_dir(job_result_url, username)
        # Find .tar.gz file in the output dir
        tar_file = [f for f in os.listdir(job_output_dir)
                     if f.endswith('.tar.gz')]
        if len(tar_file) > 1:
            warnings.warn(f"Multiple .tar.gz files found in "
                          f"{job_output_dir}.")
        if len(tar_file) == 0:
            warnings.warn(f"No .tar.gz files found in "
                          f"{job_output_dir}.")
        if tar_file:
            tar_paths.append(os.path.join(job_output_dir, tar_file[0]))

    # Log the succeeded and failed job IDs
    logging.info(f"{len(succeeded_job_ids)} jobs succeeded.")
    logging.info(f"Succeeded job IDs: {succeeded_job_ids}\n")
    logging.info(f"{len(failed_job_ids)} jobs failed.")
    logging.info(f"Failed job IDs: {failed_job_ids}\n")
    logging.info(f"{len(other_job_ids)} jobs in other states.")
    logging.info(f"Other job IDs: {other_job_ids}\n")

    # Copy all tarballs to the output directory
    click.echo(f"Copying {len(tar_paths)} Tarballs to {output_dir}.")
    copy_batch_count = 0
    for tar_path in tqdm(tar_paths):
        try:
            shutil.copy(tar_path, output_dir)
            copy_batch_count += 1
            if copy_batch_count == 50:
                time.sleep(60)
                copy_batch_count = 0
            else:
                time.sleep(2)
            
        except Exception as e:
            warnings.warn(f"Error copying {tar_path} to {output_dir}: {str(e)}")
            click.echo("Retrying in 10 seconds.")
            time.sleep(10)
            try:
                shutil.copy(tar_path, output_dir)
            except Exception as e:
                click.echo(f"Retry failed: {str(e)}")
                click.echo(f"Skipping {tar_path}.")
                continue

    # Compress the output directory
    # click.echo(f"Compressing output directory.")
    # shutil.make_archive(output_dir, 'zip', output_dir)
    # click.echo(f"Output directory compressed to {output_dir}.zip.")

    end_time = datetime.datetime.now()

    log_and_print(f"Model run completed at {end_time}.")
    
# if __name__ == "__main__":
#     main()

Starting new model run at MAAP at 2024-11-19 15:29:05.771577.
Submitting 9 jobs.
Submitted 9 jobs.


TypeError: unsupported operand type(s) for /: 'str' and 'str'

In [85]:
new_job_ids = job_ids
new_job_ids

['a39f19f7-b9dc-4f81-aaf9-d2a7d7d4bf0b',
 '43bb3357-a447-41a6-8d18-2c24947e349e',
 '886c04df-acbc-4e3e-8807-356951c86da9',
 '0e238056-d343-4924-b8ff-7e04d0fb7224',
 '7af49420-7269-4955-b6f9-2565849d8be1',
 '96001cf0-2a46-469f-9f40-d7335e8823d5',
 'cf180c75-0169-4c11-a8dc-96ae5bb1d4be',
 '7737a35c-5724-40b4-a352-b479b14df366',
 '668f99ba-5fca-4980-a82b-276d06681ca5']

In [6]:
# comment out if doing new stuff
job_ids = old_jobs
output_dir = old_output_dir
username = "jclayton0"

# Process the results once all jobs are completed
succeeded_job_ids = [job_id for job_id in job_ids
                     if job_status_for(job_id) == "Succeeded"]

failed_job_ids = [job_id for job_id in job_ids
                  if job_status_for(job_id) == "Failed"]

other_job_ids = [job_id for job_id in job_ids
                 if job_status_for(job_id)
                 not in ["Succeeded", "Failed"]]

click.echo(f"Processing results for {len(succeeded_job_ids)} "
           f"succeeded jobs.")

click.echo(f"Gathering tarball paths from succeeded jobs.")

tar_paths = []
for job_id in tqdm(succeeded_job_ids):
    job_result_url = job_result_for(job_id)
    time.sleep(1) # to avoid overwhelming the API
    job_output_dir = to_job_output_dir(job_result_url, username)
    # Find .tar.gz file in the output dir
    tar_file = [f for f in os.listdir(job_output_dir)
                 if f.endswith('.tar.gz')]
    if len(tar_file) > 1:
        warnings.warn(f"Multiple .tar.gz files found in "
                      f"{job_output_dir}.")
    if len(tar_file) == 0:
        warnings.warn(f"No .tar.gz files found in "
                      f"{job_output_dir}.")
    if tar_file:
        tar_paths.append(os.path.join(job_output_dir, tar_file[0]))

# Log the succeeded and failed job IDs
logging.info(f"{len(succeeded_job_ids)} jobs succeeded.")
logging.info(f"Succeeded job IDs: {succeeded_job_ids}\n")
logging.info(f"{len(failed_job_ids)} jobs failed.")
logging.info(f"Failed job IDs: {failed_job_ids}\n")
logging.info(f"{len(other_job_ids)} jobs in other states.")
logging.info(f"Other job IDs: {other_job_ids}\n")

Processing results for 239 succeeded jobs.
Gathering tarball paths from succeeded jobs.


100%|██████████| 239/239 [04:49<00:00,  1.21s/it]


In [6]:
## pickle tarpaths 

import pickle

# with open('tar_paths_step_1.pkl', 'wb') as f:
#    pickle.dump(tar_paths, f)
with open('tar_paths_step_1.pkl', 'rb') as f:
    tar_paths = pickle.load(f)


In [7]:
output_dir = old_output_dir
outdir_files = os.listdir(output_dir)
# Copy all tarballs to the output directory
click.echo(f"Copying {len(tar_paths)} Tarballs to {output_dir}.")
copy_batch_count = 0
for tar_path in tqdm(tar_paths):
    fname = tar_path.split('/')[-1]
    if fname in outdir_files:
        copy_batch_count +=1
        continue
    else:
        try:
            shutil.copy(tar_path, output_dir)
            copy_batch_count += 1
            if copy_batch_count == 10:
                time.sleep(60)
                copy_batch_count = 0
            else:
                time.sleep(2)
            
        except Exception as e:
            warnings.warn(f"Error copying {tar_path} to {output_dir}: {str(e)}")
            click.echo("Retrying in 10 seconds.")
            time.sleep(10)
            try:
                shutil.copy(tar_path, output_dir)
            except Exception as e:
                click.echo(f"Retry failed: {str(e)}")
                click.echo(f"Skipping {tar_path}.")
                continue


# Compress the output directory
# click.echo(f"Compressing output directory.")
# shutil.make_archive(output_dir, 'zip', output_dir)
# click.echo(f"Output directory compressed to {output_dir}.zip.")

end_time = datetime.datetime.now()

log_and_print(f"Model run completed at {end_time}.")


Copying 239 Tarballs to ../my-private-bucket/run_output_20241115_131005/.


100%|██████████| 239/239 [07:29<00:00,  1.88s/it]

Model run completed at 2024-11-19 20:15:00.425728.





## Above, split all files in the directory by submitting a split job then gather the results into a single directory

## This has to be run several times. You also need, in the terminal, to type ls -lhS | tac and remove any files that have a size of 0; they need to be re-transferred



## Then, run extract_tarballs.sh after modifying it to include the correct directory path. 

## It will take some TIME

## Now we below test batching 1460 jobs for one tile. This will need to be re-written as a for loop

In [16]:
# Directory containing your files
output_dir = old_output_dir



'../my-private-bucket/run_output_20241115_131005/TEAK_large_016/'

## What is gone:

* Code to make the kwargs from the directory
* Code to save the kwargs as pickles
* Code to read the kwargs from pickles and submit the jobs
* Any other code and notes I had there.

## What to do: 

1. Restart the shell script
2. Make internal hyperlinks
3. Plan every part of the rest of this
4. rewrite missing code
5. Push to github regularly
6. Something fucking else that I don't remember
7. List it all, 100% of everything that needs to happen for sure. What else though. There was something fucking else. it will come back to me.
8. Change step 1 so that it does not tarball things.

# What does the script still need to do?

1. Missing old parts up to submitting the segmentation jobs
2. Check every job ID for missing, resubmit the failed, and then overwrite the failed IDs
   * probably this looks like copying the successful IDs over and adding the newly submitted IDs to that list, then saving that list in the same file
3. Move all the segmented tiles into one directory and tarball them
   * this looks like what I did before: gather the paths for each and save them to a pkl, then move them en masse to new dir.
   * Then this looks like: tarballing them all
4. Then write a script to gather all the tarball paths and make kwargs from them
5. Then submit these all as jobs to step 3
6. Then gather outputs and move them, again.
7. Finally submit everything that's left to step 4.
8. Gather Step 4 outputs and then make the CSV. Probably on our server.