<a href="https://www.kaggle.com/code/liuserr/notebook2313628fdd?scriptVersionId=225601960" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Apache 2.0 License


Copyright (C) 2025 Cole Liu

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Task B: FSDP2 with QLoRA
This notebook demonstrates a functional application of QLoRA using FSDP2. The acceleration is moderate, due to large overhead. Here are the three loss curves generated. 1 is using FSDP1 on a single GPU. 2 is using FSDP2 without gradient accumulation, making it more susceptible to memory overload. 3 is using FSDP2 with gradient accumulation. As you can see, the loss curves are the same for all configurations.

*FSDP1*

![FSDP1](loss_plot.png "FSDP1") 

*FSDP2 No Gradient Accum*

![alt text](loss_plot_v2.png "FSDP2 No Gradient Accum") 

*FSDP2 Gradient Accum*

![alt text](loss_plot_gradient_accum.png "FSDP2 Gradient Accum")

In [1]:
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29 peft trl triton
!pip install --no-deps cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
!pip install --no-deps unsloth

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.3-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting xformers==0.0.29
  Downloading xformers-0.0.29-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (1.0 kB)
Collecting trl
  Downloading trl-0.15.2-py3-none-any.whl.metadata (11 kB)
Collecting triton
  Downloading triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading xformers-0.0.29-cp310-cp310-manylinux_2_28_x86_64.whl (15.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.3/15.3 MB[0m [31m80.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading bitsandbytes-0.45.3-py3-none-manylinux_2_24_x86_64.whl (76.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m20.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading trl-0.15.2-py3-none-any.whl (318 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m318.9/318.9 kB[0m [31m16.

In [2]:
!pip install torch accelerate transformers bitsandbytes peft



In [3]:
!pip install --upgrade  torchvision

Collecting torchvision
  Downloading torchvision-0.21.0-cp310-cp310-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting torch==2.6.0 (from torchvision)
  Downloading torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl.metadata (28 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch==2.6.0->torchvision)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch==2.6.0->torchvision)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch==2.6.0->torchvision)
  Dow

In [4]:
!pip install --upgrade --force-reinstall "numpy<2.0"

Collecting numpy<2.0
  Downloading numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m79.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.26.4
    Uninstalling numpy-1.26.4:
      Successfully uninstalled numpy-1.26.4
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
unsloth 2025.2.15 requires tyro, which is not installed.
unsloth-zoo 2025.2.7 requires tyro, which is not installed.
cudf-cu12 24.12.0 requires 

In [5]:
!pip install importlib-metadata



# Single GPU FSDP1

![image](loss_plot.png)

In [None]:
%%writefile single_gpu.py
import os
import random
import numpy as np
import torch

seed = 3407
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# --- Environment Setup ---
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    CPUOffload,
    ShardingStrategy,
    MixedPrecision,
)
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
from accelerate import Accelerator
from functools import partial

# --- 1) Distributed Setup ---
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
else:
    local_rank = 0

accelerator = Accelerator(mixed_precision="bf16")
device = torch.device("cuda", local_rank)

# --- 2) Model & 4-bit Quantization ---
model_name = "unsloth/meta-Llama-3.1-8B-Instruct-bnb-4bit"
dtype = torch.bfloat16

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=dtype,
    bnb_4bit_quant_storage=dtype,  # try storing quantized weights as BF16
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    torch_dtype=dtype,
    trust_remote_code=True,
)
model = prepare_model_for_kbit_training(model)

# --- 3) Apply LoRA ---
lora_config = LoraConfig(
    r=64,
    lora_alpha=128,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.0,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)

for name, param in model.named_parameters():
    if ".lora_" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

model.gradient_checkpointing_enable()
model.enable_input_require_grads()

# --- Integrate torch.compile (PyTorch 2.0 feature) ---
try:
    model = torch.compile(model)
    accelerator.print("torch.compile: Model compiled successfully.")
except Exception as e:
    accelerator.print(f"torch.compile: Compilation failed with error: {e}. Continuing without compilation.")

torch.cuda.empty_cache()

# --- Debug: Print buffer names and dtypes ---
if local_rank == 0:
    buffer_info = [(name, buf.dtype) for name, buf in model.named_buffers()]
    print("Buffers in model (name, dtype):")
    for name, dt in buffer_info:
        print(name, dt)

# --- 5) FSDP Wrapping with FSDP1 ---
# Ignore the frozen base module from FSDP sharding.
base_model_submodule = model.model  
ignored_modules = [base_model_submodule]

auto_wrap_policy = partial(
    size_based_auto_wrap_policy,
    min_num_params=5e7,
    recurse=True,
    nonwrapped_numel=0,
)

model = FSDP(
    model,
    auto_wrap_policy=auto_wrap_policy,
    sharding_strategy=ShardingStrategy.FULL_SHARD,  # Ensure full sharding (FSDP1)
    use_orig_params=True,
    mixed_precision=MixedPrecision(param_dtype=dtype),
    ignored_modules=ignored_modules,
    device_id=device,
)

# --- 6) Tokenizer & Dataset ---
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "</s>"

dataset = load_dataset(
    "json",
    data_files={"train": "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"},
    split="train[:10%]"
)

def tokenize_example(ex):
    enc = tokenizer(
        ex["text"],
        max_length=512,
        padding="max_length",
        truncation=True,
    )
    enc["labels"] = enc["input_ids"].copy()
    return enc

dataset = dataset.map(tokenize_example, batched=True, remove_columns=["text"])
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
train_dataloader = accelerator.prepare(train_dataloader)

# --- 7) Optimizer ---
trainable_params = [p for p in model.parameters() if p.requires_grad]
try:
    from bitsandbytes.optim import Adam8bit
    optimizer = Adam8bit(trainable_params, lr=2e-4)
    accelerator.print("Using bitsandbytes Adam8bit optimizer.")
except ImportError:
    optimizer = torch.optim.AdamW(trainable_params, lr=2e-4)
    accelerator.print("Using torch.optim.AdamW optimizer.")
optimizer = accelerator.prepare(optimizer)

# --- 8) Training Loop with Gradient Accumulation ---
loss_history = []
model.train()
gradient_accumulation_steps = 16  # Adjust this value as needed

for step, batch in enumerate(train_dataloader, start=1):
    # The accelerator.accumulate context will accumulate gradients over the specified steps.
    with accelerator.accumulate(model):
        outputs = model(**batch)
        # Scale loss to account for gradient accumulation
        loss = outputs.loss / gradient_accumulation_steps
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()
        # Clear cache after the optimizer step to lower VRAM usage
        if accelerator.sync_gradients:
            torch.cuda.empty_cache()
    
    if accelerator.is_main_process:
        loss_history.append((step, outputs.loss.item()))
        accelerator.print(f"Step {step} - Loss: {outputs.loss.item():.4f}")
    
    # Terminate after 60 steps (micro-batches)
    if step >= 60:
        break
accelerator.wait_for_everyone()

if accelerator.is_main_process:
    try:
        import matplotlib.pyplot as plt
        steps, losses = zip(*loss_history) if loss_history else ([], [])
        if steps:
            plt.figure(figsize=(8, 4))
            plt.plot(steps, losses, marker="o")
            plt.xlabel("Training Step")
            plt.ylabel("Loss")
            plt.title("Loss over the First 60 Training Steps")
            plt.grid(True)
            plt.savefig("loss_plot.png")
            plt.show()
        else:
            print("No loss data recorded.")
    except ImportError:
        accelerator.print("matplotlib is not installed. Skipping loss plot.")

# --- 9) Save Model ---
if accelerator.is_main_process:
    final_model = accelerator.unwrap_model(model)
    final_model.save_pretrained("llama-8b-finetuned", safe_serialization=True)
    tokenizer.save_pretrained("llama-8b-finetuned")

if dist.is_initialized():
    dist.destroy_process_group()


In [None]:
!torchrun --nproc_per_node=2 --nnodes=1 single_gpu.py

# FSDP2 without Gradient Accumulation
![image](loss_plot_v2.png)

In [None]:
%%writefile finetune_llama31_8b_fsdp2_qlora_revised_v2.py
import os

import random
import numpy as np
import torch

seed = 3407
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Optionally, enforce determinism (may slow down training):
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False


# --- Environment Setup ---
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    CPUOffload,
    ShardingStrategy,
    MixedPrecision,
)
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
from accelerate import Accelerator
from functools import partial

# --- 1) Distributed Setup ---
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
else:
    local_rank = 0

accelerator = Accelerator(mixed_precision="bf16")
device = torch.device("cuda", local_rank)

# --- 2) Model & 4-bit Quantization ---
model_name = "unsloth/meta-Llama-3.1-8B-Instruct-bnb-4bit"
dtype = torch.bfloat16

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=dtype,
    bnb_4bit_quant_storage=dtype,  # try storing quantized weights as BF16
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    torch_dtype=dtype,
    trust_remote_code=True,
)
model = prepare_model_for_kbit_training(model)

# --- 3) Apply LoRA ---
lora_config = LoraConfig(
    r=64,
    lora_alpha=128,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.0,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)

for name, param in model.named_parameters():
    if ".lora_" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

model.gradient_checkpointing_enable()
model.enable_input_require_grads()

# --- Integrate torch.compile (PyTorch 2.0 feature) ---
try:
    model = torch.compile(model)
    accelerator.print("torch.compile: Model compiled successfully.")
except Exception as e:
    accelerator.print(f"torch.compile: Compilation failed with error: {e}. Continuing without compilation.")

# --- Optionally, we previously converted int parameters to buffers.
# For this version, we skip that conversion and rely on bitsandbytes to manage quantized weights.
# def convert_int_params_to_buffers(module: torch.nn.Module):
#     for name, param in list(module.named_parameters(recurse=False)):
#         if param is not None and not param.dtype.is_floating_point:
#             del module._parameters[name]
#             module.register_buffer(name, param.data)
#     for child in module.children():
#         convert_int_params_to_buffers(child)
#
# convert_int_params_to_buffers(model)

torch.cuda.empty_cache()

# --- Debug: Print buffer names and dtypes ---
if local_rank == 0:
    buffer_info = [(name, buf.dtype) for name, buf in model.named_buffers()]
    print("Buffers in model (name, dtype):")
    for name, dt in buffer_info:
        print(name, dt)

# --- 5) FSDP Wrapping ---
# Ignore the frozen base module from FSDP sharding.
base_model_submodule = model.model  
ignored_modules = [base_model_submodule]

auto_wrap_policy = partial(
    size_based_auto_wrap_policy,
    min_num_params=5e7,
    recurse=True,
    nonwrapped_numel=0,
)

mp_policy = MixedPrecision(
    param_dtype=dtype,      # cast parameters to BF16
    reduce_dtype=dtype,     # use BF16 for gradient reduction
    buffer_dtype=None,      # do not cast buffers
)

cpu_offload = None
# Uncomment the next line if needed:
# cpu_offload = CPUOffload(offload_params=True)

model = FSDP(
    model,
    auto_wrap_policy=auto_wrap_policy,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    use_orig_params=True,
    cpu_offload=cpu_offload,
    limit_all_gathers=True,
    device_id=device,
    ignored_modules=ignored_modules,
    mixed_precision=mp_policy,
)

# --- 6) Tokenizer & Dataset ---
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "</s>"

dataset = load_dataset(
    "json",
    data_files={"train": "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"},
    split="train[:10%]"
)

def tokenize_example(ex):
    enc = tokenizer(
        ex["text"],
        max_length=512,
        padding="max_length",
        truncation=True,
    )
    enc["labels"] = enc["input_ids"].copy()
    return enc

dataset = dataset.map(tokenize_example, batched=True, remove_columns=["text"])
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
train_dataloader = accelerator.prepare(train_dataloader)

# --- 7) Optimizer ---
trainable_params = [p for p in model.parameters() if p.requires_grad]
try:
    from bitsandbytes.optim import Adam8bit
    optimizer = Adam8bit(trainable_params, lr=2e-4)
    accelerator.print("Using bitsandbytes Adam8bit optimizer.")
except ImportError:
    optimizer = torch.optim.AdamW(trainable_params, lr=2e-4)
    accelerator.print("Using torch.optim.AdamW optimizer.")
optimizer = accelerator.prepare(optimizer)

# --- 8) Training Loop ---
loss_history = []
model.train()
for step, batch in enumerate(train_dataloader, start=1):
    outputs = model(**batch)
    loss = outputs.loss
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()
    torch.cuda.empty_cache()
    if accelerator.is_main_process:
        loss_history.append((step, loss.item()))
    
    accelerator.print(f"Step {step} - Loss: {loss.item():.4f}")
    
    # Terminate after 60 steps.
    if step >= 60:
        break
accelerator.wait_for_everyone()

if accelerator.is_main_process:
    try:
        import matplotlib.pyplot as plt
        steps, losses = zip(*loss_history) if loss_history else ([], [])
        if steps:
            plt.figure(figsize=(8, 4))
            plt.plot(steps, losses, marker="o")
            plt.xlabel("Training Step")
            plt.ylabel("Loss")
            plt.title("Loss over the First 60 Training Steps")
            plt.grid(True)
            plt.savefig("loss_plot_v2.png")
            plt.show()
        else:
            print("No loss data recorded.")
    except ImportError:
        accelerator.print("matplotlib is not installed. Skipping loss plot.")


# --- 9) Save Model ---
if accelerator.is_main_process:
    final_model = accelerator.unwrap_model(model)
    final_model.save_pretrained("llama-8b-finetuned", safe_serialization=True)
    tokenizer.save_pretrained("llama-8b-finetuned")

if dist.is_initialized():
    dist.destroy_process_group()


In [None]:
!accelerate launch --num_processes=2 --mixed_precision=bf16 finetune_llama31_8b_fsdp2_qlora_revised_v2.py

# FSDP2 with Gradient Accumulation
![image](loss_plot_gradient_accum.png)

In [9]:
%%writefile gradient_accumulation.py
import os
import random
import numpy as np
import torch

seed = 3407
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Optionally, enforce determinism (may slow down training):
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

# --- Environment Setup ---
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

import torch.distributed as dist
# Import FSDP2 from the new location.
from torch.distributed.fsdp import fully_shard as FSDP2
from torch.distributed.fsdp import CPUOffload, ShardingStrategy, MixedPrecision
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
from accelerate import Accelerator
from functools import partial

# --- 1) Distributed Setup ---
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
else:
    local_rank = 0

accelerator = Accelerator(mixed_precision="bf16")
device = torch.device("cuda", local_rank)

# --- 2) Model & 4-bit Quantization ---
model_name = "unsloth/meta-Llama-3.1-8B-Instruct-bnb-4bit"
dtype = torch.bfloat16

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=dtype,
    bnb_4bit_quant_storage=dtype,  # store quantized weights as BF16
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    torch_dtype=dtype,
    trust_remote_code=True,
)
model = prepare_model_for_kbit_training(model)

# --- 3) Apply LoRA ---
lora_config = LoraConfig(
    r=64,
    lora_alpha=128,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.0,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)

# Freeze all parameters except LoRA ones.
for name, param in model.named_parameters():
    if ".lora_" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

model.gradient_checkpointing_enable()
model.enable_input_require_grads()

# --- Integrate torch.compile (PyTorch 2.0 feature) ---
try:
    model = torch.compile(model)
    accelerator.print("torch.compile: Model compiled successfully.")
except Exception as e:
    accelerator.print(f"torch.compile: Compilation failed with error: {e}. Continuing without compilation.")

torch.cuda.empty_cache()

# --- Debug: Print buffer names and dtypes ---
if local_rank == 0:
    buffer_info = [(name, buf.dtype) for name, buf in model.named_buffers()]
    print("Buffers in model (name, dtype):")
    for name, dt in buffer_info:
        print(name, dt)

# --- 5) FSDP2 Wrapping ---
# Ignore the frozen base module from FSDP sharding.
base_model_submodule = model.model  
ignored_modules = [base_model_submodule]

auto_wrap_policy = partial(
    size_based_auto_wrap_policy,
    min_num_params=5e7,
    recurse=True,
    nonwrapped_numel=0,
)

mp_policy = MixedPrecision(
    param_dtype=dtype,      # cast parameters to BF16
    reduce_dtype=dtype,     # use BF16 for gradient reduction
    buffer_dtype=None,      # do not cast buffers
)

# Uncomment the next line if you want CPU offloading.
cpu_offload = CPUOffload(offload_params=True)
# Alternatively, set cpu_offload = None if offloading is not desired.

model = FSDP2(
    model,
    auto_wrap_policy=auto_wrap_policy,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    use_orig_params=True,
    cpu_offload=cpu_offload,
    limit_all_gathers=True,
    device_id=device,
    ignored_modules=ignored_modules,
    mixed_precision=mp_policy,
)

# --- 6) Tokenizer & Dataset ---
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token else "</s>"

dataset = load_dataset(
    "json",
    data_files={"train": "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"},
    split="train[:10%]"
)

def tokenize_example(ex):
    enc = tokenizer(
        ex["text"],
        max_length=512,
        padding="max_length",
        truncation=True,
    )
    enc["labels"] = enc["input_ids"].copy()
    return enc

dataset = dataset.map(tokenize_example, batched=True, remove_columns=["text"])
dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
train_dataloader = accelerator.prepare(train_dataloader)

# --- 7) Optimizer ---
trainable_params = [p for p in model.parameters() if p.requires_grad]
try:
    from bitsandbytes.optim import Adam8bit
    optimizer = Adam8bit(trainable_params, lr=2e-4)
    accelerator.print("Using bitsandbytes Adam8bit optimizer.")
except ImportError:
    optimizer = torch.optim.AdamW(trainable_params, lr=2e-4)
    accelerator.print("Using torch.optim.AdamW optimizer.")
optimizer = accelerator.prepare(optimizer)

# --- 8) Training Loop with Gradient Accumulation ---
loss_history = []
model.train()
gradient_accumulation_steps = 16  # Adjust this value as needed

for step, batch in enumerate(train_dataloader, start=1):
    # The accelerator.accumulate context will accumulate gradients over the specified steps.
    with accelerator.accumulate(model):
        outputs = model(**batch)
        # Scale loss to account for gradient accumulation
        loss = outputs.loss / gradient_accumulation_steps
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()
        # Clear cache after the optimizer step to lower VRAM usage
        if accelerator.sync_gradients:
            torch.cuda.empty_cache()
    
    if accelerator.is_main_process:
        loss_history.append((step, outputs.loss.item()))
        accelerator.print(f"Step {step} - Loss: {outputs.loss.item():.4f}")
    
    # Terminate after 60 steps (micro-batches)
    if step >= 60:
        break
accelerator.wait_for_everyone()

if accelerator.is_main_process:
    try:
        import matplotlib.pyplot as plt
        steps, losses = zip(*loss_history) if loss_history else ([], [])
        if steps:
            plt.figure(figsize=(8, 4))
            plt.plot(steps, losses, marker="o")
            plt.xlabel("Training Step")
            plt.ylabel("Loss")
            plt.title("Loss over the First 60 Training Steps")
            plt.grid(True)
            plt.savefig("loss_plot_gradient_accum.png")
            plt.show()
        else:
            print("No loss data recorded.")
    except ImportError:
        accelerator.print("matplotlib is not installed. Skipping loss plot.")

# --- 9) Save Model ---
if accelerator.is_main_process:
    final_model = accelerator.unwrap_model(model)
    final_model.save_pretrained("llama-8b-finetuned", safe_serialization=True)
    tokenizer.save_pretrained("llama-8b-finetuned")

if dist.is_initialized():
    dist.destroy_process_group()


Overwriting gradient_accumulation.py


In [12]:
%%writefile gradient_accumulation.py
import os
import random
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard, MixedPrecision, OffloadPolicy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    apply_activation_checkpointing,
    checkpoint_wrapper,
)
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from datasets import load_dataset

# --- 1) Environment Setup & Seeds ---
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

seed = 3407
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# --- 2) Distributed Setup ---
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.cuda.set_device(local_rank)

# --- 3) Model & Tokenizer & Quantization ---
model_name = "unsloth/meta-Llama-3.1-8B-Instruct-bnb-4bit"
dtype = torch.bfloat16

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=dtype,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
if not tokenizer.pad_token_id:
    tokenizer.pad_token_id = tokenizer.eos_token_id

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    torch_dtype=dtype,
    trust_remote_code=True,
    device_map=None,
)
model = prepare_model_for_kbit_training(model)

# --- 4) LoRA Configuration ---
lora_config = LoraConfig(
    r=64,
    lora_alpha=128,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.0,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)

# Freeze non-LoRA parameters and ensure valid dtypes
for name, param in model.named_parameters():
    if ".lora_" not in name:
        param.requires_grad = False
        
        # Ensure the parameter is a floating-point type
        if param.dtype != torch.float32 and param.dtype != torch.float16 and param.dtype != torch.bfloat16:
            print(f"Converting parameter {name} from {param.dtype} to bfloat16.")
            param.data = param.data.to(torch.bfloat16)

# --- 5) Dataset Preparation ---
def prepare_dataset():
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            truncation=True,
            padding="max_length",
            max_length=512,
        )

    tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
    return tokenized_dataset

# --- 6) FSDP2 Configuration ---
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)

def check_fn(submodule):
    return isinstance(submodule, (LlamaDecoderLayer,))

apply_activation_checkpointing(
    model,
    checkpoint_wrapper_fn=checkpoint_wrapper,
    check_fn=check_fn,
)

model = fully_shard(
    model,
    mp_policy=mp_policy,
    offload_policy=OffloadPolicy(),
    reshard_after_forward=True,
)

# --- 7) HuggingFace Trainer Setup ---
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=1,  # Consider reducing this if OOM persists
    gradient_accumulation_steps=16,
    learning_rate=2e-4,
    fp16=False,
    bf16=True,  # Using bfloat16 precision
    logging_steps=10,
    save_steps=100,
    evaluation_strategy="steps",
    eval_steps=100,
    save_total_limit=2,
)

# --- 8) Trainer Setup ---
from transformers import DataCollatorForLanguageModeling

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=prepare_dataset(),
    data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
)

# --- 9) Training ---
try:
    trainer.train()
except RuntimeError as e:
    if 'out of memory' in str(e):
        print("Caught OOM error! Consider reducing batch size or using CPU offloading.")
        raise e

# --- 10) Model Saving ---
if local_rank == 0:
    trainer.save_model("llama-8b-finetuned")
    tokenizer.save_pretrained("llama-8b-finetuned")

dist.destroy_process_group()


Overwriting gradient_accumulation.py


In [13]:
!torchrun --standalone --nnodes=1 --nproc_per_node=2 gradient_accumulation.py


W0304 01:27:21.819000 224 torch/distributed/run.py:792] 
W0304 01:27:21.819000 224 torch/distributed/run.py:792] *****************************************
W0304 01:27:21.819000 224 torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0304 01:27:21.819000 224 torch/distributed/run.py:792] *****************************************
[W304 01:27:32.632184263 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
[W304 01:27:42.642664837 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3
2025-03-04 01:27:46.831552: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-04 01:27:46.841127: E external/local_xla/xl