In [3]:
# 2025/7/25
# zhangzhong
# 

In [16]:
from datasets import Dataset
from datasets.distributed import split_dataset_by_node
from torchdata.stateful_dataloader import StatefulDataLoader

In [5]:
iterable_dataset = Dataset.from_dict({"a": range(64)}).to_iterable_dataset(
    num_shards=4
)
print(iterable_dataset)

IterableDataset({
    features: ['a'],
    num_shards: 4
})


In [6]:
ds1 = iterable_dataset.shard(num_shards=4, index=0)
for example in ds1:
    print(example)

# split_dataset_by_node
# For iterable datasets:
# If the dataset has a number of shards that is a factor of world_size (i.e. if dataset.num_shards % world_size == 0), 
# then the shards are evenly assigned across the nodes, which is the most optimized. 
# Otherwise, each node keeps 1 example out of world_size, skipping the other examples.
# 换句话说，就是“每 4 个样本取一个”，且不同的进程负责不同的样本索引。

# 每个dataloader都需要记录自己的状态
# 他们在不同的进程里面是不一样的
# 如果可以做到只保留一个就很好
# 因为不同的进程分别保存自己的checkpoint比较麻烦
# 在笔记上画一下数据构造的全部流程吧
# 看看有什么什么突破口

{'a': 0}
{'a': 1}
{'a': 2}
{'a': 3}
{'a': 4}
{'a': 5}
{'a': 6}
{'a': 7}
{'a': 8}
{'a': 9}
{'a': 10}
{'a': 11}
{'a': 12}
{'a': 13}
{'a': 14}
{'a': 15}


In [9]:
dataset = split_dataset_by_node(iterable_dataset, rank=0, world_size=4)
print(dataset)

IterableDataset({
    features: ['a'],
    num_shards: 1
})


In [None]:
# 要跟DDP结合啊
for example in dataset:
    print(example)

{'a': 0}
{'a': 1}
{'a': 2}
{'a': 3}
{'a': 4}
{'a': 5}
{'a': 6}
{'a': 7}
{'a': 8}
{'a': 9}
{'a': 10}
{'a': 11}
{'a': 12}
{'a': 13}
{'a': 14}
{'a': 15}


In [11]:
iterable_dataset = Dataset.from_dict({"a": range(64)}).to_iterable_dataset(
    num_shards=4
)
print(iterable_dataset)

IterableDataset({
    features: ['a'],
    num_shards: 4
})


In [13]:
ds1 = split_dataset_by_node(iterable_dataset, rank=0, world_size=4)
ds2 = split_dataset_by_node(iterable_dataset, rank=1, world_size=4)
ds3 = split_dataset_by_node(iterable_dataset, rank=2, world_size=4)
ds4 = split_dataset_by_node(iterable_dataset, rank=3, world_size=4)

for example in zip(ds1, ds2, ds3, ds4):
    print(example)
    break



({'a': 0}, {'a': 16}, {'a': 32}, {'a': 48})


In [None]:
# save the state dict of iterable_dataset
state_dict = iterable_dataset.state_dict()
iterable_dataset = Dataset.from_dict({"a": range(64)}).to_iterable_dataset(
    num_shards=4
)
iterable_dataset.load_state_dict(state_dict)

ds1 = split_dataset_by_node(iterable_dataset, rank=0, world_size=4)
ds2 = split_dataset_by_node(iterable_dataset, rank=1, world_size=4)
ds3 = split_dataset_by_node(iterable_dataset, rank=2, world_size=4)
ds4 = split_dataset_by_node(iterable_dataset, rank=3, world_size=4)

for example in zip(ds1, ds2, ds3, ds4):
    print(example)
    break

# 没用，可以看到iterable的dataset，新加上去的层不会影响之前的层
# 所以最合适的方法，就是使用stateful dataloader就可以le

({'a': 0}, {'a': 16}, {'a': 32}, {'a': 48})


In [None]:
dataloader = StatefulDataLoader(ds1, batch_size=8)
for example in dataloader:
    print(example)
    break

state_dict = dataloader.state_dict()

# 这里就是可以正常工作的

iterable_dataset = Dataset.from_dict({"a": range(64)}).to_iterable_dataset(
    num_shards=4
)
ds1 = split_dataset_by_node(iterable_dataset, rank=0, world_size=4)
dataloader = StatefulDataLoader(ds1, batch_size=8)
dataloader.load_state_dict(state_dict)
for example in dataloader:
    print(example)
    break


{'a': tensor([0, 1, 2, 3, 4, 5, 6, 7])}
{'a': tensor([ 8,  9, 10, 11, 12, 13, 14, 15])}


In [None]:
# 所以这个dataloder目前看来只能是每个进程自己保存一份自己的state_dict了
# 