# Model form

This notebook introduces the model form, its connections to tensor methods, and equivalent parameterizations shown in the paper.

At the end, we construct each experts' weights explicitly, showing the absence of structural low-rankness.

In [1]:
import torch

In [None]:
torch.random.manual_seed(42)

input_dim = 16
output_dim = 64
n_experts = 32

C = torch.randn(n_experts, output_dim)
W = torch.randn(input_dim, output_dim)

D = torch.randn(input_dim, output_dim)

# dummy gating parameter -- any linear transformation
G = torch.randn(input_dim, n_experts)
E = torch.randn(input_dim, input_dim) # dummy encoder

x = torch.randn(input_dim) # pre-MLP token representation
z = torch.nn.functional.gelu(E.T@x) # hidden units, with whatever activation function
a = torch.nn.functional.softmax(G.T@z, dim=0) # generate the conditional units (or 'expert coefficients') 

# Hadamard-factorization

Our weight tensor $\boldsymbol{\mathcal{W}}\in\mathbb{R}^{N\times H \times O}$ is defined elementwise as

$
    \boldsymbol{\mathcal{W}}(n,h,:) = \mathbf{c}_n * \mathbf{d}_h \in \mathbb{R}^{O},
      \quad \forall\,n\!\in\{1,\ldots,N\},\; h\!\in\{1,\ldots,H\}
$

Constructing the tensor explicitly:

In [3]:
W = torch.zeros(n_experts, input_dim, output_dim)

for n in range(n_experts):
    for h in range(input_dim):
        W[n, h] = C[n] * D[h] # note: all O-dimensional vectors

**Note**, as per Appendix *Sect. A.3.1*, we can also understand this parameterization through the khatri-rao product:

$
\mathbf{W}_{(3)} := \left(\mathbf{C} \odot \mathbf{D}\right)^\top\in\mathbb{R}^{O\times(N\cdot H)}.
$


In [4]:
import tensorly as tl
tl.set_backend('pytorch')

W_kr_unfolded = tl.tenalg.khatri_rao([C, D]).T # parameterize the mode-3 unfolding
W_kr = tl.fold(W_kr_unfolded, mode=2, shape=W.shape) # and re-shape

assert torch.allclose(W, W_kr, atol=1e-6)
print("Khatri-Rao parameterization is equivalent")

Khatri-Rao parameterization is equivalent


# Forward pass

The full MxD forward pass is given by:

In [5]:
y_mod = torch.zeros(output_dim)
for n in range(n_experts):
    # linear combination of N experts' outputs
    y_mod += a[n] * W[n].T @ z

And the equivalent forward pass from Lemma 2:

$
\mathbf{y} = (\mathbf{C}^\top\mathbf{a})
*
(\mathbf{D}^\top\mathbf{z})
$

In [6]:
y = (C.T@a) * (D.T@z)

assert torch.allclose(y, y_mod)
print('Elementwise equals explicit MoE!')

Elementwise equals explicit MoE!


Or, as a series of two tensor contractions (following the tensorized interpretation in *Sect. A.3.2* of the Appendix):

In [7]:
import tensorly as tl
tl.set_backend('pytorch')

y_moden = tl.tenalg.multi_mode_dot(W, [a, z], modes=[0, 1])

assert torch.allclose(y, y_moden)
print('Elementwise equals mode-n product forward pass!')

Elementwise equals mode-n product forward pass!


# Full-rankness

We can also verify the (normalized) rank of each expert is close to 1 when randomly initialized (i.e., there are no structural rank constraints)

In [8]:
y_rank = torch.zeros(output_dim)

ranks = []
for n in range(n_experts):
    Wn = D@torch.diag(C[n]) # materialize the nth expert
    
    ranks += [torch.linalg.matrix_rank(Wn)]

    # compute the output again to assert correctness
    y_rank += a[n] * Wn.T @ z

assert torch.allclose(y_rank, y_mod)
print('Elementwise equals explicit MoE!')

Elementwise equals explicit MoE!


In [9]:
print('Mean (normalized) rank of experts:')
torch.mean(torch.stack(
    [rank / min(input_dim, output_dim) for rank in ranks]
)).item()

Mean (normalized) rank of experts:


1.0