In [None]:
import numpy as np
import joblib
from read_tif import read_tif
from write_tif import write_tif


def predict_tif(input_tif, model_path, scaler_path, output_tif):
    """
    Predict class/probability for a GeoTIFF image and save as new GeoTIFF.

    Parameters
    ----------
    input_tif : str
        Path to input GeoTIFF (multi-band).
    model_path : str
        Path to trained ML model (.pkl).
    scaler_path : str
        Path to trained scaler (.pkl).
    output_tif : str
        Path to save prediction result.
    """

    # Load model & scaler
    model = joblib.load(model_path)
    scaler = joblib.load(scaler_path)

    # Read input image
    data, X, Y, geo, proj, bands = read_tif(input_tif)
    print(f"Input raster shape: {data.shape} (bands, rows, cols)")

    # Reshape data (bands, rows, cols) -> (pixels, bands)
    if data.ndim == 3:
        data_reshaped = data.reshape(bands, -1).T
    else:  # single-band fallback
        data_reshaped = data.reshape(-1, 1)

    # Scale features
    data_scaled = scaler.transform(data_reshaped)

    # Predict
    y_pred = model.predict(data_scaled)

    # Reshape back to raster
    pred_map = y_pred.reshape(Y, X).astype(np.float32)

    # Save prediction as GeoTIFF
    write_tif(pred_map, geo, proj, output_tif)
    print(f"Prediction saved: {output_tif}")


# ========= Example usage =========
if __name__ == "__main__":
    predict_tif(
        input_tif=r".tif", # your multi-band input image
        model_path=r".pkl", # trained model
        scaler_path=r"r.pkl", # trained scaler
        output_tif=r".tif"  # save prediction
    )