This notebook illustrates how to apply the transformer metric to a single burst time series. 

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from rasterio.plot import show
from scipy.special import expit, logit

from distmetrics import (
    compute_mahalonobis_dist_2d,
    compute_transformer_zscore,
    despeckle_rtc_arrs_with_tv,
    get_asf_rtc_burst_ts,
    load_transformer_model,
    read_asf_rtc_image_data,
)

from distmetrics.asf_burst_search import get_pre_post_df_rtc_df

# Parameters

In [3]:
# Papau New Guinea Landslide -  https://www.cnn.com/2024/05/25/world/video/damage-papua-new-guinea-landslide-ldn-digvid
BURST_ID = "T009_019294_IW2"
EVENT_TS = pd.Timestamp("2024-05-28", tz="utc")

# Los Angeles Angeles Fire - Pacific Palisades Burst
BURST_ID = 'T071-151228-IW3'
EVENT_TS = pd.Timestamp("2025-01-08", tz="utc")

DEVICE =  'cpu'

N_PRE_IMAGES = 10
APPLY_LOGIT = True

N_PRE_IMGS_PER_YEAR = 3

# Download Data

In [4]:
df_rtc_ts = get_asf_rtc_burst_ts(BURST_ID)
print("# of results: ", df_rtc_ts.shape[0])
df_rtc_ts.head()

ASFSearchError: Connection Error (Timeout): CMR took too long to respond. Set asf constant "asf_search.constants.INTERNAL.CMR_TIMEOUT" to increase. (url='https://cmr.earthdata.nasa.gov/search/granules.umm_json', timeout=30)

In [None]:
df_prod = get_pre_post_df_rtc_df(df_rtc_ts, 
                                 EVENT_TS,
                                 n_anniversaries=3,
                                 n_pre_imgs=N_PRE_IMAGES)
df_prod.tail()

In [None]:
df_prod.acq_datetime.tolist()

**Warning**: sometime this cell fails and must be re-run due to a transient server error.

In [None]:
# Load only the filtered images (much more memory efficient!)
arrs_vv, profiles = read_asf_rtc_image_data(df_prod.url_copol)
arrs_vh, _ = read_asf_rtc_image_data(df_prod.url_crosspol)

In [None]:
arrs_vv_d = despeckle_rtc_arrs_with_tv(arrs_vv, n_jobs=10, interp_method='bilinear')
arrs_vh_d = despeckle_rtc_arrs_with_tv(arrs_vh, n_jobs=10, interp_method='bilinear')

In [None]:
plt.imshow(arrs_vh_d[-1], vmin=0, vmax=.15)

In [None]:
from distmetrics.model_load import ALLOWED_MODELS
ALLOWED_MODELS

In [None]:
transformer = load_transformer_model(lib_model_token='transformer_original', device=DEVICE)

In [None]:
def apply_logit_func(arr):
    arr_in = arr.copy()
    arr_in[np.isnan(arr_in)] = 1e-7
    return logit(arr_in)

pre_imgs_vv = arrs_vv_d[:-1]
pre_imgs_vh = arrs_vh_d[:-1]
if APPLY_LOGIT:
    pre_imgs_vv = list(map(logit, pre_imgs_vv))
    pre_imgs_vh = list(map(logit, pre_imgs_vh))

In [None]:
post_vv = arrs_vv_d[-1]
post_vh = arrs_vh_d[-1]
if APPLY_LOGIT:
    post_vv = logit(post_vv)
    post_vh = logit(post_vh)

In [None]:
# Landslide
if BURST_ID == "T009_019294_IW2":
    sy = np.s_[1250:1500]
    sx = np.s_[400:750]
    pre_vv_c = [arr[sy, sx] for arr in pre_imgs_vv]
    pre_vh_c = [arr[sy, sx] for arr in pre_imgs_vh]

    post_vv_c = post_vv[sy, sx]
    post_vh_c = post_vh[sy, sx]

else:

    pre_vv_c = [arr for arr in pre_imgs_vv]
    pre_vh_c = [arr for arr in pre_imgs_vh]

    post_vv_c = post_vv
    post_vh_c = post_vh

In [None]:
plt.imshow(pre_vv_c[-1])
plt.colorbar()

In [None]:
len(pre_vv_c)

In [None]:
dist_ob = compute_transformer_zscore(
    transformer,
    pre_vv_c[:],
    pre_vh_c[:],
    post_vv_c,
    post_vh_c,
    stride=16,
    agg="max",
    batch_size=512,
    memory_strategy="high",
    device=DEVICE,
    # tile_size=512
)

In [None]:
dist = dist_ob.dist

In [None]:
plt.imshow(dist_ob.dist, vmax=5)
plt.colorbar()

In [None]:
plt.imshow(dist_ob.dist > 3.5, vmax=1, interpolation="none")
plt.colorbar()

**Warninig**: the stat outputs of the function above are in `logits` NOT `gamma naught`. Hence `expit`!

In [None]:
from scipy.special import expit

plt.title("Mean Estimate VV ($\gamma$)")
plt.imshow(expit(dist_ob.mean[0, ...]))
plt.colorbar()

Can't really apply `expit` to `sigma`...

In [None]:
plt.title("Std Estimate logit(VV)")
plt.imshow(dist_ob.std[0, ...], vmax=1, vmin=0)
plt.colorbar()