# Imports, setup and paths

In [None]:
%cd ..

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib ipympl

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import pandas as pd
import yaml
import pandas as pd
import plotly.express as px
from collections import defaultdict
from pathlib import Path
from hydra import compose, initialize
from omegaconf import OmegaConf

In [None]:
# context initialization
os.environ['NEWSCRATCH'] = "/lustre/fsn1/projects/rech/xyw/ute68qj"
with initialize(version_base=None, config_path="../conf"):
    cfg = compose(config_name="preproc", overrides=["paths=local"])
    paths = cfg['paths']

# Looking at the available sources

Let's first load the samples metadata file, which gives all of the metadata for every sample in the train, val and test sets:

In [None]:
dataset_dir = Path(paths['preprocessed_dataset'])
df = pd.read_csv(dataset_dir / 'train.csv', parse_dates=["time"])
df.head()

In [None]:
df.info()

**Important**: by default, this notebook will use ALL of the preprocessed data, i.e. all sources.

## Repartition by source

In [None]:
sns.histplot(df, x='source_name')
plt.xticks(rotation=90)
plt.show()

While the ERA5 states and infrared observations are each considered as a source on its own, the passive microwave and radar observations are split into one source per satellite, sensor and swath.

## Repartition by satellite over time

In [None]:
# Subset of the dataset made up of pmw and radar data only.
pmw_df = df[~(df['source_name'].str.contains('infrared') | df['source_name'].str.contains('era5'))]

In [None]:
# Compute the number of observations per month for each source
month = pmw_df.time.dt.to_period('M')
sat = pmw_df.source_name.apply(lambda s: s.split('_')[3].split('_')[0])
monthly_obs = pmw_df.groupby([sat, month])['sid'].count().rename('Monthly overpasses').reset_index()
monthly_obs['time'] = monthly_obs.time.dt.to_timestamp()

In [None]:
sns.lineplot(monthly_obs, x='time', y='Monthly overpasses', hue='source_name', style='source_name')
plt.xticks(rotation=90)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.show()

Let's also look at the counts per sensor and satellite:

In [None]:
sensat = pmw_df.source_name.apply(lambda s: '_'.join(s.split('_')[3:5]))
monthly_obs = df.groupby([sensat, month])['sid'].count().rename('Monthly overpasses').reset_index()
monthly_obs['time'] = monthly_obs.time.dt.to_timestamp()

In [None]:
sns.lineplot(monthly_obs, x='time', y='Monthly overpasses', hue='source_name', style='source_name')
plt.xticks(rotation=90)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.show()

The observations from the sensor/satellite pair ```SSMIS_F19``` abruptly end in early 2016, while ```ATMS_NOAA20``` only started returning observations in late 2017.  
Let's take a clearer look at the time span of each sensor / satellite:

In [None]:
# Compute the earliest and latest dates for each satellite/sensor pair
gpby = pmw_df.groupby(sensat)
start_dates = gpby['time'].min()
end_dates = gpby['time'].max()
timeline_df = pd.merge(start_dates, end_dates, left_index=True, right_index=True, suffixes=['_start', '_end'])
timeline_df = timeline_df.reset_index()

In [None]:
fig = px.timeline(timeline_df.sort_values('time_start'),
                  x_start="time_start",
                  x_end="time_end",
                  y="source_name",
                  text="source_name",
                  color_discrete_sequence=["tan"])
fig.show()

# From dataset to samples
We now need to define what a *sample* means. We can define a sample from a reference observation:  
Let $x_0 \in S_0$ be an observation from the source $S_0$; let $t_0$ be the time of that observation.  
We'll define the sample referenced by $x_0$ as:  
$$
x = \{x_k \in S_k;\quad x_k = \text{arg min}_{u_k\in S_k; t_{u_k} \leq t_0}(t_0 - t_{u_k})\}
$$
i.e. for each source, we'll use the observation older than $t_0$ that is closest to $t_0$. Observations from different storms are never mixed.  
Since observations that are far away in time are less correlated, we'll introduce a maximum time delta between $t_0$ and $t_k$ for $x_k$ to be actually included in $x$:
$$
x_{\text{filtered}} = \{x_k \in x; \Delta t_k := t_0 - t_k \leq \Delta t_{max}\}
$$
A larger $\Delta t_{max}$ will allow more source to appear in the sample.

## Isolating samples for which specific sources are available
We may need during training and evaluation to only use samples for which at least a certain set of sources is available.

In [None]:
unique_sids = df['sid'].unique()
unique_sources = df['source_name'].unique()

In [None]:
source_name_col = df['source_name'].to_numpy()
sid_col = df['sid'].to_numpy()
time_col = df['time'].to_numpy()

# maps {sid: sid_mask} and {source_name: source_mask}
# We precompute those as they will be reused multiple times each,
# and bypassing pandas to use numpy directly is much faster.
source_name_mask = {}
for source_name in unique_sources:
    source_name_mask[source_name] = source_name_col == source_name
sid_mask = {}
for sid in unique_sids:
    sid_mask[sid] = sid_col == sid

Let's now compute the available sources for all samples in the training set:

In [None]:
def is_source_available(ref_obs, source_name, dt_max):
    """Given a row in the metadata df which defines a sample (reference
    observation) and a source, returns 1 if the source is present in the
    sample and 0 otherwise."""
    t0 = ref_obs['time']
    # Compute the oldest time an observation can have to respect the time delta constraint.
    min_t = t0 - dt_max
    # Isolate the times of observations corresponding to the correct sid and source
    sid = ref_obs['sid']
    times = time_col[sid_mask[sid] & source_name_mask[source_name]]
    # Check for times that respect the time delta constraint
    return int(((times <= t0) & (times >= min_t)).sum() > 0)

In [None]:
def sources_availability(dt_max):
    """Returns a dataframe D with one column per source, and one row
    per sample in df, such that D[i, s] is 1 if source s is available
    at sample i and 0 otherwise."""
    dt_max = pd.Timedelta(hours=dt_max)
    availability = {}  # {source: [availability flag for each sample]}
    for source in unique_sources:
        availability[source] = df.apply(is_source_available, args=[source, dt_max], axis='columns')
    return pd.DataFrame(availability)

Let's check as an example the available sources for $\Delta t_{max}=24h$:

In [None]:
avail_24h = sources_availability(24)
avail_24h.head()

We can deduce from this the number of available sources for each sample, if which sources are present doesn't matter:

In [None]:
avail_counts_24h = avail_24h.sum(axis=1)
avail_counts_24h.head()

## Statistical presence of the sources
We'll thus now study how many sources are present statistically in the samples depending on $\Delta t_{max}$:

In [None]:
def sources_availability_frequency(dt_max):
    """Returns a dataframe D with one column per source, and one row
    per sample in df, such that D[i, s] is 1 if source s is available
    at sample i and 0 otherwise."""
    dt_max = pd.Timedelta(hours=dt_max)
    availability = defaultdict(int)  # {source: frequency of presence}
    for source in unique_sources:
        availability[source] = df.apply(is_source_available, args=[source, dt_max], axis='columns').mean()
    return availability

We can now plot the frequency at which a source is available in the samples, depending on $\Delta t_{max}$:  
(this can take a while to compute)

In [None]:
avail_freq = defaultdict(list)  # {source: [freq for each dt_max]}
dt_max_values = np.arange(1, 25, 2)
for dt_max in dt_max_values:
    print(f"Computing availability for dt_max={dt_max}h")
    avail = sources_availability_frequency(dt_max)
    for source in unique_sources:
        avail_freq[source].append(avail[source])
avail_freq = pd.DataFrame(avail_freq)
avail_freq['dt_max'] = dt_max_values

In [None]:
avail_freq

In [None]:
sns.lineplot(
    avail_freq.melt(id_vars=['dt_max'], var_name='source_name', value_name='frequency'),
    x='dt_max',
    y='frequency',
    hue='source_name',
    style='source_name'
)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
plt.show()