# Large Model Initialization

In this notebook, we study the initialization of very large models. There are many types of optimization and different ways of implementation for large model initialization. Here we want to study their impact on

1. peak memory
2. running time

## Memory Footprint

Suppose we have a pretrained checkpoint of a model. Then the peak memory of model initialization can be reduced by

1. checkpoint sharding
2. deferred model materialzation
3. data type (model quantization)

A typical model initialization process is 1) materialize a model with random weights; 2) load in the checkpointed weights; 3) assign the weights to the model. Unfortunately, the peak memory of such process is `twice the model size` (random weights + checkpointed weights).

One way to remedy this is to use sharded checkpoints. We load in one shard of the checkpoint at a time and assign the weights accordingly. This process has the peak memory of `model size + shard size`.

To achieve a peak memory of `model size`, we can first initialize the model with a empty shell and materialize it directly with the loaded checkpoints. This requires initialization with `meta device` in PyTorch.

Here we compare these ideas using two implementations: HuggingFace and native PyTorch.

> _NOTE:_ We will not use `torchdistx` and its `deferred_init`, as it does not work with PyTorch 2.0+.

In [1]:
import torch
import tracemalloc

from pathlib import Path

In [2]:

tracemalloc.start()

# your code here

original, _ = tracemalloc.get_traced_memory()

llama_weight_path = Path('/project/llama/7B')
# Create a sharded version of the original parameter
weights = torch.load(
    llama_weight_path / 'consolidated.00.pth', map_location='cpu')

current, peak = tracemalloc.get_traced_memory()

print(f"Original memory usage is {original / 10**6}MB; Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")

Original memory usage is 0.000704MB; Current memory usage is 0.259698MB; Peak was 0.591343MB


In [5]:
weights['layers.30.feed_forward.w2.weight']

tensor([[-0.0199,  0.0213,  0.0250,  ..., -0.0610,  0.0007, -0.0143],
        [ 0.0409,  0.0204,  0.0125,  ...,  0.0070, -0.0222,  0.0151],
        [ 0.0217,  0.0078,  0.0133,  ...,  0.0112,  0.0403,  0.0081],
        ...,
        [-0.0366,  0.0066,  0.0679,  ..., -0.0173, -0.0131,  0.0312],
        [ 0.0116, -0.0162,  0.0045,  ...,  0.0458,  0.0015, -0.0046],
        [ 0.0330,  0.0108, -0.0049,  ..., -0.0088,  0.0036, -0.0050]],
       dtype=torch.float16)