# Fully Sharded Data Parallelism (FSDP2)

## FSDP is Data Parallelism...

Discussions about FSDP can get confusing, because it implements a lot of different techinques (some of them quite advanced).  But the first thing to know is that it is data parallelism, just as we've seen in [the DDP notebook](4_Distributed_data_parallel.ipynb), and as we saw in images there:

![Overview of data parallelism](images/data-par-1.png)

So each instance is responsible for training the entire model on a separate batch of data; you need something like DistributedSampler in your data loader, etc.   It's data parallelism.

## ... and FSDP also uses Model Parallelism Techniques (Amongst Others) To Reduce Memory Usage

FSDP differs by implementing a number of techniques to reduce memory usage, so that **FSDP can work even if the entire model won't fit on a single GPU**.  

The signature method, sharding, means that each replica only persistantly stores a shard of the entire model, and state is materialized in place only when needed:

![Sharded data parallelism](images/sharded-data-par-2.png)

So that each GPU can be training a replica of a model which is, in principle, significantly larger than the memory of the GPU.

(FSDP uses [DTensors](https://docs.pytorch.org/docs/stable/distributed.tensor.html) and [Device Meshes](https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.device_mesh.DeviceMesh), from the Tensor parallism framework, to handle the sharding.   This isn't tensor parallelism, though; computation isn't parallelized over pieces of tensors.   Each of the replicas trains the entire over its subset of batches; it's the _persistant storage_ of shards of tensors which is distributed.)

My clumsy diagrams above probably make it look like it's only the model parameters which are sharded, but in fact GPU memory is required for parameters, gradients, and the potentially quite large optimizer state; all of those can be sharded (or at least not persisted:)

![Diagram showing a memory-use graph demonstrating sharding of parameters, gradients, and optimizer state, from the ZeRO paper](images/ZeRO.png)

The figure above is from the paper "[ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054)" which described and implemented these approaches; PyTorch's implementation of these methods is FSDP.   The math to the side sketches out the memory requirements; if there are $x$ parameters, and we're using FP16 (2 bytes per parameter), the sizes of the different layers are:

* Parameters - $2x$
* Gradients - $2x$
* Optimizer State (for, say, Adam) - $12x$
  * Parameter copy $4x$ (4 bytes for float32)
  * Momentum $4x$
  * Variance $4x$

When it comes time to materialze an entire layer and associated state to compute on it, there are a lot of different options that have different computation vs memory vs communication tradeoffs.

* Do we free this materialized layer and all its properties immediately afterwards, or do we keep some amount of it persistent even though it reuqires memory?
* Do we pre-load upcoming layers, so that the data is already in place when its time to start the computation, or do we not and sae the memory until it's needed?
* Is there state that can be offloaded to CPU memory rather than relying on a possibly off-node GPU copy?
* Can some of the state be recomputed rather than copied?
* Can we use reduced precision for some of the data we're copying over?

This range of options (not all relevant to all of the state) can make FSDP a little intimidating to learn about; there's so many knobs that can be turned!  But we can start with FSDP fairily easily.  FSDP defaults are pretty good.  Our goal should normally be to be get our model working on FSDP, so that a sharded version can run on one of our GPUs; if it doesn't right away,  you can be agressive with sharding and memory needs initially so that things start working.  Then, at our leasure, we can start playing with options and turning things on again to see if they improve speed of training.

Let's try and get started:

While all the above is notionally fairly complex, the FSDP2 APIs hide almost all of this from us!  Let's look at a simple example, taken very slightly modified from the [Pytorch Getting Started w/ FSDP2 guide](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html):

* [Before FSDP](code/fsdp-example-single.py)
* [After](code/fsdp-example-multi-simple.py)

We can run them as below:

In [None]:

!python3 code/fsdp-example-single.py

!./code/run_w_torchrun.sh 2  ./code/fsdp-example-multi-simple.py

Note how simple this is!  We really only have three real changes:

* The now usual `torch.distributed` boilerplate for any distributed job
* Wrapping the model layers and then the full model with `fully_shard`
* A somewhat more complicated `torch.save`, so we can to reconstitute the model in CPU memory.  (This isn't necessary or even desirable or possible with large models, but for models run on modest numbers of ranks this is simple and convenient.  Otherwise we'd do something more like the pipeline parallelism approach).

And that's it!

## The FSDP2 training workflow

![FSDP combines tensor and pipeline parallelism; from the FSDP paper](images/fsdp-workflow.png)

The training workflow we've described goes as above.  We set up our model, shard it, and then with each forward pass we all-gather all the shards, then proceed to the next... when it comes time to back-prop through the layer, we all-gather again, do the back-prop, and do our usual data-parallel reduction of the gradients but this time we scatter the results across the shards; no one needs all the gradients.   Then there's the optimizer step, and we continue.

