# Compute connectivity matrix (one file)

Processes a single file (`05m_00-07days`) via OPeNDAP using stride-75 chunking.
Computes normalized relative dilution factors:

$$F = \frac{\text{obs\_sum}}{N_{\text{hex0\_sum}} \cdot DT_h \cdot n_{\text{months\_years}}} \cdot \frac{wf_{\text{hex0}}}{wf_{\text{hex1}}}$$

Produces `database/data/connectivity.pq`.

In [None]:
import json
import numpy as np
import pandas as pd
import xarray as xr
from pathlib import Path

BASE = (
    "https://data.geomar.de/thredds/dodsC/"
    "20.500.12085/11cc2d8f-4039-49d3-aaab-04ce0fb23190/submission"
)

ESCAPE_HEX = b"(0, 0, 0)"
OUT_DIR = Path("../../database/data")

# Single file for viability test
DEPTH = "05m"
TIME = "00-07days"
DT_H = 168  # hours in window
TIME_LABEL = "00d-07d"
STRIDE = 75  # hex0 rows per OPeNDAP request (~100 MB/chunk)

def url(depth, time):
    name = f"040_connectivity_analysis_{depth}_{time}.nc"
    return f"{BASE}/040_connectivity_analysis_{depth}/{name}"

# Load hex label → int ID mapping built in notebook 02
with open(OUT_DIR / "hex_label_to_id.json") as f:
    label_to_id = json.load(f)  # keys are str like "(-1, -19, 20)"

print(f"Loaded {len(label_to_id)} hex IDs")

In [None]:
ds = xr.open_dataset(url(DEPTH, TIME), engine="netcdf4")
print(ds)

n_hex0 = ds.sizes["hex0"]
n_hex1 = ds.sizes["hex1"]
n_months = ds.sizes["month"]
n_years = ds.sizes["year"]
n_months_years = n_months * n_years
print(f"\nhex0={n_hex0}, hex1={n_hex1}, month={n_months}, year={n_years}")
print(f"n_months_years={n_months_years}, STRIDE={STRIDE}, chunks={int(np.ceil(n_hex0/STRIDE))}")

In [None]:
# Load per-hex metadata needed for normalization
hex0_labels = ds["hex0"].values          # byte strings
hex1_labels = ds["hex1"].values
wf_hex0 = ds["water_fraction_hex0"].values   # shape (hex0,)
wf_hex1 = ds["water_fraction_hex1"].values   # shape (hex1,)

# Find escape hex index in hex1
escape_mask_hex1 = hex1_labels == ESCAPE_HEX
print(f"Escape hex in hex1: {escape_mask_hex1.sum()} occurrence(s)")

# Encode hex labels as str for ID lookup (bytes → str)
hex0_str = np.array([b.decode() for b in hex0_labels])
hex1_str = np.array([b.decode() for b in hex1_labels])

# hex1 IDs for non-escape hexes
valid_hex1_mask = ~escape_mask_hex1
hex1_ids = np.array([label_to_id.get(s, -1) for s in hex1_str])
print(f"hex1 IDs: min={hex1_ids[valid_hex1_mask].min()}, max={hex1_ids[valid_hex1_mask].max()}, missing={( hex1_ids[valid_hex1_mask] == -1).sum()}")

In [None]:
# Process in stride-75 hex0 chunks, accumulate sparse records
import time as time_mod
from tqdm.auto import tqdm

records = []  # list of DataFrames
n_chunks = int(np.ceil(n_hex0 / STRIDE))

for chunk_idx in tqdm(range(n_chunks), desc="hex0 chunks"):
    i0 = chunk_idx * STRIDE
    i1 = min(i0 + STRIDE, n_hex0)
    t0 = time_mod.time()

    # Load obs for this hex0 slice: shape (month, year, STRIDE, hex1)
    obs_chunk = ds["obs"].isel(hex0=slice(i0, i1)).values  # triggers OPeNDAP request

    elapsed = time_mod.time() - t0

    # Sum over month and year axes (0 and 1) → shape (hex0_slice, hex1)
    obs_sum = np.nansum(obs_chunk, axis=(0, 1))  # (slice, hex1)

    # N_hex0_sum: total particle-hours from each source (include escape hex)
    N_hex0_sum = obs_sum.sum(axis=1)  # (slice,)

    for local_i in range(i1 - i0):
        global_i = i0 + local_i
        if N_hex0_sum[local_i] == 0:
            continue
        src_label = hex0_str[global_i]
        src_id = label_to_id.get(src_label, -1)
        if src_id == -1:
            continue
        wf0 = wf_hex0[global_i]
        if wf0 == 0 or np.isnan(wf0):
            continue

        row = obs_sum[local_i]           # shape (hex1,)
        # Only non-escape, non-zero targets
        target_mask = valid_hex1_mask & (row > 0)
        if target_mask.sum() == 0:
            continue

        tgt_ids = hex1_ids[target_mask]
        tgt_obs = row[target_mask]
        tgt_wf1 = wf_hex1[target_mask]

        F = (tgt_obs / (N_hex0_sum[local_i] * DT_H * n_months_years)) * (wf0 / tgt_wf1)

        df = pd.DataFrame({
            "start_id": np.int64(src_id),
            "end_id": tgt_ids.astype(str),
            "weight": F,
        })
        records.append(df)

ds.close()
print(f"\nDone. {len(records)} source hex blocks.")

In [None]:
# Concatenate and tag
conn = pd.concat(records, ignore_index=True)
conn["depth"] = DEPTH
conn["time"] = TIME_LABEL

# Reorder columns to match schema
conn = conn[["start_id", "end_id", "time", "depth", "weight"]]

print(conn.dtypes)
print(f"\nRows: {len(conn):,}")
print(f"weight range: {conn.weight.min():.3e} – {conn.weight.max():.3e}")
print(conn.head())

In [None]:
out_path = OUT_DIR / f"connectivity_{DEPTH}_{TIME_LABEL}.pq"
conn.to_parquet(out_path, index=False)
print(f"Written: {out_path} ({out_path.stat().st_size / 1e6:.1f} MB, {len(conn):,} rows)")