In [None]:
# ===================== OPTIONAL INSTALLS (uncomment if needed) =====================
# !pip install rasterio --quiet
# !pip install fiona --quiet
# !pip install geopandas --quiet

# ===================== IMPORTS =====================
import os
import json
import glob
import cv2
import numpy as np
from tqdm import tqdm
import rasterio
from rasterio.windows import Window
from rasterio.features import shapes as rio_shapes
from shapely.geometry import shape, mapping
from skimage.morphology import remove_small_objects, opening, closing, disk

import tensorflow as tf
from tensorflow.keras.layers import (
    Conv2D, concatenate, Input, Dropout, MaxPooling2D, Conv2DTranspose
)
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from sklearn.model_selection import train_test_split

# ===================== CONFIG =====================
# Paths and filenames
CKPT_DIR = "checkpoints"
os.makedirs(CKPT_DIR, exist_ok=True)
BEST_MODEL_PATH = os.path.join(CKPT_DIR, "best_change_detection_unet_model.keras")

OUT_DIR = "outputs"
os.makedirs(OUT_DIR, exist_ok=True)
OUT_PROB = os.path.join(OUT_DIR, "change_prob_improved.tif")
OUT_MASK = os.path.join(OUT_DIR, "change_mask_improved.tif")
OUT_VECTOR_SHP = os.path.join(OUT_DIR, "change_polygons.shp")
OUT_VECTOR_GEOJSON = os.path.join(OUT_DIR, "change_polygons.geojson")

BEFORE_TIF = "/kaggle/input/testtiff12/processed_img1.tif"
AFTER_TIF = "/kaggle/input/testtiff12/processed_img2.tif"

# Model & inference parameters
INPUT_SHAPE = (256, 256, 6)
PATCH_H = 256
PATCH_W = 256
OVERLAP = 64
BATCH_SIZE = 8
THRESH = 0.50
USE_TTA = True

# Postprocessing
MIN_CC_SIZE = 256
MORPH_OPEN_DISK = 2
MORPH_CLOSE_DISK = 3

# ===================== METRIC =====================
# Name this iou so ModelCheckpoint can monitor "val_iou"
def iou(y_true, y_pred, smooth=1e-6):
    y_pred_bin = tf.cast(y_pred > 0.5, tf.float32)
    y_true_f = tf.cast(y_true, tf.float32)
    intersection = tf.reduce_sum(y_true_f * y_pred_bin)
    union = tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_bin) - intersection
    return (intersection + smooth) / (union + smooth)

# ===================== MODEL BUILD =====================
def build_unet_model(input_size=INPUT_SHAPE):
    inputs = Input(input_size)
    input_image1 = inputs[..., :3]
    input_image2 = inputs[..., 3:]

    # Encoder 1
    c1_1 = Conv2D(64, 3, activation='relu', padding='same')(input_image1)
    c1_1 = Conv2D(64, 3, activation='relu', padding='same')(c1_1)
    c1_1 = Dropout(0.3)(c1_1)
    p1_1 = MaxPooling2D((2, 2))(c1_1)

    c2_1 = Conv2D(128, 3, activation='relu', padding='same')(p1_1)
    c2_1 = Conv2D(128, 3, activation='relu', padding='same')(c2_1)
    c2_1 = Dropout(0.3)(c2_1)
    p2_1 = MaxPooling2D((2, 2))(c2_1)

    c3_1 = Conv2D(256, 3, activation='relu', padding='same')(p2_1)
    c3_1 = Conv2D(256, 3, activation='relu', padding='same')(c3_1)
    c3_1 = Dropout(0.3)(c3_1)
    p3_1 = MaxPooling2D((2, 2))(c3_1)

    # Encoder 2 (same)
    c1_2 = Conv2D(64, 3, activation='relu', padding='same')(input_image2)
    c1_2 = Conv2D(64, 3, activation='relu', padding='same')(c1_2)
    c1_2 = Dropout(0.3)(c1_2)
    p1_2 = MaxPooling2D((2, 2))(c1_2)

    c2_2 = Conv2D(128, 3, activation='relu', padding='same')(p1_2)
    c2_2 = Conv2D(128, 3, activation='relu', padding='same')(c2_2)
    c2_2 = Dropout(0.3)(c2_2)
    p2_2 = MaxPooling2D((2, 2))(c2_2)

    c3_2 = Conv2D(256, 3, activation='relu', padding='same')(p2_2)
    c3_2 = Conv2D(256, 3, activation='relu', padding='same')(c3_2)
    c3_2 = Dropout(0.3)(c3_2)
    p3_2 = MaxPooling2D((2, 2))(c3_2)

    # Bottleneck
    c4_1 = Conv2D(512, 3, activation='relu', padding='same')(p3_1)
    c4_1 = Conv2D(512, 3, activation='relu', padding='same')(c4_1)
    c4_1 = Dropout(0.4)(c4_1)

    c4_2 = Conv2D(512, 3, activation='relu', padding='same')(p3_2)
    c4_2 = Conv2D(512, 3, activation='relu', padding='same')(c4_2)
    c4_2 = Dropout(0.4)(c4_2)

    c4 = concatenate([c4_1, c4_2])

    # Decoder
    u5 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c4)
    u5 = concatenate([u5, c3_1, c3_2])
    c5 = Conv2D(256, 3, activation='relu', padding='same')(u5)
    c5 = Dropout(0.3)(c5)
    c5 = Conv2D(256, 3, activation='relu', padding='same')(c5)

    u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = concatenate([u6, c2_1, c2_2])
    c6 = Conv2D(128, 3, activation='relu', padding='same')(u6)
    c6 = Dropout(0.3)(c6)
    c6 = Conv2D(128, 3, activation='relu', padding='same')(c6)

    u7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = concatenate([u7, c1_1, c1_2])
    c7 = Conv2D(64, 3, activation='relu', padding='same')(u7)
    c7 = Dropout(0.3)(c7)
    c7 = Conv2D(64, 3, activation='relu', padding='same')(c7)

    outputs = Conv2D(1, (1, 1), activation='sigmoid')(c7)
    model = Model(inputs, outputs)
    return model

# Build & compile
model_unet = build_unet_model(INPUT_SHAPE)
model_unet.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
                   loss='binary_crossentropy',
                   metrics=[iou, 'accuracy'])
print("Model built. Input shape:", model_unet.input_shape)

# ===================== DATA LOADING (LEVIR-CD example) =====================
# Adjust these paths to your dataset location if necessary
image1_train_dir = "/kaggle/input/levir-cd/LEVIR CD/train/A"
image2_train_dir = "/kaggle/input/levir-cd/LEVIR CD/train/B"
mask_train_dir   = "/kaggle/input/levir-cd/LEVIR CD/train/label"

image1_test_dir = "/kaggle/input/levir-cd/LEVIR CD/test/A"
image2_test_dir = "/kaggle/input/levir-cd/LEVIR CD/test/B"
mask_test_dir   = "/kaggle/input/levir-cd/LEVIR CD/test/label"

RESIZE_SHAPE = (256, 256)

def load_images(image1_dir, image2_dir, mask_dir):
    files1 = sorted(os.listdir(image1_dir))
    files2 = sorted(os.listdir(image2_dir))
    filesm = sorted(os.listdir(mask_dir))
    X, y = [], []
    for f1, f2, fm in zip(files1, files2, filesm):
        i1 = cv2.imread(os.path.join(image1_dir, f1))
        i2 = cv2.imread(os.path.join(image2_dir, f2))
        m  = cv2.imread(os.path.join(mask_dir, fm), cv2.IMREAD_GRAYSCALE)
        i1 = cv2.resize(i1, RESIZE_SHAPE)
        i2 = cv2.resize(i2, RESIZE_SHAPE)
        m  = cv2.resize(m, RESIZE_SHAPE)
        i1 = i1.astype(np.float32)/255.0
        i2 = i2.astype(np.float32)/255.0
        m  = (m.astype(np.float32)/255.0).astype(np.float32)
        stacked = np.concatenate([i1, i2], axis=-1)  # (H,W,6)
        X.append(stacked)
        y.append(m[..., np.newaxis])  # keep channel dim (H,W,1)
    return np.array(X), np.array(y)

# Load (this may take time)
X_all, y_all = load_images(image1_train_dir, image2_train_dir, mask_train_dir)
X_test, y_test = load_images(image1_test_dir, image2_test_dir, mask_test_dir)

# Train/val split
X_train, X_val, y_train, y_val = train_test_split(X_all, y_all, test_size=0.2, random_state=42)
print("Shapes:", X_train.shape, X_val.shape, X_test.shape)

# ===================== CALLBACKS & TRAIN =====================

callbacks = [
    EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True, verbose=1),
    ModelCheckpoint(BEST_MODEL_PATH, monitor="val_loss", save_best_only=True, verbose=1),
    ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=3, verbose=1)
]

history_unet = model_unet.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    batch_size=BATCH_SIZE,
    epochs=20,
    verbose=1,
    callbacks=callbacks
)

# Show history keys and checkpoint dir for debugging
print("history keys:", list(history_unet.history.keys()))
print("checkpoint files:", os.listdir(CKPT_DIR))

# ===================== ENSURE & LOAD BEST MODEL =====================
def ensure_and_load_best_model(in_memory_model, path=BEST_MODEL_PATH):
    if os.path.exists(path):
        try:
            print("Loading saved model from:", path)
            m = tf.keras.models.load_model(path, compile=False)
            print("Loaded model from checkpoint.")
            return m
        except Exception as e:
            print("Warning: load_model failed:", e)
    # Fallback: save current in-memory model (EarlyStopping restored best weights) then load
    print("Saving current in-memory model to:", path)
    in_memory_model.save(path, include_optimizer=False)
    return tf.keras.models.load_model(path, compile=False)

MODEL_UNET_BEST = ensure_and_load_best_model(model_unet, BEST_MODEL_PATH)
print("MODEL_UNET_BEST input shape:", MODEL_UNET_BEST.input_shape)

# Evaluate using the best model (safer)
try:
    loss, acc = MODEL_UNET_BEST.evaluate(X_test, y_test, batch_size=BATCH_SIZE, verbose=1)
    print("Test loss, accuracy (best model):", loss, acc)
except Exception as e:
    print("Evaluation failed:", e)

# ===================== INFERENCE HELPERS =====================
def read_window(src, col_off, row_off, w, h, bands=None):
    if bands is None:
        bands = src.count
    band_indices = list(range(1, bands + 1))
    arr = src.read(band_indices, window=Window(col_off, row_off, w, h))
    return np.moveaxis(arr, 0, -1)

def infer_bands_from_model(model):
    try:
        inp_shape = model.input_shape
        if inp_shape is None or len(inp_shape) != 4:
            return None
        c = inp_shape[-1]
        if c % 2 == 0:
            return c // 2
        return None
    except Exception:
        return None

def gaussian_window(h, w, sigma_scale=0.125):
    yy = np.linspace(-1, 1, h)
    xx = np.linspace(-1, 1, w)
    xv, yv = np.meshgrid(xx, yy)
    sigma = sigma_scale
    g = np.exp(- (xv**2 + yv**2) / (2 * sigma**2))
    g = g / g.max()
    return g.astype(np.float32)

def apply_tta_predict(model, batch):
    preds = []
    p0 = model.predict(batch, verbose=0)
    p0 = np.squeeze(p0, axis=-1) if (p0.ndim == 4 and p0.shape[-1] == 1) else p0
    preds.append(p0)
    batch_h = batch[:, :, ::-1, :]
    p_h = model.predict(batch_h, verbose=0)
    p_h = np.squeeze(p_h, axis=-1) if (p_h.ndim == 4 and p_h.shape[-1] == 1) else p_h
    preds.append(p_h[:, :, ::-1])
    batch_v = batch[:, ::-1, :, :]
    p_v = model.predict(batch_v, verbose=0)
    p_v = np.squeeze(p_v, axis=-1) if (p_v.ndim == 4 and p_v.shape[-1] == 1) else p_v
    preds.append(p_v[:, ::-1, :])
    batch_hv = batch_h[:, ::-1, :, :]
    p_hv = model.predict(batch_hv, verbose=0)
    p_hv = np.squeeze(p_hv, axis=-1) if (p_hv.ndim == 4 and p_hv.shape[-1] == 1) else p_hv
    preds.append(p_hv[:, ::-1, ::-1])
    return np.mean(preds, axis=0)

# ===================== TILED INFERENCE + VECTORIZE =====================
def predict_full_tif_with_model_improved(
        model,
        before_path,
        after_path,
        out_prob_path,
        out_mask_path,
        out_vector_shp_path=None,
        out_vector_geojson_path=None,
        patch_h=None,
        patch_w=None,
        overlap=64,
        threshold=0.65,
        batch_size=8,
        use_tta=True,
        min_cc_size=256,
        morph_open_disk=2,
        morph_close_disk=3
    ):
    if getattr(model, "input_shape", None) is None:
        raise RuntimeError("Model has no input_shape.")
    inp_shape = model.input_shape
    model_H, model_W, model_C = inp_shape[1], inp_shape[2], inp_shape[3]
    if patch_h is None: patch_h = model_H
    if patch_w is None: patch_w = model_W

    inferred_bands = infer_bands_from_model(model)

    if out_vector_shp_path is None:
        out_vector_shp_path = OUT_VECTOR_SHP
    if out_vector_geojson_path is None:
        out_vector_geojson_path = OUT_VECTOR_GEOJSON

    with rasterio.open(before_path) as sb, rasterio.open(after_path) as sa:
        assert sb.width == sa.width and sb.height == sa.height, "Before/After dims mismatch"
        width, height = sb.width, sb.height
        profile = sb.profile.copy()
        profile.update(count=1, dtype="float32", compress="lzw")

        prob_sum = np.zeros((height, width), dtype=np.float32)
        weight_sum = np.zeros((height, width), dtype=np.float32)

        step_x = patch_w - overlap
        step_y = patch_h - overlap
        xs = list(range(0, max(1, width - overlap), step_x))
        ys = list(range(0, max(1, height - overlap), step_y))
        if xs[-1] + patch_w < width:
            xs.append(max(0, width - patch_w))
        if ys[-1] + patch_h < height:
            ys.append(max(0, height - patch_h))
        tiles = [(x, y) for y in ys for x in xs]

        gw = gaussian_window(patch_h, patch_w)

        print(f"üß© Raster: {width}x{height}, tiles: {len(tiles)}, patch: {patch_h}x{patch_w}, overlap: {overlap}")

        batches, batch_meta = [], []

        for (x_off, y_off) in tqdm(tiles, desc="Processing tiles"):
            w = min(patch_w, width - x_off)
            h = min(patch_h, height - y_off)

            b_patch = read_window(sb, x_off, y_off, w, h, bands=inferred_bands or sb.count)
            a_patch = read_window(sa, x_off, y_off, w, h, bands=inferred_bands or sa.count)
            if b_patch.shape[-1] != a_patch.shape[-1]:
                raise RuntimeError(f"Band mismatch: {b_patch.shape[-1]} vs {a_patch.shape[-1]}")

            pad_w, pad_h = (patch_w - w), (patch_h - h)
            if pad_w or pad_h:
                b_patch = np.pad(b_patch, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
                a_patch = np.pad(a_patch, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")

            inp = np.concatenate([b_patch, a_patch], axis=-1).astype(np.float32) / 255.0
            if inp.shape[-1] != model.input_shape[-1]:
                raise RuntimeError(f"Expected {model.input_shape[-1]} channels, got {inp.shape[-1]}.")

            batches.append(inp)
            batch_meta.append((x_off, y_off, h, w))

            if len(batches) >= batch_size:
                batch_arr = np.stack(batches, axis=0)
                preds = apply_tta_predict(model, batch_arr) if use_tta else np.squeeze(model.predict(batch_arr, verbose=0), axis=-1)
                for i, (x_i, y_i, h_i, w_i) in enumerate(batch_meta):
                    p = preds[i][:h_i, :w_i]
                    gw_crop = gw[:h_i, :w_i]
                    prob_sum[y_i:y_i+h_i, x_i:x_i+w_i] += (p * gw_crop)
                    weight_sum[y_i:y_i+h_i, x_i:x_i+w_i] += gw_crop
                batches, batch_meta = [], []

        # leftover
        if batches:
            batch_arr = np.stack(batches, axis=0)
            preds = apply_tta_predict(model, batch_arr) if use_tta else np.squeeze(model.predict(batch_arr, verbose=0), axis=-1)
            for i, (x_i, y_i, h_i, w_i) in enumerate(batch_meta):
                p = preds[i][:h_i, :w_i]
                gw_crop = gw[:h_i, :w_i]
                prob_sum[y_i:y_i+h_i, x_i:x_i+w_i] += (p * gw_crop)
                weight_sum[y_i:y_i+h_i, x_i:x_i+w_i] += gw_crop

        # normalize (avoid division by zero)
        final_prob = np.zeros_like(prob_sum, dtype=np.float32)
        nonzero = weight_sum > 0
        final_prob[nonzero] = prob_sum[nonzero] / weight_sum[nonzero]

        # write probability raster
        with rasterio.open(out_prob_path, "w", **profile) as dst:
            dst.write(final_prob.astype(np.float32), 1)

        # threshold + morphology + small objects removal
        mask = (final_prob > threshold).astype(bool)
        if morph_open_disk > 0: mask = opening(mask, disk(morph_open_disk))
        if morph_close_disk > 0: mask = closing(mask, disk(morph_close_disk))
        if min_cc_size > 0: mask = remove_small_objects(mask, min_size=min_cc_size)
        mask_final = (mask.astype(np.uint8) * 255)

        mask_profile = profile.copy()
        mask_profile.update(dtype="uint8", count=1)
        with rasterio.open(out_mask_path, "w", **mask_profile) as dst:
            dst.write(mask_final, 1)

        # Vectorize mask
        features = []
        transform = sb.transform
        src_crs = sb.crs
        bin_mask = (mask_final != 0).astype(np.uint8)

        for geom, val in rio_shapes(bin_mask, transform=transform):
            if int(val) == 0: 
                continue
            geom_shp = shape(geom)
            if geom_shp.is_valid and not geom_shp.is_empty:
                pix_w = abs(transform.a) if hasattr(transform, "a") else None
                pix_h = abs(transform.e) if hasattr(transform, "e") else None
                pixel_area = (pix_w * pix_h) if (pix_w is not None and pix_h is not None) else None
                pixel_count = int(round(geom_shp.area / pixel_area)) if pixel_area and pixel_area > 0 else -1
                geom_area = float(geom_shp.area)
                features.append({
                    "type": "Feature",
                    "geometry": mapping(geom_shp),
                    "properties": {"value": int(val), "pixel_count": pixel_count, "area_map_units": geom_area}
                })

        shp_written = None
        geojson_written_path = None

        # Try writing Shapefile via geopandas
        try:
            import geopandas as gpd
            geojson_fc = {"type": "FeatureCollection", "features": features}

            crs_for_gdf = None
            prj_wkt = None
            if src_crs is not None:
                try:
                    crs_for_gdf = src_crs.to_dict()
                except Exception:
                    try:
                        prj_wkt = src_crs.to_wkt() if hasattr(src_crs, "to_wkt") else None
                        crs_for_gdf = src_crs.to_string() if hasattr(src_crs, "to_string") else None
                    except Exception:
                        crs_for_gdf = None
                        prj_wkt = None

            gdf = gpd.GeoDataFrame.from_features(geojson_fc, crs=crs_for_gdf)
            shp_dir = os.path.dirname(out_vector_shp_path)
            if shp_dir and not os.path.isdir(shp_dir):
                os.makedirs(shp_dir, exist_ok=True)
            gdf.to_file(out_vector_shp_path, driver="ESRI Shapefile", index=False)
            shp_written = out_vector_shp_path

            # Write .prj if we have WKT
            try:
                if prj_wkt is None and src_crs is not None and hasattr(src_crs, "to_wkt"):
                    prj_wkt = src_crs.to_wkt()
                if prj_wkt:
                    prj_path = os.path.splitext(out_vector_shp_path)[0] + ".prj"
                    with open(prj_path, "w", encoding="utf-8") as pf:
                        pf.write(prj_wkt)
            except Exception:
                pass

        except Exception as e:
            print(f"‚ö†Ô∏è geopandas->shapefile attempt failed: {e}\nFalling back to GeoJSON at: {out_vector_geojson_path}.")
            try:
                geojson_fc = {"type": "FeatureCollection", "features": features}
                with open(out_vector_geojson_path, "w", encoding="utf-8") as jf:
                    json.dump(geojson_fc, jf, ensure_ascii=False, indent=2)
                geojson_written_path = out_vector_geojson_path
                if src_crs is not None and hasattr(src_crs, "to_wkt"):
                    try:
                        prj_path = os.path.splitext(out_vector_geojson_path)[0] + ".prj"
                        with open(prj_path, "w", encoding="utf-8") as pf:
                            pf.write(src_crs.to_wkt())
                    except Exception:
                        pass
            except Exception as e2:
                print(f"‚ö†Ô∏è Also failed to write fallback GeoJSON: {e2}")

        # Final prints & return
        print(f"‚úÖ Done. Saved:\n  - Prob map: {out_prob_path}\n  - Mask: {out_mask_path}")
        if shp_written:
            print(f"  - Vector (Shapefile): {shp_written}")
        elif geojson_written_path:
            print(f"  - Vector (GeoJSON fallback): {geojson_written_path}")
        else:
            print("  - Vector: none (failed to write shapefile or geojson)")

        return out_prob_path, out_mask_path, shp_written, geojson_written_path

# ===================== RUN INFERENCE (use the loaded best model) =====================
out_prob, out_mask, out_shp, out_geojson = predict_full_tif_with_model_improved(
    MODEL_UNET_BEST,
    BEFORE_TIF,
    AFTER_TIF,
    out_prob_path=OUT_PROB,
    out_mask_path=OUT_MASK,
    out_vector_shp_path=OUT_VECTOR_SHP,
    out_vector_geojson_path=OUT_VECTOR_GEOJSON,
    patch_h=PATCH_H,
    patch_w=PATCH_W,
    overlap=OVERLAP,
    threshold=THRESH,
    batch_size=BATCH_SIZE,
    use_tta=USE_TTA,
    min_cc_size=MIN_CC_SIZE,
    morph_open_disk=MORPH_OPEN_DISK,
    morph_close_disk=MORPH_CLOSE_DISK
)

print("\nReturn values:")
print("  out_prob:", out_prob)
print("  out_mask:", out_mask)
print("  out_shp:", out_shp)
print("  out_geojson:", out_geojson)
