In [None]:
from datetime import datetime
import logging
import os
from pathlib import Path
import sys

import geopandas as gpd
import tensorflow as tf
import torch

import gee
from inference_pipeline import get_outpath
import tile_utils

%load_ext autoreload
%autoreload 2

In [None]:
region_name = 'test_region_small'
region_path = f'../data/boundaries/{region_name}.geojson'

model_name = '48px_v0.46891.0-1.3SSL4EO-MLPensemble_2025-10-21'
model_path = f'../checkpts-tmp/{model_name}.h5'
model = tf.keras.models.load_model(model_path, compile=False)

#embed_model = None
embed_model = torch.load(gee.SSL4EO_PATH, weights_only=False)

start_date='2025-07-01'
end_date='2025-09-30'
    
data_config = gee.DataConfig(
    tilesize=576,
    pad=0, 
    collection="S2L1C",
    clear_threshold=0.6,
    max_workers=4, # Turn this down from 8 if cacheing, to ease memory pressure 
    image_cache_dir=None,
)

inference_config = gee.InferenceConfig(
    pred_threshold=0.85,
    stride_ratio=2, 
    geo_chip_size=48,
    embeddings_cache_dir=None,
)

data_extractor = gee.GEE_Data_Extractor(
    start_date, 
    end_date, 
    data_config
)


In [None]:
region = gpd.read_file(region_path).geometry.__geo_interface__
tiles = tile_utils.create_tiles(region, data_config.tilesize, data_config.pad)
print(f"Created {len(tiles):,} tiles")

logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

engine = gee.InferenceEngine(
        data_extractor=data_extractor,
        model=model,
        config=inference_config,
        embed_model=embed_model,
        logger=logger
    )

In [None]:
# Single tile inference for debugging
preds_gdf = engine.predict_on_tile(tiles[11])
display(preds_gdf)

In [None]:
outpath = get_outpath(
    Path(model_path),
    Path(region_path),
    start_date,
    end_date,
    inference_config.pred_threshold
)
outpath

In [None]:
preds = engine.bulk_predict(tiles, outpath)

In [None]:
preds[preds.confidence > 0.9].explore()