-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Description
Describe the bug
I have a large dataset that I shared into 1024 shards and save on the disk during pre-processing. During training, I load the dataset using load_from_disk() and convert it into an iterable dataset, shuffle it and split the shards to different DDP nodes using the recommended method.
However, when the training is resumed mid-epoch, I get thousands of identical warning messages:
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Steps to reproduce the bug
- Run a multi-node training job using the following python script and interrupt the training after a few seconds to save a mid-epoch checkpoint.
#!/usr/bin/env python
import os
import time
from typing import Dict, List
import torch
import lightning as pl
from torch.utils.data import DataLoader
from datasets import Dataset
from datasets.distributed import split_dataset_by_node
import datasets
from transformers import AutoTokenizer
from more_itertools import flatten, chunked
from torchdata.stateful_dataloader import StatefulDataLoader
from lightning.pytorch.callbacks.on_exception_checkpoint import (
OnExceptionCheckpoint,
)
datasets.logging.set_verbosity_debug()
def dummy_generator():
# Generate 60 examples: integers from $0$ to $59$
# 64 sequences of different lengths
dataset = [
list(range(3, 10)),
list(range(10, 15)),
list(range(15, 21)),
list(range(21, 27)),
list(range(27, 31)),
list(range(31, 36)),
list(range(36, 45)),
list(range(45, 50)),
]
for i in range(8):
for j, ids in enumerate(dataset):
yield {"token_ids": [idx + i * 50 for idx in ids]}
def group_texts(
examples: Dict[str, List[List[int]]],
block_size: int,
eos_token_id: int,
bos_token_id: int,
pad_token_id: int,
) -> Dict[str, List[List[int]]]:
real_block_size = block_size - 2 # make space for bos and eos
# colapse the sequences into a single list of tokens and then create blocks of real_block_size
input_ids = []
attention_mask = []
for block in chunked(flatten(examples["token_ids"]), real_block_size):
s = [bos_token_id] + list(block) + [eos_token_id]
ls = len(s)
attn = [True] * ls
s += [pad_token_id] * (block_size - ls)
attn += [False] * (block_size - ls)
input_ids.append(s)
attention_mask.append(attn)
return {"input_ids": input_ids, "attention_mask": attention_mask}
def collate_fn(batch):
return {
"input_ids": torch.tensor(
[item["input_ids"] for item in batch], dtype=torch.long
),
"attention_mask": torch.tensor(
[item["attention_mask"] for item in batch], dtype=torch.long
),
}
class DummyModule(pl.LightningModule):
def __init__(self):
super().__init__()
# A dummy linear layer (not used for actual computation)
self.layer = torch.nn.Linear(1, 1)
self.ds = None
self.prepare_data_per_node = False
def on_train_start(self):
# This hook is called once training begins on each process.
print(f"[Rank {self.global_rank}] Training started.", flush=True)
self.data_file = open(f"data_{self.global_rank}.txt", "w")
def on_train_end(self):
self.data_file.close()
def training_step(self, batch, batch_idx):
# Print batch information to verify data loading.
time.sleep(5)
# print("batch", batch, flush=True)
print(
f"\n[Rank {self.global_rank}] Training step, epoch {self.trainer.current_epoch}, batch {batch_idx}: {batch['input_ids']}",
flush=True,
)
self.data_file.write(
f"[Rank {self.global_rank}] Training step, epoch {self.trainer.current_epoch}, batch {batch_idx}: {batch['input_ids']}\n"
)
# Compute a dummy loss (here, simply a constant tensor)
loss = torch.tensor(0.0, requires_grad=True)
return loss
def on_train_epoch_start(self):
epoch = self.trainer.current_epoch
print(
f"[Rank {self.global_rank}] Training epoch {epoch} started.",
flush=True,
)
self.data_file.write(
f"[Rank {self.global_rank}] Training epoch {epoch} started.\n"
)
def configure_optimizers(self):
# Return a dummy optimizer.
return torch.optim.SGD(self.parameters(), lr=0.001)
class DM(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.ds = None
self.prepare_data_per_node = False
def set_epoch(self, epoch: int):
self.ds.set_epoch(epoch)
def prepare_data(self):
# download the dataset
dataset = Dataset.from_generator(dummy_generator)
# save the dataset
dataset.save_to_disk("dataset", num_shards=4)
def setup(self, stage: str):
# load the dataset
ds = datasets.load_from_disk("dataset").to_iterable_dataset(
num_shards=4
)
ds = ds.map(
group_texts,
batched=True,
batch_size=5,
fn_kwargs={
"block_size": 5,
"eos_token_id": 1,
"bos_token_id": 0,
"pad_token_id": 2,
},
remove_columns=["token_ids"],
).shuffle(seed=42, buffer_size=8)
ds = split_dataset_by_node(
ds,
rank=self.trainer.global_rank,
world_size=self.trainer.world_size,
)
self.ds = ds
def train_dataloader(self):
print(
f"[Rank {self.trainer.global_rank}] Preparing train_dataloader...",
flush=True,
)
rank = self.trainer.global_rank
print(
f"[Rank {rank}] Global rank: {self.trainer.global_rank}",
flush=True,
)
world_size = self.trainer.world_size
print(f"[Rank {rank}] World size: {world_size}", flush=True)
return StatefulDataLoader(
self.ds,
batch_size=2,
num_workers=2,
collate_fn=collate_fn,
drop_last=True,
persistent_workers=True,
)
if __name__ == "__main__":
print("Starting Lightning training", flush=True)
# Optionally, print some SLURM environment info for debugging.
print(f"SLURM_NNODES: {os.environ.get('SLURM_NNODES', '1')}", flush=True)
# Determine the number of nodes from SLURM (defaulting to 1 if not set)
num_nodes = int(os.environ.get("SLURM_NNODES", "1"))
model = DummyModule()
dm = DM()
on_exception = OnExceptionCheckpoint(
dirpath="checkpoints",
filename="on_exception",
)
# Configure the Trainer to use distributed data parallel (DDP).
trainer = pl.Trainer(
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=1,
strategy=(
"ddp" if num_nodes > 1 else "auto"
), # Use DDP strategy for multi-node training.
num_nodes=num_nodes,
max_epochs=2,
logger=False,
enable_checkpointing=True,
num_sanity_val_steps=0,
enable_progress_bar=False,
callbacks=[on_exception],
)
# resume (uncomment to resume)
# trainer.fit(model, datamodule=dm, ckpt_path="checkpoints/on_exception.ckpt")
# train
trainer.fit(model, datamodule=dm)#!/bin/bash
#SBATCH --job-name=pl_ddp_test
#SBATCH --nodes=2 # Adjust number of nodes as needed
#SBATCH --ntasks-per-node=1 # One GPU (process) per node
#SBATCH --cpus-per-task=3 # At least as many dataloader workers as required
#SBATCH --gres=gpu:1 # Request one GPU per node
#SBATCH --time=00:10:00 # Job runtime (adjust as needed)
#SBATCH --partition=gpu-preempt # Partition or queue name
#SBATCH -o script.out
# Disable Python output buffering.
export PYTHONUNBUFFERED=1
echo "SLURM job starting on $(date)"
echo "Running on nodes: $SLURM_NODELIST"
echo "Current directory: $(pwd)"
ls -l
# Launch the script using srun so that each process starts the Lightning module.
srun script.py- Uncomment the "resume" line (second to last) and comment the original
trainer.fitcall (last line).
It will produce the following log.
[Rank 0] Preparing train_dataloader...
[Rank 0] Global rank: 0
[Rank 0] World size: 2
[Rank 1] Preparing train_dataloader...
[Rank 1] Global rank: 1
[Rank 1] World size: 2
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Assigning 2 shards (or data sources) of the dataset to each node.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#0 dataloader worker#1, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#0 dataloader worker#0, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Set __getitem__(key) output type to arrow for no columns (when key is int or slice) and don't output other (un-formatted) columns.
Set __getitem__(key) output type to arrow for no columns (when key is int or slice) and don't output other (un-formatted) columns.
node#0 dataloader worker#1, ': Finished iterating over 1/1 shards.
node#0 dataloader worker#0, ': Finished iterating over 1/1 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
[Rank 0] Training started.
[Rank 0] Training epoch 0 started.
[Rank 0] Training epoch 1 started.
Assigning 2 shards (or data sources) of the dataset to each node.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#0 dataloader worker#1, ': Starting to iterate over 1/2 shards.
node#0 dataloader worker#0, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#1 dataloader worker#1, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#1 dataloader worker#0, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Set __getitem__(key) output type to arrow for no columns (when key is int or slice) and don't output other (un-formatted) columns.
Set __getitem__(key) output type to arrow for no columns (when key is int or slice) and don't output other (un-formatted) columns.
node#0 dataloader worker#1, ': Finished iterating over 1/1 shards.
node#0 dataloader worker#0, ': Finished iterating over 1/1 shards.
`Trainer.fit` stopped: `max_epochs=2` reached.
Set __getitem__(key) output type to arrow for no columns (when key is int or slice) and don't output other (un-formatted) columns.
Set __getitem__(key) output type to arrow for no columns (when key is int or slice) and don't output other (un-formatted) columns.
node#1 dataloader worker#1, ': Finished iterating over 1/1 shards.
node#1 dataloader worker#0, ': Finished iterating over 1/1 shards.
[Rank 1] Training started.
[Rank 1] Training epoch 0 started.
[Rank 1] Training epoch 1 started.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#1 dataloader worker#1, ': Starting to iterate over 1/2 shards.
node#1 dataloader worker#0, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Set __getitem__(key) output type to arrow for no columns (when key is int or slice) and don't output other (un-formatted) columns.
Set __getitem__(key) output type to arrow for no columns (when key is int or slice) and don't output other (un-formatted) columns.
node#1 dataloader worker#0, ': Finished iterating over 1/1 shards.
node#1 dataloader worker#1, ': Finished iterating over 1/1 shards.
I'm also attaching the relevant state_dict to make sure that the state is being checkpointed as expected.
{'_iterator_finished': True,
'_snapshot': {'_last_yielded_worker_id': 1,
'_main_snapshot': {'_IterableDataset_len_called': None,
'_base_seed': 3992758080362545099,
'_index_sampler_state': {'samples_yielded': 64},
'_num_workers': 2,
'_sampler_iter_state': None,
'_sampler_iter_yielded': 32,
'_shared_seed': None},
'_snapshot_step': 32,
'_worker_snapshots': {'worker_0': {'dataset_state': {'ex_iterable': {'shard_example_idx': 0,
'shard_idx': 1},
'num_examples_since_previous_state': 0,
'previous_state': {'shard_example_idx': 0,
'shard_idx': 1},
'previous_state_example_idx': 33},
'fetcher_state': {'dataset_iter_state': None,
'fetcher_ended': False},
'worker_id': 0},
'worker_1': {'dataset_state': {'ex_iterable': {'shard_example_idx': 0,
'shard_idx': 1},
'num_examples_since_previous_state': 0,
'previous_state': {'shard_example_idx': 0,
'shard_idx': 1},
'previous_state_example_idx': 33},
'fetcher_state': {'dataset_iter_state': None,
'fetcher_ended': False},
'worker_id': 1}}},
'_steps_since_snapshot': 0}
Expected behavior
Since I'm following all the recommended steps, I don't expect to see any warning when resuming. Am I doing something wrong? Also, can someone explain why I'm seeing 20 identical messages in the log in this reproduction setting? I'm trying to understand why I see thousands of these messages with the actual dataset.
One more surprising thing I noticed in the logs is the change in a number of shards per worker. In the following messages, the denominator changes from 2 to 1.
node#1 dataloader worker#1, ': Starting to iterate over 1/2 shards.
...
node#1 dataloader worker#1, ': Finished iterating over 1/1 shards.
Environment info
python: 3.11.10
datasets: 3.3.2
lightning: 2.3.1