In [None]:
%load_ext autoreload
%autoreload 2

# Analyze results (postprocessing)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

from src.postprocessing.preds_buildings import load_buildings_with_preds
from shapely.geometry import box

zoom_box = box(31.3649273,51.5272942,31.3693541,51.5301153)
aoi = 'UKR6'
run_name = "240212"
gdf = load_buildings_with_preds(aoi, run_name)
print(gdf.shape)
gdf.head()

In [None]:
labels_to_keep = [1,2]
threshold = 0.65 * 255
print(f'Threshold: {threshold}')
gdf['preds'] = (gdf['weighted_mean'] > threshold).astype(int)
gdf['label'] = gdf.damage_5m.isin(labels_to_keep).astype(int)

In [None]:
from sklearn.metrics import classification_report

print(classification_report(gdf.label, gdf.preds, target_names=["intact", "destroyed"]))

In [None]:
tp = len(gdf[gdf.preds&gdf.damage_5m.isin(labels_to_keep)])
fp = len(gdf[gdf.preds&~gdf.damage_5m.isin(labels_to_keep)])
fn = len(gdf[~gdf.preds&gdf.damage_5m.isin(labels_to_keep)])
tn = len(gdf[~gdf.preds&~gdf.damage_5m.isin(labels_to_keep)])

print(tp, fp, fn, tn)

In [None]:
threshold = np.arange(0.5, 1, 0.05)
labels_to_keep = [1,2,3]
recalls = []
precisions = []
for t in threshold:
    gdf['preds'] = (gdf['weighted_mean'] > t*255).astype(int)
    recalls.append(len(gdf[gdf.preds&gdf.damage_5m.isin(labels_to_keep)])/len(gdf[gdf.damage_5m.isin(labels_to_keep)]))
    precisions.append(len(gdf[gdf.preds&gdf.damage_5m.isin(labels_to_keep)])/len(gdf[gdf.preds==1]))

fig, ax = plt.subplots()
ax.plot(threshold, recalls, label='Recall')
ax.set_xlabel('Threshold')
ax.set_ylabel('Recall')
ax.set_title('Recall vs Threshold')
ax.legend()
plt.show()

# Analyze time windows

In [None]:
import pandas as pd
import joblib
from src.data.utils import read_ts
from src.data.time_series.stacked_ts import get_df_from_single_ts
from src.classification.features import default_extract_features


def predict_for_ts(ts, model, start_pre, end_pre, start_post, end_post):

    df = get_df_from_single_ts(ts)
    time_periods = {
        "pre": (start_pre, end_pre),
        "post": (start_post, end_post)
    }

    dfs=[]
    for name_period, (start, end) in time_periods.items():
        df_features = default_extract_features(df, start, end, prefix=f"{name_period}_{ts.extraction_strategy}")
        df_features[f"{name_period}_start"] = start
        df_features[f"{name_period}_end"] = end
        dfs.append(df_features)
    df_final = pd.concat(dfs, axis=1)
    X = df_final[[c for c in df_final.columns if c.startswith(("VV", "VH"))]].values
    return model.predict_proba(X)[:,1]

from src.constants import UKRAINE_WAR_START
import datetime as dt
import matplotlib.pyplot as plt

def plot_ts_with_periods(ts, periods, preds=0, loc_legend = "lower left", add_legend = True, title = None, add_invasion_date = True, add_analysis_date = True, show=True):

    #ts = ts.sel(date=slice("2021-01-01",None))

    fig, ax = plt.subplots(figsize=(10, 3))
    ax.axhline(0, color="k", linewidth=0.5)
    d_color = {"VV": "C1", "VH": "C0"}
    for band in ts.band.values:
        if band not in d_color:
            # ignore additional bands if any
            continue
        ts.sel(band=band).plot(x="date", color=d_color[band], label=band, ax=ax)

    invasion_date = UKRAINE_WAR_START
    if add_invasion_date and ts.date[0].dt.strftime("%Y-%m-%d") < invasion_date < ts.date[-1].dt.strftime("%Y-%m-%d"):
        ax.axvline(
            dt.date.fromisoformat(invasion_date),
            color="r",
            linestyle="--",
            label="date of invasion",
        )

    analysis_date = ts.date_of_analysis
    if add_analysis_date and ts.date[0].dt.strftime("%Y-%m-%d") < analysis_date < ts.date[-1].dt.strftime("%Y-%m-%d"):
        ax.axvline(
            dt.date.fromisoformat(analysis_date),
            color="g",
            linestyle="--",
            label="date of analysis",
        )


    ax.axvspan(dt.date.fromisoformat(periods['pre'][0]), dt.date.fromisoformat(periods['pre'][1]), color='C2', alpha=0.3)
    ax.axvspan(dt.date.fromisoformat(periods["post"][0]), dt.date.fromisoformat(periods["post"][1]), color="C3", alpha=0.3)
    # ax.axvspan(dt.date.fromisoformat(periods["post_neg"][0]), dt.date.fromisoformat(periods["post_neg"][1]), color="C3", alpha=0.3)
    # ax.axvspan(dt.date.fromisoformat(periods["post_pos"][0]), dt.date.fromisoformat(periods["post_pos"][1]), color="C4", alpha=0.3)

    ax.set_xlabel("Date")
    ax.set_ylabel("Backscatter (dB)")
    ax.set_xlim([dt.date.fromisoformat("2020-02-01"), dt.date.fromisoformat("2023-05-24")])
    ax.set_ylim([-25, 10])
    ax.grid(axis="x")
    if add_legend:
        ax.legend(loc=loc_legend)
    if title is None:
        title = f"{ts.aoi} - orbit {ts.orbit} - ID {ts.unosat_id} - {ts.extraction_strategy} - {preds}"
        ax.set_title(title)
    if show:
        plt.show()
    return fig

In [None]:
aoi = "UKR15"
orbit = 65
id_ = 15487
extraction_strategy = "3x3"
ts = read_ts(aoi, orbit, id_, extraction_strategy=extraction_strategy)
model = joblib.load('current_model.joblib')

start_pre = "2020-4-01"
end_pre = "2021-11-01"
start_post = "2022-04-01"
end_post = "2022-11-01"
periods = {
    "pre": ('2020-02-24', '2021-02-23'),
    "post_neg": ('2021-02-24', '2022-02-23'),
    "post_pos": ('2022-02-24', '2023-02-23'),
}
#preds = predict_for_ts(ts, model, start_pre, end_pre, start_post, end_post)
plot_ts_with_periods(ts, periods, add_legend=False); #, preds[0])

In [None]:
aoi = "UKR15"
orbit = 65
id_ = 15487
extraction_strategy = "3x3"
ts = read_ts(aoi, orbit, id_, extraction_strategy=extraction_strategy)
model = joblib.load('current_model.joblib')

#preds = predict_for_ts(ts, model, start_pre, end_pre, start_post, end_post)
post_periods = [
    ('2021-02-24', '2021-05-23'),
    ('2021-05-24', '2021-08-23'),
    ('2021-08-24', '2021-11-23'),
    ('2021-11-24', '2022-02-23'),
    ('2022-02-24', '2022-05-23'),
    ('2022-05-24', '2022-08-23'),
    ('2022-08-24', '2022-11-23'),
    ('2022-11-24', '2023-02-23'),
    ('2023-02-24', '2023-05-23'),
]

from PIL import Image
frames = []
for tp in post_periods:
    periods = {
        "pre": ('2020-02-24', '2021-02-23'),
        "post": tp
    }
    fig = plot_ts_with_periods(ts, periods, add_legend=False, show=False) #, preds[0])
    # save the figure
    plt.tight_layout()
    filename = f"./plot_sliding/ts_{tp[0]}_{tp[1]}.png"
    plt.savefig(filename, bbox_inches='tight')
    plt.close()

    frames.append(Image.open(filename))

# Save into a GIF
frames[0].save('./plot_sliding/output.gif', format='GIF',
               append_images=frames[1:],
               save_all=True,
               duration=300, loop=0)

In [None]:
import ipywidgets as widgets
from ipywidgets import interact

start_pre_dates = pd.date_range('2021-11-01', '2023-01-01', freq="MS").strftime("%Y-%m-%d").tolist()
end_pre_dates = pd.date_range('2022-03-01', '2023-04-01', freq="MS").strftime("%Y-%m-%d").tolist()

def sliding_predictions(ts, model):

    def predict_and_plot(start_pre, end_pre, start_post, end_post):
        periods = {
            "pre": (start_pre, end_pre),
            "post": (start_post, end_post)
        }
        preds = predict_for_ts(ts, model, start_pre, end_pre, start_post, end_post)
        plot_ts_with_periods(ts, periods, preds[0])
        return preds

    pre_start = widgets.SelectionSlider(options=['2021-04-11'], description="Start Pre Date", continuous_update=True)
    pre_end = widgets.SelectionSlider(options=['2021-11-01'], description="End Pre Date", continuous_update=True)
    post_start = widgets.SelectionSlider(options=start_pre_dates, description="Start Post Date", continuous_update=False)
    post_end = widgets.SelectionSlider(options=end_pre_dates, description="End Post Date", continuous_update=False)
    interact(predict_and_plot, start_pre=pre_start, end_pre=pre_end, start_post=post_start, end_post=post_end)

In [None]:
sliding_predictions(ts, model)

In [None]:
import ipywidgets as widgets
from ipywidgets import interact

def sliding_predictions(ts, model, earliest_date="2020-06-01", latest_date="2022-06-01"):
    # For each model, predict for each month
    start_dates = pd.date_range(earliest_date, latest_date, freq="MS").strftime("%Y-%m-%d").tolist()
    df_results = predict_for_each_date(ts, model, start_dates)

    def analyze_result(start_date):
        ts_ = ts.sel(date=slice(start_date, None)).isel(date=slice(0, 32))
        first_date = ts_.date[0].dt.strftime("%Y-%m-%d").item()
        last_date = ts_.date[-1].dt.strftime("%Y-%m-%d").item()

        _, axs = plt.subplots(2, 1, figsize=(10, 10))

        # Plot TS with sliding window
        plot_ts(ts, axs[0], loc_legend="lower left")
        axs[0].set_xlim(
            [dt.date.fromisoformat(earliest_date), dt.date.fromisoformat(latest_date) + dt.timedelta(days=32 * 12)]
        )
        axs[0].axvspan(first_date, last_date, color=SLIDING_COLOR, alpha=0.3)

        # Plot predictions for each model vs true labels
        _plot_label(df_results, ax=axs[1], color="k")

        label = f"predictions"
        axs[1].plot(df_results.index, df_results.preds, label=label, color=PREDS_COLOR)
        axs[1].axvline(dt.date.fromisoformat(start_date), color=SLIDING_COLOR, linewidth=3)
        axs[1].axhline(0.5, color="k", linestyle="--")
        axs[1].set_title("Predictions vs Labels")
        axs[1].set_ylabel("Probability of destruction")
        axs[1].set_xlabel("Start Month")
        axs[1].legend(loc="upper left")
        plt.tight_layout()

    date_slider = widgets.SelectionSlider(options=start_dates, description="Start Date", continuous_update=True)
    interact(analyze_result, start_date=date_slider)