# License

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at:

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.



# Composable Function-preserving Expansions for Transformer Architectures

This notebook contains implementations of the six function-preserving transformations of transformer-based models proposed in "Composable Function-preserving Expansions for Transformer Architectures". We provide a basic implementation of a generic transformer architecture and show that each transformation is function-preserving, both for individual architectural components and the whole transformer model, as well as for individual transformations and combinations of transformations.


## Imports


In [None]:
# On Colab, we recommend a GPU or CPU runtime.

import numpy as np
import math
from typing import Any, Callable, Tuple

try:
  import jax
except ModuleNotFoundError: # Install jax if missing
  !pip install --quiet jax
  import jax

import jax.numpy as jnp
from jax import random

# Seeding for random operations
main_rng = random.PRNGKey(42)

try:
  import flax
except ModuleNotFoundError: # Install flax if missing
  !pip install --quiet flax
  import flax
from flax import linen as nn

print("Device:", jax.devices()[0])

## Name constants

In [None]:
NAME_P = "PosE"
NAME_W_OUT = "W_OUT"
NAME_MLP_l1 = "MLP_l1"
NAME_MLP_l2 = "MLP_l2"
NAME_W_Oe = "W_Oe"
NAME_W_Q = "W_Q"
NAME_W_K = "W_K"
NAME_W_V = "W_V"
NAME_norm1 = "norm1"
NAME_norm2 = "norm2"


## Architectural components

In [None]:
def scaled_dot_product(q, k, v, mask=None):
  d_k = q.shape[-1]
  attn_logits = jnp.matmul(q, jnp.swapaxes(k, -2, -1))
  attn_logits = attn_logits / math.sqrt(d_k)
  if mask is not None:
    attn_logits = jnp.where(mask == 0, -9e15, attn_logits)
  attention = nn.softmax(attn_logits, axis=-1)
  values = jnp.matmul(attention, v)
  return values, attention

In [None]:
class Attention(nn.Module):
  dim_k : int # key/query representation dimension
  dim_v : int # attention head representation dimension

  def setup(self):
    self.W_Q = nn.Dense(self.dim_k,
                        kernel_init=nn.initializers.xavier_uniform(),
                        use_bias=False,
                        name = NAME_W_Q
                        )
    self.W_K = nn.Dense(self.dim_k,
                        kernel_init=nn.initializers.xavier_uniform(),
                        use_bias=False,
                        name = NAME_W_K
                        )
    self.W_V = nn.Dense(self.dim_v,
                        kernel_init=nn.initializers.xavier_uniform(),
                        use_bias=False,
                        name = NAME_W_V
                        )

  def __call__(self, x, mask=None):
    Q = self.W_Q(x)
    K = self.W_K(x)
    V = self.W_V(x)

    values, attention = scaled_dot_product(Q, K, V, mask=mask) # Softmax

    return values, attention


class SimpleMultiHeadAttention(nn.Module):
  dim_k : int # key/query representation dimension
  dim_v : int # attention head representation dimension
  dim_E : int # number of attention heads
  dim_h : int # transformer hidden dimension (MHA in/out dimension)

  def setup(self):
    self.heads = [Attention(dim_k=self.dim_k,
                            dim_v=self.dim_v) for _ in range(self.dim_E)]
    self.embeds = [nn.Dense(self.dim_h,
                            kernel_init=nn.initializers.xavier_uniform(),
                            use_bias=False,
                            name = NAME_W_Oe + "_" + str(i)
                            ) for i in range(self.dim_E)]

  def __call__(self, x):
    return sum([embed(head(x)[0]) for head, embed in zip(self.heads,
                                                         self.embeds)])


class SimpleMultiLayerPerceptron(nn.Module):
  dim_p : int # MLP inner dimension
  dim_h : int # transformer hidden dimension (MLP in/out dimension)

  def setup(self):
    self.layer1 = nn.Dense(self.dim_p,
                           kernel_init=nn.initializers.xavier_uniform(),
                           name=NAME_MLP_l1)
    self.nonlinearity = nn.relu
    self.layer2 = nn.Dense(self.dim_h,
                           kernel_init=nn.initializers.xavier_uniform(),
                           name=NAME_MLP_l2)

  def __call__(self, x):
    return self.layer2(self.nonlinearity(self.layer1(x)))


class SimpleTransformerLayer(nn.Module):
  dim_k : int # key/query representation dimension
  dim_v : int # attention head representation dimension
  dim_E : int # number of attention heads
  dim_h : int # transformer hidden dimension
  dim_p : int # MLP inner dimension

  def setup(self):
    self.MHA = SimpleMultiHeadAttention(dim_k=self.dim_k,
                                        dim_v=self.dim_v,
                                        dim_E=self.dim_E,
                                        dim_h=self.dim_h)
    self.MLP = SimpleMultiLayerPerceptron(dim_p=self.dim_p,
                                          dim_h=self.dim_h)
    self.norms = [nn.RMSNorm(name=NAME_norm1), nn.RMSNorm(name=NAME_norm2)]

  def __call__(self, x):
    # MHA block
    MHA_out = self.MHA(self.norms[0](x))
    x = x + MHA_out

    # MLP block
    MLP_out = self.MLP(self.norms[1](x))
    x = x + MLP_out

    return x


class SimplePositionalEncoding(nn.Module):
  pos_init : Callable[[Any, Tuple[int], Any], Any]

  @nn.compact
  def __call__(self, inputs):
    pos_shape = (1, inputs.shape[1], inputs.shape[2])
    pe = self.param(NAME_P, self.pos_init, pos_shape)
    return inputs + pe


class SimpleTransformer(nn.Module):
  dim_k : int # key/query representation dimension
  dim_v : int # attention head representation dimension
  dim_E : int # number of attention heads
  dim_h : int # transformer hidden dimension
  dim_p : int # MLP inner dimension
  dim_N : int # number of transformer layers
  dim_hout : int # final output dimension

  def setup(self):
    self.encoding = SimplePositionalEncoding(
              pos_init=nn.initializers.normal(stddev=0.02),
              name=NAME_P)
    self.layers = [SimpleTransformerLayer(dim_k=self.dim_k,
                                          dim_v=self.dim_v,
                                          dim_E=self.dim_E,
                                          dim_h=self.dim_h,
                                          dim_p=self.dim_p
                                          ) for _ in range(self.dim_N)]
    self.outlayer = nn.Dense(self.dim_hout,
                             kernel_init=nn.initializers.xavier_uniform(),
                             use_bias=False)

  def __call__(self, x):
    # Check if input needs to be zero padded
    if x.shape[-1] < self.dim_h:
      padding = [[0, 0] for _ in range(len(x.shape))]
      padding[-1][1] = self.dim_h - x.shape[-1]
      x = jnp.pad(x, padding, mode="constant", constant_values=0)

    # Positional encoding
    x = self.encoding(x)

    # Transformer layers
    for layer in self.layers:
      x = layer(x)

    # Final output
    return self.outlayer(x)


## Expansion utilities

In [None]:
def keys_to_string_id(keys):
  string_id = [k.key for k in keys]
  string_id = '/'.join(string_id)
  return string_id


In [None]:
def params_to_dict(keys_params):
  rtn = {}
  for (keys, params) in keys_params:
    string_id = keys_to_string_id(keys)
    rtn[string_id] = params
  return rtn

In [None]:
def params_pad_to_shape(params_source,
                        params_target,
                        function_preserving=True):
  flat_params_source = jax.tree_util.tree_flatten_with_path(params_source)
  flat_params_target = jax.tree_util.tree_flatten_with_path(params_target)

  assert len(flat_params_target[0]) >= len(flat_params_source[0])

  sid2params_source = params_to_dict(flat_params_source[0])
  padded_flat_params = []

  for i in range(len(flat_params_target[0])):
    string_id =  keys_to_string_id(flat_params_target[0][i][0])
    print(string_id, "shape:", flat_params_target[0][i][1].shape)

    if string_id in sid2params_source:
      to_pad_params = sid2params_source[string_id]
      print("Found in source model, shape:",
            to_pad_params.shape, "sum:",  jnp.sum(to_pad_params))
      original_shape = to_pad_params.shape
      target_shape = flat_params_target[0][i][1].shape

      to_pad_shape = []
      assert len(target_shape) == len(original_shape)
      for i in range(len(target_shape)):
        to_pad_shape.append((0, target_shape[i] - original_shape[i]))

      print("to pad shape:", to_pad_shape)

      zero_init = False
      if function_preserving:
        # MLP expansion Sec 3.1
        if NAME_MLP_l2 in string_id and to_pad_shape[0][1] > 0:
          zero_init = True
        # Heads expansion Sec 3.2
        if NAME_W_Oe in string_id and to_pad_shape[0][1] > 0:
          zero_init = True
        # Attention expansion Sec 3.4
        if NAME_W_K in string_id and to_pad_shape[1][1] > 0:
          zero_init = True
        # Hidden dimension expansion Sec 3.5
        if ((NAME_MLP_l2 in string_id and to_pad_shape[-1][1] > 0) or
            (NAME_W_Oe in string_id and to_pad_shape[1][1] > 0) or
            (NAME_P in string_id and to_pad_shape[-1][1] > 0)):
           zero_init = True

      if zero_init:
        print("0 init")
        constant_values = 0
      else:
        print("No 0 init")
        constant_values = 42 # placeholder

      # Key matrix scaling for attention expansion Sec 3.4
      if NAME_W_K in string_id and to_pad_shape[1][1] > 0:
        to_pad_params = to_pad_params * jnp.sqrt(target_shape[1]/
                                                 original_shape[1])

      # norm scaling for hidden dimension expansion Sec 3.5
      if ((NAME_norm1 in string_id or NAME_norm2 in string_id) and
              to_pad_shape[-1][1] > 0):
        to_pad_params = to_pad_params * jnp.sqrt(original_shape[0]/
                                                 target_shape[-1])

      padded_params = jnp.pad(to_pad_params, to_pad_shape,
                              'constant', constant_values=constant_values)

      print("Padded" ,padded_params.shape, "sum:", jnp.sum(padded_params))
      padded_flat_params.append(padded_params)
    else:
      print("Not found in source model")
      zero_init = False
      if function_preserving:
        # Head addition Sec 3.3
        if NAME_W_Oe in string_id:
          zero_init = True
        # Layer addition Sec 3.6
        if ((NAME_MLP_l2 in string_id) or
            (NAME_W_Oe in string_id)):
          zero_init = True

      if zero_init:
        print("0 init")
        padded_flat_params.append(jnp.zeros_like(flat_params_target[0][i][1]))
      else:
        print("No 0 init")
        padded_flat_params.append(flat_params_target[0][i][1])


    print("----")
  return jax.tree_util.tree_unflatten(flat_params_target[1], padded_flat_params)

## MultiLayer Perceptron block


#### MLP base

In [None]:
ORIGINAL_p = 6
ORIGINAL_h = 5
BATCH = 3
SEQUENCE = 2

main_rng, X_rng = random.split(main_rng)
X_MLP = random.normal(X_rng, (BATCH, SEQUENCE, ORIGINAL_h))

mlp = SimpleMultiLayerPerceptron(dim_h=ORIGINAL_h, dim_p=ORIGINAL_p)
main_rng, init_rng = random.split(main_rng)
params = mlp.init(init_rng, X_MLP)['params']
O_MLP = mlp.apply({'params': params}, X_MLP)

#### MLP expansion (Section 3.1)

In [None]:
MOD_p = ORIGINAL_p + 3
MOD_mlp = SimpleMultiLayerPerceptron(dim_h=ORIGINAL_h, dim_p=MOD_p)
main_rng, init_rng = random.split(main_rng)
MOD_params_init = MOD_mlp.init(init_rng, X_MLP)['params']

MOD_params = params_pad_to_shape(params, MOD_params_init)

MOD_O_MLP = MOD_mlp.apply({'params': MOD_params}, X_MLP)

print("Original:",  jnp.sum(O_MLP), "Expanded MLP:", jnp.sum(MOD_O_MLP))
assert jnp.allclose(MOD_O_MLP, O_MLP, rtol = 1e-3)

## MultiHead Attention block

#### MHA base

In [None]:
ORIGINAL_k = 6
ORIGINAL_v = 6
ORIGINAL_h = 5
ORIGINAL_E = 2
BATCH = 3
SEQUENCE = 2

main_rng, I_rng = random.split(main_rng)
X_MHA = random.normal(I_rng, (BATCH, SEQUENCE, ORIGINAL_h))

mha = SimpleMultiHeadAttention(dim_k=ORIGINAL_k,
                                dim_v=ORIGINAL_v,
                                dim_h=ORIGINAL_h,
                                dim_E=ORIGINAL_E)
main_rng, init_rng = random.split(main_rng)
params = mha.init(init_rng, X_MHA)['params']
O_MHA = mha.apply({'params': params}, X_MHA)

#### Head addition on MHA (Section 3.2)

In [None]:
MOD_E = ORIGINAL_E + 3
MOD_mha = SimpleMultiHeadAttention(
    dim_k=ORIGINAL_k,
    dim_v=ORIGINAL_v,
    dim_h=ORIGINAL_h,
    dim_E=MOD_E)
main_rng, init_rng = random.split(main_rng)
MOD_params_init = MOD_mha.init(init_rng, X_MHA)['params']

MOD_params = params_pad_to_shape(params, MOD_params_init)

MOD_O_MHA = MOD_mha.apply({'params': MOD_params}, X_MHA)

print("Original:",  jnp.sum(O_MHA), "Added heads:", jnp.sum(MOD_O_MHA))
assert jnp.allclose(MOD_O_MHA, O_MHA, rtol = 1e-3)

#### Heads expansion on MHA (Section 3.3)

In [None]:
MOD_v = ORIGINAL_v + 3
MOD_mha = SimpleMultiHeadAttention(
    dim_k=ORIGINAL_k,
    dim_v=MOD_v,
    dim_h=ORIGINAL_h,
    dim_E=ORIGINAL_E)
main_rng, init_rng = random.split(main_rng)
MOD_params_init = MOD_mha.init(init_rng, X_MHA)['params']

MOD_params = params_pad_to_shape(params, MOD_params_init)

MOD_O_MHA = MOD_mha.apply({'params': MOD_params}, X_MHA)

print("Original:",  jnp.sum(O_MHA), "Expanded heads:", jnp.sum(MOD_O_MHA))
assert jnp.allclose(MOD_O_MHA, O_MHA, rtol = 1e-3)

#### Attention expansion on MHA (Section 3.4)

In [None]:
MOD_k = ORIGINAL_k + 3
MOD_mha = SimpleMultiHeadAttention(
    dim_k=MOD_k,
    dim_v=ORIGINAL_v,
    dim_h=ORIGINAL_h,
    dim_E=ORIGINAL_E)
main_rng, init_rng = random.split(main_rng)
MOD_params_init = MOD_mha.init(init_rng, X_MHA)['params']

MOD_params = params_pad_to_shape(params, MOD_params_init)

MOD_O_MHA = MOD_mha.apply({'params': MOD_params}, X_MHA)

print("Original:",  jnp.sum(O_MHA),
      "Expanded key/query representation:", jnp.sum(MOD_O_MHA))
assert jnp.allclose(MOD_O_MHA, O_MHA, rtol = 1e-3)

#### MHA Combination

In [None]:
MOD_mha = SimpleMultiHeadAttention(
    dim_k=MOD_k,
    dim_v=MOD_v,
    dim_h=ORIGINAL_h,
    dim_E=MOD_E)
main_rng, init_rng = random.split(main_rng)
MOD_params_init = MOD_mha.init(init_rng, X_MHA)['params']

MOD_params = params_pad_to_shape(params, MOD_params_init)

MOD_O_MHA = MOD_mha.apply({'params': MOD_params}, X_MHA)

print("Original:",  jnp.sum(O_MHA),
      "Expanded key/query representation:", jnp.sum(MOD_O_MHA))
assert jnp.allclose(MOD_O_MHA, O_MHA, rtol = 1e-3)

## Transformer layer

#### Layer base

In [None]:
ORIGINAL_k = 6
ORIGINAL_v = 6
ORIGINAL_h = 5
ORIGINAL_E = 2
ORIGINAL_p = 4
BATCH = 3
SEQUENCE = 2

main_rng, I_rng = random.split(main_rng)
I_layer = random.normal(I_rng, (BATCH, SEQUENCE, ORIGINAL_h))

layer = SimpleTransformerLayer(dim_k=ORIGINAL_k,
                               dim_v=ORIGINAL_v,
                               dim_h=ORIGINAL_h,
                               dim_E=ORIGINAL_E,
                               dim_p=ORIGINAL_p)
main_rng, init_rng = random.split(main_rng)
params = layer.init(init_rng, I_layer)['params']
O_layer = layer.apply({'params': params}, I_layer)

#### Hidden Dimension Expansion (3.5)




In [None]:
MOD_h = ORIGINAL_h + 3
MOD_I_layer = jnp.pad(I_layer, ((0, 0), (0, 0), (0, MOD_h-ORIGINAL_h)),
                      'constant', constant_values=0)
O_layer_padded = jnp.pad(O_layer, ((0, 0), (0, 0), (0, MOD_h-ORIGINAL_h)),
                         'constant', constant_values=0)
MOD_layer = SimpleTransformerLayer(dim_h=MOD_h,
                                   dim_E=ORIGINAL_E,
                                   dim_p=ORIGINAL_p,
                                   dim_k=ORIGINAL_k,
                                   dim_v=ORIGINAL_v)
MOD_params = MOD_layer.init(init_rng, MOD_I_layer)['params']
MOD_params = params_pad_to_shape(params, MOD_params)
MOD_O_layer = MOD_layer.apply({'params': MOD_params}, MOD_I_layer)

print("Original:",  jnp.sum(O_layer),
      "Expanded hidden dimension:", jnp.sum(MOD_O_layer))
assert jnp.allclose(MOD_O_layer, O_layer_padded, rtol = 1e-3)


## Full model

#### Transformer base

In [None]:
ORIGINAL_k = 6
ORIGINAL_v = 6
ORIGINAL_h = 5
ORIGINAL_E = 2
ORIGINAL_p = 4
ORIGINAL_N = 2
BATCH = 3
SEQUENCE = 2
HOUT = 2

# Input representation
main_rng, I_rng = random.split(main_rng)
I = random.normal(I_rng, (BATCH, SEQUENCE, ORIGINAL_h))

# MHA block
model = SimpleTransformer(dim_k=ORIGINAL_k,
                          dim_v=ORIGINAL_v,
                          dim_h=ORIGINAL_h,
                          dim_E=ORIGINAL_E,
                          dim_p=ORIGINAL_p,
                          dim_N=ORIGINAL_N,
                          dim_hout=HOUT)
main_rng, init_rng = random.split(main_rng)
params = model.init(init_rng, I)['params']
O = model.apply({'params': params}, I)

#### Layer Addition (3.6)



In [None]:
MOD_N = ORIGINAL_N + 3
MOD_model = SimpleTransformer(dim_h=ORIGINAL_h,
                              dim_E=ORIGINAL_E,
                              dim_p=ORIGINAL_p,
                              dim_k=ORIGINAL_k,
                              dim_v=ORIGINAL_v,
                              dim_N=MOD_N,
                              dim_hout=HOUT)
MOD_params = MOD_model.init(init_rng, I)['params']
MOD_params = params_pad_to_shape(params, MOD_params)
MOD_O = MOD_model.apply({'params': MOD_params}, I)

print("Original:",  jnp.sum(O), "Added layers:", jnp.sum(MOD_O))
assert jnp.allclose(MOD_O, O, rtol = 1e-3)


#### Hidden Dimension Expansion (Sec. 3.5) on full model

In [None]:
MOD_model = SimpleTransformer(dim_h=MOD_h,
                              dim_E=ORIGINAL_E,
                              dim_p=ORIGINAL_p,
                              dim_k=ORIGINAL_k,
                              dim_v=ORIGINAL_v,
                              dim_N=ORIGINAL_N,
                              dim_hout=HOUT)
MOD_params = MOD_model.init(init_rng, I)['params']
MOD_params = params_pad_to_shape(params, MOD_params)
MOD_O = MOD_model.apply({'params': MOD_params}, I)

print("Original:",  jnp.sum(O), "Expanded hidden dimension:", jnp.sum(MOD_O))
assert jnp.allclose(MOD_O, O, rtol = 1e-3)

#### All transformations on full model

In [None]:
MOD_model = SimpleTransformer(dim_h=MOD_h,
                              dim_E=MOD_E,
                              dim_p=MOD_p,
                              dim_k=MOD_k,
                              dim_v=MOD_v,
                              dim_N=MOD_N,
                              dim_hout=HOUT)
MOD_params = MOD_model.init(init_rng, I)['params']
MOD_params = params_pad_to_shape(params, MOD_params)
MOD_O = MOD_model.apply({'params': MOD_params}, I)

print("Original:",  jnp.sum(O), "Expanded hidden dimension:", jnp.sum(MOD_O))
assert jnp.allclose(MOD_O, O, rtol = 1e-3)