# FSDP example with Llama 3.2 1B Instruct and openassistant-guanaco dataset
In this example a network is trained on multiple GPUs with the help of FSDP (Fully Sharded Data Parallel). This approach allows to train networks that are too large to fit into the memory of a single GPU.

If we want to use multiple GPUs, we need to write the code to a file and submit the job to the SLURM scheduler, because the JupyterHub that we are using today does not have access to any GPU. This example uses two GPUs on one node, but could be extended simply by adjusting the number of GPUs and nodes in the SLURM script.

#### First, we write the python code to a file:

In [1]:
%%writefile llama_guanaco_fsdp.py
# Import libraries
import torch
from accelerate import PartialState
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTTrainer, SFTConfig
import pynvml

def print_gpu_utilization():
    pynvml.nvmlInit()
    device_count = pynvml.nvmlDeviceGetCount()
    memory_used = []
    for device_index in range(device_count):
        device_handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
        device_info = pynvml.nvmlDeviceGetMemoryInfo(device_handle)
        memory_used.append(device_info.used/1024**3)
    print('Memory occupied on GPUs: ' + ' + '.join([f'{mem:.1f}' for mem in memory_used]) + ' GB.')


# Choose a model and load tokenizer and model (using 4bit quantization):
# model_name = "meta-llama/Llama-3.2-1B-Instruct"
model_name = "/leonardo_scratch/fast/EUHPC_D20_063/huggingface/models/meta-llama--Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# For some models (such as LLama-3.2-1B-Instruct), we need to set a padding token and the padding side:
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side='left'


# For multi-GPU training, find out how many GPUs there are and which one we should use:
ps = PartialState()
num_processes = ps.num_processes
process_index = ps.process_index
local_process_index = ps.local_process_index

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type='nf4',
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_storage=torch.bfloat16,  # Added for FSDP
    ),
    # device_map={'':local_process_index},  # Removed for FSDP
    attn_implementation='eager',  # 'eager', 'sdpa', or "flash_attention_2"
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
)

# Load the guanaco dataset
guanaco_train = load_dataset('/leonardo_scratch/fast/EUHPC_D20_063/huggingface/datasets/timdettmers--openassistant-guanaco', split='train')
guanaco_test = load_dataset('/leonardo_scratch/fast/EUHPC_D20_063/huggingface/datasets/timdettmers--openassistant-guanaco', split='test')
# guanaco_train = load_dataset('timdettmers/openassistant-guanaco', split='train')
# guanaco_test = load_dataset('timdettmers/openassistant-guanaco', split='test')

def reformat_text(text, include_answer=True):
    question1 = text.split('###')[1].removeprefix(' Human: ')
    answer1 = text.split('###')[2].removeprefix(' Assistant: ')
    if include_answer:
        messages = [
            {'role': 'user', 'content': question1},
            {'role': 'assistant', 'content': answer1}
        ]
    else:
        messages = [
            {'role': 'user', 'content': question1}
        ]        
    reformatted_text = tokenizer.apply_chat_template(messages, tokenize=False)
    return reformatted_text

# Now, apply reformat_train(..) to both datasets:
guanaco_train = guanaco_train.map(lambda entry: {
    'reformatted_text': reformat_text(entry['text'])
})
guanaco_test = guanaco_test.map(lambda entry: {
    'reformatted_text': reformat_text(entry['text'])
})

model.config.use_cache = False  # KV cache can only speed up inference, but we are doing training.

# Add low-rank adapters (LORA) to the model:
peft_config = LoraConfig(
    task_type='CAUSAL_LM',
    r=16,
    lora_alpha=32,  # thumb rule: lora_alpha should be 2*r
    lora_dropout=0.05,
    bias='none',
    target_modules='all-linear',
)

training_arguments = SFTConfig(
    output_dir='output/llama-3.2-1b-instruct-guanaco-fsdp',
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True, # Gradient checkpointing improves memory efficiency, but slows down training,
        # e.g. Mistral 7B with PEFT using bitsandbytes:
        # - enabled: 11 GB GPU RAM and 8 samples/second
        # - disabled: 40 GB GPU RAM and 12 samples/second
    gradient_checkpointing_kwargs={'use_reentrant': False},  # Use newer implementation that will become the default.
    # We don't need the following two lines for FSDP (compared to DDP):
    # ddp_find_unused_parameters=False,  # Set to False when using gradient checkpointing to suppress warning message.
    # log_level_replica='error',  # Disable warnings in all but the first process.
    optim='adamw_torch',
    learning_rate=2e-4,  # QLoRA suggestions: 2e-4 for 7B or 13B, 1e-4 for 33B or 65B
    logging_strategy='no',
    # logging_strategy='steps',  # 'no', 'epoch' or 'steps'
    # logging_steps=10,
    save_strategy='no',  # 'no', 'epoch' or 'steps'
    # save_steps=2000,
    # num_train_epochs=5,
    max_steps=100,
    bf16=True,  # mixed precision training
    report_to='none',  # disable wandb
    max_length=1024,
    dataset_text_field='reformatted_text',
)

trainer = SFTTrainer(
    model=model,
    peft_config=peft_config,
    args=training_arguments,
    train_dataset=guanaco_train,
    eval_dataset=guanaco_test,
    processing_class=tokenizer,
)

if process_index == 0:  # Only print in first process.
    if hasattr(trainer.model, "print_trainable_parameters"):
        trainer.model.print_trainable_parameters()

train_result = trainer.train()
if process_index == 0:
    print("Training result:")
    print(train_result)

# Print memory usage once per node:
if local_process_index == 0:
    print_gpu_utilization()

# Save model:
# Note: This needs to be executed an all python process to work, not only on the first process.
trainer.save_model()

Overwriting llama_guanaco_fsdp.py


#### Next, we write a file with the configuration for FSDP:

In [2]:
%%writefile fsdp_config.yml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: c10d
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Overwriting fsdp_config.yml


#### Finally, we write the SLURM script:

In [3]:
%%writefile run_llama_guanaco_fsdp.slurm
#!/bin/bash

#SBATCH --partition=boost_usr_prod
# #SBATCH --qos=boost_qos_dbg
#SBATCH --account=EUHPC_D20_063
#SBATCH --reservation=s_tra_ncc

## Specify resources:
## Leonardo Booster: 32 CPU cores and 4 GPUs per node => request 8 * number of GPUs CPU cores
## Leonardo Booster: 512 GB in total => request approx. 120 GB * number of GPUs requested
#SBATCH --nodes=1
#SBATCH --gpus-per-task=2  # up to 4 on Leonardo
#SBATCH --ntasks-per-node=1  # always 1
#SBATCH --mem=240GB  # should be 120GB * gpus-per-task on Leonardo
#SBATCH --cpus-per-task=16  # should be 8 * gpus-per-task on Leonardo

#SBATCH --time=0:30:00

# Include commands in output:
set -x

# Print current time and date:
date

# Print host name:
hostname

# List available GPUs:
nvidia-smi

# Construct command to run container:
export CONTAINER="singularity run --nv --home=$HOME $SINGULARITY_CONTAINER"

# Set environment variables for communication between nodes:
export MASTER_PORT=$(shuf -i 20000-30000 -n 1)  # Choose a random port
export MASTER_ADDR=$(scontrol show hostnames ${SLURM_JOB_NODELIST} | head -n 1)
export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK

# Set launcher and launcher arguments:
export LAUNCHER="accelerate launch \
    --num_machines $SLURM_NNODES \
    --num_processes $((SLURM_NNODES * SLURM_GPUS_ON_NODE)) \
    --num_cpu_threads_per_process 8 \
    --main_process_ip $MASTER_ADDR \
    --main_process_port $MASTER_PORT \
    --machine_rank \$SLURM_PROCID \
    --config_file \"fsdp_config.yml\" \
    "
# Set training script that will be executed:
export PROGRAM="llama_guanaco_fsdp.py"

# Run:
time srun bash -c "$CONTAINER $LAUNCHER $PROGRAM"

Overwriting run_llama_guanaco_fsdp.slurm


#### We can now execute the SLURM script and, once the job ran, look at the output:

In [4]:
!sbatch --job-name=$TRAINEE_USERNAME run_llama_guanaco_fsdp.slurm

Submitted batch job 19817793


In [5]:
!squeue --name=$TRAINEE_USERNAME

             JOBID PARTITION     NAME     USER ST       TIME  NODES NODELIST(REASON)
          19817793 boost_usr   martin mpfister  R       0:04      1 lrdn3366


In [6]:
!cat slurm-19817793.out

+ date
Wed Sep 10 20:33:52 CEST 2025
+ hostname
lrdn3366.leonardo.local
+ nvidia-smi
Wed Sep 10 20:33:52 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.03              Driver Version: 535.54.03    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM-64GB           On  | 00000000:1D:00.0 Off |                    0 |
| N/A   44C    P0             ERR! / 470W |      2MiB / 65536MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+-------

#### Finally, we can clean up and delete the files that we just created:

In [9]:
!rm fsdp_config.yml llama_guanaco_fsdp.py run_llama_guanaco_fsdp.slurm slurm-*.out