diff --git a/test/cpp/api/dataloader.cpp b/test/cpp/api/dataloader.cpp index 844eaed91d3ce..7026b6c69ca43 100644 --- a/test/cpp/api/dataloader.cpp +++ b/test/cpp/api/dataloader.cpp @@ -97,17 +97,19 @@ TEST(DataTest, ChunkDataSetWithInvalidInitParameter) { auto initialization_function = [&](size_t preloader_count, size_t batch_size, size_t cache_size) { + std::shared_ptr chunk_selector = + std::make_shared( + data_reader.chunk_count()); + datasets::SharedBatchDataset> dataset = datasets::make_shared_dataset>( data_reader, sampler, - sampler, + chunk_selector, datasets::ChunkDatasetOptions( preloader_count, batch_size, cache_size)); }; @@ -831,6 +833,179 @@ TEST(DataTest, CanUseCustomTypeAsIndexType) { } } +TEST(DataTest, RandomChunkSelectorSingleReplicaSingleThread) { + size_t chunk_count = 10; + datasets::RandomChunkSelector cs(chunk_count); + ASSERT_EQ(cs.local_chunk_count(), chunk_count); + + std::vector res; + torch::optional idx; + while ((idx = cs.next()).has_value()) { + res.push_back(*idx); + } + + ASSERT_EQ(res.size(), chunk_count); + + std::sort(res.begin(), res.end()); + for (size_t i = 0; i < res.size(); ++i) { + ASSERT_EQ(res[i], i); + } +} + +TEST(DataTest, RandomChunkSelectorMultiReplicaSingleThread) { + size_t chunk_count = 10; + size_t num_replicas = 3; + + auto test_function = [&](bool allow_duplicates, + size_t local_chunk_count, + std::vector& output) { + std::vector> selectors; + + for (size_t i = 0; i < num_replicas; ++i) { + selectors.emplace_back(torch::make_unique( + chunk_count, num_replicas, i, allow_duplicates)); + } + // local_chunk_count does not depend on the rank. So only checking one. + ASSERT_EQ((*selectors[0]).local_chunk_count(), local_chunk_count); + + std::vector res; + for (size_t i = 0; i < num_replicas; ++i) { + (*selectors[i]).reset(); + torch::optional idx; + while ((idx = (*selectors[i]).next()).has_value()) { + res.push_back(*idx); + } + ASSERT_EQ(res.size(), local_chunk_count * (i + 1)); + } + std::sort(res.begin(), res.end()); + ASSERT_EQ(res, output); + }; + + size_t local_chunk_count = + static_cast(std::ceil(chunk_count * 1.0 / num_replicas)); + std::vector output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + test_function(true, local_chunk_count, output1); + + local_chunk_count = + static_cast(std::floor(chunk_count * 1.0 / num_replicas)); + std::vector output2{0, 1, 2, 3, 4, 5, 6, 7, 8}; + test_function(false, local_chunk_count, output2); +} + +TEST(DataTest, RandomChunkSelectorMultiReplicaMultiThread) { + size_t chunk_count = 10; + + datasets::RandomChunkSelector cs(chunk_count); + std::vector res; + std::shared_ptr guard_ptr = std::make_shared(); + + auto loader = [&] { + torch::optional idx; + while ((idx = cs.next()).has_value()) { + { + std::lock_guard lock(*guard_ptr); + res.push_back(*idx); + } + } + }; + std::thread t1(loader); + std::thread t2(loader); + t1.join(); + t2.join(); + + std::sort(res.begin(), res.end()); + for (size_t i = 0; i < res.size(); ++i) { + ASSERT_EQ(res[i], i); + } +} + +TEST(DataTest, SequentialChunkSelectorSingleReplicaSingleThread) { + size_t chunk_count = 10; + datasets::SequentialChunkSelector cs(chunk_count); + ASSERT_EQ(cs.local_chunk_count(), chunk_count); + + std::vector res; + torch::optional idx; + while ((idx = cs.next()).has_value()) { + res.push_back(*idx); + } + + ASSERT_EQ(res.size(), chunk_count); + + std::sort(res.begin(), res.end()); + for (size_t i = 0; i < res.size(); ++i) { + ASSERT_EQ(res[i], i); + } +} + +TEST(DataTest, SequentialChunkSelectorMultiReplicaSingleThread) { + size_t chunk_count = 10; + size_t num_replicas = 3; + + auto test_function = [&](bool allow_duplicates, + size_t local_chunk_count, + std::vector& output) { + std::vector> selectors; + + for (size_t i = 0; i < num_replicas; ++i) { + selectors.emplace_back( + torch::make_unique( + chunk_count, num_replicas, i, allow_duplicates)); + } + // local_chunk_count does not depend on the rank. So only checking one. + ASSERT_EQ((*selectors[0]).local_chunk_count(), local_chunk_count); + + std::vector res; + for (size_t i = 0; i < num_replicas; ++i) { + (*selectors[i]).reset(); + torch::optional idx; + while ((idx = (*selectors[i]).next()).has_value()) { + res.push_back(*idx); + } + ASSERT_EQ(res.size(), local_chunk_count * (i + 1)); + } + std::sort(res.begin(), res.end()); + ASSERT_EQ(res, output); + }; + + size_t local_chunk_count = + static_cast(std::ceil(chunk_count * 1.0 / num_replicas)); + std::vector output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + test_function(true, local_chunk_count, output1); + + local_chunk_count = + static_cast(std::floor(chunk_count * 1.0 / num_replicas)); + std::vector output2{0, 1, 2, 3, 4, 5, 6, 7, 8}; + test_function(false, local_chunk_count, output2); +} + +TEST(DataTest, SequentialChunkSelectorMultiReplicaMultiThread) { + size_t chunk_count = 10; + + datasets::SequentialChunkSelector cs(chunk_count); + std::vector res; + std::shared_ptr guard_ptr = std::make_shared(); + + auto loader = [&] { + torch::optional idx; + while ((idx = cs.next()).has_value()) { + { + std::lock_guard lock(*guard_ptr); + res.push_back(*idx); + } + } + }; + std::thread t1(loader); + std::thread t2(loader); + t1.join(); + t2.join(); + + std::sort(res.begin(), res.end()); + for (size_t i = 0; i < res.size(); ++i) { + ASSERT_EQ(res[i], i); + } +} + TEST(DataLoaderTest, DataLoaderOptionsDefaultAsExpected) { DataLoaderOptions partial_options; FullDataLoaderOptions full_options(partial_options); @@ -1445,17 +1620,19 @@ TEST(DataLoaderTest, ChunkDataSetGetBatch) { for (auto prefetch_count : prefetch_counts) { for (auto batch_size : batch_sizes) { for (auto dataloader_worker_count : dataloader_worker_counts) { + std::shared_ptr chunk_selector = + std::make_shared( + data_reader.chunk_count()); + datasets::SharedBatchDataset> dataset = datasets::make_shared_dataset>( data_reader, sampler, - sampler, + chunk_selector, datasets::ChunkDatasetOptions(prefetch_count, batch_size)); auto data_loader = torch::data::make_data_loader( @@ -1499,23 +1676,22 @@ TEST(DataLoaderTest, ChunkDataSetWithBatchSizeMismatch) { DummyChunkDataReader data_reader; samplers::SequentialSampler sampler(0); + std::shared_ptr chunk_selector = + std::make_shared( + data_reader.chunk_count()); - datasets::SharedBatchDataset> + datasets::SharedBatchDataset< + datasets::ChunkDataset> dataset = datasets::make_shared_dataset>( data_reader, sampler, - sampler, + chunk_selector, datasets::ChunkDatasetOptions(prefetch_count, batch_size)); auto data_loader = torch::data::make_data_loader( - dataset, - DataLoaderOptions(requested_batch_size).workers(0)); + dataset, DataLoaderOptions(requested_batch_size).workers(0)); std::string exception_msg = "The requested batch size does not match with the initialized batch " @@ -1546,18 +1722,19 @@ TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) { const size_t batch_size = 5; DummyEmptyChunkDataReader data_reader; samplers::SequentialSampler sampler(0); + std::shared_ptr chunk_selector = + std::make_shared( + data_reader.chunk_count()); datasets::SharedBatchDataset> dataset = datasets::make_shared_dataset>( data_reader, sampler, - sampler, + chunk_selector, datasets::ChunkDatasetOptions(prefetch_count, batch_size)); auto data_loader = torch::data::make_data_loader( @@ -1569,39 +1746,42 @@ TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) { } } -TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) { - struct D : public datasets::ChunkDataReader> { - public: - using BatchType = std::vector; - - BatchType read_chunk(size_t chunk_index) override { - BatchType batch_data(10, 0); - return batch_data; - } +struct DummyTwoChunkReader + : public datasets::ChunkDataReader> { + public: + using BatchType = std::vector; - size_t chunk_count() override { - return 2; - }; + BatchType read_chunk(size_t chunk_index) override { + BatchType batch_data(10, 0); + return batch_data; + } - void reset() override{}; + size_t chunk_count() override { + return 2; }; + void reset() override{}; +}; + +TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) { const size_t batch_sizes[] = {17, 30}; - D data_reader; + DummyTwoChunkReader data_reader; samplers::SequentialSampler sampler(0); for (auto batch_size : batch_sizes) { + std::shared_ptr chunk_selector = + std::make_shared( + data_reader.chunk_count()); + datasets::SharedBatchDataset> dataset = datasets::make_shared_dataset>( data_reader, sampler, - sampler, + chunk_selector, datasets::ChunkDatasetOptions(1, batch_size)); auto data_loader = torch::data::make_data_loader( @@ -1620,3 +1800,37 @@ TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) { } } } + +TEST(DataLoaderTest, ChunkDataSetEnumerationWithMultipleEpochs) { + DummyTwoChunkReader data_reader; + samplers::SequentialSampler sampler(0); + size_t batch_size = 17; + + std::shared_ptr chunk_selector = + std::make_shared( + data_reader.chunk_count()); + + datasets::SharedBatchDataset< + datasets::ChunkDataset> + dataset = datasets::make_shared_dataset>( + data_reader, + sampler, + chunk_selector, + datasets::ChunkDatasetOptions(1, batch_size)); + + auto data_loader = torch::data::make_data_loader( + dataset, DataLoaderOptions(batch_size).workers(0)); + + for (size_t epoch = 0; epoch < 3; ++epoch) + chunk_selector->set_epoch( + epoch); // setting epoch number for this enumeration. + { + for (auto iterator = data_loader->begin(); iterator != data_loader->end(); + ++iterator) { + std::vector batch = *iterator; + ASSERT_TRUE(batch.size() == 17 || batch.size() == 3); + } + } +} diff --git a/torch/csrc/api/include/torch/data/datasets/chunk.h b/torch/csrc/api/include/torch/data/datasets/chunk.h index ae8e1c8b0394a..6a6a445940cb1 100644 --- a/torch/csrc/api/include/torch/data/datasets/chunk.h +++ b/torch/csrc/api/include/torch/data/datasets/chunk.h @@ -1,6 +1,15 @@ #pragma once #include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace torch { namespace data { @@ -242,6 +251,145 @@ class BatchDataBuffer { }; } // namespace detail +/// Select chunks for loading and define a sampling behavior. +/// In a distributed setting, this selects a subset of the chunks depending on +/// the provided num_replicas and rank parameters. The `next()` method of this +/// class needs to be thread-safe as it will be called from different threads +/// during chunk loading. +/// When deciding the total number of chunks, the Selector performs a rounding +/// operation based on the `allow_duplicates` parameter. +class ChunkSelector { + public: + virtual ~ChunkSelector() = default; + ChunkSelector( + size_t chunk_count, + size_t num_replicas = 1, + size_t rank = 0, + bool allow_duplicates = true) + : chunk_count_(chunk_count), + num_replicas_(num_replicas), + rank_(rank), + epoch_(0) { + if (allow_duplicates) { + local_chunk_count_ = + static_cast(std::ceil(chunk_count_ * 1.0 / num_replicas_)); + } else { + local_chunk_count_ = + static_cast(std::floor(chunk_count_ * 1.0 / num_replicas_)); + } + } + + /// Get the next chunk index for loading. + /// Note: this method needs to be thread-safe. + virtual optional next() = 0; + + /// Reset the chunk selector for a new enumeration. + virtual void reset() = 0; + + /// Set the epoch for the current enumeration. This can be used to alter the + /// chunk selection and shuffling behavior. + void set_epoch(size_t epoch) { + epoch_ = epoch; + } + + /// Return the number of chunks to be loaded. In the case of distributed + /// training, this is different to chunk_count as each loader needs to load + /// only a subset of chunks. + size_t local_chunk_count() { + return local_chunk_count_; + } + + protected: + size_t chunk_count_; + size_t num_replicas_; + size_t rank_; + size_t epoch_; + size_t local_chunk_count_; +}; + +/// Select chunks randomly. The chunk order shuffled at each `reset()` call. +class RandomChunkSelector : public ChunkSelector { + public: + RandomChunkSelector( + size_t chunk_count, + size_t num_replicas = 1, + size_t rank = 0, + bool allow_duplicates = true) + : ChunkSelector(chunk_count, num_replicas, rank, allow_duplicates) { + size_t index_count = + num_replicas_ == 1 ? chunk_count_ : local_chunk_count_ * num_replicas_; + all_indices_.resize(index_count); + std::iota(std::begin(all_indices_), std::end(all_indices_), 0); + if (num_replicas_ > 1 && index_count > chunk_count) { + for (size_t i = chunk_count; i < index_count; ++i) { + all_indices_[i] = + i % chunk_count_; // we added duplicate chunks to make all + // replicas to have the same number of chunks. + } + } + + begin_index_ = rank_ * local_chunk_count_; + end_index_ = begin_index_ + local_chunk_count_; + chunk_index_ = begin_index_; + // shuffle first time. + reset(); + } + + optional next() override { + size_t idx = chunk_index_.fetch_add(1, std::memory_order_relaxed); + if (idx < end_index_) { + return all_indices_[idx]; + } else { + return nullopt; + } + } + + void reset() override { + std::minstd_rand rand(epoch_); + std::shuffle(all_indices_.begin(), all_indices_.end(), rand); + chunk_index_ = begin_index_; + } + + private: + size_t begin_index_; + size_t end_index_; + std::atomic chunk_index_; + std::vector all_indices_; +}; + +/// Select chunks sequentially. +class SequentialChunkSelector : public ChunkSelector { + public: + SequentialChunkSelector( + size_t chunk_count, + size_t num_replicas = 1, + size_t rank = 0, + bool allow_duplicates = true) + : ChunkSelector(chunk_count, num_replicas, rank, allow_duplicates) { + begin_index_ = rank_ * local_chunk_count_; + end_index_ = begin_index_ + local_chunk_count_; + chunk_index_ = begin_index_; + } + + optional next() override { + size_t idx = chunk_index_.fetch_add(1, std::memory_order_relaxed); + if (idx < end_index_) { + return idx % chunk_count_; + } else { + return nullopt; + } + } + + void reset() override { + chunk_index_ = begin_index_; + } + + private: + size_t begin_index_; + size_t end_index_; + std::atomic chunk_index_; +}; + /// Options to configure a `ChunkDataset`. struct ChunkDatasetOptions { ChunkDatasetOptions() = delete; @@ -273,7 +421,7 @@ struct ChunkDatasetOptions { /// The size of each batch. TORCH_ARG(size_t, batch_size); - // the capacity of the queue for batch caching. + // The capacity of the queue for batch caching. TORCH_ARG(size_t, cache_size) = 2048; }; @@ -287,31 +435,28 @@ struct ChunkDatasetOptions { /// inspired by this paper http://martin.zinkevich.org/publications/nips2010.pdf template < typename ChunkReader, - typename ChunkSampler = samplers::RandomSampler, typename ExampleSampler = samplers::RandomSampler> class ChunkDataset final : public StatefulDataset< - ChunkDataset, + ChunkDataset, typename ChunkReader::BatchType, size_t> { public: using BatchType = torch::optional; using UnwrappedBatchType = typename ChunkReader::BatchType; using BatchRequestType = size_t; - using ChunkSamplerType = ChunkSampler; using ExampleSamplerType = ExampleSampler; ChunkDataset( ChunkReader chunk_reader, - ChunkSampler chunk_sampler, ExampleSampler example_sampler, + std::shared_ptr chunk_selector, ChunkDatasetOptions options) : chunk_reader_(std::move(chunk_reader)), - chunk_sampler_(std::move(chunk_sampler)), example_sampler_(std::move(example_sampler)), + chunk_selector_(std::move(chunk_selector)), options_(std::move(options)), - quit_worker_(false) { - } + quit_worker_(false) {} virtual ~ChunkDataset() { free_workers(); @@ -344,8 +489,12 @@ class ChunkDataset final chunk_reader_.reset(); - size_t chunks_to_load = chunk_reader_.chunk_count(); - chunk_sampler_.reset(chunks_to_load); + // reset the chunk selector. + chunk_selector_->reset(); + + // In distributed training, local chunk count could be different to total + // chunks availble. Chunk selector holds the truth. + size_t chunks_to_load = chunk_selector_->local_chunk_count(); // Throw out any existing cached batch in the buffer and re-creates a new // chunk buffer. @@ -360,8 +509,7 @@ class ChunkDataset final quit_worker_ = false; for (size_t i = 0; i < options_.preloader_count_; ++i) { - preload_threads_.emplace_back( - [this, i]() { this->preloader(i); }); + preload_threads_.emplace_back([this, i]() { this->preloader(i); }); } } @@ -376,8 +524,8 @@ class ChunkDataset final while (!quit_worker_.load()) { try { size_t chunk_id = 0; - if (auto chunk_sampler_result = chunk_sampler_.next(1)) { - chunk_id = chunk_sampler_result.value()[0]; + if (auto chunk_sampler_result = chunk_selector_->next()) { + chunk_id = chunk_sampler_result.value(); } else { break; } @@ -386,8 +534,7 @@ class ChunkDataset final // if the chunk is empty, skip the current chunk data and move on to // the next. batch_buffer_->skip_chunk(); - } - else { + } else { batch_buffer_->add_chunk_data(std::move(data)); } } catch (...) { @@ -415,14 +562,15 @@ class ChunkDataset final // batches and caches them in batch_buffer_. ChunkReader chunk_reader_; - // chunk sampler to shuffle different chunks - samplers::LockedSampler chunk_sampler_; - // example sampler to shuffle examples in a specific chunk ExampleSamplerType example_sampler_; + // Selects chunks and their order for this reader. + std::shared_ptr chunk_selector_; + // batch data buffer which holds chunk data from preloading thread. - std::shared_ptr> + std::shared_ptr< + detail::BatchDataBuffer> batch_buffer_; // worker thread pool diff --git a/torch/csrc/api/include/torch/data/samplers/base.h b/torch/csrc/api/include/torch/data/samplers/base.h index d57620815ac70..94fd488addfa8 100644 --- a/torch/csrc/api/include/torch/data/samplers/base.h +++ b/torch/csrc/api/include/torch/data/samplers/base.h @@ -5,7 +5,6 @@ #include #include -#include namespace torch { namespace serialize { @@ -41,42 +40,6 @@ class Sampler { /// Deserializes the `Sampler` from the `archive`. TORCH_API virtual void load(serialize::InputArchive& archive) = 0; }; - -/// Wraps a provided sampler to make it thread safe. -template -class LockedSampler - : public Sampler { - public: - using BatchRequestType = typename OriginalSampler::BatchRequestType; - - explicit LockedSampler(OriginalSampler sampler) : sampler_(std::move(sampler)) {} - - void reset(optional new_size) override { - std::lock_guard lock(this->mutex_); - sampler_.reset(new_size); - } - - optional next(size_t batch_size) override { - std::lock_guard lock(this->mutex_); - return sampler_.next(batch_size); - } - - void save(serialize::OutputArchive& archive) const override { - std::lock_guard lock(this->mutex_); - sampler_.save(archive); - } - - void load(serialize::InputArchive& archive) override { - std::lock_guard lock(this->mutex_); - sampler_.load(archive); - } - - private: - // member variable for multi-threading lock. - // declare it to be mutable for locking in const member function. - mutable std::mutex mutex_; - OriginalSampler sampler_; -}; } // namespace samplers } // namespace data } // namespace torch