## Benchmarking

* GPU: 
    * 34.45it/s    
    * 28 minutes 
* 105 CPUs: 
    * 6.89it/s   
    * 7 min 19 sec 

In [None]:
from glob import glob
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from datetime import datetime
from dateutil.parser import parse
from tqdm import tqdm
from joblib import Parallel, delayed
import ollie_clay_utils as utils
from ollie_clay_utils import Thumbnail, to_xarray, apply_model, write_embeddings
import pandas as pd
from google.cloud import bigquery
import torch
import os
from datetime import datetime, timedelta

def download_thumbnails(date):
    if not os.path.exists("data"):
        os.makedirs("data")

    if not os.path.exists("data/"+date):
        os.makedirs("data/"+date)

    # Download thumbnails
    cmd = "gsutil -m cp -r gs://gfw-sentinel2-thumbnails-us-central1/sentinel2_world_v20230811/{} data/{}".format(date, date)
    print(cmd)
    os.system(cmd)

    print("Thumbnails downloaded successfully")

def upload_to_bq(df, dataset_id, table_base_name, date_suffix):
    """
    Upload a dataframe to a BigQuery date-sharded table.

    Parameters:
        df (pandas.DataFrame): DataFrame to upload.
        dataset_id (str): BigQuery dataset ID.
        table_base_name (str): Base name of the table (without the date suffix).
        date_suffix (str): Date suffix in 'YYYYMMDD' format.
    """
    client = bigquery.Client(project='world-fishing-827')
    dataset_ref = client.dataset(dataset_id)
    
    # Construct the full table ID with the date suffix
    table_id = f"{table_base_name}{date_suffix}"
    table_ref = dataset_ref.table(table_id)
    
    job_config = bigquery.LoadJobConfig()
    job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE
    job_config.schema = [
        bigquery.SchemaField("detect_id", "STRING"),
        bigquery.SchemaField("embedding", "FLOAT", mode="REPEATED"),
    ]

    # Upload the DataFrame to the specified table
    job = client.load_table_from_dataframe(df, table_ref, job_config=job_config)
    job.result()  # Wait for the job to complete

    print(f"Loaded {len(df)} rows into {dataset_id}.{table_id}")


def check_bq_table_exists(dataset_id, table_id):
    """
    Check if a BigQuery table exists.

    Parameters:
        dataset_id (str): BigQuery dataset ID.
        table_id (str): BigQuery table ID.

    Returns:
        bool: True if the table exists, False otherwise.
    """
    client = bigquery.Client(project='world-fishing-827')
    dataset_ref = client.dataset(dataset_id)
    table_ref = dataset_ref.table(table_id)

    try:
        client.get_table(table_ref)
        return True
    except:
        return False

def get_embeddings_from_png(file, model, device=torch.device("cpu")):
    try:
        thumbnail = Thumbnail.load_s2_png(file)
        detect_id = thumbnail.detect_id
        embedding = apply_model(model, device, thumbnail)
        row = pd.DataFrame({'detect_id': detect_id, 'embedding': [embedding.flatten()]})
        return row
    except Exception as e:
        print(f"Error processing detection {file}: {e}")
        return None
    
def batch_get_embeddings_from_pngs(batch, model):
    rows = []
    for file in tqdm(batch):
        row = get_embeddings_from_png(file, model)
        if row is not None:
            rows.append(row)
    return pd.concat(rows, ignore_index=True)

def run_s2_pipeline(date, n_jobs=-1):
    files = glob(f'data/{date}/*RGB*.png')
    df=pd.DataFrame()

    if n_jobs==1:
        device=torch.device("cuda")
        model, _ = utils.load_model()
        model.batch_first=True
        for file in tqdm(files):
            row=get_embeddings_from_png(file, model, device)
            df=pd.concat([df, row], ignore_index=True)
    else:
        device=torch.device("cpu")
        model, _ = utils.load_model()
        #model.batch_first=True
        indices = list(range(len(files)))
        batch_size =  len(indices) // n_jobs
        batches=[files[i:i + batch_size] for i in range(0, len(indices), batch_size)]
        rows = Parallel(n_jobs=n_jobs)(delayed(batch_get_embeddings_from_pngs)(batch, model) for batch in batches)
        df = pd.concat(rows, ignore_index=True)
    
    upload_to_bq(df, 'scratch_ollie', 'sentinel2_world_v20230811_embeddings_', date)
    cmd = f"rm -r data/{date}"
    os.system(cmd)
    print(f"Finished processing {date}")

def run_year(year):
    start = datetime(year, 1, 1)
    end = datetime(year, 12, 31)
    delta = end - start
    dates = [start + timedelta(days=i) for i in range(delta.days + 1)]
    dates = [d.strftime('%Y%m%d') for d in dates]
    for date in tqdm(dates, desc=f"Processing {year}"): 
        skip=check_bq_table_exists('scratch_ollie', f'sentinel2_world_v20230811_embeddings_{date}')
        if skip:
            print(f"{date} already processed")
        else:
            try:
                download_thumbnails(date)
                run_s2_pipeline(date, n_jobs=104)
            except Exception as e:
                print(f"Error processing {date}: {e}")



Processing 2024:   0%|          | 1/366 [00:00<04:58,  1.22it/s]

20240101 already processed
gsutil -m cp -r gs://gfw-sentinel2-thumbnails-us-central1/sentinel2_world_v20230811/20240102 data/20240102


Copying gs://gfw-sentinel2-thumbnails-us-central1/sentinel2_world_v20230811/20240102/S2A_MSIL1C_20240102T000011_N0510_R073_T57PUT_20240102T010544;158.1520210;15.9431820_NIR.png...
Copying gs://gfw-sentinel2-thumbnails-us-central1/sentinel2_world_v20230811/20240102/S2A_MSIL1C_20240102T000011_N0510_R073_T57PUT_20240102T010544;158.1520210;15.9431820_RGB.png...
Copying gs://gfw-sentinel2-thumbnails-us-central1/sentinel2_world_v20230811/20240102/S2A_MSIL1C_20240102T000011_N0510_R073_T57PVT_20240102T010544;158.1049910;15.9512180_NIR.png...
Copying gs://gfw-sentinel2-thumbnails-us-central1/sentinel2_world_v20230811/20240102/S2A_MSIL1C_20240102T000011_N0510_R073_T57PVT_20240102T010544;158.1049910;15.9512180_RGB.png...
Copying gs://gfw-sentinel2-thumbnails-us-central1/sentinel2_world_v20230811/20240102/S2A_MSIL1C_20240102T000011_N0510_R073_T57PVT_20240102T010544;158.1356590;15.9461000_NIR.png...
Copying gs://gfw-sentinel2-thumbnails-us-central1/sentinel2_world_v20230811/20240102/S2A_MSIL1C_2024

In [8]:
import time 
n = 21348568 #+ 21382497 + 21801647 + 21358629 + 21958098 + 21958098
it_s=1.1#6.89
#it_s=34.45 
total_s=n/it_s
cores=100

t=total_s/60/60/24/cores



0.3586210557436973

In [3]:
from datetime import datetime, timedelta

def run_year(year):
    start = datetime(year, 1, 1)
    end = datetime(year, 12, 31)
    delta = end - start
    dates = [start + timedelta(days=i) for i in range(delta.days + 1)]
    dates = [d.strftime('%Y%m%d') for d in dates]
    for date in tqdm(dates, desc=f"Processing {year}"): 
        skip=check_bq_table_exists('scratch_ollie', f'sentinel2_world_v20230811_embeddings_{date}')
        if skip:
            print(f"{date} already processed")
        else:
            try:
                download_thumbnails(date)
                run_s2_pipeline(date, n_jobs=104)
            except Exception as e:
                print(f"Error processing {date}: {e}")

#run_year(2024)

20240101
20240102
20240103
20240104
20240105
20240106
20240107
20240108
20240109
20240110
20240111
20240112
20240113
20240114
20240115
20240116
20240117
20240118
20240119
20240120
20240121
20240122
20240123
20240124
20240125
20240126
20240127
20240128
20240129
20240130
20240131
20240201
20240202
20240203
20240204
20240205
20240206
20240207
20240208
20240209
20240210
20240211
20240212
20240213
20240214
20240215
20240216
20240217
20240218
20240219
20240220
20240221
20240222
20240223
20240224
20240225
20240226
20240227
20240228
20240229
20240301
20240302
20240303
20240304
20240305
20240306
20240307
20240308
20240309
20240310
20240311
20240312
20240313
20240314
20240315
20240316
20240317
20240318
20240319
20240320
20240321
20240322
20240323
20240324
20240325
20240326
20240327
20240328
20240329
20240330
20240331
20240401
20240402
20240403
20240404
20240405
20240406
20240407
20240408
20240409
20240410
20240411
20240412
20240413
20240414
20240415
20240416
20240417
20240418
20240419
20240420
2

NameError: name 'download_thumbnails' is not defined