In [1]:
# bptt_scalar_mxnet.py
# 一个最简单的一维 RNN，两步序列：
# 1）用纯 Python 手动推导 BPTT 梯度
# 2）用 MXNet autograd 求梯度并对比

import math

import mxnet as mx
from mxnet import nd, autograd


def forward_and_backward_manual():
    print("===== 手动计算 BPTT（标量两步 RNN）=====")

    # 参数
    w_x = 1.0
    w_h = 1.0
    w_y = 1.0

    # 输入与标签
    x1, x2 = 1.0, 1.0
    y1_hat, y2_hat = 0.0, 0.0

    # 初始隐状态
    h0 = 0.0

    # ========= 正向传播 =========
    # t = 1
    a1 = w_x * x1 + w_h * h0
    h1 = math.tanh(a1)
    y1 = w_y * h1
    l1 = 0.5 * (y1 - y1_hat) ** 2

    # t = 2
    a2 = w_x * x2 + w_h * h1
    h2 = math.tanh(a2)
    y2 = w_y * h2
    l2 = 0.5 * (y2 - y2_hat) ** 2

    L = l1 + l2

    print(f"a1 = {a1:.4f}, h1 = tanh(a1) = {h1:.4f}, y1 = {y1:.4f}, l1 = {l1:.4f}")
    print(f"a2 = {a2:.4f}, h2 = tanh(a2) = {h2:.4f}, y2 = {y2:.4f}, l2 = {l2:.4f}")
    print(f"Total Loss L = {L:.4f}")
    print()

    # ========= 反向传播 =========
    # 输出层梯度：dl/dy = y - y_hat
    e1 = y1 - y1_hat
    e2 = y2 - y2_hat

    # 对 h 的梯度（来自当前步的损失）
    dL_dh1_from_l1 = e1 * w_y
    dL_dh2 = e2 * w_y  # t=2 没有未来步

    print("== 输出层梯度 ==")
    print(f"e1 = y1 - y1_hat = {e1:.4f}")
    print(f"e2 = y2 - y2_hat = {e2:.4f}")
    print(f"dL/dh1 (仅来自 l1) = {dL_dh1_from_l1:.4f}")
    print(f"dL/dh2 = {dL_dh2:.4f}")
    print()

    # ---- t = 2: h2 -> a2 -> (w_x, w_h, h1) ----
    dh2_da2 = 1.0 - h2 ** 2
    dL_da2 = dL_dh2 * dh2_da2

    # a2 = w_x * x2 + w_h * h1
    da2_dw_x = x2
    da2_dw_h = h1
    da2_dh1 = w_h

    dL_dw_x_t2 = dL_da2 * da2_dw_x
    dL_dw_h_t2 = dL_da2 * da2_dw_h
    dL_dh1_from_t2 = dL_da2 * da2_dh1

    print("== t = 2 反向 ==")
    print(f"dh2/da2 = 1 - h2^2 = {dh2_da2:.4f}")
    print(f"dL/da2 = dL/dh2 * dh2/da2 = {dL_da2:.4f}")
    print(f"dL/dw_x (来自 t=2) = {dL_dw_x_t2:.4f}")
    print(f"dL/dw_h (来自 t=2) = {dL_dw_h_t2:.4f}")
    print(f"dL/dh1 (来自 t=2)  = {dL_dh1_from_t2:.4f}")
    print()

    # ---- t = 1: 累加来自 l1 和来自 t2 的梯度 ----
    dL_dh1_total = dL_dh1_from_l1 + dL_dh1_from_t2

    dh1_da1 = 1.0 - h1 ** 2
    dL_da1 = dL_dh1_total * dh1_da1

    # a1 = w_x * x1 + w_h * h0 (h0 = 0)
    da1_dw_x = x1
    da1_dw_h = h0

    dL_dw_x_t1 = dL_da1 * da1_dw_x
    dL_dw_h_t1 = dL_da1 * da1_dw_h  # 为 0

    print("== t = 1 反向 ==")
    print(f"dL/dh1_total = dL/dh1_from_l1 + dL/dh1_from_t2 = {dL_dh1_total:.4f}")
    print(f"dh1/da1 = 1 - h1^2 = {dh1_da1:.4f}")
    print(f"dL/da1 = dL/dh1_total * dh1/da1 = {dL_da1:.4f}")
    print(f"dL/dw_x (来自 t=1) = {dL_dw_x_t1:.4f}")
    print(f"dL/dw_h (来自 t=1) = {dL_dw_h_t1:.4f}")
    print()

    # ---- 汇总参数梯度（对所有时间步求和）----
    dL_dw_x = dL_dw_x_t1 + dL_dw_x_t2
    dL_dw_h = dL_dw_h_t1 + dL_dw_h_t2

    # 输出层参数 w_y 的梯度 = Σ (y_t - y_hat_t) * h_t
    dL_dw_y = e1 * h1 + e2 * h2

    print("== 最终参数梯度（手动） ==")
    print(f"dL/dw_x = {dL_dw_x:.4f}")
    print(f"dL/dw_h = {dL_dw_h:.4f}")
    print(f"dL/dw_y = {dL_dw_y:.4f}")
    print()

    return {
        "w_x": dL_dw_x,
        "w_h": dL_dw_h,
        "w_y": dL_dw_y,
        "L": L,
    }


def forward_and_backward_mxnet():
    print("===== 用 MXNet autograd 验证 =====")

    ctx = mx.cpu()

    # 参数：标量 NDArray，并 attach_grad
    w_x = nd.array([1.0], ctx=ctx)
    w_h = nd.array([1.0], ctx=ctx)
    w_y = nd.array([1.0], ctx=ctx)
    for p in (w_x, w_h, w_y):
        p.attach_grad()

    # 输入与标签
    x1 = nd.array([1.0], ctx=ctx)
    x2 = nd.array([1.0], ctx=ctx)
    y1_hat = nd.array([0.0], ctx=ctx)
    y2_hat = nd.array([0.0], ctx=ctx)
    h0 = nd.array([0.0], ctx=ctx)

    with autograd.record():
        # t = 1
        a1 = w_x * x1 + w_h * h0
        h1 = nd.tanh(a1)
        y1 = w_y * h1
        l1 = 0.5 * (y1 - y1_hat) ** 2

        # t = 2
        a2 = w_x * x2 + w_h * h1
        h2 = nd.tanh(a2)
        y2 = w_y * h2
        l2 = 0.5 * (y2 - y2_hat) ** 2

        L = l1 + l2

    print(f"MXNet forward: L = {L.asscalar():.4f}")

    # 反向传播
    L.backward()

    print(f"dL/dw_x (mxnet) = {w_x.grad.asscalar():.4f}")
    print(f"dL/dw_h (mxnet) = {w_h.grad.asscalar():.4f}")
    print(f"dL/dw_y (mxnet) = {w_y.grad.asscalar():.4f}")
    print()

    return {
        "w_x": w_x.grad.asscalar(),
        "w_h": w_h.grad.asscalar(),
        "w_y": w_y.grad.asscalar(),
        "L": L.asscalar(),
    }


if __name__ == "__main__":
    manual = forward_and_backward_manual()
    mxnet_grads = forward_and_backward_mxnet()

    print("===== 对比 手算 vs MXNet autograd =====")
    for k in ["w_x", "w_h", "w_y"]:
        print(f"{k}: manual = {manual[k]:.4f}, mxnet = {mxnet_grads[k]:.4f}")


===== 手动计算 BPTT（标量两步 RNN）=====
a1 = 1.0000, h1 = tanh(a1) = 0.7616, y1 = 0.7616, l1 = 0.2900
a2 = 1.7616, h2 = tanh(a2) = 0.9427, y2 = 0.9427, l2 = 0.4443
Total Loss L = 0.7343

== 输出层梯度 ==
e1 = y1 - y1_hat = 0.7616
e2 = y2 - y2_hat = 0.9427
dL/dh1 (仅来自 l1) = 0.7616
dL/dh2 = 0.9427

== t = 2 反向 ==
dh2/da2 = 1 - h2^2 = 0.1114
dL/da2 = dL/dh2 * dh2/da2 = 0.1050
dL/dw_x (来自 t=2) = 0.1050
dL/dw_h (来自 t=2) = 0.0799
dL/dh1 (来自 t=2)  = 0.1050

== t = 1 反向 ==
dL/dh1_total = dL/dh1_from_l1 + dL/dh1_from_t2 = 0.8666
dh1/da1 = 1 - h1^2 = 0.4200
dL/da1 = dL/dh1_total * dh1/da1 = 0.3639
dL/dw_x (来自 t=1) = 0.3639
dL/dw_h (来自 t=1) = 0.0000

== 最终参数梯度（手动） ==
dL/dw_x = 0.4689
dL/dw_h = 0.0799
dL/dw_y = 1.4687

===== 用 MXNet autograd 验证 =====
MXNet forward: L = 0.7343
dL/dw_x (mxnet) = 0.4689
dL/dw_h (mxnet) = 0.0799
dL/dw_y (mxnet) = 1.4687

===== 对比 手算 vs MXNet autograd =====
w_x: manual = 0.4689, mxnet = 0.4689
w_h: manual = 0.0799, mxnet = 0.0799
w_y: manual = 1.4687, mxnet = 1.4687
