# Mixture of Layers (MoLs)

We propose a method for neural networks to route information dynamically through their layers in an *arbitrary order*, allowing for in-context parameter tying.

![](https://i.ibb.co/XsMYr0c/mol.png)

## 1. Setup

In [1]:
#@markdown Install dependencies.

!pip -q install transformers \
    diffusers \
    datasets \
    accelerate \
    einops

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m20.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m290.1/290.1 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m20.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━

## 2. LayerRouter

The core of MoL is *LayerRouter*, a module that determines which layer the antecedent layer's activations should be forwarded through. Formally, LayerRouter is a function $f(\mathbf{x}_t, t)$ given by,

$$
    f(\mathbf{x}_t, t) = (g(\mathbf{x}_t, t), h(\mathbf{x}_t, t)),
$$

where $g(\mathbf{x}_t, t)$ returns a distribution over subsequent layer indices and $h(\mathbf{x}_t, t)$ is an arbitrary transformation on $\mathbf{x}_t$. The subsequent layer index is chosen as $\text{argmax}\, g(\mathbf{x}_t, t)$. Then, $h(\mathbf{x}_t, t)$ is given to it as input.

In [136]:
#@markdown Implement the router.

from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange


class MLP(nn.Module):
    """MLP.

    Example
    -------
    >>> module = MLP(embedding_dimension=256, condition_dimension=16)
    >>> x = torch.randn((1, 10, 256))
    >>> c = torch.randn((16,))
    >>> x = module(x, c)  # Shape: (1, 10, 256).
    """

    def __init__(
        self,
        embedding_dimension: int,
        condition_dimension: int,
    ) -> None:
        """Initialize the module.

        Parameters
        ----------
        embedding_dimension : int
            The embedding dimension.
        """

        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(
                in_features=embedding_dimension + condition_dimension,
                out_features=embedding_dimension * 3,
            ),
            nn.GELU(),
            nn.Linear(
                in_features=embedding_dimension * 3,
                out_features=embedding_dimension,
            ),
        )

    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
        """Forward the module.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor (B, T, E).
        c : torch.Tensor
            The condition tensor (B, C).

        Returns
        -------
        x : torch.Tensor
            The output tensor.
        """

        c = c[None, None, :].repeat((x.size(0), x.size(1), 1))  # Make `c` catable.
        x = torch.cat((x, c), dim=-1)
        x = self.layers(x)

        return x


class LayerRouter(nn.Module):
    """LayerRouter.

    Example
    -------
    >>> module = LayerRouter(
    ...     embedding_dimension=256,
    ...     steps=16,
    ...     layers=(
    ...         ...
    ...     ),
    ... )
    >>> x = ...
    >>> x = module(x)
    """

    def __init__(
        self,
        embedding_dimension: int,
        steps: int,
        layers: Tuple[nn.Module],
    ) -> None:
        """Initialize the module.

        Parameters
        ----------
        embedding_dimension : int
            The embedding dimension.
        steps : int
            The number of steps.
        layers : int
            The subsequent layers.
        """

        super().__init__()

        self.steps = steps
        self.layers = layers

        self.mlp_1 = MLP(
            embedding_dimension=embedding_dimension,
            condition_dimension=steps,
        )

        self.mlp_2 = MLP(
            embedding_dimension=embedding_dimension,
            condition_dimension=steps,
        )

        self.head = nn.Linear(
            in_features=embedding_dimension,
            out_features=len(layers),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward the module.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor (B, T, E).

        Returns
        -------
        x : torch.Tensor
            The output tensor (B, T, E).
        """

        for step in torch.arange(self.steps):

            condition = F.one_hot(step, num_classes=self.steps).float()
            score = F.softmax(self.head(self.mlp_1(x, condition)), dim=-1)
            score = score.mean(dim=-2)
            index = (score + (score.argmax(dim=-1).view(-1, 1).detach() - score)).mean(dim=-1)  # STE.
            x = self.mlp_2(x, condition)

            # Reconstruct batch with x routed to chosen layers.

            for i in range(x.size(0)):
                index_i = int(index[i].item())
                x[i, ...] = self.layers[index_i](x[i, ...])

        return x