In [1]:
# ============================================================
# Generate visual-argument images for neural forecasting
# Creates:
#   - nn_front_visual_argument_4panel.png
#   - nn_front_template_motion.png
# Output folder:
#   ../../images/img16/
# ============================================================

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

OUTDIR = Path("../../images/img16")
OUTDIR.mkdir(parents=True, exist_ok=True)

def make_field(nx=220, ny=140, x0=60, y0=70, angle_deg=20, width=10, noise=0.08, seed=0):
    """
    Synthetic 'front-like' field:
    sharp gradient band + smooth background + small noise
    """
    rng = np.random.default_rng(seed)
    x = np.linspace(-1, 1, nx)[None, :]
    y = np.linspace(-1, 1, ny)[:, None]

    # smooth background
    bg = 0.4*np.sin(2.0*x) + 0.25*np.cos(1.5*y) + 0.15*np.sin(1.2*(x+y))

    # rotated front: distance to line
    X = np.arange(nx)[None, :]
    Y = np.arange(ny)[:, None]
    a = np.deg2rad(angle_deg)
    # line through (x0,y0) with angle
    d = ( (X-x0)*np.sin(a) - (Y-y0)*np.cos(a) )
    front = np.tanh(d/width)  # sharp transition

    field = bg + 0.9*front + noise*rng.normal(size=(ny, nx))
    return field

def make_template(size=19, angle_deg=20):
    """
    Simple 'front template': oriented edge/gradient kernel
    """
    s = size
    ax = np.linspace(-(s//2), s//2, s)
    X, Y = np.meshgrid(ax, ax)
    a = np.deg2rad(angle_deg)
    # signed distance to line through center
    d = X*np.sin(a) - Y*np.cos(a)
    tmpl = np.tanh(d/2.0)
    # emphasize edge
    tmpl = tmpl - tmpl.mean()
    tmpl = tmpl / (np.sqrt((tmpl**2).sum()) + 1e-12)
    return tmpl

def corr2d(field, tmpl):
    """
    Simple valid correlation (no scipy): sliding dot product
    """
    ny, nx = field.shape
    s = tmpl.shape[0]
    hy = ny - s + 1
    hx = nx - s + 1
    out = np.zeros((hy, hx), dtype=float)
    # normalize field locally? keep simple (works visually)
    for j in range(hy):
        patch = field[j:j+s, :]
        for i in range(hx):
            P = patch[:, i:i+s]
            out[j, i] = np.sum(P * tmpl)
    # relu-like
    out = np.maximum(out, 0)
    return out

def shift_map(M, dx, dy):
    """
    shift heatmap by integer dx,dy with padding zeros
    """
    out = np.zeros_like(M)
    ny, nx = M.shape
    ys = slice(max(0, dy), min(ny, ny+dy))
    xs = slice(max(0, dx), min(nx, nx+dx))
    ys0 = slice(max(0, -dy), min(ny, ny-dy))
    xs0 = slice(max(0, -dx), min(nx, nx-dx))
    out[ys, xs] = M[ys0, xs0]
    return out

# ------------------------------------------------------------
# Generate synthetic example
# ------------------------------------------------------------
field_t = make_field(seed=3)
tmpl = make_template(angle_deg=20)
heat = corr2d(field_t, tmpl)

# pad heat to field size for overlay
s = tmpl.shape[0]
heat_pad = np.zeros_like(field_t)
heat_pad[s//2:heat.shape[0]+s//2, s//2:heat.shape[1]+s//2] = heat

# "motion" direction
dx, dy = 18, -6  # eastward & a bit northward
heat_shift = shift_map(heat_pad, dx=dx, dy=dy)

# create a plausible next field by shifting the front position
field_t1 = make_field(x0=60+dx, y0=70+dy, seed=4)

# ------------------------------------------------------------
# Figure 1: 4-panel visual argument
# ------------------------------------------------------------
fig = plt.figure(figsize=(14, 8))
gs = fig.add_gridspec(2, 2, wspace=0.08, hspace=0.18)

# (1) template
ax = fig.add_subplot(gs[0, 0])
ax.imshow(tmpl, cmap="coolwarm")
ax.set_title("1) Front template (\"stencil\")", fontsize=14)
ax.set_xticks([]); ax.set_yticks([])

# (2) detection heatmap
ax = fig.add_subplot(gs[0, 1])
ax.imshow(field_t, cmap="RdYlBu_r")
ax.imshow(heat_pad, cmap="magma", alpha=0.55)
ax.set_title("2) Detection heatmap: where does it match?", fontsize=14)
ax.set_xticks([]); ax.set_yticks([])

# (3) motion direction / shift heatmap
ax = fig.add_subplot(gs[1, 0])
ax.imshow(field_t, cmap="RdYlBu_r", alpha=0.85)
ax.imshow(heat_shift, cmap="magma", alpha=0.55)
ax.arrow(60, 70, dx, dy, width=1.2, head_width=7, head_length=10, color="white")
ax.set_title("3) Motion direction: shift the pattern", fontsize=14)
ax.set_xticks([]); ax.set_yticks([])

# (4) next map
ax = fig.add_subplot(gs[1, 1])
ax.imshow(field_t1, cmap="RdYlBu_r")
ax.set_title("4) Next forecast map (pattern transported)", fontsize=14)
ax.set_xticks([]); ax.set_yticks([])

fig.suptitle("Neural forecasting intuition: template → heatmap → shift → next field", fontsize=16, y=0.98)

out1 = OUTDIR / "nn_front_visual_argument_4panel.png"
fig.savefig(out1, dpi=220, bbox_inches="tight")
plt.close(fig)

# ------------------------------------------------------------
# Figure 2: simpler horizontal diagram for Slide 20
# ------------------------------------------------------------
fig = plt.figure(figsize=(14, 3.6))
gs = fig.add_gridspec(1, 3, wspace=0.15)

ax1 = fig.add_subplot(gs[0, 0])
ax1.imshow(field_t, cmap="RdYlBu_r")
ax1.set_title("State now", fontsize=13)
ax1.set_xticks([]); ax1.set_yticks([])

ax2 = fig.add_subplot(gs[0, 1])
ax2.imshow(heat_pad, cmap="magma")
ax2.set_title("Pattern location", fontsize=13)
ax2.set_xticks([]); ax2.set_yticks([])

ax3 = fig.add_subplot(gs[0, 2])
ax3.imshow(field_t1, cmap="RdYlBu_r")
ax3.set_title("State later", fontsize=13)
ax3.set_xticks([]); ax3.set_yticks([])

fig.suptitle("Learned rule: detect pattern + move it forward in time", fontsize=15, y=0.98)

out2 = OUTDIR / "nn_front_template_motion.png"
fig.savefig(out2, dpi=220, bbox_inches="tight")
plt.close(fig)

print("Saved:")
print(" -", out1)
print(" -", out2)


Saved:
 - ../../images/img16/nn_front_visual_argument_4panel.png
 - ../../images/img16/nn_front_template_motion.png
