# Compute connectivity matrix (one file)

Processes a single file (`05m_00-07days`) via OPeNDAP using stride-75 chunking.
Computes normalized relative dilution factors:

$$F = \frac{\text{obs\_sum}}{N_{\text{hex0\_sum}} \cdot DT_h \cdot n_{\text{months\_years}}} \cdot \frac{wf_{\text{hex0}}}{wf_{\text{hex1}}}$$

Produces `database/data/connectivity.pq`.

In [1]:
import json
import numpy as np
import pandas as pd
import xarray as xr
from pathlib import Path

BASE = (
    "https://data.geomar.de/thredds/dodsC/"
    "20.500.12085/11cc2d8f-4039-49d3-aaab-04ce0fb23190/submission"
)

ESCAPE_HEX = b"(0, 0, 0)"
OUT_DIR = Path("../../database/data")

# Single file for viability test
DEPTH = "05m"
TIME = "00-07days"
DT_H = 168  # hours in window
TIME_LABEL = "00d-07d"
STRIDE = 75  # hex0 rows per OPeNDAP request (~100 MB/chunk); ignored for local files

# Set to local input_comp/ dir to use compressed local NC files instead of OPeNDAP
LOCAL_INPUT = True
LOCAL_INPUT_DIR = Path("../input_comp")

_TIME_TO_DAYS = {"00-07days": "07", "07-14days": "14", "07-28days": "28"}

def url(depth, time):
    name = f"040_connectivity_analysis_{depth}_{time}.nc"
    return f"{BASE}/040_connectivity_analysis_{depth}/{name}"

def local_path(depth, time):
    days = _TIME_TO_DAYS[time]
    return LOCAL_INPUT_DIR / f"{depth}_ds_conn_{days}.nc"

# Load hex label → int ID mapping built in notebook 02
with open(OUT_DIR / "hex_label_to_id.json") as f:
    label_to_id = json.load(f)  # keys are str like "(-1, -19, 20)"

print(f"Loaded {len(label_to_id)} hex labels")
if LOCAL_INPUT:
    p = local_path(DEPTH, TIME)
    print(f"Using local file: {p} (exists={p.exists()})")
else:
    print(f"Using OPeNDAP: {url(DEPTH, TIME)}")

Loaded 8425 hex labels
Using local file: ../input_comp/05m_ds_conn_07.nc (exists=True)


In [2]:
# Parameters
DEPTH = "15m"
TIME = "07-14days"
DT_H = 168
TIME_LABEL = "07d-14d"
LOCAL_INPUT = True


In [3]:
if LOCAL_INPUT:
    ds = xr.open_dataset(local_path(DEPTH, TIME), engine="netcdf4")
else:
    ds = xr.open_dataset(url(DEPTH, TIME), engine="netcdf4")
print(ds)

n_hex0 = ds.sizes["hex0"]
n_hex1 = ds.sizes["hex1"]
n_months = ds.sizes["month"]
n_years = ds.sizes["year"]
n_months_years = n_months * n_years
print(f"\nhex0={n_hex0}, hex1={n_hex1}, month={n_months}, year={n_years}")
print(f"n_months_years={n_months_years}, STRIDE={STRIDE}, chunks={int(np.ceil(n_hex0/STRIDE))}")

<xarray.Dataset> Size: 22GB
Dimensions:              (hex0: 8223, hex1: 8282, month: 5, year: 4, corner: 7)
Coordinates: (12/13)
  * hex0                 (hex0) <U47 2MB '(-1, -2, 3)' ... '(9, 9, -18)'
    hex_label            (hex0) <U14 460kB ...
    lat_hex0             (hex0) float64 66kB ...
    lat_hex0_corners     (corner, hex0) float64 460kB ...
    lon_hex0             (hex0) float64 66kB ...
    lon_hex0_corners     (corner, hex0) float64 460kB ...
    ...                   ...
    lat_hex1             (hex1) float64 66kB ...
    lat_hex1_corners     (corner, hex1) float64 464kB ...
    lon_hex1             (hex1) float64 66kB ...
    lon_hex1_corners     (corner, hex1) float64 464kB ...
  * month                (month) float64 40B 0.0 1.0 2.0 3.0 4.0
  * year                 (year) float64 32B 0.0 1.0 2.0 3.0
Dimensions without coordinates: corner
Data variables: (12/28)
    aqc_count_hex0       (hex0) float64 66kB ...
    aqc_count_hex1       (hex1) float64 66kB ...
    dep

In [4]:
# Load per-hex metadata needed for normalization
hex0_labels = ds["hex0"].values          # byte strings (OPeNDAP) or str (local netCDF4)
hex1_labels = ds["hex1"].values
wf_hex0 = ds["water_fraction_hex0"].values   # shape (hex0,)
wf_hex1 = ds["water_fraction_hex1"].values   # shape (hex1,)

def _to_str(v):
    return v.decode() if isinstance(v, bytes) else str(v)

# Find escape hex index in hex1
escape_mask_hex1 = np.array([_to_str(v) == _to_str(ESCAPE_HEX) for v in hex1_labels])
print(f"Escape hex in hex1: {escape_mask_hex1.sum()} occurrence(s)")

# Encode hex labels as str for ID lookup
hex0_str = np.array([_to_str(b) for b in hex0_labels])
hex1_str = np.array([_to_str(b) for b in hex1_labels])

# hex1 IDs for non-escape hexes
valid_hex1_mask = ~escape_mask_hex1
hex1_ids = np.array([label_to_id.get(s, -1) for s in hex1_str])
print(f"hex1 IDs: min={hex1_ids[valid_hex1_mask].min()}, max={hex1_ids[valid_hex1_mask].max()}, missing={(hex1_ids[valid_hex1_mask] == -1).sum()}")

Escape hex in hex1: 0 occurrence(s)
hex1 IDs: min=1, max=8424, missing=0


In [5]:
import time as time_mod

records = []  # list of DataFrames

if LOCAL_INPUT:
    # Load entire obs array at once — fast from local disk
    print("Loading obs from local file ...")
    t0 = time_mod.time()
    obs_all = ds["obs"].values  # shape (month, year, hex0, hex1)
    print(f"  loaded in {time_mod.time()-t0:.1f}s, shape={obs_all.shape}")
    obs_sum_all = np.nansum(obs_all, axis=(0, 1))  # (hex0, hex1)
    N_hex0_sum_all = obs_sum_all.sum(axis=1)       # (hex0,)

    for global_i in range(n_hex0):
        if global_i % 1000 == 0:
            print(f"  {global_i}/{n_hex0} hex0 ...")
        if N_hex0_sum_all[global_i] == 0:
            continue
        src_label = hex0_str[global_i]
        src_id = label_to_id.get(src_label, -1)
        if src_id == -1:
            continue
        wf0 = wf_hex0[global_i]
        if wf0 == 0 or np.isnan(wf0):
            continue
        row = obs_sum_all[global_i]
        target_mask = valid_hex1_mask & (row > 0)
        if target_mask.sum() == 0:
            continue
        tgt_ids = hex1_ids[target_mask]
        tgt_obs = row[target_mask]
        tgt_wf1 = wf_hex1[target_mask]
        F_raw = (tgt_obs / (N_hex0_sum_all[global_i] * DT_H * n_months_years)) * (wf0 / tgt_wf1)
        exp = np.floor(np.log10(F_raw))
        F = np.round(F_raw * 10 ** (5 - exp)) / 10 ** (5 - exp)
        records.append(pd.DataFrame({
            "start_id": np.int64(src_id),
            "end_id": tgt_ids.astype(str),
            "weight": F,
        }))
else:
    # OPeNDAP: process in stride-75 hex0 chunks
    n_chunks = int(np.ceil(n_hex0 / STRIDE))
    for chunk_idx in range(n_chunks):
        i0 = chunk_idx * STRIDE
        i1 = min(i0 + STRIDE, n_hex0)
        if chunk_idx % 10 == 0:
            print(f"  chunk {chunk_idx}/{n_chunks} ...")
        obs_chunk = ds["obs"].isel(hex0=slice(i0, i1)).values
        obs_sum = np.nansum(obs_chunk, axis=(0, 1))
        N_hex0_sum = obs_sum.sum(axis=1)
        for local_i in range(i1 - i0):
            global_i = i0 + local_i
            if N_hex0_sum[local_i] == 0:
                continue
            src_label = hex0_str[global_i]
            src_id = label_to_id.get(src_label, -1)
            if src_id == -1:
                continue
            wf0 = wf_hex0[global_i]
            if wf0 == 0 or np.isnan(wf0):
                continue
            row = obs_sum[local_i]
            target_mask = valid_hex1_mask & (row > 0)
            if target_mask.sum() == 0:
                continue
            tgt_ids = hex1_ids[target_mask]
            tgt_obs = row[target_mask]
            tgt_wf1 = wf_hex1[target_mask]
            F_raw = (tgt_obs / (N_hex0_sum[local_i] * DT_H * n_months_years)) * (wf0 / tgt_wf1)
            exp = np.floor(np.log10(F_raw))
            F = np.round(F_raw * 10 ** (5 - exp)) / 10 ** (5 - exp)
            records.append(pd.DataFrame({
                "start_id": np.int64(src_id),
                "end_id": tgt_ids.astype(str),
                "weight": F,
            }))

ds.close()
print(f"\nDone. {len(records)} source hex blocks.")

Loading obs from local file ...


  loaded in 11.3s, shape=(5, 4, 8223, 8282)


  0/8223 hex0 ...
  1000/8223 hex0 ...
  2000/8223 hex0 ...


  3000/8223 hex0 ...
  4000/8223 hex0 ...


  5000/8223 hex0 ...
  6000/8223 hex0 ...
  7000/8223 hex0 ...


  8000/8223 hex0 ...

Done. 8204 source hex blocks.


In [6]:
# Concatenate and tag
conn = pd.concat(records, ignore_index=True)
conn["depth"] = DEPTH
conn["time"] = TIME_LABEL

# Reorder columns to match schema
conn = conn[["start_id", "end_id", "time", "depth", "weight"]]

print(conn.dtypes)
print(f"\nRows: {len(conn):,}")
print(f"weight range: {conn.weight.min():.3e} – {conn.weight.max():.3e}")
print(conn.head())

start_id      int64
end_id       object
time         object
depth        object
weight      float64
dtype: object

Rows: 1,805,871


weight range: 2.610e-10 – 5.480e-04
   start_id end_id     time depth    weight
0         2      2  07d-14d   15m  0.000030
1         2   1028  07d-14d   15m  0.000126
2         2   1030  07d-14d   15m  0.000044
3         2   2022  07d-14d   15m  0.000137
4         2   3174  07d-14d   15m  0.000034


In [7]:
out_path = OUT_DIR / f"connectivity_{DEPTH}_{TIME_LABEL}.pq"
conn.to_parquet(out_path, index=False)
print(f"Written: {out_path} ({out_path.stat().st_size / 1e6:.1f} MB, {len(conn):,} rows)")

Written: ../../database/data/connectivity_15m_07d-14d.pq (16.6 MB, 1,805,871 rows)
