# SAVE MEMORY WITH MIXED PRECISION

# 1 - What is Mixed Precision

Like most deep learning frameworks, PyTorch runs on 32-bit floating-point (FP32) arithmetic by default. However, many deep learning models do not require this to reach complete accuracy during training.

Mixed precision training delivers significant computational speedup by conducting operations in half-precision while keeping minimum information in single-precision to maintain as much information as possible in crucial areas of the network.

Switching to mixed precision has resulted in considerable training speedups since the introduciton of Tensor Cores in the Volta and Turing architectures. It combines FP32 and lower-bit floating points (such as FP16) to reduce memory footprint and increase performance during model training and evaluation. It accomplishes this by recognizing the steps that require complete accuracy, employing 32-bit floating point only in those steps and 16-bit floating point for the rest.

[Compared to complete precision training, mixed precision training delivers all these benefits while ensuring no task-specific accuracy is lost.](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html)


----


**Note:** In some cases, it is essential to remain in FP32 for numerical stability, so keep this in mind when using mixed precision. For example, when running scatter operations during the forward (such as torchpoint3d), the computation must remain in FP32.


---

## 1.1 - Lightning Fabric

In [None]:
from lightning.fabric import Fabric

# This is the default
fabric = Fabric(precision="32-true")

# Also FP32 (legacy)
fabric = Fabric(precision=32)

# FP32 as well (legacy)
fabric = Fabric(precision="32")

# Float16 mixed precision
fabric = Fabric(precision="16-mixed")

# Float16 true half precision
fabric = Fabric(precision="16-true")

# BFloat16 mixed precision (Volta GPUs and later)
fabric = Fabric(precision="bf16-mixed")

# BFloat16 true half precision (Volta GPUs and later)
fabric = Fabric(precision="bf16-true")

# 8-bit mixed precision via TransformerEngine (Hopper GPUs and later)
fabric = Fabric(precision="transformer-engine")

# Double precision
fabric = Fabric(precision="64-true")

# Or (legacy)
fabric = Fabric(precision="64")

# Or (legacy)
fabric = Fabric(precision=64)

# 2 - FP16 Mixed Precision

In most cases, mixed precision uses FP16. [Supported PyTorch operations](https://pytorch.org/docs/stable/amp.html#op-specific-behavior) automatically run in FP16, saving memory and improving throughput on the supported accelerators.

Since computation happens in FP16, which as a very limited "dynamic range", there is a chance of numerical instability during training. This is handled internally by a dynamic grad scaler which skips invalid steps and adjusts the scaler to ensure subsequent steps fall within a finite range. For more information, [see the autocast docs.](https://pytorch.org/docs/stable/amp.html#gradient-scaling)

----

**Note:** When using TPUs, setting precision="16-mixed" will enable bfloat16 based mixed precision, the only supported half-precision type on TPUs.

----

## 2.1 - Ligtning Fabric

This is how you enable FP16 in Fabric:

In [None]:
# Select FP16 mixed precision
fabric = Fabric(precision="16-mixed")

# 3 - BFloat16 Mixed Precision

BFloat16 Mixed precision is similar to FP16 mixed precision. However, it maintains more of the "dynamic range" that FP32 offers. This means it can improve numerical stability with respect to FP16 mixed precision. For more information, see [this TPU performance blog post](https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus).


----

Note: BFloat16 may not provide significant speedups or memory improvements, offering better numerical stability. For GPUs, the most significant benefits require [Ampere](https://en.wikipedia.org/wiki/Ampere_(microarchitecture)) based GPUs or newer, such as A100s or 3090s.

----

## 3.1 - Lightning Fabric

In [None]:
# Select BF16 precision
fabric = Fabric(precision="bf16-mixed")

Under the hood, it uses `torch.autocast` with the `dtype` set to `bfloat16`, with no gradient scaling. It is also possible to use BFloat16 mixed precision on the CPU, relying on MKLDNN.

# 4 - Float8 Mixed Precision via NVidia's TransformerEngine

[Transformer Engine (TE)](https://github.com/NVIDIA/TransformerEngine) is a library for accelerating models on the latest NVIDIA GPUs using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower memory utilization in both trianing and inference. It offers improved performance over half precision with no degradation in accuracy.

Using TE requires replacing some of the layers in your model. Fabric automatically replaces the `torch.nn.Linear` and `torch.nn.LayerNorm` layers in your model with their TE alternatives. In addition, TE also offers fused layers to squeeze out all the possible performance.

----

**Note:** Float8 Mixed Precision requires Hopper based GPUs or newer, such the H100.

----

#### What are fused layers?

In PyTorch, fused layers combine multiple operations into a single, optimized operation executed by the hardware (e.g., GPU, CPU).

**Key Benefits of Fused Layers:**

* Reduced Overhead: Fusing operations eliminates the need for separate kernel launches and context switching between them, which minimizes overhead costs associated with function calls, memory management, and intermediate data transfers.

* Improved Hardware Utilization: Fused operations often leverage hardware-specific optimizations, such as vectorization and instruction-level parallelism, leading to more efficient utilization of processing resources.

* Minimized Memory Traffic: Fusing operations can potentially decrease the amount of data transferred between memory and the processing unit (e.g., GPU) by reducing intermediate storage requirements.

**Commonly Fused Operations in PyTorch:**

* Activation Functions: Combining activation functions (e.g., ReLU, LeakyReLU) with the preceding linear layer operation often results in performance improvements.

* Batch Normalization (BN): Fusing BN with the convolutional layer can decrease computational overhead and improve inference speed.

* Element-wise Operations: Combining element-wise operations (e.g., addition, subtraction) with other operations can be beneficial.

**Factors Affecting Fusing Effectiveness:**

* Hardware Architecture: The level of performance gain from fusing layers depends on the specific hardware being used. For example, GPUs typically offer better support for fused operations compared to CPUs.

* Model Architecture: The structure of your deep learning model plays a role as well. Fusing operations within a complex architecture with multiple branches and convolutions might not yield significant benefits due to increased complexity and potentially limited hardware support.

**How to Leverage Fused Layers in PyTorch:**

While PyTorch doesn't offer an explicit mechanism to force layer fusion, various techniques can encourage it:

* Utilize Existing Fused Layers: Several built-in layers in PyTorch already incorporate fused operations, such as nn.Conv2d(..., bias=False) or nn.BatchNorm2d(..., affine=False). These layers are optimized for specific hardware and can benefit your model's performance.

* Rearrange Operations: Experiment with rearranging operations within your model's architecture. Sometimes, swapping the order of operations or grouping them strategically can lead to better hardware utilization and potential fusion opportunities.

* Profiling: Use profiling tools to identify performance bottlenecks in your model. This can help you determine if fusing operations might be beneficial for specific sections.

**Additional Considerations:**

* Fusing layers might not always lead to performance improvements. It's essential to evaluate the impact on your specific model and hardware through experimentation.

* Fused layers can sometimes complicate the model and make it less readable, making debugging and maintenance more challenging.



## 4.1 - Lightning Fabric

In [None]:
# Select 8bit mixed precision via TransformerEngine, with model weights in bfloat16
fabric = Fabric(precision="transformer-engine")

# Select 8bit mixed precision via TransformerEngine, with model weights in float16
fabric = Fabric(precision="transformer-engine-float16")

# Customize the fp8 recipe or set a different base precision:
from lightning.fabric.plugins import TransformerEnginePrecision

recipe = {"fp8_format": "HYBRID", "amax_history_len": 16, "amax_compute_algo": "max"}
precision = TransformerEnginePrecision(weights_dtype=torch.bfloat16, recipe=recipe)
fabric = Fabric(plugins=precision)

# 5 - True Half Precision

As mentioned before, for numerical stability, mixed precision keeps the model weights in full `float32` precision while casting only supported operations to lower bit precision. However, in scome cases it is indeed possible to train completely in half precision. 

Similarly, for inference the model weights can often be cast to half precision without a loss in accuracy (even when trained with mixed precision)

In [None]:
# Select FP16 precision
fabric = Fabric(precision="16-true")
model = MyModel()
model = fabric.setup(model)  # model gets cast to torch.float16

# Select BF16 precision
fabric = Fabric(precision="bf16-true")
model = MyModel()
model = fabric.setup(model)  # model gets cast to torch.bfloat16

**Tip:** For faster initialization, we can create a model parameters with the desired dtype directly on the device:

In [None]:
fabric = Fabric(precision="bf16-true")

# init the model directly on the device and with parameters in half-precision
with fabric.init_module():
    model = MyModel()

model = fabric.setup(model)