In [None]:
%load_ext autoreload
%autoreload 2

# Classification sliding window

In [None]:
from omegaconf import OmegaConf
import pandas as pd

from src.constants import AOIS_TEST
from src.data import UNOSAT_S1TS_Dataset
from src.classification.model_factory import load_model
from src.classification.trainer import S1TSDD_Trainer

In [None]:
def extract_features(df, start, end, prefix=""):

    # columns are datetime -> can slice directly between two dates
    df = df.loc[:, start:end]

    # features
    df_features = pd.DataFrame(index=df.index)
    df_features["mean"] = df.mean(axis=1)
    df_features["std"] = df.std(axis=1)
    df_features["median"] = df.median(axis=1)
    df_features["min"] = df.min(axis=1)
    df_features["max"] = df.max(axis=1)
    df_features["skew"] = df.skew(axis=1)
    df_features["kurt"] = df.kurt(axis=1)

    # rename columns using band, prefix (eg pre/post/pre_3x3, ...)
    df_vv = df_features.xs("VV", level="band")
    df_vh = df_features.xs("VH", level="band")
    df_vv.columns = [f"VV_{prefix}_{col}" for col in df_vv.columns]
    df_vh.columns = [f"VH_{prefix}_{col}" for col in df_vh.columns]
    return pd.concat([df_vv, df_vh], axis=1)


cfg = OmegaConf.create(
    dict(
        aggregation_method="mean",
        model_name= "random_forest",
        model_kwargs=dict(
            n_estimators=200,
            min_samples_leaf=2,
            n_jobs=12,
        ),
        data=dict(
            aois_test = [f'UKR{i}' for i in range(1,19) if i not in [1,2,3,4]], #["UKR6", "UKR8", "UKR12", "UKR15"],
            damages_to_keep=[1,2],
            extract_winds = ['3x3'], # ['1x1', '3x3', '5x5']
            random_neg_labels=0.0,  # percentage of negative labels to add in training set (eg 0.1 for 10%)
            time_periods = {
                'pre': [('2020-02-24', '2021-02-23')],
                'post' : [
                    # ('2021-02-24', '2022-02-23'),
                    # ('2022-02-24', '2023-02-23')
                    ('2021-02-24', '2021-05-23'),
                    ('2021-05-24', '2021-08-23'),
                    ('2021-08-24', '2021-11-23'),
                    ('2021-11-24', '2022-02-23'),
                    ('2022-02-24', '2022-05-23'),
                    ('2022-05-24', '2022-08-23'),
                    ('2022-08-24', '2022-11-23'),
                    ('2022-11-24', '2023-02-23'),
                ]
            }
        ),
        seed=123,
        run_name=None,
    )
)

ds = UNOSAT_S1TS_Dataset(cfg.data, extract_features=extract_features)

In [None]:
model = load_model(cfg)
trainer = S1TSDD_Trainer(ds, model, aggregation=cfg.aggregation_method, seed=cfg.seed, verbose=1)
trainer.train_and_test(threshold_for_metrics=0.5)

In [None]:
_, df_test = ds.get_datasets('test', remove_unknown_labels=False)

In [None]:
df_preds = []

for post_start, df in df_test.groupby('post_start'):

    X = df[[c for c in df.columns if c.startswith(("VV", "VH"))]].values
    y = df["label"].values
    preds_proba = trainer.model.predict_proba(X)[:, 1]

    _df_preds = df[["aoi", "unosat_id", "orbit", "date", "label"]].copy()
    _df_preds['post_start'] = post_start
    _df_preds["preds_proba"] = preds_proba
    df_preds.append(_df_preds)
df_preds = pd.concat(df_preds)
d_agg = {'label':'first', 'preds_proba': 'mean'}
df_agg = df_preds.groupby(["aoi", "unosat_id", 'date', "post_start"]).agg(d_agg)
#df_agg.index = df_agg.index.set_levels(df_agg.index.levels[2].date, level=2)
df_agg.index = df_agg.index.set_levels(pd.to_datetime(df_agg.index.levels[3]), level=3)
df_agg.head(8)

In [None]:
idx = pd.IndexSlice
df_neg = df_agg.loc[idx[:,:,:,:'2022-02-23']].groupby(['aoi','unosat_id']).max()
df_pos = df_agg.loc[idx[:,:,:,'2022-02-24':]].groupby(['aoi','unosat_id']).max()
df_agg_agg = pd.concat([df_neg, df_pos])

y_true = df_agg_agg.label
y_preds_proba = df_agg_agg.preds_proba

from collections import Counter
print(Counter(y_true))
print(Counter(y_preds_proba > 0.5))

from src.classification.utils import compute_metrics
compute_metrics(df_agg_agg.label ,df_agg_agg.preds_proba, threshold=0.5)

In [None]:
df_agg.head(8)

In [None]:
idx = pd.IndexSlice
df_neg = df_agg.loc[idx[:,:,:,:'2022-02-23']]
df_neg.preds_proba.argmax()

In [None]:
df_neg.iloc[3303]

In [None]:
from src.visualization.time_series import plot_all_ts_from_id
plot_all_ts_from_id('UKR5', 20419)

In [None]:
from src.data import load_unosat_labels
labels = load_unosat_labels('UKR5')
labels[['geometry']].explore(tiles='https://services.arcgisonline.com/arcgis/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}.png', attr='ESRI')

In [None]:
idx = pd.IndexSlice
df_neg = df_agg.loc[idx[:,:,:,:'2022-02-23']].groupby(['aoi','unosat_id']).max()
df_pos = df_agg.loc[idx[:,:,:,'2022-02-24':]].groupby(['aoi','unosat_id']).max()
df_agg_agg = pd.concat([df_neg, df_pos])

y_true = df_agg_agg.label
y_preds_proba = df_agg_agg.preds_proba

from collections import Counter
print(Counter(y_true))
print(Counter(y_preds_proba > 0.5))

from src.classification.utils import compute_metrics
compute_metrics(df_agg_agg.label ,df_agg_agg.preds_proba, threshold=0.5)