# Multi-Scale Multi-Modal Data: Fine-Grained Correspondences

This notebook builds on the **multimodal synthetic data** framework (see `synthetic_multiscale_multimodal_data.ipynb` in the multiscale repo):

- **S1**: High-resolution driving signal (multiscale in frequency: multiple k, ω components).
- **S2**: Lower-resolution coupled signal, driven by S1 via Kuramoto-like phase coupling (strength K).

We extend it to **explicit multi-scale data** with **fine-grained correspondence** maps so that:
1. Multiple resolution levels (pyramid) are defined.
2. For each coarse cell we know exactly which fine-grid indices it corresponds to.
3. Data can be used to learn correspondences (and/or predict K).

See `feasibility_multiscale_multimodal.md` (same directory) for the design.

## 1. Imports and core generator (from multimodal framework)

In [None]:
import numpy as np
from scipy.integrate import solve_ivp
from tqdm import tqdm


def generate_base_signal(x, t, params):
    """Generates the foundational multiscale signal S1(x, t)."""
    X, T = np.meshgrid(x, t)
    S1 = np.zeros_like(X, dtype=float)
    phases = []
    for p in params:
        component_phase = p['k'] * X - p['omega'] * T + p['phi']
        S1 += p['A'] * np.sin(component_phase)
        phases.append(component_phase)
    return S1, np.array(phases)


def kuramoto_ode(t_point, theta, t_grid, base_phases_interp, K, omega_prime):
    """Kuramoto-like ODE: dθ/dt = ω' + K * mean_field(sin(ψ - θ))."""
    psi = np.array([interp_func(t_point) for interp_func in base_phases_interp])
    mean_field = np.mean(np.sin(psi - theta[:, np.newaxis]), axis=1)
    return omega_prime + K * mean_field


def generate_coupled_signal_two_resolutions(x, t, x2, t2, s1_params, s2_params, K):
    """
    Generates S1 on fine grid (x, t) and S2 on coarse grid (x2, t2).
    S2 is driven by S1 at the nearest fine spatial point; ODE evaluated at t2.
    Returns S1, S2, and base/coupled phase slices for viz.
    """
    S1, base_phases = generate_base_signal(x, t, s1_params)
    S2 = np.zeros((len(t2), len(x2)), dtype=float)
    omega_prime = np.array([p['omega'] for p in s2_params])
    initial_phases_s2 = np.array([p['phi'] for p in s2_params])
    amplitudes_s2 = np.array([p['A'] for p in s2_params])

    for i in range(len(x2)):
        s1_x_idx = np.argmin(np.abs(x - x2[i]))
        base_phases_at_x = base_phases[:, :, s1_x_idx]
        base_phases_interp = [
            lambda t_eval, phase_series=series: np.interp(t_eval, t, phase_series)
            for series in base_phases_at_x
        ]
        sol = solve_ivp(
            kuramoto_ode,
            [t[0], t[-1]],
            initial_phases_s2,
            t_eval=t2,
            args=(t, base_phases_interp, K, omega_prime),
            method='RK45',
        )
        coupled_phases = sol.y
        S2[:, i] = np.sum(amplitudes_s2[:, np.newaxis] * np.sin(coupled_phases), axis=0)

    x_viz = len(x) // 2
    return S1, S2, base_phases[:, :, x_viz], sol.y


## 2. Multi-scale pyramid and fine-grained correspondence

Define several resolution levels and **explicit maps**: for each coarse index, which fine indices (space and time) it corresponds to.

In [None]:
def build_pyramid_grids(x_fine, t_fine, ratios):
    """
    Build coarse grids by downsampling. ratios[l] = (r_x, r_t) means level l
    has r_x fewer points in space and r_t fewer in time than the fine grid.
    Level 0 = fine; level 1, 2, ... = coarser.
    """
    grids = [(x_fine, t_fine)]
    n_x, n_t = len(x_fine), len(t_fine)
    for (r_x, r_t) in ratios:
        n_x = n_x // r_x
        n_t = n_t // r_t
        if n_x < 1 or n_t < 1:
            break
        x_c = np.linspace(x_fine[0], x_fine[-1], n_x)
        t_c = np.linspace(t_fine[0], t_fine[-1], n_t)
        grids.append((x_c, t_c))
    return grids


def build_correspondence_fine_to_coarse(n_x_fine, n_t_fine, ratios):
    """
    For each level, define which fine indices belong to which coarse index.
    Returns a list of length num_levels. For level L (0 = fine):
    - coarse_to_fine_x[L][i_c] = slice of fine x-indices for coarse index i_c
    - coarse_to_fine_t[L][k_c] = slice of fine t-indices for coarse time k_c
    So at level L we have (n_x_L, n_t_L) coarse points.
    """
    n_x, n_t = n_x_fine, n_t_fine
    coarse_to_fine_x = []  # per level: list of slices (or lists) for each coarse x
    coarse_to_fine_t = []  # per level: list of slices for each coarse t
    # Level 0: fine grid, each "coarse" index is just one fine index
    coarse_to_fine_x.append([slice(i, i + 1) for i in range(n_x)])
    coarse_to_fine_t.append([slice(k, k + 1) for k in range(n_t)])

    for (r_x, r_t) in ratios:
        n_x = n_x // r_x
        n_t = n_t // r_t
        if n_x < 1 or n_t < 1:
            break
        # Coarse index i_c corresponds to fine x indices [i_c*r_x : (i_c+1)*r_x]
        c2f_x = [slice(i_c * r_x, min((i_c + 1) * r_x, n_x_fine)) for i_c in range(n_x)]
        c2f_t = [slice(k_c * r_t, min((k_c + 1) * r_t, n_t_fine)) for k_c in range(n_t)]
        coarse_to_fine_x.append(c2f_x)
        coarse_to_fine_t.append(c2f_t)

    return coarse_to_fine_x, coarse_to_fine_t


def build_fine_to_coarse_maps(n_x_fine, n_t_fine, ratios):
    """
    For each fine (i, k) return the coarse (i_c, k_c) at each level.
    fine_to_coarse_x[L][i_f] = i_c,  fine_to_coarse_t[L][k_f] = k_c.
    """
    n_x, n_t = n_x_fine, n_t_fine
    f2c_x_per_level = [[i for i in range(n_x)]]  # level 0: identity
    f2c_t_per_level = [[k for k in range(n_t)]]

    for (r_x, r_t) in ratios:
        n_x_c = n_x // r_x
        n_t_c = n_t // r_t
        if n_x_c < 1 or n_t_c < 1:
            break
        f2c_x = [i_f // r_x for i_f in range(n_x_fine)]
        f2c_t = [k_f // r_t for k_f in range(n_t_fine)]
        f2c_x_per_level.append(f2c_x)
        f2c_t_per_level.append(f2c_t)
        n_x, n_t = n_x_c, n_t_c

    return f2c_x_per_level, f2c_t_per_level


## 3. Generate S1 (fine) and S2 at multiple coarse levels

S1 is generated once on the fine grid. For each coarse level we generate S2 using the same Kuramoto coupling, with each coarse spatial point driven by S1 at the **center** of the corresponding fine block (or nearest fine point).

In [None]:
def generate_multiscale_dataset(x_fine, t_fine, ratios, s1_params, s2_params, K):
    """
    Generate S1 on fine grid and S2 at each coarse level. Also build correspondence.
    ratios: list of (r_x, r_t) per coarsening step.
    Returns:
      S1: (n_t_fine, n_x_fine)
      S2_per_level: list of arrays, S2_per_level[L] has shape (n_t_L, n_x_L)
      grids: list of (x, t) per level
      correspondence: dict with coarse_to_fine_x/t and fine_to_coarse_x/t per level
    """
    n_x_fine, n_t_fine = len(x_fine), len(t_fine)
    grids = build_pyramid_grids(x_fine, t_fine, ratios)
    c2f_x, c2f_t = build_correspondence_fine_to_coarse(n_x_fine, n_t_fine, ratios)
    f2c_x, f2c_t = build_fine_to_coarse_maps(n_x_fine, n_t_fine, ratios)

    # S1 on fine grid only
    S1, base_phases = generate_base_signal(x_fine, t_fine, s1_params)
    omega_prime = np.array([p['omega'] for p in s2_params])
    initial_phases_s2 = np.array([p['phi'] for p in s2_params])
    amplitudes_s2 = np.array([p['A'] for p in s2_params])

    S2_per_level = []
    # Level 0: S2 could be defined on same grid as S1 (same res); here we skip and start at level 1
    for level in range(1, len(grids)):
        x_c, t_c = grids[level]
        S2 = np.zeros((len(t_c), len(x_c)), dtype=float)
        for i_c in range(len(x_c)):
            # Use center of fine block to get S1 phases
            sl_x = c2f_x[level][i_c]
            fine_x_idx = (sl_x.start + sl_x.stop - 1) // 2 if sl_x.stop > sl_x.start else sl_x.start
            base_phases_at_x = base_phases[:, :, fine_x_idx]
            base_phases_interp = [
                lambda t_eval, phase_series=series: np.interp(t_eval, t_fine, phase_series)
                for series in base_phases_at_x
            ]
            sol = solve_ivp(
                kuramoto_ode,
                [t_fine[0], t_fine[-1]],
                initial_phases_s2,
                t_eval=t_c,
                args=(t_fine, base_phases_interp, K, omega_prime),
                method='RK45',
            )
            S2[:, i_c] = np.sum(amplitudes_s2[:, np.newaxis] * np.sin(sol.y), axis=0)
        S2_per_level.append(S2)

    correspondence = {
        'coarse_to_fine_x': c2f_x,
        'coarse_to_fine_t': c2f_t,
        'fine_to_coarse_x': f2c_x,
        'fine_to_coarse_t': f2c_t,
    }
    return S1, S2_per_level, grids, correspondence


## 4. Run and validate

In [None]:
# Small grids for quick run
x_fine = np.linspace(0, 10, 64)
t_fine = np.linspace(0, 50, 128)
ratios = [(2, 2), (2, 2)]  # level 1: 32x64, level 2: 16x32

s1_params = [
    {'A': 1.0, 'k': 1.0, 'omega': 1.5, 'phi': 0},
    {'A': 0.5, 'k': 5.0, 'omega': 3.0, 'phi': np.pi / 2},
    {'A': 0.2, 'k': 15.0, 'omega': 10.0, 'phi': np.pi},
]
s2_params = [
    {'A': 1.0, 'omega': 1.4, 'phi': 0.1},
    {'A': 0.8, 'omega': 3.2, 'phi': 0.2},
]
K = 2.5

S1, S2_per_level, grids, correspondence = generate_multiscale_dataset(
    x_fine, t_fine, ratios, s1_params, s2_params, K
)

print("S1 shape:", S1.shape)
for L, S2 in enumerate(S2_per_level):
    print(f"S2 level {L+1} shape: {S2.shape}")
print("Grids:", [g[0].shape[0] for g in grids], "x", [g[1].shape[0] for g in grids])
print("Correspondence: coarse_to_fine_x levels:", len(correspondence['coarse_to_fine_x']))
print("Example: at level 2, coarse x index 0 -> fine x indices", correspondence['coarse_to_fine_x'][2][0])
print("Example: at level 2, coarse t index 0 -> fine t indices", correspondence['coarse_to_fine_t'][2][0])


## 5. Save dataset with correspondence (for learning)

We save S1, S2 per level, K, grids, and the correspondence maps so that a model can use fine-grained alignment.

In [None]:
def save_multiscale_dataset(filepath, S1, S2_per_level, grids, correspondence, K):
    """Save multi-scale dataset. Correspondence: coarse_to_fine as arrays (start, stop) per level."""
    # Per-level arrays for coarse_to_fine: (n_coarse, 2) with [start, stop]
    c2f_x_arrs = [np.array([[s.start, s.stop] for s in correspondence['coarse_to_fine_x'][l]]) for l in range(len(correspondence['coarse_to_fine_x']))]
    c2f_t_arrs = [np.array([[s.start, s.stop] for s in correspondence['coarse_to_fine_t'][l]]) for l in range(len(correspondence['coarse_to_fine_t']))]
    # Build a flat list for npz (variable-length arrays stored with allow_pickle or as object)
    out = {
        'S1': S1,
        'x_fine': grids[0][0],
        't_fine': grids[0][1],
        'K': np.array(K),
        'num_levels': len(S2_per_level),
    }
    for i, S2 in enumerate(S2_per_level):
        out[f'S2_level{i}'] = S2
    for i in range(len(c2f_x_arrs)):
        out[f'c2f_x_L{i}'] = c2f_x_arrs[i]
        out[f'c2f_t_L{i}'] = c2f_t_arrs[i]
    np.savez_compressed(filepath, **out)
    print(f"Saved to {filepath}")


save_multiscale_dataset(
    "multiscale_multimodal_sample.npz",
    S1, S2_per_level, grids, correspondence, K
)


## 6. Quick visualization

Plot S1 (fine) and S2 at the coarsest level to confirm they are aligned in space/time.

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(2, 1, figsize=(10, 6))
x0, t0 = grids[0][0], grids[0][1]
axs[0].imshow(S1, aspect='auto', origin='lower', extent=[x0[0], x0[-1], t0[0], t0[-1]], cmap='viridis')
axs[0].set_title('S1 (fine)')
axs[0].set_xlabel('x')
axs[0].set_ylabel('t')

if S2_per_level:
    x_last, t_last = grids[-1][0], grids[-1][1]
    axs[1].imshow(S2_per_level[-1], aspect='auto', origin='lower', extent=[x_last[0], x_last[-1], t_last[0], t_last[-1]], cmap='plasma')
    axs[1].set_title('S2 (coarsest level)')
    axs[1].set_xlabel('x')
    axs[1].set_ylabel('t')
plt.tight_layout()
plt.show()
print("Feasibility check: multi-scale data with explicit correspondence generated successfully.")
