In [None]:
# Sparsely-Gated Mixture-of-Experts Layers.
# See "Outrageously Large Neural Networks"
# https://arxiv.org/abs/1701.06538
#
# Author: David Rau
#
# The code is based on the TensorFlow implementation:
# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/expert_utils.py


import numpy as np
import torch
import torch.nn as nn
from torch.distributions.normal import Normal


class MLP(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        self.log_soft = nn.LogSoftmax(1)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.log_soft(out)
        return out


class SparseDispatcher(object):
    """Helper for implementing a mixture of experts.
    The purpose of this class is to create input minibatches for the
    experts and to combine the results of the experts to form a unified
    output tensor.
    There are two functions:
    dispatch - take an input Tensor and create input Tensors for each expert.
    combine - take output Tensors from each expert and form a combined output
      Tensor.  Outputs from different experts for the same batch element are
      summed together, weighted by the provided "gates".
    The class is initialized with a "gates" Tensor, which specifies which
    batch elements go to which experts, and the weights to use when combining
    the outputs.  Batch element b is sent to expert e iff gates[b, e] != 0.
    The inputs and outputs are all two-dimensional [batch, depth].
    Caller is responsible for collapsing additional dimensions prior to
    calling this class and reshaping the output to the original shape.
    See common_layers.reshape_like().
    Example use:
    gates: a float32 `Tensor` with shape `[batch_size, num_experts]`
    inputs: a float32 `Tensor` with shape `[batch_size, input_size]`
    experts: a list of length `num_experts` containing sub-networks.
    dispatcher = SparseDispatcher(num_experts, gates)
    expert_inputs = dispatcher.dispatch(inputs)
    expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)]
    outputs = dispatcher.combine(expert_outputs)
    The preceding code sets the output for a particular example b to:
    output[b] = Sum_i(gates[b, i] * experts[i](inputs[b]))
    This class takes advantage of sparsity in the gate matrix by including in the
    `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`.
    """

    def __init__(self, num_experts, gates):
        """Create a SparseDispatcher."""

        self._gates = gates
        self._num_experts = num_experts
        # sort experts
        sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
        # drop indices
        _, self._expert_index = sorted_experts.split(1, dim=1)
        # get according batch index for each expert
        self._batch_index = sorted_experts[index_sorted_experts[:, 1], 0]
        # calculate num samples that each expert gets
        self._part_sizes = list((gates > 0).sum(0).numpy())
        # expand gates to match with self._batch_index
        gates_exp = gates[self._batch_index.flatten()]
        self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)

    def dispatch(self, inp):
        """Create one input Tensor for each expert.
        The `Tensor` for a expert `i` contains the slices of `inp` corresponding
        to the batch elements `b` where `gates[b, i] > 0`.
        Args:
          inp: a `Tensor` of shape "[batch_size, <extra_input_dims>]`
        Returns:
          a list of `num_experts` `Tensor`s with shapes
            `[expert_batch_size_i, <extra_input_dims>]`.
        """

        # assigns samples to experts whose gate is nonzero

        # expand according to batch index so we can just split by _part_sizes
        inp_exp = inp[self._batch_index].squeeze(1)
        return torch.split(inp_exp, self._part_sizes, dim=0)

    def combine(self, expert_out, multiply_by_gates=True):
        """Sum together the expert output, weighted by the gates.
        The slice corresponding to a particular batch element `b` is computed
        as the sum over all experts `i` of the expert output, weighted by the
        corresponding gate values.  If `multiply_by_gates` is set to False, the
        gate values are ignored.
        Args:
          expert_out: a list of `num_experts` `Tensor`s, each with shape
            `[expert_batch_size_i, <extra_output_dims>]`.
          multiply_by_gates: a boolean
        Returns:
          a `Tensor` with shape `[batch_size, <extra_output_dims>]`.
        """
        # apply exp to expert outputs, so we are not longer in log space
        print(len(expert_out), expert_out[1].size())
        stitched = torch.cat(expert_out, 0).exp()
        print(stitched.size())
        if multiply_by_gates:
            stitched = stitched.mul(self._nonzero_gates)
            print(self._nonzero_gates)
        zeros = torch.zeros(
            self._gates.size(0), expert_out[-1].size(1), requires_grad=True
        )
        # combine samples that have been processed by the same k experts
        combined = zeros.index_add(0, self._batch_index, stitched.float())
        # add eps to all zero values in order to avoid nans when going back to log space
        combined[combined == 0] = np.finfo(float).eps
        # back to log space
        return combined.log()

    def expert_to_gates(self):
        """Gate values corresponding to the examples in the per-expert `Tensor`s.
        Returns:
          a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32`
              and shapes `[expert_batch_size_i]`
        """
        # split nonzero gates for each expert
        return torch.split(self._nonzero_gates, self._part_sizes, dim=0)


class MoE(nn.Module):

    """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
    Args:
    input_size: integer - size of the input
    output_size: integer - size of the input
    num_experts: an integer - number of experts
    hidden_size: an integer - hidden size of the experts
    noisy_gating: a boolean
    k: an integer - how many experts to use for each batch element
    """

    def __init__(self, input_size, output_size, num_experts, noisy_gating=True, k=5):
        super(MoE, self).__init__()
        self.noisy_gating = noisy_gating
        self.num_experts = num_experts
        self.output_size = output_size
        self.input_size = input_size
        self.k = k
        self.w_gate = nn.Parameter(
            torch.zeros(input_size, num_experts), requires_grad=True
        )
        self.w_noise = nn.Parameter(
            torch.zeros(input_size, num_experts), requires_grad=True
        )

        self.softplus = nn.Softplus()
        self.softmax = nn.Softmax(1)
        self.normal = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
        assert self.k <= self.num_experts

    def cv_squared(self, x):
        """The squared coefficient of variation of a sample.
        Useful as a loss to encourage a positive distribution to be more uniform.
        Epsilons added for numerical stability.
        Returns 0 for an empty Tensor.
        Args:
        x: a `Tensor`.
        Returns:
        a `Scalar`.
        """
        eps = 1e-10
        # if only num_experts = 1
        if x.shape[0] == 1:
            return torch.Tensor([0])
        return x.float().var() / (x.float().mean() ** 2 + eps)

    def _gates_to_load(self, gates):
        """Compute the true load per expert, given the gates.
        The load is the number of examples for which the corresponding gate is >0.
        Args:
        gates: a `Tensor` of shape [batch_size, n]
        Returns:
        a float32 `Tensor` of shape [n]
        """
        return (gates > 0).sum(0)

    def _prob_in_top_k(
        self, clean_values, noisy_values, noise_stddev, noisy_top_values
    ):
        """Helper function to NoisyTopKGating.
        Computes the probability that value is in top k, given different random noise.
        This gives us a way of backpropagating from a loss that balances the number
        of times each expert is in the top k experts per example.
        In the case of no noise, pass in None for noise_stddev, and the result will
        not be differentiable.
        Args:
        clean_values: a `Tensor` of shape [batch, n].
        noisy_values: a `Tensor` of shape [batch, n].  Equal to clean values plus
          normally distributed noise with standard deviation noise_stddev.
        noise_stddev: a `Tensor` of shape [batch, n], or None
        noisy_top_values: a `Tensor` of shape [batch, m].
           "values" Output of tf.top_k(noisy_top_values, m).  m >= k+1
        Returns:
        a `Tensor` of shape [batch, n].
        """

        batch = clean_values.size(0)
        m = noisy_top_values.size(1)
        top_values_flat = noisy_top_values.flatten()
        threshold_positions_if_in = torch.arange(batch) * m + self.k
        threshold_if_in = torch.unsqueeze(
            torch.gather(top_values_flat, 0, threshold_positions_if_in), 1
        )
        is_in = torch.gt(noisy_values, threshold_if_in)
        threshold_positions_if_out = threshold_positions_if_in - 1
        threshold_if_out = torch.unsqueeze(
            torch.gather(top_values_flat, 0, threshold_positions_if_out), 1
        )
        # is each value currently in the top k.
        prob_if_in = self.normal.cdf((clean_values - threshold_if_in) / noise_stddev)
        prob_if_out = self.normal.cdf((clean_values - threshold_if_out) / noise_stddev)
        prob = torch.where(is_in, prob_if_in, prob_if_out)
        return prob

    def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2):
        """Noisy top-k gating.
        See paper: https://arxiv.org/abs/1701.06538.
        Args:
          x: input Tensor with shape [batch_size, input_size]
          train: a boolean - we only add noise at training time.
          noise_epsilon: a float
        Returns:
          gates: a Tensor with shape [batch_size, num_experts]
          load: a Tensor with shape [num_experts]
        """
        clean_logits = x @ self.w_gate
        if self.noisy_gating:
            raw_noise_stddev = x @ self.w_noise
            noise_stddev = (self.softplus(raw_noise_stddev) + noise_epsilon) * train
            noisy_logits = clean_logits + (
                torch.randn_like(clean_logits) * noise_stddev
            )
            logits = noisy_logits
        else:
            logits = clean_logits

        # calculate topk + 1 that will be needed for the noisy gates
        top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1)
        top_k_logits = top_logits[:, : self.k]
        top_k_indices = top_indices[:, : self.k]
        top_k_gates = self.softmax(top_k_logits)

        zeros = torch.zeros_like(logits, requires_grad=True)
        gates = zeros.scatter(1, top_k_indices, top_k_gates)
        if self.noisy_gating and self.k < self.num_experts:
            load = (
                self._prob_in_top_k(
                    clean_logits, noisy_logits, noise_stddev, top_logits
                )
            ).sum(0)
        else:
            load = self._gates_to_load(gates)
        return gates, load

    def forward(self, query, key, train=True, loss_coef=1e-2):
        """Args:
        x: tensor shape [batch_size, input_size]
        train: a boolean scalar.
        loss_coef: a scalar - multiplier on load-balancing losses
        Returns:
        y: a tensor with shape [batch_size, output_size].
        extra_training_loss: a scalar.  This should be added into the overall
        training loss of the model.  The backpropagation of this loss
        encourages all experts to be approximately equally used across a batch.
        """
        gates, load = self.noisy_top_k_gating(query, train)
        # calculate importance loss
        importance = gates.sum(0)
        #
        loss = self.cv_squared(importance) + self.cv_squared(load)
        loss *= loss_coef

        dispatcher = SparseDispatcher(self.num_experts, gates)
        expert_inputs = dispatcher.dispatch(x)
        gates = dispatcher.expert_to_gates()
        y = dispatcher.combine(key)
        return y, loss

In [None]:
import torch

x_1 = torch.randn(100, 32)
x_2 = torch.randn(100, 32)
x_3 = torch.randn(100, 32)

In [None]:
x = torch.stack([x_1, x_2, x_3])
x.size()

In [None]:
# instantiate the MoE layer
moe = MoE(
    input_size=768,
    output_size=768,
    num_experts=10,
    k=5,
    noisy_gating=True,
)

# forward
# y_hat, aux_loss = moe(x_1, x)

In [None]:
x_1 = torch.randn(20, 12, 768).unsqueeze(2)
value = torch.randn(20, 12, 10, 768)
gates, load = moe.noisy_top_k_gating(x_1, True)

In [None]:
gates.squeeze(2).size()

In [1]:
import torch
from transformers import AdapterFusionConfig, AdapterType, AutoModel, BertTokenizer

In [2]:
base_model = AutoModel.from_pretrained("bert-base-uncased")
adapter_names = [f"adapter_{i}" for i in range(20)]
for adapter_name in adapter_names:
    base_model.add_adapter(adapter_name, adapter_type=AdapterType.text_task)
fusion_config = AdapterFusionConfig.load("dynamic", temperature=-5)
base_model.set_active_adapters(adapter_names)
base_model.add_fusion(adapter_names,fusion_config)

In [3]:
base_model.config.adapters.get(adapter_name)

PfeifferConfig(original_ln_before=True, original_ln_after=True, residual_before_ln=True, adapter_residual_before_ln=False, ln_before=False, ln_after=False, mh_adapter=False, output_adapter=True, non_linearity='relu', reduction_factor=16, invertible_adapter=InvertibleAdapterConfig(block_type='nice', non_linearity='relu', reduction_factor=2), leave_out=[])

In [4]:
# load pre-trained BERT tokenizer from Huggingface
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# tokenize an input sentence
sentence = "It's also, clearly, great fun."

# convert input tokens to indices and create PyTorch input tensor
input_tensor = torch.tensor([tokenizer.encode(sentence) for i in range(20)])
input_tensor.size()

torch.Size([20, 12])

In [None]:
base_model.config.adapter_fusion["temperature"]*-1

In [5]:
outputs = base_model(input_tensor)

In [6]:
outputs

(tensor([[[ 0.1207,  0.1630, -0.1404,  ..., -0.3977,  0.1520,  0.5696],
          [ 0.4179, -0.0688, -0.2795,  ..., -0.3041,  0.3532,  0.3990],
          [ 0.8635, -0.1008,  0.5260,  ..., -0.8003,  0.5659,  0.0475],
          ...,
          [ 0.6075,  0.2061,  0.5748,  ..., -0.2470, -0.2852, -0.3002],
          [ 0.8881,  0.3087,  0.0201,  ...,  0.0293, -0.5853, -0.3446],
          [ 0.1115,  0.4390,  0.0401,  ..., -0.1743,  0.3040, -0.0831]],
 
         [[ 0.1363,  0.1695, -0.1281,  ..., -0.3921,  0.1543,  0.5669],
          [ 0.3726, -0.0795, -0.2825,  ..., -0.2572,  0.3715,  0.3324],
          [ 0.8525, -0.0550,  0.4882,  ..., -0.7421,  0.5554,  0.0065],
          ...,
          [ 0.6609,  0.1591,  0.5604,  ..., -0.2804, -0.1932, -0.2588],
          [ 0.8860,  0.3071,  0.0184,  ...,  0.0306, -0.5896, -0.3420],
          [ 0.1699,  0.4331,  0.0997,  ..., -0.1782,  0.3389, -0.0417]],
 
         [[ 0.1029,  0.1710, -0.1524,  ..., -0.4141,  0.1474,  0.5850],
          [ 0.4229, -0.0842,

In [None]:
outputs = base_model(input_tensor)

In [7]:
outputs[0].size()

torch.Size([20, 12, 768])

In [None]:
outputs[0].unsqueeze(2).size()

In [None]:
outputs[0].transpose(-2, -1).size()

In [None]:
torch.squeeze(torch.matmul(outputs[0][:1], outputs[0].transpose(-2, -1)), dim=2).size()

In [None]:
outputs[0].size(),outputs[0][:1].size()

In [None]:
outputs[0].transpose(-2, -1).size()

In [19]:
a= outputs[0]

In [21]:
from torch.distributions.normal import Normal

Normal(torch.tensor([0.0]).to("cuda"), torch.tensor([1.0]).to("cuda")).cdf(a)
            

tensor([[[0.5480, 0.5647, 0.4442,  ..., 0.3454, 0.5604, 0.7155],
         [0.6620, 0.4726, 0.3899,  ..., 0.3805, 0.6380, 0.6551],
         [0.8061, 0.4599, 0.7005,  ..., 0.2118, 0.7143, 0.5190],
         ...,
         [0.7282, 0.5816, 0.7173,  ..., 0.4025, 0.3877, 0.3820],
         [0.8127, 0.6212, 0.5080,  ..., 0.5117, 0.2792, 0.3652],
         [0.5444, 0.6697, 0.5160,  ..., 0.4308, 0.6194, 0.4669]],

        [[0.5542, 0.5673, 0.4490,  ..., 0.3475, 0.5613, 0.7146],
         [0.6453, 0.4683, 0.3888,  ..., 0.3985, 0.6449, 0.6302],
         [0.8030, 0.4781, 0.6873,  ..., 0.2290, 0.7107, 0.5026],
         ...,
         [0.7457, 0.5632, 0.7124,  ..., 0.3896, 0.4234, 0.3979],
         [0.8122, 0.6206, 0.5074,  ..., 0.5122, 0.2777, 0.3662],
         [0.5675, 0.6675, 0.5397,  ..., 0.4293, 0.6326, 0.4834]],

        [[0.5410, 0.5679, 0.4394,  ..., 0.3394, 0.5586, 0.7207],
         [0.6638, 0.4665, 0.3809,  ..., 0.3896, 0.6417, 0.6490],
         [0.7966, 0.4778, 0.7068,  ..., 0.2365, 0.7091, 0.

In [22]:
outputs[0]

tensor([[[ 0.1207,  0.1630, -0.1404,  ..., -0.3977,  0.1520,  0.5696],
         [ 0.4179, -0.0688, -0.2795,  ..., -0.3041,  0.3532,  0.3990],
         [ 0.8635, -0.1008,  0.5260,  ..., -0.8003,  0.5659,  0.0475],
         ...,
         [ 0.6075,  0.2061,  0.5748,  ..., -0.2470, -0.2852, -0.3002],
         [ 0.8881,  0.3087,  0.0201,  ...,  0.0293, -0.5853, -0.3446],
         [ 0.1115,  0.4390,  0.0401,  ..., -0.1743,  0.3040, -0.0831]],

        [[ 0.1363,  0.1695, -0.1281,  ..., -0.3921,  0.1543,  0.5669],
         [ 0.3726, -0.0795, -0.2825,  ..., -0.2572,  0.3715,  0.3324],
         [ 0.8525, -0.0550,  0.4882,  ..., -0.7421,  0.5554,  0.0065],
         ...,
         [ 0.6609,  0.1591,  0.5604,  ..., -0.2804, -0.1932, -0.2588],
         [ 0.8860,  0.3071,  0.0184,  ...,  0.0306, -0.5896, -0.3420],
         [ 0.1699,  0.4331,  0.0997,  ..., -0.1782,  0.3389, -0.0417]],

        [[ 0.1029,  0.1710, -0.1524,  ..., -0.4141,  0.1474,  0.5850],
         [ 0.4229, -0.0842, -0.3032,  ..., -0