In [1]:
!pip -q install -U pip
!pip -q install the-well torch numpy matplotlib
# Optional (only needed if you later switch to local HDF5 validation/repair)
!pip -q install h5py


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.5/1.8 MB[0m [31m14.4 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.8/1.8 MB[0m [31m27.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.10.0 which is incompatible.[0m[31m
[0m

In [2]:
import math
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from the_well.data import WellDataset

device = "cuda" if torch.cuda.is_available() else "cpu"
print("torch:", torch.__version__, "| device:", device)


torch: 2.9.0+cpu | device: cpu


In [None]:
DATASET = "turbulent_radiative_layer_2D"
SPLIT   = "train"

T_IN  = 32
T_OUT = 32
WINDOW_INDEX = 0

ds = WellDataset(
    well_base_path="hf://datasets/polymathic-ai/",
    well_dataset_name=DATASET,
    well_split_name=SPLIT,
    n_steps_input=T_IN,
    n_steps_output=T_OUT,
    use_normalization=False,
)

item = ds[WINDOW_INDEX]
print("input_fields:", item["input_fields"].shape)
print("output_fields:", item["output_fields"].shape)
print("metadata.field_names:", ds.metadata.field_names)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
def flatten_field_names(field_names_dict):
    out = []
    for _, group in field_names_dict.items():
        out.extend(group)
    return out

field_names = flatten_field_names(ds.metadata.field_names)
print("Flattened field names:", field_names)

def find_velocity_indices(names):
    pairs = [
        ("velocity_x", "velocity_y"),
        ("u", "v"),
        ("vel_x", "vel_y"),
        ("vx", "vy"),
    ]
    for a, b in pairs:
        if a in names and b in names:
            return names.index(a), names.index(b)

    vel = [i for i, n in enumerate(names) if "velocity" in n.lower()]
    if len(vel) >= 2:
        return vel[0], vel[1]

    raise RuntimeError("No velocity channels found.\n" + "\n".join(names))

ix, iy = find_velocity_indices(field_names)
print("Using velocity channels:", field_names[ix], field_names[iy])


In [None]:
fields = torch.cat([item["input_fields"], item["output_fields"]], dim=0).to(torch.float32)  # [T,H,W,F]
vel = fields[..., [ix, iy]].to(device)  # [T,H,W,2]

T, H, W, _ = vel.shape
print("vel shape:", vel.shape)


In [None]:
def finite_diff_grad(A):
    A_left  = torch.roll(A, +1, 1)
    A_right = torch.roll(A, -1, 1)
    dA_dx = 0.5 * (A_right - A_left)

    A_up   = torch.roll(A, +1, 0)
    A_down = torch.roll(A, -1, 0)
    dA_dy = 0.5 * (A_down - A_up)
    return dA_dx, dA_dy

def make_base_grid(H, W, device):
    xs = torch.arange(W, device=device).float()
    ys = torch.arange(H, device=device).float()
    yy, xx = torch.meshgrid(ys, xs, indexing="ij")
    return torch.stack([xx, yy], dim=-1)  # [H,W,2]

def pixels_to_gridnorm(xy, H, W):
    x = xy[..., 0]
    y = xy[..., 1]
    x_norm = 2.0 * x / max(W - 1, 1) - 1.0
    y_norm = 2.0 * y / max(H - 1, 1) - 1.0
    return torch.stack([x_norm, y_norm], dim=-1)

def advect_scalar(C, u, v, dt, base_xy, periodic=True):
    # Semi-Lagrangian: C_{t+1}(x) = C_t(x - v dt)
    H, W = u.shape
    back_xy = base_xy.clone()
    back_xy[..., 0] = back_xy[..., 0] - u * dt
    back_xy[..., 1] = back_xy[..., 1] - v * dt

    if periodic:
        back_xy[..., 0] = torch.remainder(back_xy[..., 0], W)
        back_xy[..., 1] = torch.remainder(back_xy[..., 1], H)

    grid = pixels_to_gridnorm(back_xy, H, W).unsqueeze(0)  # [1,H,W,2]
    return F.grid_sample(C, grid, mode="bilinear", padding_mode="border", align_corners=True)

def make_milk_blob(H, W, device, sigma=10.0):
    xs = torch.arange(W, device=device).float()
    ys = torch.arange(H, device=device).float()
    yy, xx = torch.meshgrid(ys, xs, indexing="ij")
    cx, cy = (W - 1) / 2.0, (H - 1) / 2.0
    r2 = (xx - cx) ** 2 + (yy - cy) ** 2
    C = torch.exp(-r2 / (2 * sigma * sigma))
    C = C / C.max().clamp_min(1e-12)
    return C.unsqueeze(0).unsqueeze(0)  # [1,1,H,W]

def fidelity_corr(Ca, Cb):
    a = Ca.flatten()
    b = Cb.flatten()
    num = torch.dot(a, b)
    den = torch.linalg.norm(a) * torch.linalg.norm(b)
    return float((num / den.clamp_min(1e-12)).detach().cpu().item())

def build_A_theta(mode, H, W, device, alpha=1.0, kx=3, ky=0, field_scalar=None):
    if mode == "ramp":
        x = torch.linspace(0, 1, W, device=device)
        return alpha * x.unsqueeze(0).repeat(H, 1)
    if mode == "sine":
        xs = torch.linspace(0, 1, W, device=device)
        ys = torch.linspace(0, 1, H, device=device)
        yy, xx = torch.meshgrid(ys, xs, indexing="ij")
        return torch.sin(2 * math.pi * (kx * xx + ky * yy)) * alpha
    if mode == "field":
        if field_scalar is None:
            raise ValueError("mode='field' needs field_scalar")
        f = field_scalar - field_scalar.min()
        f = f / (f.max() - f.min()).clamp_min(1e-12)
        return alpha * f
    raise ValueError("Unknown Aθ mode")


In [None]:
def run_echo(vel_seq, A_theta, N_wind, kappa, dt, reverse_mode="hidden", vel_scale=1.0, C0=None, noise_vel=0.0):
    device = vel_seq.device
    T, H, W, _ = vel_seq.shape
    T_half = T // 2

    base_xy = make_base_grid(H, W, device=device)
    dA_dx, dA_dy = finite_diff_grad(A_theta)

    # total winding N over forward half
    theta_dot_fwd = (2.0 * math.pi * N_wind) / (max(1, T_half - 1) * dt)

    Cc = C0.clone()
    Cx = C0.clone()

    # forward
    for t in range(T_half):
        u = vel_seq[t, :, :, 0] * vel_scale
        v = vel_seq[t, :, :, 1] * vel_scale
        if noise_vel > 0:
            u = u + torch.randn_like(u) * noise_vel
            v = v + torch.randn_like(v) * noise_vel

        Cc = advect_scalar(Cc, u, v, dt, base_xy, periodic=True)

        u_x = u + kappa * dA_dx * theta_dot_fwd
        v_x = v + kappa * dA_dy * theta_dot_fwd
        Cx = advect_scalar(Cx, u_x, v_x, dt, base_xy, periodic=True)

    # backward: reverse base flow
    for t in range(T_half - 1, -1, -1):
        u = -vel_seq[t, :, :, 0] * vel_scale
        v = -vel_seq[t, :, :, 1] * vel_scale
        if noise_vel > 0:
            u = u + torch.randn_like(u) * noise_vel
            v = v + torch.randn_like(v) * noise_vel

        Cc = advect_scalar(Cc, u, v, dt, base_xy, periodic=True)

        if reverse_mode == "perfect":
            theta_dot_bwd = -theta_dot_fwd
        elif reverse_mode == "hidden":
            theta_dot_bwd = 0.0
        else:
            raise ValueError("reverse_mode must be perfect|hidden")

        u_x = u + kappa * dA_dx * theta_dot_bwd
        v_x = v + kappa * dA_dy * theta_dot_bwd
        Cx = advect_scalar(Cx, u_x, v_x, dt, base_xy, periodic=True)

    F_control = fidelity_corr(Cc, C0)
    F_xtheta  = fidelity_corr(Cx, C0)
    return F_control, F_xtheta, Cc, Cx


In [None]:
# dt: if dataset gives time grids, use them; else 1.0
dt = 1.0
t_in = item.get("input_time_grid", None)
t_out = item.get("output_time_grid", None)
if t_in is not None and t_out is not None:
    t_full = torch.cat([t_in, t_out], dim=0).to(torch.float32)
    if len(t_full) >= 2:
        dt = float((t_full[1] - t_full[0]).abs().cpu().item()) or 1.0
print("dt =", dt)

# Milk blob
C0 = make_milk_blob(H, W, device=device, sigma=10.0)

# Choose Aθ
A_MODE = "ramp"   # ramp | sine | field
ALPHA  = 10.0
A_KX, A_KY = 3, 0

A_theta = build_A_theta(A_MODE, H, W, device=device, alpha=ALPHA, kx=A_KX, ky=A_KY)

# X-θ strength
KAPPA = 30.0

# Reverse mode: 'hidden' shows controlled irreversibility
REVERSE_MODE = "hidden"  # hidden | perfect

# Winding list
N_LIST = [-8, -4, -2, -1, 0, 1, 2, 4, 8]

# Flow scaling (if you want stronger mixing visually)
VEL_SCALE = 1.0

# Optional noise on velocity (simulating imperfect control)
NOISE_VEL = 0.0

print("Config:", dict(A_MODE=A_MODE, ALPHA=ALPHA, KAPPA=KAPPA, REVERSE_MODE=REVERSE_MODE, VEL_SCALE=VEL_SCALE))


In [None]:
F_ctrl = []
F_xth  = []
last = None

for N in N_LIST:
    fc, fx, Cc, Cx = run_echo(
        vel_seq=vel,
        A_theta=A_theta,
        N_wind=N,
        kappa=KAPPA,
        dt=dt,
        reverse_mode=REVERSE_MODE,
        vel_scale=VEL_SCALE,
        C0=C0,
        noise_vel=NOISE_VEL,
    )
    F_ctrl.append(fc)
    F_xth.append(fx)
    last = (N, Cc.detach().cpu(), Cx.detach().cpu())
    print(f"N={N:>3}  F_control={fc:.6f}  F_Xθ={fx:.6f}")

Ns = np.array(N_LIST, dtype=float)
F_ctrl = np.array(F_ctrl, dtype=float)
F_xth  = np.array(F_xth, dtype=float)

plt.figure()
plt.plot(Ns, F_ctrl, marker="o", label="Control (no drift)")
plt.plot(Ns, F_xth,  marker="o", label=f"X-θ ({REVERSE_MODE})")
plt.xlabel("Winding number N")
plt.ylabel("Loschmidt echo fidelity (corr)")
plt.title(f"Milk-in-Coffee Echo on {DATASET} (stream)")
plt.grid(True)
plt.legend()
plt.show()


In [None]:
N_last, Cc_last, Cx_last = last

plt.figure(figsize=(12,4))
plt.subplot(1,3,1)
plt.imshow(C0.detach().cpu()[0,0], origin="upper")
plt.title("Initial Milk C0")
plt.colorbar(fraction=0.046, pad=0.04)

plt.subplot(1,3,2)
plt.imshow(Cc_last[0,0], origin="upper")
plt.title(f"Final Control (N={N_last})")
plt.colorbar(fraction=0.046, pad=0.04)

plt.subplot(1,3,3)
plt.imshow(Cx_last[0,0], origin="upper")
plt.title(f"Final X-θ (N={N_last})")
plt.colorbar(fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()


In [None]:
def run_mode(reverse_mode):
    Fx = []
    for N in N_LIST:
        fc, fx, *_ = run_echo(
            vel_seq=vel, A_theta=A_theta, N_wind=N, kappa=KAPPA, dt=dt,
            reverse_mode=reverse_mode, vel_scale=VEL_SCALE, C0=C0, noise_vel=NOISE_VEL
        )
        Fx.append(fx)
    return np.array(Fx, dtype=float)

Fx_hidden  = run_mode("hidden")
Fx_perfect = run_mode("perfect")

plt.figure()
plt.plot(Ns, Fx_hidden,  marker="o", label="X-θ hidden (θ not reversed)")
plt.plot(Ns, Fx_perfect, marker="o", label="X-θ perfect (θ reversed)")
plt.xlabel("Winding number N")
plt.ylabel("Loschmidt echo fidelity (corr)")
plt.title("X-θ control test: hidden vs perfect reversal")
plt.grid(True)
plt.legend()
plt.show()
