In [1]:
import geopandas as gpd
from joblib import Parallel, delayed
import pandas as pd
from pathlib import Path
import duckdb
import os



local_dir = "/home/christopher.x.ren/embeddings/ra_tea/planet_embeddings_v2"
local_paths = list(Path(local_dir).glob("*.parquet"))
mgrs_id_mapping = pd.read_parquet("gs://demeter-labs/tea/mgrs_id_mapping_tom_tiles_deduplicated.parquet")

## union parallel

In [2]:
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from tqdm import tqdm
import duckdb
import pandas as pd
from google.cloud import storage

# Assume mgrs_id_mapping and local_dir are defined globally.
storage_client = storage.Client()
bucket = storage_client.bucket("demeter-labs")
local_dir = "/home/christopher.x.ren/embeddings/ra_tea/planet_embeddings_v2"

def process_mgrs(mgrs_id):
    output_path = f"gs://demeter-labs/tea/planet_embeddings/{mgrs_id}.parquet"
    blob_name = f"tea/planet_embeddings/{mgrs_id}.parquet"
    blob = bucket.blob(blob_name)

    if blob.exists():
        print(f"Skipping {mgrs_id} - output already exists")
        return mgrs_id


    # Build the union query for the given mgrs_id.
    union_queries = []
    local_files = list(Path(local_dir).glob("*.parquet"))
    
    # Filter the mapping for this mgrs_id.
    tile_ids = mgrs_id_mapping[mgrs_id_mapping['mgrs_id'] == mgrs_id]['id'].tolist()
    ids_string = ", ".join([f"'{id}'" for id in tile_ids])
    
    for file_path in local_files:
        # Extract the time period from the file name.
        file_name = Path(file_path).name
        parts = file_name.split('_')
        year_month = f"{parts[-3]}_{parts[-2]}"
        
        union_queries.append(f"""
            SELECT id, embedding AS embedding, '{year_month}' AS time_period
            FROM read_parquet('{file_path}')
            WHERE id IN ({ids_string})
        """)
    
    # Build the final query by unioning the queries from all files.
    big_query = "\nUNION ALL\n".join(union_queries)
    
    # Create a new DuckDB connection in this process.
    con = duckdb.connect()
    con.sql("INSTALL spatial; LOAD spatial;")
    
    try:
        df = con.execute(big_query).df()
    except Exception as e:
        print(f"Error processing mgrs_id {mgrs_id}: {e}")
        df = pd.DataFrame()  # Optionally return an empty DataFrame on error.
    finally:
        con.close()
    
    # Write the result to a Parquet file (ensure gs:// bucket accepts concurrent writes).
    output_path = f"gs://demeter-labs/tea/planet_embeddings/{mgrs_id}.parquet"
    df.to_parquet(output_path)
    
    return mgrs_id  # Optionally return something to indicate success.

# Get all unique mgrs_ids.
mgrs_ids = mgrs_id_mapping['mgrs_id'].unique()

# Process in parallel using ProcessPoolExecutor with max 3 workers
with ProcessPoolExecutor(max_workers=4) as executor:
    # Submit all tasks.
    futures = {executor.submit(process_mgrs, mgrs_id): mgrs_id for mgrs_id in mgrs_ids}
    
    # Optionally use tqdm to track progress.
    for future in tqdm(as_completed(futures), total=len(futures)):
        try:
            result_mgrs = future.result()
            # You can log or print the processed mgrs_id if needed.
        except Exception as exc:
            mgrs_id = futures[future]
            print(f"mgrs_id {mgrs_id} generated an exception: {exc}")


  0%|          | 0/63 [00:00<?, ?it/s]

Skipping 47NKF - output already existsSkipping 47NLF - output already exists

Skipping 47NLE - output already exists
Skipping 47NKE - output already exists


  6%|▋         | 4/63 [00:00<00:01, 29.79it/s]

Skipping 47NLC - output already exists
Skipping 47NMC - output already exists
Skipping 47NME - output already exists
Skipping 47NMB - output already exists
Skipping 47NMD - output already exists


 14%|█▍        | 9/63 [00:00<00:01, 38.45it/s]

Skipping 47NNB - output already exists
Skipping 47NNC - output already exists
Skipping 47NND - output already exists
Skipping 47NNA - output already exists
Skipping 47MPV - output already exists
Skipping 47NPA - output already exists
Skipping 47NPB - output already exists
Skipping 47MPU - output already exists
Skipping 47MQU - output already exists


 29%|██▊       | 18/63 [00:00<00:00, 51.99it/s]

Skipping 47NQA - output already exists
Skipping 47NPC - output already exists
Skipping 47NLD - output already exists
Skipping 47MQV - output already exists
Skipping 47NQB - output already exists
Skipping 47MQT - output already exists
Skipping 47MRT - output already exists
Skipping 47MRV - output already exists


 41%|████▏     | 26/63 [00:00<00:00, 58.92it/s]

Skipping 47MRU - output already exists
Skipping 47NRA - output already exists
Skipping 48MSB - output already exists
Skipping 48MSC - output already exists
Skipping 48MSD - output already exists
Skipping 48MSE - output already exists
Skipping 48MTB - output already exists
Skipping 48MTD - output already exists
Skipping 48MTC - output already exists
Skipping 48MTE - output already exists


 57%|█████▋    | 36/63 [00:00<00:00, 71.13it/s]

Skipping 47MRS - output already exists
Skipping 48MUC - output already exists
Skipping 48MUB - output already exists
Skipping 48MTA - output already exists
Skipping 48MVA - output already exists
Skipping 48MVB - output already exists
Skipping 48MUA - output already exists
Skipping 48MUV - output already exists


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

 57%|█████▋    | 36/63 [00:19<00:00, 71.13it/s]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

 71%|███████▏  | 45/63 [13:17<09:37, 32.06s/it]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

 73%|███████▎  | 46/63 [16:12<11:16, 39.82s/it]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

 75%|███████▍  | 47/63 [17:00<10:45, 40.35s/it]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

 76%|███████▌  | 48/63 [19:09<12:09, 48.63s/it]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

 78%|███████▊  | 49/63 [23:45<17:35, 75.43s/it]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

 79%|███████▉  | 50/63 [32:59<31:14, 144.21s/it]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

 81%|████████  | 51/63 [35:44<29:32, 147.69s/it]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

 83%|████████▎ | 52/63 [37:49<26:16, 143.31s/it]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

 84%|████████▍ | 53/63 [40:48<25:10, 151.02s/it]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

 86%|████████▌ | 54/63 [48:13<33:09, 221.11s/it]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

 87%|████████▋ | 55/63 [54:37<34:58, 262.34s/it]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

 89%|████████▉ | 56/63 [57:06<27:05, 232.28s/it]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

 90%|█████████ | 57/63 [57:52<18:04, 180.81s/it]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

 92%|█████████▏| 58/63 [1:07:44<24:44, 296.97s/it]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

 94%|█████████▎| 59/63 [1:11:36<18:33, 278.32s/it]

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

100%|██████████| 63/63 [1:20:27<00:00, 76.63s/it] 
