# Notebook for CML processing and testing

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline

In [None]:
import numpy as np
import math
import matplotlib.pyplot as plt
import matplotlib as mpl
import tqdm
import xarray as xr
import pandas as pd
import contextily as cx

import benchmarks.processing.baseline as baseline
import benchmarks.processing.wet_antenna as wet_antenna
import benchmarks.processing.k_R_relation as KR
import benchmarks.processing.spatial.interpolator as interpolator

from utils.load import get_gauge_coordinate_mappings, load_raingauge_dataset
from scipy.stats import pearsonr, spearmanr
from utils.visualisation import *

In [None]:
data_path = ""
cmls = xr.open_dataset("database/cml_data_processed_2025.nc")
print(cmls.coords)

In [None]:
cml_list = [cmls.sel(link_id=i) for i in cmls["link_id"]]
print(cml_list)

In [None]:
fig, ax = plt.subplots(3,1, sharex=True, figsize=(12,5))
cml_list[0].TSL_AVG.plot.line(x='time', ax=ax[0])
cml_list[0].RSL_AVG.plot.line(x='time', ax=ax[1])
cml_list[0].trsl.plot.line(x='time', ax=ax[2])

In [None]:
cml = cml_list[0].copy()
threshold = 0.8

roll_std_dev = cml.trsl.rolling(time=60, center=True).std()
cml['wet_2'] = cml.trsl.rolling(time=60, center=True).std() > threshold

print(cml['wet'].values)

In [None]:
fig, ax = plt.subplots(2,1, figsize=(12,5), sharex=True)

roll_std_dev.plot.line(x='time', ax=ax[0])
ax[0].axhline(threshold, color='k', linestyle='--')
cml.trsl.plot.line(x='time', ax=ax[1])

wet_start = np.roll(cml.wet_2, -1) & ~cml.wet_2
wet_end = np.roll(cml.wet_2, 1) & ~cml.wet_2

for wet_start_i, wet_end_i in zip(
    wet_start.values[0].nonzero()[0], #values[0] is station A, values[1] is station B
    wet_end.values[0].nonzero()[0],
):
     ax[1].axvspan(cml.time.values[wet_start_i], cml.time.values[wet_end_i], color='b', alpha=0.1)

     ax[1].axvspan(cml.time.values[wet_start_i], cml.time.values[wet_end_i], color='g', alpha=0.1)

ax[1].set_title('')

In [None]:
cml['baseline'] = baseline.baseline_constant(trsl=cml.trsl, wet=cml.wet_2, n_average_last_dry=5)

fig, ax = plt.subplots(figsize=(12,3))

cml.trsl.plot.line(x='time', alpha=0.5)
plt.gca().set_prop_cycle(None)
cml.baseline.plot.line(x='time')
plt.gca().set_prop_cycle(None)
plt.ylabel('TRSL')

In [None]:
cml['waa'] = wet_antenna.waa_schleiss_2013(
    rsl=cml.trsl, 
    baseline=cml.baseline, 
    wet=cml.wet, 
    waa_max=2.2, 
    delta_t=1, 
    tau=15,
)

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(12,5), sharex=True)

# plt.sca(axs[0])
# cml.isel(channel_id=0).trsl.plot.line(x='time', alpha=0.5, label='TRSL')
# plt.gca().set_prop_cycle(None)
# cml.isel(channel_id=0).baseline.plot.line(x='time', linestyle=':', label='baseline without WAA');
# plt.gca().set_prop_cycle(None)
# (cml.baseline + cml.waa).isel(channel_id=0).plot.line(x='time', label='baseline with WAA');
# plt.ylabel('TRSL (dB)')
# axs[0].legend()

# estimate WAA and correct baseline
cml['A'] = cml.trsl - cml.baseline - cml.waa
cml['A'].values[cml.A < 0] = 0
cml['A_no_waa_correct'] = cml.trsl - cml.baseline
cml['A_no_waa_correct'].values[cml.A_no_waa_correct < 0] = 0 

plt.sca(axs[1])
cml.A_no_waa_correct.sel(station=cml.station[0]).plot.line(x='time', linestyle=':', label='without WAA');
plt.gca().set_prop_cycle(None)
cml.A.sel(station=cml.station[0]).plot.line(x='time', label='with WAA');
plt.ylabel('path attenuation\nfrom rain (dB)');
axs[1].set_title('');
axs[1].legend()


In [None]:
print(cml.frequency)
cml['R'] = KR.calc_R_from_A(A=cml.A, L_km=float(cml.length), f_GHz=cml.frequency, pol=cml.polarization)

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(12,5), sharex=True)
cml.trsl.plot.line(x='time', ax=axs[0])
cml.R.plot.line(x='time', ax=axs[1])
axs[1].set_title('')

In [None]:
for cml in tqdm.tqdm(cml_list):
    cml['wet'] = cml.trsl.rolling(time=60, center=True).std(skipna=False) > 0.8
    
    cml['wet_fraction'] = (cml.wet==1).sum() / len(cml.time)
    
    cml['baseline'] = baseline.baseline_constant(
        trsl=cml.trsl, 
        wet=cml.wet, 
        n_average_last_dry=5,
    )
    cml['waa'] = wet_antenna.waa_schleiss_2013(
        rsl=cml.trsl, 
        baseline=cml.baseline, 
        wet=cml.wet, 
        waa_max=2.2, 
        delta_t=1, 
        tau=15,
    )
    cml['A'] = cml.trsl - cml.baseline - cml.waa
    
    # Note that we set A < 0 to 0 here, but it is not strictly required for 
    # the next step, because calc_R_from_A sets all rainfall rates below 
    # a certain threshold (default is 0.1) to 0. Some people might want to
    # keep A as it is to check later if there were negative numbers.
    cml['A'].values[cml.A < 0] = 0
    
    if (cml.polarization=='HV'):
      cml['R'] = KR.calc_R_from_A(
          A=cml.A, L_km=float(cml.length), f_GHz=cml.frequency, pol='V'
    ) 
    else :
      cml['R'] = KR.calc_R_from_A(
          A=cml.A, L_km=float(cml.length), f_GHz=cml.frequency, pol=cml.polarization
    )

In [None]:
for cml in cml_list:
    if cml.wet_fraction > 0.8:
        cml.trsl.plot.line(x='time', figsize=(12,2))
        plt.title(f'cml_id: {cml.link_id.values} wet_fraction: {cml.wet_fraction.values:0.2f}')
        plt.show()

In [None]:
print(cml_list[0].sel(station='Station A').R)
for i in [0, 12, 57]:
  (cml_list[i].sel(station='Station A').R).plot(
      x='time', label='CML', color='C0', figsize=(12, 3)
  )
  

In [None]:
ds_cmls = xr.concat(cml_list, dim='link_id')

In [None]:
print(ds_cmls)

In [None]:
cmls_R_15min = ds_cmls.R.resample(time='15min', label='right').mean().to_dataset()
print(cmls_R_15min)

In [None]:
cmls_R_15min['lat_center'] = (cmls_R_15min.site_a_latitude + cmls_R_15min.site_b_latitude)/2
cmls_R_15min['lon_center'] = (cmls_R_15min.site_a_longitude + cmls_R_15min.site_b_longitude)/2

In [None]:
idw_interpolator = interpolator.IdwKdtreeInterpolator(
  nnear=15,
  p=2,
  exclude_nan=True,
  max_distance = 100000,
)

In [None]:
# print(cmls_R_1h.R.sel(link_id=cmls['link_id'][0]).sum(dim='time').where(ds_cmls.wet_fraction < 0.3))
print(cmls_R_15min.lat_center)

In [None]:
print(cmls_R_15min.R.sel(link_id=cmls['link_id'][0]).sum(dim='time').where(ds_cmls.wet_fraction < 0.3))

In [None]:
def plot_cml_lines(ds_cml, ax, visualise_station=False):
  ax.plot(
    [ds_cmls.site_a_longitude, ds_cmls.site_b_longitude],
    [ds_cmls.site_a_latitude, ds_cmls.site_b_latitude],
    'k',
    linewidth=1,
  )
  if visualise_station:
    ax.scatter(ds_cmls.site_a_longitude, ds_cmls.site_a_latitude, color='red', s=12)
    ax.scatter(ds_cmls.site_b_longitude, ds_cmls.site_b_latitude, color='red', s=12)

xcoords= np.arange(103.605, 104.05, 0.01)
ycoords= np.arange(1.145, 1.51, 0.01)
ycoords=ycoords[::-1]
xgrid, ygrid = np.meshgrid(xcoords, ycoords)

R_grid = idw_interpolator(
  x=cmls_R_15min.lon_center,
  y=cmls_R_15min.lat_center,
  z=cmls_R_15min.R.sel(station='Station A').sum(dim='time').where(ds_cmls.wet_fraction < 0.3),
  xgrid=xgrid,
  ygrid=ygrid,
  resolution=0.01,
)

bounds = np.arange(0, 80, 5.0)
bounds[0] = 1
norm = mpl.colors.BoundaryNorm(boundaries=bounds, ncolors=256, extend='both')

fig, ax = plt.subplots(figsize=(8,6))
pc = plt.pcolormesh(
  idw_interpolator.xgrid,
  idw_interpolator.ygrid,
  R_grid,
  shading='nearest',
  cmap='turbo',
  #norm=norm,
)
cx.add_basemap(ax=ax, crs=4326, source=cx.providers.CartoDB.Voyager)

plot_cml_lines(cmls_R_15min, ax=ax)
fig.colorbar(pc, label="rainfall sum in mm")

In [None]:
print(len(set(cmls_R_15min.site_a_latitude.values)))

In [None]:
from utils.visualisation import visualise_with_basemap, visualise_singapore_outline
fig, ax = plt.subplots()
plot_cml_lines(cmls_R_15min, ax=ax, visualise_station=True)
visualise_singapore_outline(ax=ax)
visualise_with_basemap(ax=ax)

In [None]:
print(len(cmls_R_15min.site_a_latitude))

In [None]:
fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(10, 10))

bounds = [0.1, 0.2, 0.5, 1, 2, 4, 7, 10, 20] 
norm = mpl.colors.BoundaryNorm(boundaries=bounds, ncolors=256, extend='both')
cmap = plt.get_cmap('turbo').copy()
cmap.set_under('w')


for i, axi in enumerate(ax.flat):
    R_grid = idw_interpolator(
        x=cmls_R_15min.lon_center, 
        y=cmls_R_15min.lat_center, 
        z=cmls_R_15min.R.sel(station='Station A').isel(time=i + 400).where(ds_cmls.wet_fraction < 0.3), 
        xgrid=xgrid,
        ygrid=ygrid,
        resolution=0.01,
    )
    pc = axi.pcolormesh(
        idw_interpolator.xgrid, 
        idw_interpolator.ygrid, 
        R_grid, 
        shading='nearest', 
        cmap=cmap,
        norm=norm,
        alpha=0.5
    )
    axi.set_title(str(cmls_R_15min.time.values[i + 400])[:19])
    visualise_with_basemap(axi)
    plot_cml_lines(cmls_R_15min, ax=axi)
    visualise_singapore_outline(ax=axi)

fig.subplots_adjust(right=0.9)
cbar_ax = fig.add_axes([0.95, 0.15, 0.02, 0.7])
cb = fig.colorbar(pc, cax=cbar_ax, label='15min rainfall sum in mm', )

In [None]:
print(idw_interpolator.xgrid)
print(idw_interpolator.ygrid)

In [None]:
raingauge_df = load_raingauge_dataset('rainfall_data.csv', N=0)
station_dict = get_gauge_coordinate_mappings()

print(raingauge_df.shape)

In [None]:
print(xgrid)
print(ygrid)

print

In [None]:
raingauge_df_5mins = raingauge_df.mul(12)
raingauge_df_15mins = raingauge_df_5mins.resample("15min").first()
station_dict = get_gauge_coordinate_mappings()
raingauge_choice = raingauge_df_15mins

global_gauge_values = []
global_cml_predictions = []

# global_gauge_values = np.zeros(shape=[raingauge_choice.shape[0], len(station_dict)])
# global_cml_predictions = np.zeros(shape=[raingauge_choice.shape[0], len(station_dict)])

for idx, (timestamp, row_data) in enumerate(raingauge_df_15mins.iterrows()):

  #We ignore nan values when performing correlation
  stations = [station for station in list(row_data.keys()) if not math.isnan(row_data[station])]
  try:

    R_grid = idw_interpolator(
        x=cmls_R_15min.lon_center, 
        y=cmls_R_15min.lat_center, 
        z=cmls_R_15min.R.sel(station='Station A').sel(time=timestamp.to_datetime64()), 
        xgrid=xgrid,
        ygrid=ygrid,
        resolution=0.01,
    )
  except KeyError:
    continue
  
  row_gauge_values = []
  row_pred_values = []
  for station in stations:
    lat, lon = station_dict[station]
    x_offset = math.floor((lon - xcoords[0]) / 0.01)
    y_offset = math.floor((ycoords[0] - lat) / 0.01)
    predicted_Z = R_grid[y_offset][x_offset] 
    if not np.isnan(predicted_Z):
      row_gauge_values.append(row_data[station])
      row_pred_values.append(predicted_Z)

  global_gauge_values.extend(row_gauge_values)
  global_cml_predictions.extend(row_pred_values)

global_cml_predictions = np.array(global_cml_predictions)
global_gauge_values=np.array(global_gauge_values)


pearson_r_global, pearson_p_global = pearsonr(global_gauge_values.flatten(), global_cml_predictions.flatten())
spearman_r_global, spearman_p_global = spearmanr(global_gauge_values.flatten(), global_cml_predictions.flatten())

#scatter plot of predicted vs actual values
plt.scatter(global_cml_predictions, global_gauge_values)
plot_bound = max(np.max(global_gauge_values.flatten()).astype(int),np.max(global_cml_predictions.flatten()).astype(int))
plt.plot(np.linspace(0,plot_bound,100),
        np.linspace(0,plot_bound,100), linestyle='--')

plt.xlabel("CML Predictions")
plt.ylabel("Gauge Values")


print(f"Pearson correlaton: {pearson_r_global}")
print(f"Spearman correlation: {spearman_r_global}")


In [None]:
fig, ax = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(10, 10))

bounds = [0.1, 0.2, 0.5, 1, 2, 4, 7, 10, 20] 
norm = mpl.colors.BoundaryNorm(boundaries=bounds, ncolors=256, extend='both')
cmap = plt.get_cmap('turbo').copy()
cmap.set_under('w')


for i, axi in enumerate(ax.flat):
    R_grid = idw_interpolator(
        x=cmls_R_15min.lon_center, 
        y=cmls_R_15min.lat_center, 
        z=cmls_R_15min.R.sel(station='Station A').isel(time=i + 400).where(ds_cmls.wet_fraction < 0.3), 
        xgrid=xgrid,
        ygrid=ygrid,
        resolution=0.01,
    )
    pc = axi.pcolormesh(
        idw_interpolator.xgrid, 
        idw_interpolator.ygrid, 
        R_grid, 
        shading='nearest', 
        cmap=cmap,
        norm=norm,
    )
    axi.set_title(str(cmls_R_15min.time.values[i + 400])[:19])
    plot_cml_lines(cmls_R_15min, ax=axi)

fig.subplots_adjust(right=0.9)
cbar_ax = fig.add_axes([0.95, 0.15, 0.02, 0.7])
cb = fig.colorbar(pc, cax=cbar_ax, label='15min rainfall sum in mm', )