In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from datetime import datetime
import logging
import os
from pathlib import Path
import sys

import geopandas as gpd
import tensorflow as tf
import torch

import gee
from inference_pipeline import get_outpath
import tile_utils

In [None]:
region_name = 'test_region'
region_path = f'../data/boundaries/{region_name}.geojson'

model_name = '48px_v4.0ep300_2025-08-29'
#model_name = '48px_v0.2SSL4EO-MLPep70_2025-09-12'
model_path = f'../checkpts-tmp/{model_name}.h5'

embed_model = None
#embed_model = torch.load(gee.SSL4EO_PATH, weights_only=False)


start_date='2024-01-01'
end_date='2024-12-31'
    
data_config = gee.DataConfig(
    tilesize=576,
    pad=24,
    collection="S2L1C",
    clear_threshold=0.6,
    max_workers=8, # Turn this down from 8 if cacheing, to ease memory pressure 
)

inference_config = gee.InferenceConfig(
    pred_threshold=0.8,
    stride_ratio=2, 
    geo_chip_size=48,
    #cache_dir='cache-tmp'
)

data_extractor = gee.GEE_Data_Extractor(
    start_date, 
    end_date, 
    data_config
)


In [None]:
model = tf.keras.models.load_model(model_path, compile=False)
region = gpd.read_file(region_path).geometry[0].__geo_interface__
tiles = tile_utils.create_tiles(region, data_config.tilesize, data_config.pad)
print(f"Created {len(tiles):,} tiles")

logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

engine = gee.InferenceEngine(
        data_extractor=data_extractor,
        model=model,
        config=inference_config,
        embed_model=embed_model,
        logger=logger
    )

In [None]:
# Single tile inference for debugging
preds, _ = engine.predict_on_tile(tiles[0])
preds

In [None]:
outpath = get_outpath(
    Path(model_path),
    Path(region_path),
    start_date,
    end_date,
    inference_config.pred_threshold
)
outpath

In [None]:
preds = engine.make_predictions(tiles, outpath)

In [None]:
len(preds[preds.confidence > 0.91])