In [3]:
#!/usr/bin/env python
# node_edge_equivalence.py
#
# Verify that a FlexAttention rewrite of NodeEdgeBlock is numerically
# identical to the hand-rolled implementation.

import math, torch
import torch.nn as nn
from torch.nn.attention.flex_attention import flex_attention


# --------------------------------------------------------------------------- #
# Helper layers                                                               #
# --------------------------------------------------------------------------- #
def masked_softmax_dim2(scores: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    """Soft-max over dim=2 with –inf masking. Accepts (bs,n,n,h,df) or (bs,n,n,h)."""
    if scores.dim() == 5:                      # (bs,n_q,n_k,h,df)
        scores = scores.sum(-1)                # -> (bs,n,n,h)
    scores = scores.masked_fill(mask == 0, -1e30)
    return torch.softmax(scores, dim=2)


class Xtoy(nn.Module):
    def __init__(self, dx, dy):
        super().__init__()
        self.lin = nn.Linear(dx, dy)

    def forward(self, X, x_mask):
        summed = (X * x_mask).sum(1)
        denom  = x_mask.sum(1).clamp_min(1e-6)
        return self.lin(summed / denom)


class Etoy(nn.Module):
    def __init__(self, de, dy):
        super().__init__()
        self.lin = nn.Linear(de, dy)

    def forward(self, E, e_mask1, e_mask2):
        mask   = e_mask1 * e_mask2
        summed = (E * mask).sum((1, 2))
        denom  = mask.sum((1, 2)).clamp_min(1e-6)
        return self.lin(summed / denom)


# --------------------------------------------------------------------------- #
# 1.  Original block                                                          #
# --------------------------------------------------------------------------- #
class NodeEdgeBlock(nn.Module):
    def __init__(self, dx, de, dy, n_head):
        super().__init__()
        assert dx % n_head == 0
        self.dx, self.de, self.dy = dx, de, dy
        self.n_head, self.df      = n_head, dx // n_head

        # projections
        self.q = nn.Linear(dx, dx)
        self.k = nn.Linear(dx, dx)
        self.v = nn.Linear(dx, dx)

        self.e_mul, self.e_add = nn.Linear(de, dx), nn.Linear(de, dx)
        self.y_e_mul, self.y_e_add = nn.Linear(dy, dx), nn.Linear(dy, dx)
        self.y_x_mul, self.y_x_add = nn.Linear(dy, dx), nn.Linear(dy, dx)

        self.y_y  = nn.Linear(dy, dy)
        self.x_y  = Xtoy(dx, dy)
        self.e_y  = Etoy(de, dy)

        self.x_out = nn.Linear(dx, dx)
        self.e_out = nn.Linear(dx, de)
        self.y_out = nn.Sequential(nn.Linear(dy, dy), nn.ReLU(), nn.Linear(dy, dy))

    # --------------------------------------------------------------------- #
    def forward(self, X, E, y, node_mask):
        bs, n, _ = X.shape
        x_mask  = node_mask.unsqueeze(-1)          # (bs,n,1)
        e_mask1 = x_mask.unsqueeze(2)              # (bs,n,1,1)
        e_mask2 = x_mask.unsqueeze(1)              # (bs,1,n,1)

        # -------- Q, K ---------------------------------------------------- #
        Q = self.q(X) * x_mask
        K = self.k(X) * x_mask
        Qv = Q.view(bs, n, self.n_head, self.df)   # (bs,n,h,df)
        Kv = K.view(bs, n, self.n_head, self.df)

        # full vector scores
        Y = (Qv.unsqueeze(2) * Kv.unsqueeze(1)) / math.sqrt(self.df)  # (bs,n,n,h,df)

        E1 = (self.e_mul(E) * e_mask1 * e_mask2).view(bs, n, n, self.n_head, self.df)
        E2 = (self.e_add(E) * e_mask1 * e_mask2).view_as(E1)
        Y  = Y * (E1 + 1) + E2                         # (bs,n,n,h,df)

        # -------- edge update -------------------------------------------- #
        newE = Y.flatten(3)
        ye1, ye2 = self.y_e_add(y)[:, None, None, :], self.y_e_mul(y)[:, None, None, :]
        newE = self.e_out(ye1 + (ye2 + 1) * newE) * e_mask1 * e_mask2

        # -------- attention scalars -------------------------------------- #
        soft_mask = e_mask2.expand(-1, n, -1, self.n_head)
        attn = masked_softmax_dim2(Y, soft_mask)      # (bs,n,n,h)

        # -------- value --------------------------------------------------- #
        V = self.v(X) * x_mask                        # (bs,n,dx)
        V = V.view(bs, n, self.n_head, self.df).unsqueeze(1)  # (bs,1,n,h,df)
        weighted = (attn.unsqueeze(-1) * V).sum(2)    # (bs,n,h,df)
        weighted = weighted.flatten(2)                # (bs,n,dx)

        # -------- node update -------------------------------------------- #
        yx1, yx2 = self.y_x_add(y)[:, None, :], self.y_x_mul(y)[:, None, :]
        newX = self.x_out(yx1 + (yx2 + 1) * weighted) * x_mask

        # -------- global update ------------------------------------------ #
        new_y = self.y_out(self.y_y(y) + self.x_y(X, x_mask) + self.e_y(E, e_mask1, e_mask2))
        return newX, newE, new_y


# --------------------------------------------------------------------------- #
# 2.  FlexAttention version                                                   #
# --------------------------------------------------------------------------- #
class NodeEdgeBlockFlex(NodeEdgeBlock):
    """Same math as NodeEdgeBlock but with a fused Flash/Flex kernel."""
    def __init__(self, dx, de, dy, n_head):
        super().__init__(dx, de, dy, n_head)
        self.sqrt_df = math.sqrt(self.df)

    def forward(self, X, E, y, node_mask):
        bs, n, _ = X.shape
        x_mask  = node_mask.unsqueeze(-1)
        e_mask1 = x_mask.unsqueeze(2)
        e_mask2 = x_mask.unsqueeze(1)

        # -------- projections -------------------------------------------- #
        Q = self.q(X) * x_mask                 # (bs,n,dx)
        K = self.k(X) * x_mask
        V = self.v(X) * x_mask

        Qv = Q.view(bs, n, self.n_head, self.df)   # (bs,n,h,df)
        Kv = K.view(bs, n, self.n_head, self.df)
        Vv = V.view(bs, n, self.n_head, self.df)

        # same Y as reference for edge update
        Y = (Qv.unsqueeze(2) * Kv.unsqueeze(1)) / self.sqrt_df

        E1 = (self.e_mul(E) * e_mask1 * e_mask2).view(bs, n, n, self.n_head, self.df)
        E2 = (self.e_add(E) * e_mask1 * e_mask2).view_as(E1)
        Y  = Y * (E1 + 1) + E2                        # (bs,n,n,h,df)

        # -------- edge update identical ---------------------------------- #
        newE = Y.flatten(3)
        newE = self.e_out(self.y_e_add(y)[:, None, None, :] +
                          (self.y_e_mul(y)[:, None, None, :] + 1) * newE) * e_mask1 * e_mask2

        # -------- FlexAttention fusion ----------------------------------- #
        # Layout: (bs, h, seq, df)
        Q_flex = Qv.transpose(1, 2)       # (bs,h,n,df)
        K_flex = Kv.transpose(1, 2)
        V_flex = Vv.transpose(1, 2)

        # Closure tensors must already live on the right device/dtype
        E1_flat = E1                      # (bs,n,n,h,df)
        E2_flat = E2

        key_padding = (node_mask == 0)    # True = ignore

        def score_mod(_score, b, h, q, k):
            """Exact scalar = Σ_df [ (q_d k_d / √df)*(1+E1) + E2 ]"""
            q_vec = Qv[b, q, h]           # (df)
            k_vec = Kv[b, k, h]
            base  = q_vec * k_vec / self.sqrt_df
            return (base * (E1_flat[b, q, k, h] + 1) +
                    E2_flat[b, q, k, h]).sum(-1)

        attn_V = flex_attention(
            Q_flex, K_flex, V_flex,
            score_mod=score_mod,
            key_padding_mask=key_padding
        )                                 # (bs,h,n,df)

        weighted = attn_V.transpose(1, 2).contiguous()  # (bs,n,h,df)
        weighted = weighted.flatten(2)                  # (bs,n,dx)

        # -------- node + global updates as before ------------------------ #
        newX = self.x_out(self.y_x_add(y)[:, None, :] +
                          (self.y_x_mul(y)[:, None, :] + 1) * weighted) * x_mask
        new_y = self.y_out(self.y_y(y) + self.x_y(X, x_mask) + self.e_y(E, e_mask1, e_mask2))
        return newX, newE, new_y


# --------------------------------------------------------------------------- #
# 3.  Sanity check                                                            #
# --------------------------------------------------------------------------- #
def _run_equivalence():
    torch.manual_seed(0)

    bs, n      = 3, 5
    dx, de, dy = 32, 16, 8
    n_head     = 8

    X  = torch.randn(bs, n, dx)
    E  = torch.randn(bs, n, n, de)
    y  = torch.randn(bs, dy)
    mask = (torch.rand(bs, n) > 0.3).float()

    ref  = NodeEdgeBlock(dx, de, dy, n_head)
    flex = NodeEdgeBlockFlex(dx, de, dy, n_head)
    flex.load_state_dict(ref.state_dict())           # identical weights

    ref.eval(); flex.eval()
    with torch.no_grad():
        x0, e0, y0 = ref (X, E, y, mask)
        x1, e1, y1 = flex(X, E, y, mask)

    def err(a, b): return (a - b).abs().max().item()
    print(f"max |ΔX| = {err(x0,x1):.3e}")
    print(f"max |ΔE| = {err(e0,e1):.3e}")
    print(f"max |Δy| = {err(y0,y1):.3e}")

    tol = 1e-5
    assert err(x0,x1) < tol and err(e0,e1) < tol and err(y0,y1) < tol
    print("🎉  FlexAttention block matches the original within", tol)


In [4]:
_run_equivalence()

TypeError: flex_attention() got an unexpected keyword argument 'key_padding_mask'