# mHC-Pytorch 

mHC: [mHC: Manifold-Constrained Hyper-Connections](https://arxiv.org/abs/2512.24880)

git: [dhcode-cpp/mHC-pytorch](https://github.com/dhcode-cpp/mHC-pytorch)

blog: [【手撕 mHC】详解DeepSeek残差链接mHC进化之路（超长文、附代码）](https://zhuanlan.zhihu.com/p/1990683672337223894)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(42)

<torch._C.Generator at 0x1179c0e70>

In [2]:
# config 

dim = 512
rate = 2
layer_id = 10
dynamic = True

bsz = 1
seq_len = 16

## Manifold-Constrained Hyper-Connections

## Math

\begin{equation}
    \begin{cases}
        \vec{\mathbf{x}}'_l = \text{RMSNorm}(\vec{\mathbf{x}}_l) \\
        \tilde{\mathcal{H}}_l^\text{pre} = \alpha_l^\mathrm{pre} \cdot (\vec{\mathbf{x}}'_l\phi^\mathrm{pre}_l) + \mathbf{b}_l^\mathrm{pre} \\
        \tilde{\mathcal{H}}_l^\text{post} = \alpha_l^\mathrm{post} \cdot (\vec{\mathbf{x}}'_l\phi^\mathrm{post}_l) + \mathbf{b}_l^\mathrm{post} \\
        \tilde{\mathcal{H}}_l^\text{res} = \alpha_l^\mathrm{res} \cdot \text{mat}(\vec{\mathbf{x}}'_l\phi^\mathrm{res}_l) + \mathbf{b}_l^\mathrm{res}, \\
    \end{cases}
\end{equation}



\begin{equation}
    \begin{cases}
        \mathcal{H}_l^\text{pre} = \sigma(\tilde{\mathcal{H}}_l^\text{pre}) \\
        \mathcal{H}_l^\text{post} = 2\sigma(\tilde{\mathcal{H}}_l^\text{post}) \\
        \mathcal{H}_l^\text{res} = \text{Sinkhorn-Knopp}(\tilde{\mathcal{H}}_l^\text{res}),
    \end{cases}
\end{equation}


## implemented

## sinkhorn_knopp 归一化

原论文

To this end, we restrict $\mathcal{H}^\text{res}_{l}$ to be a doubly stochastic matrix, which has non-negative entries where both the rows and columns sum to 1. Formally, let $\mathcal{M}^\mathrm{res}$ denote the manifold of doubly stochastic matrices (also known as the Birkhoff polytope).
We constrain $\mathcal{H}^\text{res}_{l}$ to $\mathcal{P}_{\mathcal{M}^\mathrm{res}}(\mathcal{H}^\text{res}_{l})$, defined as:
$$
\begin{equation}
    \mathcal{P}_{\mathcal{M}^\mathrm{res}}(\mathcal{H}^\text{res}_{l}) \coloneq \left\{ \mathcal{H}^\text{res}_{l} \in \mathbb{R}^{n \times n} \mid \mathcal{H}^\text{res}_{l}\mathbf{1}_n = \mathbf{1}_n, \ \mathbf{1}^\top_n\mathcal{H}^\text{res}_{l} = \mathbf{1}^\top_n, \ \mathcal{H}^\text{res}_{l} \geq 0 \right\},
\end{equation}
$$
where $\mathbf{1}_n$ represents the $n$-dimensional vector of all ones.

### sinkhorn_knopp

In [3]:
def sinkhorn_knopp_batched(A, it=1000, eps=1e-8):
    """
    A is not negative matrix
    """
    
    batch_size, n, _, = A.shape
    
    u = torch.ones(batch_size, n)
    v = torch.ones(batch_size, n)
    
    for _ in range(it):
        v_temp = v.unsqueeze(2)  # (B, n, 1)
        Av = torch.bmm(A, v_temp).squeeze(2)  # (B, n)
        u = 1.0 / (Av + eps)
        
        u_temp = u.unsqueeze(2)  # (B, n, 1)
        At_u = torch.bmm(A.transpose(1, 2), u_temp).squeeze(2)
        v = 1.0 / (At_u + eps)
        
    U = torch.diag_embed(u)  # (B, n, n)
    V = torch.diag_embed(v)  # (B, n, n)
    P = torch.bmm(torch.bmm(U, A), V)
    
    return P, U, V

In [4]:
A = torch.randn(2,3,3)
A = A.exp() # NOT negative trick
# example1
P, _, _, = sinkhorn_knopp_batched(A, it=2)
print(P.shape)
print('it=2\t', P[0].sum(dim=0), P[0].sum(dim=1))

# example2
P, _, _, = sinkhorn_knopp_batched(A, it=20)
print('it=20\t', P[0].sum(dim=0), P[0].sum(dim=1))

torch.Size([2, 3, 3])
it=2	 tensor([1., 1., 1.]) tensor([0.9788, 1.0334, 0.9878])
it=20	 tensor([1.0000, 1.0000, 1.0000]) tensor([1.0000, 1.0000, 1.0000])


In [5]:
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super(RMSNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x):
        mean = (x**2).mean(-1, keepdim=True)
        out_mean = x / torch.sqrt(mean + self.eps) # root mean square
        out = self.gamma * out_mean 
        return out

## mHC Fuse Kernel

$$
\begin{align}
    \phi_l                                                                      &: \text{tfloat32}          &&[nC, n^2+2n]                                                             \\
    \vec{\mathbf{x}}_l                                                                      &: \text{bfloat16}          &&[1, nC]                                                                     \\
    \alpha_l^\mathrm{pre}, \alpha_l^\mathrm{post}, \alpha_l^\mathrm{res}                                                    &: \text{float32}           &&\text{Scalars}                                                         \\
    \mathbf{b}_l                                                                      &: \text{float32}           &&[1, n^2+2n]                                                                  \\
    \left[{\tilde{\tilde{\mathcal{H}}}^{\mathrm{pre}}_{l}}, {\tilde{\tilde{\mathcal{H}}}^{\mathrm{post}}_{l}}, {\tilde{\tilde{\mathcal{H}}}^{\mathrm{res}}_{l}}\right]   &: \text{float32}           &&= \vec{\mathbf{x}}_l\phi_l                                                  \\
    r                                                                               &: \text{float32}           &&= \left\|\vec{\mathbf{x}}_l\right\|_2 / \sqrt{nC}  ;  \text{——RMS:}r = \frac{1}{RMS} = \frac{||\vec{\mathbf{x}}_l||}{\sqrt{nC}}                                              \\
    \left[\tilde{\mathcal{H}}^{\mathrm{pre}}_{l}, \tilde{\mathcal{H}}^{\mathrm{post}}_{l}, \tilde{\mathcal{H}}^{\mathrm{res}}_{l}\right]         &: \text{float32}           &&= 1/r \left[\alpha_l^\mathrm{pre}{\tilde{\tilde{\mathcal{H}}}^{\mathrm{pre}}_{l}}, \alpha_l^\mathrm{post}{\tilde{\tilde{\mathcal{H}}}^{\mathrm{post}}_{l}}, \alpha_l^\mathrm{res}{\tilde{\tilde{\mathcal{H}}}^{\mathrm{res}}_{l}}\right] + \mathbf{b}_l \\
    \mathcal{H}^{\mathrm{pre}}_{l}                                                                      &: \text{float32}           &&= \sigma\left(\tilde{\mathcal{H}}^{\mathrm{pre}}_{l}\right)                                   \\
    \mathcal{H}^{\mathrm{post}}_{l}                                                                      &: \text{float32}           &&= 2\sigma\left(\tilde{\mathcal{H}}^{\mathrm{post}}_{l}\right)                                  \\
    \mathcal{H}^{\mathrm{res}}_{l}                                                                      &: \text{float32}           &&= \text{Sinkhorn-Knopp}\left(\tilde{\mathcal{H}}^{\mathrm{res}}_{l}\right)   
\end{align}
$$


> Observing that RMSNorm in \mhcshort{} imposes significant latency when operating on the high-dimensional hidden state $\vec{\mathbf{x}}_l \in \mathbb{R}^{1\times nC}$, we reorder the dividing-by-norm operation to follow the matrix multiplication. This optimization maintains mathematical equivalence while improving efficiency.


for RMS-Norm, given $\hat{x} =\gamma \frac{x}{\text{RMS}} $, fuse RMSNorm after matrix multiple($\vec{\mathbf{x}}_l\phi_l$)

$$
\begin{align}
\tilde{x} &= \gamma x, x\in\mathbb{R}^{nC}  \\
\hat{x} &=  \tilde{x} \frac{1}{\text{RMS}}
\\
\frac{1}{\text{RMS}} &= \frac{1}{\sqrt{ \frac{1}{nC}(\sum_{j=1}^{nC} x_j^2)}} = \frac{1}{\sqrt{ \frac{1}{nC}}\sqrt{(\sum_{j=1}^{nC} x_j^2)}} \\
&=\frac{1}{\sqrt{\frac{1}{nC}}} \frac{1}{\sqrt{(\sum_{j=1}^{nC} x_j^2)}} \\
&=\sqrt{nC}\frac{1}{\vert\vert x\vert\vert_2} \\
r = \text{RMS} &= \frac{\vert\vert x\vert\vert_2}{\sqrt{nC}}
\end{align}
$$

In [6]:
import math
class ManifoldHyperConnectionFuse(nn.Module):
    """
    h: hyper hidden matrix (BxLxNxD)
        B: batch_size
        L: Seq_len
        N: expansion rate
        D: feature dim
    """
    def __init__(self, dim, rate, layer_id, max_sk_it):
        super(ManifoldHyperConnectionFuse, self).__init__()

        self.n = rate
        self.dim = dim

        self.nc = self.n * self.dim
        self.n2 = self.n * self.n

        # norm flatten
        """
        Observing that RMSNorm in \mhcshort{} imposes significant latency when operating on 
        the high-dimensional hidden state $\vec{\mathbf{x}}_l \in \mathbb{R}^{1\times nC}$, 
        we reorder the dividing-by-norm operation to follow the matrix multiplication. 
        This optimization maintains mathematical equivalence while improving efficiency.
        """
        self.norm = RMSNorm(dim*rate)

        # parameters
        self.w = nn.Parameter(torch.zeros(self.nc, self.n2 + 2*self.n))
        self.alpha = nn.Parameter(torch.ones(3) * 0.01)
        self.beta = nn.Parameter(torch.zeros(self.n2 + 2*self.n) * 0.01)

        # max sinkhorn knopp iterations
        self.max_sk_it = max_sk_it

    def mapping(self, h, res_norm):
        B, L, N, D = h.shape

        # 1.vectorize
        h_vec = h.reshape(B, L, N*D)
        
        # RMSNorm Fused
        h_vec = self.norm.gamma * h_vec

        # 2.projection
        H = h_vec @ self.w

        # 3. scaled by fused RMS tricks
        r = h_vec.norm(dim=-1, keepdim=True) / math.sqrt(self.nc)
        r_ = 1.0 / r
        
        # 4. mapping
        n = N
        H_pre = r_ * H[:,:, :n] * self.alpha[0] + self.beta[:n]
        H_post = r_ * H[:,:, n:2*n] * self.alpha[1] + self.beta[n:2*n]
        H_res = r_ * H[:,:, 2*n:] * self.alpha[2] + self.beta[2*n:]

        # 5. final constrained mapping 
        H_pre = F.sigmoid(H_pre)
        H_post = 2 * F.sigmoid(H_post)

        # 6. sinkhorn_knopp iteration
        H_res = H_res.reshape(B, L, N, N)
        H_res_exp = H_res.exp()
        with torch.no_grad():
            _, U, V = res_norm(H_res_exp.reshape(B*L, N, N), self.max_sk_it)
        # recover
        P = torch.bmm(torch.bmm(U.detach(), H_res_exp.reshape(B*L, N, N)), V.detach())
        H_res_exp = H_res.reshape(B, L, N, N)

        return H_pre, H_post, H_res

    def process(self, h, H_pre, H_res):
        h_pre = H_pre.unsqueeze(dim=2) @ h
        h_res = H_res @ h
        return h_pre, h_res

    def depth_connection(self, H_pre, h_out, beta):
        post_mapping = beta.unsqueeze(dim=-1) @ h_out
        out = post_mapping + h_res
        return out
        
max_sk_it = 20
mHC = ManifoldHyperConnectionFuse(dim = dim, 
                                  rate = rate, 
                                  layer_id = layer_id,
                                  max_sk_it = max_sk_it)

### forward

In [7]:
attn = nn.Linear(dim, dim)

In [8]:
h = torch.randn(bsz, seq_len, rate, dim)
H_pre, H_post, H_res = mHC.mapping(h, sinkhorn_knopp_batched)
h_pre, h_res = mHC.process(h, H_pre, H_res)
h_out = attn(h_pre) 
out = mHC.depth_connection(h_res, h_out, beta=H_post)
print('out', out.shape)

out torch.Size([1, 16, 2, 512])


## Decoder Block

In [9]:
class DecoderBlockmHC(nn.Module):
    def __init__(self, dim, rate, layer_id, max_sk_it):
        super(DecoderBlockmHC, self).__init__()
        self.attn = nn.Linear(dim, dim)
        self.attn_mHC = ManifoldHyperConnectionFuse(dim = dim, rate = rate, layer_id = layer_id, max_sk_it = max_sk_it)
        self.ffn = nn.Linear(dim, dim)
        self.ffn_mHC = ManifoldHyperConnectionFuse(dim = dim, rate = rate, layer_id = layer_id, max_sk_it = max_sk_it)

    def forward(self, h):
        # h:[bsz, seq_len, rate, dim]
        H_pre, H_post, H_res = self.attn_mHC.mapping(h, sinkhorn_knopp_batched)
        h_pre, h_res = self.attn_mHC.process(h, H_pre, H_res)
        h_out = self.attn(h_pre) 
        h = self.attn_mHC.depth_connection(h_res, h_out, beta=H_post)

        H_pre, H_post, H_res = self.ffn_mHC.mapping(h, sinkhorn_knopp_batched)
        h_pre, h_res = self.attn_mHC.process(h, H_pre, H_res)
        h_out = self.ffn(h_pre) 
        h = self.ffn_mHC.depth_connection(h_res, h_out, beta=H_post)
        return h

## Model

In [10]:
class LanguageModelmHC(nn.Module):
    def __init__(self, num_layer, vocab_size, dim, rate, max_sk_it):
        super(LanguageModelmHC, self).__init__()
        self.n = rate
        self.embd = nn.Embedding(vocab_size, dim)
        self.decoder = nn.ModuleList(
            [ DecoderBlockmHC(dim, rate, layer_id, max_sk_it) for layer_id in range(num_layer) ]
        )
        self.lm_head = nn.Linear(dim, vocab_size)

    def forward(self, x):
        h = self.embd(x)

        # repeat h
        h=h.unsqueeze(dim=2)
        h = h.repeat(1,1,self.n,1)

        # decoder forward
        for block in self.decoder:
            h = block(h)

        # sum transform branch
        h = h.sum(dim=2)

        logits = self.lm_head(h)
        
        return logits

## mHC Model Forward

In [11]:
# model

vocab_size = 100
model =LanguageModelmHC(2, vocab_size, dim, rate, max_sk_it)
x = torch.randint(vocab_size, (bsz, seq_len))

In [12]:
logits = model(x)
print(logits.shape)

torch.Size([1, 16, 100])
