In [0]:
from pathlib import Path
import os

pathP = Path(os.environ.get('PREDICT_DIR', '/dbfs/mnt/lab/unrestricted/KritiM'))


In [0]:
import os
from pathlib import Path
import numpy as np
import rasterio
from scipy import stats
from tqdm import tqdm

In [0]:
# Discover all Predict_* folders in pathP
predict_folders = [f for f in pathP.iterdir() if f.is_dir() and f.name.lower().startswith('predict_')]
print(f"Discovered Predict folders: {[str(f) for f in predict_folders]}")

# Set output directory within pathP
output_dir = pathP / 'outputMaps'a

In [0]:
# Find all unique grid names as the first five characters of each *_predict*.tif file
grid_names = set()
for folder in predict_folders:
    for tif in folder.glob('*_predict*.tif'):
        grid_names.add(tif.name[:5])

grid_names = sorted(grid_names)

print(f"Found {len(grid_names)} grids to process.")

In [0]:
import tempfile
import shutil

def process_grid(grid, predict_folders, output_dir):
    print(f"\nProcessing grid: {grid}")
    predict_paths = []
    conf_paths = []
    for folder in predict_folders:
        pred = folder / f"{grid}_predict_xgb.tif"
        conf = folder / f"{grid}_confidence_xgb.tif"
        if pred.exists() and conf.exists():
            predict_paths.append(pred)
            conf_paths.append(conf)
    if len(predict_paths) != 5 or len(conf_paths) != 5:
        print(f"  Skipping {grid}: found {len(predict_paths)} predictions and {len(conf_paths)} confidences.")
        return
    print(f"  Loading rasters for grid {grid}...")
    # Load and stack predictions and confidences
    preds = []
    confs = []
    for p, c in zip(predict_paths, conf_paths):
        with rasterio.open(p) as src:
            preds.append(src.read(1))
            meta = src.meta.copy()
        with rasterio.open(c) as src:
            confs.append(src.read(1))
    preds = np.stack(preds, axis=0)  # shape: (5, H, W)
    confs = np.stack(confs, axis=0)  # shape: (5, H, W)

    # Mask invalid pixels (assume nodata is set in meta)
    nodata = meta.get('nodata', None)
    if nodata is not None:
        mask = np.all(preds == nodata, axis=0)
    else:
        mask = np.zeros(preds.shape[1:], dtype=bool)

    # Prepare output arrays
    out_pred = np.full(preds.shape[1:], nodata if nodata is not None else 0, dtype=preds.dtype)
    out_conf = np.full(confs.shape[1:], 0, dtype=confs.dtype)

    print(f"  Running ensemble logic for grid {grid}...")
    H, W = preds.shape[1:]
    for i in tqdm(range(H), desc=f"Grid {grid}"):
        for j in range(W):
            if mask[i, j]:
                continue  # all nodata, leave as is
            pred_vals = preds[:, i, j]
            conf_vals = confs[:, i, j]
            valid = pred_vals != (nodata if nodata is not None else 0)
            if not np.any(valid):
                continue  # all invalid, leave as is
            pred_vals = pred_vals[valid]
            conf_vals = conf_vals[valid]
            # Find mode(s)
            mode_result = stats.mode(pred_vals, keepdims=True)
            mode_val = mode_result.mode[0]
            mode_count = mode_result.count[0]
            # Check for ties
            unique, counts = np.unique(pred_vals, return_counts=True)
            max_count = np.max(counts)
            tied_classes = unique[counts == max_count]
            if len(tied_classes) == 1:
                # Single mode
                class_mask = pred_vals == tied_classes[0]
                conf_subset = conf_vals[class_mask]
                if conf_subset.size > 0:
                    best_idx = np.argmax(conf_subset)
                    best_class = tied_classes[0]
                    best_conf = conf_subset[best_idx]
                else:
                    # Fallback: all confs are nan or empty
                    best_class = tied_classes[0]
                    best_conf = np.nan
            else:
                # Tie for mode
                # Among tied, pick the one with highest confidence
                best_conf = -np.inf
                best_class = tied_classes[0]
                found = False
                for cls in tied_classes:
                    class_mask = pred_vals == cls
                    conf_subset = conf_vals[class_mask]
                    if conf_subset.size > 0:
                        max_conf = np.max(conf_subset)
                        if max_conf > best_conf:
                            best_conf = max_conf
                            best_class = cls
                        found = True
                if not found:
                    # Fallback: all confs are nan or empty
                    best_class = tied_classes[0]
                    best_conf = np.nan
                else:
                    # If still tied in confidence, pick first occurrence
                    tied_conf = [np.max(conf_vals[pred_vals == cls]) if np.any(pred_vals == cls) else -np.inf for cls in tied_classes]
                    if tied_conf.count(best_conf) > 1:
                        for idx, cls in enumerate(pred_vals):
                            if cls in tied_classes and conf_vals[idx] == best_conf:
                                best_class = cls
                                break
            out_pred[i, j] = best_class
            out_conf[i, j] = best_conf

    print(f"  Saving results for grid {grid} in {output_dir} ...")
    out_pred_path = output_dir / f"{grid}_ensemble_predict.tif"
    out_conf_path = output_dir / f"{grid}_ensemble_confidence.tif"

    # Write to temporary directory first, then move to output_dir
    tmp_dir = tempfile.mkdtemp()
    try:
        tmp_pred = os.path.join(tmp_dir, f'{grid}_ensemble_predict.tif')
        tmp_conf = os.path.join(tmp_dir, f'{grid}_ensemble_confidence.tif')
        meta.update(dtype=out_pred.dtype, count=1)
        
        with rasterio.open(tmp_pred, 'w', **meta) as dst:
            dst.write(out_pred, 1)
        meta.update(dtype=out_conf.dtype, count=1)
        
        with rasterio.open(tmp_conf, 'w', **meta) as dst:
            dst.write(out_conf, 1)
        
        shutil.copy2(tmp_pred, out_pred_path)
        shutil.copy2(tmp_conf, out_conf_path)
    finally:
        shutil.rmtree(tmp_dir)
        

    print(f"  Finished processing grid: {grid}. Output saved to {output_dir}")


In [0]:
# Process all grids
total_grids = len(grid_names)
print(f"\nStarting processing of {total_grids} grids...")
for grid in grid_names:
    process_grid(grid, predict_folders, output_dir)
    
        
print("\nAll grids processed.")