## Getting Started with JAX + Flax


<div class="alert alert-info">

<b>Summary</b>
    
We build a basic Transformer layer using regular JAX and Flax modules. This will be our baseline 
for later comparisons with Transformer Engine.

</div>

We will construct the components as follows:

- `LayerNorm`: `flax.linen.LayerNorm`
- `QKV Projection`: `flax.linen.DenseGeneral` (conceptually three `Dense` layers for Q, K, and V 
  separately, but we fuse into a single `Linear` layer that is three times larger)
- `DotProductAttention`: `DotProductAttention` from 
  [quickstart_jax_utils.py](quickstart_jax_utils.py)
- `Projection`: `flax.linen.DenseGeneral`
- `Dropout`: `flax.linen.Dropout`
- `MLP`: `BasicMLP` from [quickstart_jax_utils.py](quickstart_jax_utils.py)

Putting it all together:

In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import quickstart_jax_utils as utils
from typing import Optional
from functools import partial
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

class BasicTransformerLayer(nn.Module):
    hidden_size: int
    ffn_hidden_size: int
    num_attention_heads: int
    layernorm_eps: Optional[float] = 1e-5
    attention_dropout: Optional[float] = 0.1
    hidden_dropout: Optional[float] = 0.1
    dtype: Optional[type] = jnp.float16
    
    def setup(self):
        # Safeguard layer-norm epsilon with machine epsilon
        typed_perturb = self.dtype(self.layernorm_eps)
        typed_macheps = self.dtype(jnp.finfo(self.dtype).eps)
        epsilon = jax.lax.max(typed_perturb, typed_macheps)
        
        self.ln1 = nn.LayerNorm(float(epsilon), dtype=self.dtype)
        self.qkv_projection = nn.DenseGeneral(3 * self.hidden_size, dtype=self.dtype)
        self.kv_channels = self.hidden_size // self.num_attention_heads
        self.attention = utils.DotProductAttention(
            num_attention_heads=self.num_attention_heads,
            kv_channels=self.kv_channels,
            dropout_rate=self.attention_dropout,
            dropout_rng='attention',
            dtype=self.dtype
        )
        self.projection = nn.DenseGeneral(self.hidden_size, dtype=self.dtype)
        self.dropout = nn.Dropout(self.hidden_dropout, rng_collection='hidden')
        self.ln2 = nn.LayerNorm(float(epsilon), dtype=self.dtype)
        self.mlp = utils.BasicMLP(
            hidden_size=self.hidden_size,
            ffn_hidden_size=self.ffn_hidden_size,
            dtype=self.dtype
        )

    @nn.compact
    def __call__(
        self, 
        x: jnp.ndarray,
        attention_mask: jnp.ndarray,
        train: Optional[bool] = False
    ) -> jnp.ndarray:
        # Multi-Head Attention from Vaswani et al. "Attention is All You Need":
        #   Layer-norm -> QKV proj. -> DotProductAttention -> context proj. -> dropout -> residual
        res = x
        x = self.ln1(x)
        qkv = self.qkv_projection(x)
        qkv = jnp.reshape(
            qkv,
            (qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)
        )
        q, k, v = jnp.split(qkv, 3, axis=3)
        x = self.attention(q, k, v, attention_mask=attention_mask, train=train)
        x = self.projection(x)
        x = self.dropout(x, deterministic=(not train))
        x = res + x

        # Output layernorm -> multi-layer perceptron -> residual
        res = x
        x = self.ln2(x)
        x = self.mlp(x)
        return x + res

That's it! We now have a simple Transformer layer. Before testing it, we first have to set problem sizes, create the necessary random number generators, and initialize the layer parameters.

In [2]:
# Layer configuration:
batch_size = 32
sequence_length = 128
num_attention_heads = 16
head_size = 64
hidden_size = num_attention_heads * head_size
ffn_hidden_size = 4 * hidden_size
dtype = jnp.float16

# Create the necessary RNG keys:
#   The dropout in DotProductAttention will use the 'attention' key group, while
#   the hidden layer dropout in BasicTransformerLayer will use 'hidden' key group.
root_key = jax.random.PRNGKey(seed=0)
fwd_key, bwd_key, params_key, attention_key, hidden_key = jax.random.split(root_key, 5)
rngs = {
    'params' : params_key,
    'attention' : attention_key,
    'hidden' : hidden_key
}

# Synthetic data:
x = jax.device_put(
    jax.random.uniform(fwd_key, (sequence_length, batch_size, hidden_size), dtype=dtype)
)

# Initialize the module and inspect parameters:
basic_transformer = BasicTransformerLayer(
    hidden_size,
    ffn_hidden_size,
    num_attention_heads,
    dtype=dtype
)
print(basic_transformer, end='\n\n')
flax_params = basic_transformer.init(rngs, jnp.zeros_like(x), attention_mask=None, train=True)
utils.inspect_params(flax_params)

BasicTransformerLayer(
    # attributes
    hidden_size = 1024
    ffn_hidden_size = 4096
    num_attention_heads = 16
    layernorm_eps = 1e-05
    attention_dropout = 0.1
    hidden_dropout = 0.1
    dtype = float16
)

params
|__ln1
|  |__scale: (1024,)
|  |__bias: (1024,)
|__qkv_projection
|  |__kernel: (1024, 3072)
|  |__bias: (3072,)
|__projection
|  |__kernel: (1024, 1024)
|  |__bias: (1024,)
|__ln2
|  |__scale: (1024,)
|  |__bias: (1024,)
|__mlp
   |__linear1
   |  |__kernel: (1024, 4096)
   |  |__bias: (4096,)
   |__linear2
      |__kernel: (4096, 1024)
      |__bias: (1024,)


Finally, we test our implementation with a sum squared loss over the Transformer layer output.

In [3]:

# Mean squared error loss function:
def sum_squared_loss(module, params, x, **kwargs):
    out = module.apply(params, x, **kwargs)
    out = jnp.reshape(out, (out.size,))
    return jnp.dot(out, out).astype(out.dtype)

# Autograd sum_squared_loss for BasicTransformer w.r.t. input X:
fwd_bwd_func_flax = jax.value_and_grad(partial(sum_squared_loss, basic_transformer), argnums=(0, 1))

# Named arguments for the loss function:
fwd_kwargs = {
    'attention_mask': None,
    'train' : True,
    'rngs' : rngs
}

# Warmup iterations:
for _ in range(50):
    y, (dp, dy) = fwd_bwd_func_flax(flax_params, x, **fwd_kwargs)

# Timed iterations:
print("\nExecution time:")
%timeit -r 5 -n 10 jax.block_until_ready(fwd_bwd_func_flax(flax_params, x, **fwd_kwargs))


Execution time:
84.4 ms ± 1.03 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)


## Meet Transformer Engine

<div class="alert alert-info">

<b>Summary</b>
    
We modify the example Transformer layer to include the simplest TE modules: `Linear` and `LayerNorm`.

</div>

Now that we have a basic Transformer layer, let's use Transformer Engine to speed up the training. 

In [4]:
import transformer_engine.jax as te

TE provides a set of Flax modules that can be used to build Transformer layers. The simplest of the provided modules are the `Linear` and `LayerNorm` layers, which we can use instead of `flax.linen.Linear` and `flax.linen.LayerNorm`. Let's modify `BasicTransformerLayer`:

In [5]:
class BasicTEMLP(nn.Module):
    hidden_size: int
    ffn_hidden_size: int
    dtype : Optional[type] = jnp.float16

    def setup(self):
        self.linear1 = te.flax.DenseGeneral(self.ffn_hidden_size, dtype=self.dtype)
        self.linear2 = te.flax.DenseGeneral(self.hidden_size, dtype=self.dtype)

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = self.linear1(x)
        x = jax.nn.gelu(x, approximate=True)
        x = self.linear2(x)
        return x


class BasicTETransformerLayer(nn.Module):
    hidden_size: int
    ffn_hidden_size: int
    num_attention_heads: int
    layernorm_eps: Optional[float] = 1e-5
    attention_dropout: Optional[float] = 0.1
    hidden_dropout: Optional[float] = 0.1
    dtype: Optional[type] = jnp.float16

    def setup(self):
        self.ln1 = te.flax.LayerNorm(
            epsilon=self.layernorm_eps,
            dtype=self.dtype,
            transpose_batch_sequence=True
        )
        self.qkv_projection = te.flax.DenseGeneral(
            3 * self.hidden_size, dtype=self.dtype
        )
        self.kv_channels = self.hidden_size // self.num_attention_heads
        self.attention = utils.DotProductAttention(
            num_attention_heads=self.num_attention_heads,
            kv_channels=self.kv_channels,
            dropout_rate=self.attention_dropout,
            dropout_rng='attention',
            dtype=self.dtype
        )
        self.projection = te.flax.DenseGeneral(hidden_size, dtype=self.dtype)
        self.dropout = nn.Dropout(self.hidden_dropout, rng_collection='hidden')
        self.ln2 = te.flax.LayerNorm(
            epsilon=self.layernorm_eps,
            dtype=self.dtype,
            transpose_batch_sequence=True
        )
        self.mlp = BasicTEMLP(
            hidden_size=self.hidden_size,
            ffn_hidden_size=self.ffn_hidden_size,
            dtype=self.dtype
        )

    @nn.compact
    def __call__(self,
                x: jnp.ndarray,
                attention_mask: jnp.ndarray,
                train: Optional[bool] = False
    ) -> jnp.ndarray:
        res = x
        x = self.ln1(x)
        qkv = self.qkv_projection(x)
        qkv = jnp.reshape(
            qkv, (qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)
        )
        q, k, v = jnp.split(qkv, 3, axis=3)
        x = self.attention(q, k, v, attention_mask=attention_mask)
        x = self.projection(x)
        x = self.dropout(x, deterministic=(not train))
        x = res + x
        
        res = x
        x = self.ln2(x.astype(self.dtype))
        x = self.mlp(x)
        return x + res

In [6]:
# Initialize the module and inspect parameters:
basic_te_transformer = BasicTETransformerLayer(
    hidden_size,
    ffn_hidden_size,
    num_attention_heads,
    dtype=dtype
)
print(basic_te_transformer, end='\n\n')

# Share parameters with the Flax implementation
te_params = basic_te_transformer.init(rngs, jnp.zeros_like(x), attention_mask=None, train=True)
te_params = utils.share_params(te_params, flax_params)
utils.inspect_params(te_params)

BasicTETransformerLayer(
    # attributes
    hidden_size = 1024
    ffn_hidden_size = 4096
    num_attention_heads = 16
    layernorm_eps = 1e-05
    attention_dropout = 0.1
    hidden_dropout = 0.1
    dtype = float16
)

params
|__ln1
|  |__scale: (1024,)
|  |__ln_bias: (1024,)
|__qkv_projection
|  |__kernel: (1024, 3072)
|  |__bias: (3072,)
|__projection
|  |__kernel: (1024, 1024)
|  |__bias: (1024,)
|__ln2
|  |__scale: (1024,)
|  |__ln_bias: (1024,)
|__mlp
   |__linear1
   |  |__kernel: (1024, 4096)
   |  |__bias: (4096,)
   |__linear2
      |__kernel: (4096, 1024)
      |__bias: (1024,)


In [7]:
fwd_bwd_func_te = jax.value_and_grad(partial(sum_squared_loss, basic_te_transformer), argnums=(0, 1))

# Warmup iterations:
for _ in range(50):
    y, (dp, dy) = fwd_bwd_func_te(te_params, x, **fwd_kwargs)

# Timed iterations:
%timeit -r 5 -n 10 jax.block_until_ready(fwd_bwd_func_te(te_params, x, **fwd_kwargs))

60.1 ms ± 1.08 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)


## Fused TE Modules

<div class="alert alert-info">

<b>Summary</b>
    
We optimize the example Transformer layer with TE modules for fused operations.

</div>

The `Linear` layer is enough to build any Transformer model and it enables usage of Transformer Engine even for very custom Transformers. However, having more knowledge about the model allows for additional optimizations like kernel fusion, increasing the achievable speedup.

Transformer Engine therefore provides coarser modules that span multiple layers:

* `LayerNormDenseGeneral`
* `LayerNormMLP`
* `TransformerLayer`

Building a third iteration of our Transformer layer with `LayerNormDenseGeneral` and `LayerNormMLP`:

In [8]:
class FusedTETransformerLayer(nn.Module):
    hidden_size: int
    ffn_hidden_size: int
    num_attention_heads: int
    layernorm_eps: Optional[float] = 1e-5
    attention_dropout: Optional[float] = 0.1
    hidden_dropout: Optional[float] = 0.1
    dtype: Optional[type] = jnp.float16

    def setup(self):
        self.ln_qkv = te.flax.LayerNormDenseGeneral(
            3 * self.hidden_size,
            epsilon=self.layernorm_eps,
            return_layernorm_output=False,
            dtype=self.dtype,
            transpose_batch_sequence=True
        )
        self.kv_channels = self.hidden_size // self.num_attention_heads
        self.attention = utils.DotProductAttention(
            num_attention_heads=self.num_attention_heads,
            kv_channels=self.kv_channels,
            dropout_rate=self.attention_dropout,
            dropout_rng='attention',
            dtype=self.dtype
        )
        self.projection = te.flax.DenseGeneral(hidden_size, dtype=self.dtype)
        self.dropout = nn.Dropout(self.hidden_dropout, rng_collection='hidden')
        self.ln_mlp = te.flax.LayerNormMLP(
            self.ffn_hidden_size,
            epsilon=self.layernorm_eps,
            return_layernorm_output=False,
            activations=('gelu','linear'),
            intermediate_dropout_rate=0.0,
            dtype=self.dtype,
            transpose_batch_sequence=True
        )

    @nn.compact
    def __call__(self,
                x: jnp.ndarray,
                attention_mask: jnp.ndarray,
                train: Optional[bool] = False
    ) -> jnp.ndarray:
        res = x
        qkv, _ = self.ln_qkv(x)
        qkv = jnp.reshape(
            qkv, (qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)
        )
        q, k, v = jnp.split(qkv, 3, axis=3)
        x = self.attention(q, k, v, attention_mask=attention_mask)
        x = self.projection(x)
        x = self.dropout(x, deterministic=(not train))
        x = res + x
        
        res = x
        x, _ = self.ln_mlp(x.astype(self.dtype))
        return x + res

In [9]:
fused_te_transformer = FusedTETransformerLayer(
    hidden_size,
    ffn_hidden_size,
    num_attention_heads,
    dtype=dtype
)
print(fused_te_transformer, end='\n\n')

fused_te_params = fused_te_transformer.init(rngs, jnp.zeros_like(x), attention_mask=None, train=True)
fused_te_params = utils.share_params(fused_te_params, flax_params)
utils.inspect_params(fused_te_params)

FusedTETransformerLayer(
    # attributes
    hidden_size = 1024
    ffn_hidden_size = 4096
    num_attention_heads = 16
    layernorm_eps = 1e-05
    attention_dropout = 0.1
    hidden_dropout = 0.1
    dtype = float16
)

params
|__ln_qkv
|  |__scale: (1024,)
|  |__ln_bias: (1024,)
|  |__kernel: (1024, 3072)
|__projection
|  |__kernel: (1024, 1024)
|  |__bias: (1024,)
|__ln_mlp
   |__scale: (1024,)
   |__ln_bias: (1024,)
   |__wi_kernel: (1024, 2, 4096)
   |__wo_kernel: (4096, 1024)


In [10]:
fwd_bwd_func_fused_te = jax.value_and_grad(partial(sum_squared_loss, fused_te_transformer), argnums=(0, 1))

# Warmup iterations:
for _ in range(50):
    y, (dp, dy) = fwd_bwd_func_fused_te(fused_te_params, x, **fwd_kwargs)

# Timed iterations:
%timeit -r 5 -n 10 jax.block_until_ready(fwd_bwd_func_fused_te(fused_te_params, x, **fwd_kwargs))

57.5 ms ± 1.22 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)


Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures and it provides the highest degree of performance optimization.

In [11]:
native_te_transformer = te.flax.TransformerLayer(
    hidden_size,
    ffn_hidden_size,
    num_attention_heads,
    layernorm_epsilon=1e-5,
    intermediate_dropout=0.0,
    dropout_rng_name='transformer',
    mlp_activations=('gelu','linear'),
    enable_relative_embedding=False,
    dtype=dtype,
    transpose_batch_sequence=True
)
print(native_te_transformer, end='\n\n')

native_te_rngs = {
    'params' : params_key,
    'transformer' : jax.random.PRNGKey(seed=1)
}
native_te_params = native_te_transformer.init(
    native_te_rngs,
    jnp.zeros_like(x),
    attention_mask=None,
    deterministic=False   # train=True
)
native_te_params = utils.share_params(native_te_params, flax_params)
utils.inspect_params(native_te_params)

TransformerLayer(
    # attributes
    hidden_size = 1024
    mlp_hidden_size = 4096
    num_attention_heads = 16
    layernorm_type = 'layernorm'
    layernorm_epsilon = 1e-05
    zero_centered_gamma = False
    hidden_dropout = 0.1
    hidden_dropout_dims = ()
    attention_dropout = 0.1
    intermediate_dropout = 0.0
    intermediate_dropout_dims = ()
    dropout_rng_name = 'transformer'
    mha_kernel_init = init
    mlp_kernel_init = init
    mlp_activations = ('gelu', 'linear')
    use_bias = False
    bias_init = zeros
    apply_residual_connection_post_layernorm = False
    output_layernorm = False
    float32_attention_logits = False
    layer_type = <TransformerLayerType.ENCODER: 'encoder'>
    self_attn_mask_type = 'causal'
    enable_relative_embedding = False
    relative_embedding = None
    dtype = float16
    drop_path = 0.0
    fuse_qkv_params = True
    transpose_batch_sequence = True
    scale_attn_logits = False
    scaled_query_init = True
)

params
|__attention
| 

In [12]:
fwd_bwd_func_native_te = jax.value_and_grad(partial(sum_squared_loss, native_te_transformer), argnums=(0, 1))

native_te_kwargs = {
    'attention_mask' : None,
    'deterministic' : False,
    'rngs' : native_te_rngs
}

# Warmup iterations:
for _ in range(50):
    y, (dp, dy) = fwd_bwd_func_native_te(native_te_params, x, **native_te_kwargs)

# Timed iterations:
%timeit -r 5 -n 10 jax.block_until_ready(fwd_bwd_func_native_te(native_te_params, x, **native_te_kwargs))

57.4 ms ± 955 µs per loop (mean ± std. dev. of 5 runs, 10 loops each)


## Enabling FP8

<div class="alert alert-info">

<b>Summary</b>
    
We configure a TE module to perform compute in FP8.

</div>

Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager. Note that fp8_autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options.

In [13]:
from transformer_engine.common import recipe

fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.HYBRID)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    fp8_te_transformer = te.flax.TransformerLayer(
        hidden_size,
        ffn_hidden_size,
        num_attention_heads,
        layernorm_epsilon=1e-5,
        intermediate_dropout=0.0,
        dropout_rng_name='transformer',
        mlp_activations=('gelu','linear'),
        enable_relative_embedding=False,
        dtype=dtype,
        transpose_batch_sequence=True
    )

    fp8_te_params = fp8_te_transformer.init(native_te_rngs, jnp.zeros_like(x), attention_mask=None)
    fwd_bwd_func_fp8_te = jax.value_and_grad(partial(sum_squared_loss, fp8_te_transformer), argnums=(0, 1))

    # Warmup iterations:
    for _ in range(50):
        y, (dp, dy) = fwd_bwd_func_fp8_te(fp8_te_arams, x, **native_te_kwargs)

    # Timed iterations:
    %timeit -r 5 -n 10 jax.block_until_ready(fwd_bwd_func_fp8_te(fp8_te_params, x, **native_te_kwargs))

AssertionError: Device compute capability 8.9 or higher required for FP8 execution.