In [1]:
%load_ext autoreload
%autoreload 2

# Compare preds with UNOSAT labels

In [2]:
import geopandas as gpd
import multiprocessing as mp
import numpy as np
import pandas as pd
from tqdm import tqdm
import xarray as xr
import warnings


from src.data.buildings.overture_unosat import load_overture_buildings_aoi
from src.data import get_unosat_geometry
from src.data.utils import get_all_aois
from src.postprocessing.drive_to_results import find_post_dates
from src.postprocessing.utils import read_fp_within_geo, vectorize_xarray_3d
from src.constants import PREDS_PATH

idx = pd.IndexSlice

In [3]:
def get_preds_geo(geo, run_name):

    post_dates = find_post_dates(run_name)
    post_dates_ = [p[0] for p in post_dates]  # keep only first date for reference

    # Read and stack preds for each date
    fp_preds = [PREDS_PATH / run_name / f'ukraine_{"_".join(post_date)}.tif' for post_date in post_dates]
    dates = xr.Variable("date", pd.to_datetime(post_dates_))
    preds = xr.concat(
        [read_fp_within_geo(fp, geo) for fp in fp_preds], dim=dates
    ).squeeze()
    return preds

In [None]:
aois_train = [f'UKR{i}' for i in range(1,5)]
aois_test = [aoi for aoi in get_all_aois() if aoi not in aois_train]

In [None]:
run_name = '240301'

## Pixel-wise

In [None]:
from src.data import load_unosat_labels

def extract_raster_value(point, raster):
    value = raster.sel(x=point.x, y=point.y, method="nearest").item()
    return value

def combine_all_unosat_points_with_preds(run_name):
    gdf_labels_ = None
    for aoi in tqdm(get_all_aois()):
        geo = get_unosat_geometry(aoi)
        preds = get_preds_geo(geo, run_name)
        gdf_labels = load_unosat_labels(aoi, labels_to_keep=[1,2])[['geometry']]
        gdf_labels['aoi'] = aoi
        for date in ['2021-02-24', '2022-02-24', '2023-02-24']:
            gdf_labels[f'pred_{date}'] = gdf_labels.geometry.apply(lambda x: extract_raster_value(x, preds.sel(date=date)))

        gdf_labels_ = pd.concat([gdf_labels_, gdf_labels]) if gdf_labels_ is not None else gdf_labels

    gdf_labels_.to_file(PREDS_PATH / run_name / 'aoi_preds' / 'unosat_points_with_preds.geojson', driver='GeoJSON')

def load_unosat_points_with_preds(run_name):
    fp = PREDS_PATH / run_name / 'aoi_preds' / 'unosat_points_with_preds.geojson'
    assert fp.exists(), f"File {fp} does not exist."
    return gpd.read_file(fp)

In [None]:
# combine_all_unosat_points_with_preds(run_name)
gdf_points = load_unosat_points_with_preds(run_name)

In [None]:
d = {}
for split in ['train', 'test']:

    gdf_ = gdf_points[gdf_points.aoi.isin(aois_train if split=='train' else aois_test)]
    d[split] = {}
    for date in ['2021-02-24', '2022-02-24', '2023-02-24']:
        d[split][date] = {}
        for t in [0.5, 0.65, 0.75]:
            tp = (gdf_[f'pred_{date}']>=255*t).sum()
            fn = (gdf_[f'pred_{date}']<255*t).sum()
            recall = tp/(tp+fn)
            if date == '2021-02-24':
                recall = 1-recall
            d[split][date][t] = f'{recall:.2f}'

from pprint import pprint
pprint(d)


### Find best threshold

In [None]:

split = 'test'
date_neg = '2021-02-24'
date = '2022-02-24'
gdf_ = gdf_points[gdf_points.aoi.isin(aois_train if split=='train' else aois_test)]

precisions = []
recalls = []
f1s = []
f05s = []
f01s = []
thresholds = np.arange(0.1,0.9, 0.01)
for t in thresholds:
    tp = (gdf_[f'pred_{date}']>=255*t).sum()
    fn = (gdf_[f'pred_{date}']<255*t).sum()
    fp = (gdf_[f'pred_{date_neg}']>=255*t).sum()
    tn = (gdf_[f'pred_{date_neg}']<255*t).sum()
    precision = tp/(tp+fp)
    recall = tp/(tp+fn)
    f1 = 2*(precision*recall)/(precision+recall)
    beta=0.5
    f05 = (1+beta**2)*tp/((1+beta**2)*tp+fp + beta**2*fn)
    beta=0.1
    f01 = (1+beta**2)*tp/((1+beta**2)*tp+fp + beta**2*fn)
    precisions.append(precision)
    recalls.append(recall)
    f1s.append(f1)
    f05s.append(f05)
    f01s.append(f01)

In [None]:
import matplotlib.pyplot as plt

_, ax = plt.subplots(figsize=(8,5))
ax.plot(thresholds, precisions, label='Precision')
ax.plot(thresholds, recalls, label='Recall')
ax.plot(thresholds, f1s, label='F1')
ax.plot(thresholds, f05s, label='F0.5')
ax.plot(thresholds, f01s, label='F0.1')
ax.legend(loc='lower left')
ax.vlines(0.5, 0, 1, color='black', linestyle='--')
ax.set_xlabel('Threshold')
ax.set_ylabel('Score')
ax.set_xlim(0.1, 0.9)
ax.set_ylim(0, 1)
plt.show()

In [None]:
for scores, name in zip([f1s, f05s, f01s], ['F1', 'F0.5', 'F0.1']):
    print(f"Best {name} score: {max(scores):.2f} at threshold {thresholds[np.argmax(scores)]:.2f}")

In [None]:
y_true = np.array([0] * len(gdf_) + [1] * len(gdf_))
y_preds = np.array(gdf_[f'pred_{date_neg}'].tolist() + gdf_[f'pred_{date}'].tolist()) / 255

from sklearn import metrics
fpr, tpr, thresholds = metrics.roc_curve(y_true, y_preds)
roc_auc = metrics.auc(fpr, tpr)

gmeans = np.sqrt(tpr * (1-fpr))
ix = np.argmax(gmeans)

fig, ax = plt.subplots(figsize=(8,5))
ax.plot([0,1], [0,1], linestyle='--', color='black')
ax.plot(fpr, tpr, label=f'ROC AUC = {roc_auc:.2f}')
ax.scatter(fpr[ix], tpr[ix], marker='o', color='black', label=f'Best threshold = {thresholds[ix]:.2f}')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.legend()
plt.show()

In [None]:
_, ax = plt.subplots(figsize=(8,5))
precision, recall, thresholds = metrics.precision_recall_curve(y_true, y_preds)
fscore = (2 * precision * recall) / (precision + recall)
ix = np.argmax(fscore)
ax.plot(recall, precision, label='Precision-Recall')
ax.scatter(recall[ix], precision[ix], marker='o', color='black', label=f'Best (threshold={thresholds[ix]:.2f})')
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.legend()
plt.show()

## Building-wise

In [None]:
def add_preds_to_gdf(gdf, preds, verbose=0):

    dates = sorted([d.dt.strftime('%Y-%m-%d').item() for d in preds.date])

    # Vectorize pixels
    gdf_pixels = vectorize_xarray_3d(preds, dates)
    if verbose:
        print(f"Vectorized pixels ({gdf_pixels.shape})")

    # Overlap with buildings
    overlap = gpd.overlay(gdf.reset_index(), gdf_pixels, how="intersection").set_index("building_id")
    if verbose:
        print(f"Overlap with buildings ({overlap.shape})")

    # Add area of overlap
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=UserWarning)
        overlap["polygon_area"] = overlap.area

    # Compute weighted mean for each building and date
    overlap[[f"{d}_weighted_value" for d in dates]] = overlap[dates].multiply(
        overlap["polygon_area"], axis=0
    )
    grps = overlap.groupby("building_id")
    gdf_weighted_mean = (
        grps[[f"{d}_weighted_value" for d in dates]].sum().divide(grps["polygon_area"].sum(), axis=0)
    )
    gdf_weighted_mean = gdf_weighted_mean.stack().reset_index(level=1)
    gdf_weighted_mean.columns = ["post_date", "weighted_mean"]
    gdf_weighted_mean["post_date"] = gdf_weighted_mean["post_date"].apply(lambda x: x.split("_")[0])
    gdf_weighted_mean.set_index("post_date", append=True, inplace=True)

    # Compute max value for each building and date
    gdf_max = overlap.groupby("building_id")[dates].max().stack().to_frame(name="max")
    gdf_max.index.names = ["building_id", "post_date"]

    # Merge with original buildings
    gdf_with_preds = gdf.join(gdf_weighted_mean).join(gdf_max).sort_index()
    return gdf_with_preds

def load_aoi_buildings_with_preds(aoi, run_name):
    folder = PREDS_PATH / run_name / 'aoi_preds'
    fp = folder / f'{aoi}_buildings_with_preds.geojson'
    assert fp.exists(), f"File {fp} does not exist."
    return gpd.read_file(fp).set_index(['building_id', 'post_date'])

def create_aoi_buildings_with_preds(aoi, run_name):
    geo = get_unosat_geometry(aoi)
    gdf_buildings = load_overture_buildings_aoi(aoi).set_index("building_id")
    preds = get_preds_geo(geo, run_name)
    gdf_buildings_with_preds = add_preds_to_gdf(gdf_buildings, preds, verbose=0)

    folder = PREDS_PATH / run_name / 'aoi_preds'
    folder.mkdir(exist_ok=True, parents=True)
    gdf_buildings_with_preds.to_file(folder / f'{aoi}_buildings_with_preds.geojson', driver='GeoJSON')
    print(f"Saved buildings with preds for aoi {aoi}.")

def create_all_aoi_buildings_with_preds(run_name):

    with mp.Pool(4) as pool:
        pool.starmap(create_aoi_buildings_with_preds, [(aoi, run_name) for aoi in get_all_aois()])

def combine_all_unosat_buildings_with_preds(run_name):
    gdf_ = None
    for aoi in tqdm(get_all_aois()):
        gdf = load_aoi_buildings_with_preds(aoi, run_name)
        gdf = gdf[gdf['damage_5m'].isin([1,2])]
        gdf['aoi'] = aoi
        gdf_ = pd.concat([gdf_, gdf]) if gdf_ is not None else gdf
    gdf_.to_file(PREDS_PATH / run_name / 'aoi_preds' / 'unosat_buildings_with_preds.geojson', driver='GeoJSON')

def load_unosat_buildings_with_preds(run_name):
    fp = PREDS_PATH / run_name / 'aoi_preds' / 'unosat_buildings_with_preds.geojson'
    assert fp.exists(), f"File {fp} does not exist."
    return gpd.read_file(fp).set_index(['building_id', 'post_date'])

In [None]:
# create_all_aoi_buildings_with_preds(run_name)
# combine_all_unosat_buildings_with_preds(run_name)
gdf = load_unosat_buildings_with_preds(run_name)
gdf.shape

In [None]:
gdf

In [None]:
def compute_metrics(gdf, date, threshold, agg='weighted_mean', buffer=5):
    gdf = gdf.loc[idx[:, date],:]
    suffix = f'_{buffer}m' if buffer else ''

    gdf = gdf[gdf['damage' + suffix].isin([1,2])]
    gdf = gdf.groupby('unosat_id' + suffix).agg({"weighted_mean": "mean"})
    tp = (gdf[agg]>=255*threshold).sum()
    fn = (gdf[agg]<255*threshold).sum()
    recall = tp/(tp+fn)
    if date == '2021-02-24':
        recall = 1-recall
    # print(f"TP: {tp} ({100*tp/(tp+fn):.2f}%), FN: {fn} ({100*fn/(tp+fn):.2f}%)")
    return recall

In [None]:
d = {}
for split in ['train', 'test']:
    gdf_ = gdf[gdf.aoi.isin(aois_train if split=='train' else aois_test)]
    d[split] = {}
    for date in ['2021-02-24', '2022-02-24', '2023-02-24']:
        d[split][date] = {}
        for t in [0.5, 0.65, 0.75]:
            recall = compute_metrics(gdf_, date, t, buffer=5)
            d[split][date][t] = f'{recall:.2f}'

from pprint import pprint
pprint(d)