## Getting Started with PyTorch

<div class="alert alert-info">

<b>Summary</b>
    
We build a basic Transformer layer using regular Jax and Flax modules. This will be our baseline 
for later comparisons with Transformer Engine.

</div>

We construct the components as follows:

- `LayerNorm`: `flax.linen.LayerNorm`
- `QKV Projection`: `flax.linen.Dense` (conceptually three `Dense` layers for Q, K, and V 
  separately, but we fuse into a single `Linear` layer that is three times larger)
- `DotProductAttention`: `DotProductAttention` from 
  [quickstart_jax_utils.py](quickstart_jax_utils.py)
- `Projection`: `flax.linen.Dense`
- `Dropout`: `flax.linen.Dropout`
- `MLP`: `BasicMLP` from [quickstart_jax_utils.py](quickstart_jax_utils.py)

Putting it all together:

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import quickstart_jax_utils as utils

class BasicTransformerLayer(nn.Module):
    hidden_size: int
    ffn_hidden_size: int
    num_attention_heads: int
    layernorm_eps: int = 1e-6
    attention_dropout: float = 0.1
    hidden_dropout: float = 0.1
    
    def setup(self):
        self.kv_channels = self.hidden_size // self.num_attention_heads
        self.ln1 = nn.LayerNorm(self.hidden_size, epsilon=self.layernorm_eps)
        self.qkv_projection = nn.DenseGeneral(3 * self.hidden_size, use_bias=True)
        self.attention = utils.DotProductAttention(
            num_attention_heads=self.num_attention_heads,
            kv_channels=self.kv_channels,
            attention_dropout=self.attention_dropout
        )
        self.projection = nn.DenseGeneral(self.hidden_size, bias=True)
        self.dropout = nn.Dropout(self.hidden_dropout, rng_collection='hidden')
        self.ln2 = nn.LayerNorm(self.hidden_size, epsilon=self.layernorm_eps)
        self.mlp = utils.BasicMLP(
            hidden_size=self.hidden_size,
            ffn_hidden_size=self.ffn_hidden_size,
        ) 
        
    def __call__(
        self, 
        x: jnp.ndarray,
        attention_mask: jnp.ndarray
    ) -> jnp.ndarray:
        res = x
        x = self.ln1(x)
        
        # Fused QKV projection
        qkv = self.qkv_projection(x)
        qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
        q, k, v = jnp.split(qkv, qkv.size(3) // 3, axis=3)
        
        x = self.attention(q, k, v, attention_mask)
        x = self.projection(x)
        x = self.dropout(x)
        x = res + x
        res = x
        x = self.ln2(x)
        x = self.mlp(x)
        
        return x + res

That's it! We now have a simple Transformer layer. We can test it:

In [2]:
# Layer configuration:
hidden_size = 4096
sequence_length = 2048
batch_size = 4
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = jnp.float16

# Create the necessary RNG keys:
#   The dropout in DotProductAttention will use the 'attention' key group, while
#   the hidden layer dropout in BasicTransformerLayer will use 'hidden' key group.
root_key = jax.random.PRNGKey(seed=0)
fwd_key, bwd_key, params_key, attention_key, hidden_key = jax.random.split(root_key, 5)

# Synthetic data:
x = jax.device_put(
    jax.random.uniform(fwd_key, (sequence_length, batch_size, hidden_size), dtype=dtype)
)
dy = jax.device_put(
    jax.random.uniform(bwd_key, (sequence_length, batch_size, hidden_size), dtype=dtype)
)

In [3]:
basic_transformer = BasicTransformerLayer(
    hidden_size,
    ffn_hidden_size,
    num_attention_heads
)
basic_transformer

BasicTransformerLayer(
  (ln1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  (qkv_projection): Linear(in_features=4096, out_features=12288, bias=True)
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (projection): Linear(in_features=4096, out_features=4096, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (ln2): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  (mlp): BasicMLP(
    (linear1): Linear(in_features=4096, out_features=16384, bias=True)
    (linear2): Linear(in_features=16384, out_features=4096, bias=True)
  )
)

In [4]:
flax_params = basic_transformer.init(params_key, train=True)
y = basic_transformer.apply(flax_params, x, attention_mask=None)

In [5]:
utils.speedometer(
    basic_transformer,
    flax_params,
    x,
    dy,
    forward_kwargs = { "attention_mask": None },
)

Mean time: 43.0663916015625 ms


## Meet Transformer Engine

<div class="alert alert-info">

<b>Summary</b>
    
We modify the example Transformer layer to include the simplest TE modules: `Linear` and `LayerNorm`.

</div>

Now that we have a basic Transformer layer, let's use Transformer Engine to speed up the training. 

In [None]:
import transformer_engine.jax as te

TE provides a set of PyTorch modules that can be used to build Transformer layers. The simplest of the provided modules are the `Linear` and `LayerNorm` layers, which we can use instead of `torch.nn.Linear` and `torch.nn.LayerNorm`. Let's modify `BasicTransformerLayer`:

In [7]:
class BasicTEMLP(nn.Module):
    def __init__(self,
                 hidden_size: int,
                 ffn_hidden_size: int) -> None:
        super().__init__()
        self.linear1 = te.flax.DenseGeneral(ffn_hidden_size, use_bias=True)
        self.linear2 = te.flax.DenseGeneral(hidden_size, use_bias=True)

    def __call__(self, x):
        x = self.linear1(x)
        x = jax.nn.gelu(x, approximate=True)
        x = self.linear2(x)
        return x    
    
class BasicTETransformerLayer(nn.Module):
    hidden_size: int
    ffn_hidden_size: int
    num_attention_heads: int
    layernorm_eps: int = 1e-6
    attention_dropout: float = 0.1
    hidden_dropout: float = 0.1
                 
    def setup(self):
        self.kv_channels = self.hidden_size // self.num_attention_heads
        self.ln1 = te.flax.LayerNorm(self.hidden_size, epsilon=self.layernorm_eps)
        self.qkv_projection = te.flax.DenseGeneral(3 * self.hidden_size, use_bias=True)
        self.attention = utils.DotProductAttention(
            num_attention_heads=self.num_attention_heads,
            kv_channels=self.kv_channels,
            attention_dropout=self.attention_dropout,
        )
        self.projection = te.flax.DenseGeneral(hidden_size, use_bias=True)
        self.dropout = nn.Dropout(self.hidden_dropout, rng_collection='hidden')
        self.ln2 = te.flax.LayerNorm(hidden_size, eps=self.layernorm_eps)
        self.mlp = BasicTEMLP(
            hidden_size=self.hidden_size,
            ffn_hidden_size=self.ffn_hidden_size,
        )
        
    def __call__(self, 
                x: jnp.ndarray, 
                attention_mask: jnp.ndarray):
        res = x
        x = self.ln1(x)
        
        # Fused QKV projection
        qkv = self.qkv_projection(x)
        qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
        q, k, v = jnp.split(qkv, qkv.size(3) // 3, dim=3)
        
        x = self.attention(q, k, v, attention_mask)
        x = self.projection(x)
        x = self.dropout(x)
        x = res + x
        res = x
        x = self.ln2(x)
        x = self.mlp(x)
        
        return x + res

In [8]:
basic_te_transformer = BasicTETransformerLayer(
    hidden_size, 
    ffn_hidden_size, 
    num_attention_heads,
)
te_params = basic_te_transformer.init(params_key, train=True)
y = basic_te_transformer.apply(te_params, x, attention_mask=None)

In [10]:
utils.speedometer(
    basic_te_transformer,
    te_params,
    x,
    dy,
    forward_kwargs = { "attention_mask": None },
)

Mean time: 43.1413232421875 ms


## Fused TE Modules

<div class="alert alert-info">

<b>Summary</b>
    
We optimize the example Transformer layer with TE modules for fused operations.

</div>

The `Linear` layer is enough to build any Transformer model and it enables usage of Transformer Engine even for very custom Transformers. However, having more knowledge about the model allows for additional optimizations like kernel fusion, increasing the achievable speedup.

Transformer Engine therefore provides coarser modules that span multiple layers:

* `LayerNormLinear`
* `LayerNormMLP`
* `TransformerLayer`

Building a third iteration of our Transformer layer with `LayerNormLinear` and `LayerNormMLP`:

In [11]:
class FusedTETransformerLayer(nn.Module):
    hidden_size: int
    ffn_hidden_size: int
    num_attention_heads: int
    layernorm_eps: int = 1e-6
    attention_dropout: float = 0.1
    hidden_dropout: float = 0.1

    def setup(self):
        self.kv_channels = self.hidden_size // self.num_attention_heads
        self.ln_qkv = te.flax.LayerNormDenseGeneral(
            3 * hidden_size,
            epsilon=self.layernorm_eps,
            use_bias=True
        )
        self.attention = utils.DotProductAttention(
            num_attention_heads=num_attention_heads,
            kv_channels=self.kv_channels,
            attention_dropout=self.attention_dropout,
        )
        self.projection = te.flax.DenseGeneral(self.hidden_size, use_bias=True)
        self.dropout = nn.Dropout(self.hidden_dropout, rng_collection='hidden')
        self.ln_mlp = te.flax.LayerNormMLP(
            ffn_hidden_size,
            epsilon=self.layernorm_eps,
            use_bias=True,
            intermediate_dropout_rate=0.0)


    def __call__(self, 
                x: jnp.ndarray, 
                attention_mask: jnp.ndarray):
        res = x
        qkv = self.ln_qkv(x)
        
        # Split qkv into query, key and value
        qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
        q, k, v = jnp.split(qkv, qkv.size(3) // 3, dim=3)
        
        x = self.attention(q, k, v, attention_mask)
        x = self.projection(x)
        x = self.dropout(x)
        x = res + x
        res = x
        x = self.ln_mlp(x)
        
        return x + res

In [12]:
fused_te_transformer = FusedTETransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)
fused_te_params = fused_te_transformer.init(params_key, train=True)
y = fused_te_transformer.apply(fused_te_params, x, attention_mask=None)

In [14]:
utils.speedometer(
    fused_te_transformer,
    fused_te_params,
    x,
    dy,
    forward_kwargs = { "attention_mask": None },
)

Mean time: 43.1981201171875 ms


Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures and it provides the highest degree of performance optimization:

In [15]:
te_transformer = te.flax.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)
transformer_params = te_transformer.init(params_key, train=True)
y = te_transformer.apply(transformer_params, x, attention_mask=None)

In [17]:
utils.speedometer(
    te_transformer,
    transformer_params,
    x,
    dy,
    forward_kwargs = { "attention_mask": None },
)

Mean time: 39.99169921875 ms


## Enabling FP8

<div class="alert alert-info">

<b>Summary</b>
    
We configure a TE module to perform compute in FP8.

</div>

Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager. Note that fp8_autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options.

In [18]:
from transformer_engine.common.recipe import Format, DelayedScaling

te_transformer = te.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)
te_transformer.to(dtype=dtype).cuda()
utils.share_parameters_with_transformerlayer_te_model(te_transformer, basic_transformer)

fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
torch.manual_seed(1234)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    y = te_transformer(x, attention_mask=None)

In [19]:
utils.speedometer(
    te_transformer,
    x,
    dy,
    forward_kwargs = { "attention_mask": None },
    fp8_autocast_kwargs = { "enabled": True, "fp8_recipe": fp8_recipe },
)

Mean time: 28.61394775390625 ms
