diff --git a/DALI_EXTRA_VERSION b/DALI_EXTRA_VERSION index f5dbad711e2..4224411cd3c 100644 --- a/DALI_EXTRA_VERSION +++ b/DALI_EXTRA_VERSION @@ -1 +1 @@ -c6bfd2987d0d180e756232f69137467c5468b193 +0e51e444c4a0970446129db3431044b61d342a6f diff --git a/dali/operators/reader/CMakeLists.txt b/dali/operators/reader/CMakeLists.txt index 5897ef0afe0..5214dbe10f1 100644 --- a/dali/operators/reader/CMakeLists.txt +++ b/dali/operators/reader/CMakeLists.txt @@ -35,6 +35,10 @@ if (BUILD_NVDEC) list(APPEND DALI_OPERATOR_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/video_reader_resize_op.cc") endif() +if (BUILD_LIBTAR) + list(APPEND DALI_OPERATOR_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/webdataset_reader_op.cc") +endif() + if (BUILD_LMDB) list(APPEND DALI_OPERATOR_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/caffe_reader_op.cc") list(APPEND DALI_OPERATOR_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/caffe2_reader_op.cc") diff --git a/dali/operators/reader/loader/CMakeLists.txt b/dali/operators/reader/loader/CMakeLists.txt index 370527c5662..438a3e5beee 100644 --- a/dali/operators/reader/loader/CMakeLists.txt +++ b/dali/operators/reader/loader/CMakeLists.txt @@ -54,5 +54,10 @@ if (BUILD_LIBSND) "${CMAKE_CURRENT_SOURCE_DIR}/nemo_asr_loader_test.cc") endif() +if (BUILD_LIBTAR) + set(DALI_OPERATOR_SRCS ${DALI_OPERATOR_SRCS} + "${CMAKE_CURRENT_SOURCE_DIR}/webdataset_loader.cc") +endif() + set(DALI_OPERATOR_SRCS ${DALI_OPERATOR_SRCS} PARENT_SCOPE) set(DALI_OPERATOR_TEST_SRCS ${DALI_OPERATOR_TEST_SRCS} PARENT_SCOPE) diff --git a/dali/operators/reader/loader/webdataset/tar_utils.cc b/dali/operators/reader/loader/webdataset/tar_utils.cc index 1576239ea6f..b6e092fe871 100644 --- a/dali/operators/reader/loader/webdataset/tar_utils.cc +++ b/dali/operators/reader/loader/webdataset/tar_utils.cc @@ -13,11 +13,10 @@ // limitations under the License. #include "dali/operators/reader/loader/webdataset/tar_utils.h" -#include -#include #include #include #include +#include #include #include #include @@ -31,7 +30,6 @@ namespace detail { namespace { -constexpr size_t kBlockSize = T_BLOCKSIZE; static_assert(is_pow2(kBlockSize), "The implementation assumes that the block size is a power of 2"); @@ -96,8 +94,7 @@ static tartype_t kTarArchiveType = {LibtarOpenTarArchive, [](int) -> int { retur [](int, const void*, size_t) -> ssize_t { return 0; }}; TarArchive::TarArchive(std::unique_ptr stream) - : stream_(std::move(stream)), - instance_handle_(Register(this)) { + : stream_(std::move(stream)), instance_handle_(Register(this)) { tar_open(ToTarHandle(&handle_), "", &kTarArchiveType, 0, instance_handle_, TAR_GNU); stream_->Seek(0); eof_ = stream_->Size() == 0; @@ -142,14 +139,8 @@ bool TarArchive::NextFile() { } const int64_t offset = stream_->Tell() + RoundToBlockSize(filesize_) - readoffset_; - assert(offset >= 0); - if (static_cast(offset) >= stream_->Size()) { - SetEof(); - return false; - } - - stream_->Seek(stream_->Tell() + RoundToBlockSize(filesize_) - readoffset_); - current_header_ = stream_->Tell(); + current_header_ = offset; + stream_->Seek(offset); ParseHeader(); return !eof_; } @@ -159,14 +150,14 @@ bool TarArchive::EndOfArchive() const { } void TarArchive::SeekArchive(int64_t offset) { - assert(offset % T_BLOCKSIZE == 0); - readoffset_ = 0; - if (static_cast(offset) >= stream_->Size()) { - SetEof(); + if (offset == current_header_) { return; } + assert(offset % T_BLOCKSIZE == 0); + eof_ = false; + readoffset_ = 0; stream_->Seek(offset); - current_header_ = stream_->Tell(); + current_header_ = offset; ParseHeader(); } @@ -215,7 +206,6 @@ inline void TarArchive::SetEof() { filename_ = ""; filesize_ = 0; filetype_ = ENTRY_NONE; - current_header_ = stream_ ? stream_->Size() : 0; } inline void TarArchive::ParseHeader() { @@ -258,6 +248,7 @@ void TarArchive::Close() { handle_ = nullptr; } readoffset_ = 0; + current_header_ = 0; SetEof(); stream_.reset(); if (instance_handle_ >= 0) { diff --git a/dali/operators/reader/loader/webdataset/tar_utils.h b/dali/operators/reader/loader/webdataset/tar_utils.h index f9967d18731..a88af752593 100644 --- a/dali/operators/reader/loader/webdataset/tar_utils.h +++ b/dali/operators/reader/loader/webdataset/tar_utils.h @@ -15,6 +15,7 @@ #ifndef DALI_OPERATORS_READER_LOADER_WEBDATASET_TAR_UTILS_H_ #define DALI_OPERATORS_READER_LOADER_WEBDATASET_TAR_UTILS_H_ +#include #include #include #include @@ -24,6 +25,8 @@ namespace dali { namespace detail { +constexpr size_t kBlockSize = T_BLOCKSIZE; + /** * @brief Used to access .tar archives through the given FileStream */ @@ -103,6 +106,11 @@ class DLL_PUBLIC TarArchive { */ bool EndOfFile() const; + /** + * @brief Frees the underlying file stream + */ + void Close(); + private: std::unique_ptr stream_; int instance_handle_ = -1; @@ -120,7 +128,6 @@ class DLL_PUBLIC TarArchive { void SetEof(); void ParseHeader(); - void Close(); // resets objects to default values }; } // namespace detail diff --git a/dali/operators/reader/loader/webdataset/tar_utils_test.cc b/dali/operators/reader/loader/webdataset/tar_utils_test.cc index bc1c2170ade..7d287d8c891 100644 --- a/dali/operators/reader/loader/webdataset/tar_utils_test.cc +++ b/dali/operators/reader/loader/webdataset/tar_utils_test.cc @@ -25,14 +25,15 @@ #include "dali/operators/reader/loader/filesystem.h" #include "dali/util/file.h" #include "dali/core/util.h" +#include "dali/test/dali_test_config.h" namespace dali { namespace detail { TEST(LibTarUtilsTestSimple, Interface) { - std::string filepath(dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"), + std::string filepath(dali::filesystem::join_path(testing::dali_extra_path(), "db/webdataset/MNIST/devel-2.tar")); - std::string dummy_filepath(dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"), + std::string dummy_filepath(dali::filesystem::join_path(testing::dali_extra_path(), "db/webdataset/MNIST/devel-1.tar")); TarArchive dummy_archive(FileStream::Open(dummy_filepath, false, false)); @@ -66,7 +67,7 @@ TEST(LibTarUtilsTestSimple, Interface) { } TEST(LibTarUtilsTestSimple, LongNameIndexing) { - std::string filepath(dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"), + std::string filepath(dali::filesystem::join_path(testing::dali_extra_path(), "db/webdataset/sample-tar/gnu.tar")); TarArchive archive(FileStream::Open(filepath, false, false)); std::string name_prefix(128, '#'); @@ -78,7 +79,7 @@ TEST(LibTarUtilsTestSimple, LongNameIndexing) { } TEST(LibTarUtilsTestSimple, Types) { - std::string filepath(dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"), + std::string filepath(dali::filesystem::join_path(testing::dali_extra_path(), "db/webdataset/sample-tar/types.tar")); std::vector types = { TarArchive::ENTRY_BLOCKDEV, TarArchive::ENTRY_CHARDEV, TarArchive::ENTRY_DIR, @@ -99,7 +100,7 @@ TEST(LibTarUtilsTestSimple, Types) { } TEST(LibTarUtilsTestSimple, Offset) { - std::string filepath(dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"), + std::string filepath(dali::filesystem::join_path(testing::dali_extra_path(), "db/webdataset/sample-tar/types.tar")); TarArchive archive(FileStream::Open(filepath, false, true)); @@ -195,28 +196,28 @@ auto SimpleTarTestsValues() { vector values; SimpleTarTestsData filepaths[] = { - { dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"), + { dali::filesystem::join_path(testing::dali_extra_path(), "db/webdataset/MNIST/devel-0.tar"), false, false, 2000, 3000, {".cls", ".jpg"} }, - { dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"), + { dali::filesystem::join_path(testing::dali_extra_path(), "db/webdataset/sample-tar/empty.tar"), false, false, 0, 0, {} }, - { dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"), + { dali::filesystem::join_path(testing::dali_extra_path(), "db/webdataset/sample-tar/v7.tar"), false, false, 0, 1000, {""} }, - { dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"), + { dali::filesystem::join_path(testing::dali_extra_path(), "db/webdataset/sample-tar/oldgnu.tar"), false, false, @@ -233,7 +234,7 @@ auto SimpleTarTestsValues() { } } } - return testing::ValuesIn(values.begin(), values.end()); + return ::testing::ValuesIn(values.begin(), values.end()); } INSTANTIATE_TEST_SUITE_P(LibTarUtilsTestParametrized, SimpleTarTests, SimpleTarTestsValues()); @@ -246,7 +247,7 @@ class MultiTarTests : public ::testing::TestWithParam { const std::pair ranges[kMultithreadedSamples] = {{2000, 3000}, {0, 1000}, {1000, 2000}}; void SetUp() final { std::string filepath_prefix( - dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"), "db/webdataset/MNIST/devel-")); + dali::filesystem::join_path(testing::dali_extra_path(), "db/webdataset/MNIST/devel-")); std::string filepaths[kMultithreadedSamples] = { filepath_prefix + "0.tar", filepath_prefix + "1.tar", filepath_prefix + "2.tar"}; diff --git a/dali/operators/reader/loader/webdataset_loader.cc b/dali/operators/reader/loader/webdataset_loader.cc new file mode 100644 index 00000000000..710be95d1ef --- /dev/null +++ b/dali/operators/reader/loader/webdataset_loader.cc @@ -0,0 +1,372 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dali/operators/reader/loader/webdataset_loader.h" +#include +#include +#include +#include +#include +#include "dali/core/error_handling.h" +#include "dali/pipeline/data/types.h" +#include "dali/util/file.h" + +namespace dali { + +template +inline std::string IndexFileErrMsg(const std::string& index_path, int64_t line, + const Args&... details) { + return make_string("Malformed index file at ", index_path, " line ", line, " - ", details...); +} + +namespace detail { +namespace wds { + +inline MissingExtBehavior ParseMissingExtBehavior(std::string missing_component_behavior) { + std::transform(missing_component_behavior.begin(), missing_component_behavior.end(), + missing_component_behavior.begin(), static_cast(std::tolower)); + if (missing_component_behavior == "") { + return MissingExtBehavior::Empty; + } else if (missing_component_behavior == "skip") { + return MissingExtBehavior::Skip; + } else if (missing_component_behavior == "empty") { + return MissingExtBehavior::Empty; + } else if (missing_component_behavior == "error") { + return MissingExtBehavior::Raise; + } else { + return MissingExtBehavior::Invalid; + } +} + + +inline void ParseSampleDesc(std::vector& samples_container, + std::vector& components_container, + std::ifstream& index_file, const std::string& index_path, + int64_t line) { + // Preparing the SampleDesc + samples_container.emplace_back(); + samples_container.back().components = + VectorRange(components_container, components_container.size()); + + // Getting the components data + std::string components_metadata; + std::getline(index_file, components_metadata); + std::stringstream extensions_stream(components_metadata); + + // Reading consecutive components + ComponentDesc component; + while (extensions_stream >> component.ext) { + DALI_ENFORCE(extensions_stream >> component.offset >> component.size, + IndexFileErrMsg(index_path, line, + "size or offset corresponding to the extension not found")); + DALI_ENFORCE( + component.offset % kBlockSize == 0, + IndexFileErrMsg(index_path, line, "tar offset is not a multiple of tar block size (", + kBlockSize, "), perhaps the size value is exported before offset?")); + components_container.emplace_back(std::move(component)); + samples_container.back().components.num++; + } + + // Finishing up the SampleDesc + DALI_ENFORCE(samples_container.back().components.num, + IndexFileErrMsg(index_path, line, "no extensions provided for the sample")); +} + +inline void ParseIndexFile(std::vector& samples_container, + std::vector& components_container, + const std::string& index_path) { + std::ifstream index_file(index_path); + + // Index Checking + std::string global_meta; + getline(index_file, global_meta); + std::stringstream global_meta_stream(global_meta); + std::string index_version; + DALI_ENFORCE(global_meta_stream >> index_version, + IndexFileErrMsg(index_path, 0, "no version signature found")); + DALI_ENFORCE(kCurrentIndexVersion == index_version, + IndexFileErrMsg( + index_path, 0, + "the version of the index file does not match the expected version (expected: ", + kCurrentIndexVersion, " actual: ", index_version, ")")); + + // Getting the number of samples in the index file + int64_t sample_desc_num_signed; + DALI_ENFORCE(global_meta_stream >> sample_desc_num_signed, + IndexFileErrMsg(index_path, 0, "no sample count found")); + DALI_ENFORCE(sample_desc_num_signed > 0, + IndexFileErrMsg(index_path, 0, "sample count must be positive")); + + const size_t sample_desc_num = sample_desc_num_signed; + samples_container.reserve(samples_container.size() + sample_desc_num); + for (size_t sample_index = 0; sample_index < sample_desc_num; sample_index++) { + ParseSampleDesc(samples_container, components_container, index_file, index_path, + sample_index + 1); + } +} + +} // namespace wds +} // namespace detail + +inline std::string SupportedTypesListGen() { + std::stringstream out; + for (auto& dtype : detail::wds::kSupportedTypes) { + out << dtype << ", "; + } + std::string out_str = out.str(); + return out_str.substr(0, out_str.size() - 2 * (detail::wds::kSupportedTypes.size() > 0)); +} + +WebdatasetLoader::WebdatasetLoader(const OpSpec& spec) + : Loader(spec), + paths_(spec.GetRepeatedArgument("paths")), + index_paths_(spec.GetRepeatedArgument("index_paths")), + missing_component_behavior_(detail::wds::ParseMissingExtBehavior( + spec.GetArgument("missing_component_behavior"))) { + DALI_ENFORCE(paths_.size() == index_paths_.size(), + "Number of webdataset archives does not match the number of index files"); + DALI_ENFORCE(paths_.size() > 0, "No webdataset archives provided"); + DALI_ENFORCE(missing_component_behavior_ != detail::wds::MissingExtBehavior::Invalid, + make_string("Invalid value for missing_component_behavior '", + spec.GetArgument("missing_component_behavior"), + "' possible values are: skip, error, empty")); + + std::vector samples_exts = spec.GetRepeatedArgument("ext"); + ext_.reserve(samples_exts.size()); + + // splitting extension bundles by the delimiter + for (size_t exts_idx = 0; exts_idx < samples_exts.size(); exts_idx++) { + std::stringstream exts_stream(samples_exts[exts_idx]); + std::string ext; + ext_.emplace_back(); + while (std::getline(exts_stream, ext, detail::wds::kExtDelim)) { + if (!ext_.back().count(ext)) { + ext_.back().insert(ext); + } + } + } + + auto dtypes_ids = spec.HasArgument("dtypes") ? spec.GetRepeatedArgument("dtypes") : + std::vector(ext_.size(), DALI_UINT8); + dtypes_.reserve(dtypes_ids.size()); + std::transform(dtypes_ids.begin(), dtypes_ids.end(), std::back_inserter(dtypes_), + TypeTable::GetTypeInfo); + + for (auto& dtype : dtypes_) { + DALI_ENFORCE(detail::wds::kSupportedTypes.count(dtype.id()), + make_string("Unsupported output dtype ", dtype.name(), + ". Supported types are: ", SupportedTypesListGen())); + } + DALI_ENFORCE(ext_.size() == dtypes_.size(), + "Number of extensions does not match the number of provided types"); +} + +WebdatasetLoader::~WebdatasetLoader() {} + +void WebdatasetLoader::PrepareEmpty(vector>& empty) { + empty = std::vector>(ext_.size()); + for (size_t output_index = 0; output_index < ext_.size(); output_index++) { + empty[output_index].set_pinned(false); + empty[output_index].reserve(tensor_init_bytes_); + empty[output_index].set_type(dtypes_[output_index]); + } +} + +inline std::string GetExtension(const std::string& filepath) { + const size_t dot_pos = filepath.find_first_of('.', filepath.find_last_of('/') + 1); + return filepath.substr(dot_pos + 1); +} + +void WebdatasetLoader::ReadSample(vector>& sample) { + MoveToNextShard(sample_index_); + detail::wds::SampleDesc& current_sample = samples_[sample_index_]; + auto& current_wds_shard = wds_shards_[current_sample.wds_shard_index]; + + for (auto& component : current_sample.components) { + current_wds_shard.SeekArchive(component.offset); + + // Checking if the component data from the index file agrees with reality + const auto& index_path = index_paths_[current_sample.wds_shard_index]; + DALI_ENFORCE(!current_wds_shard.EndOfArchive(), + IndexFileErrMsg(index_path, current_sample.line_number, + "offset is outside of the archive file")); + DALI_ENFORCE( + current_wds_shard.GetFileType() == detail::TarArchive::ENTRY_FILE, + IndexFileErrMsg(index_path, current_sample.line_number, "component of a non-file type")); + DALI_ENFORCE(GetExtension(current_wds_shard.GetFileName()) == component.ext, + IndexFileErrMsg(index_path, current_sample.line_number, + "component extension does not match the archive entry extension")); + DALI_ENFORCE(current_wds_shard.GetFileSize() == component.size, + IndexFileErrMsg(index_path, current_sample.line_number, + "component size does not match the archive entry size")); + + + // Skipping cached samples + const std::string source_info = + make_string("archive ", paths_[current_sample.wds_shard_index], "index file ", + index_paths_[current_sample.wds_shard_index], "line ", + current_sample.line_number, "component offset ", component.offset); + DALIMeta meta; + meta.SetSourceInfo(source_info); + if (ShouldSkipImage(source_info)) { + meta.SetSkipSample(true); + for (auto& output : component.outputs) { + sample[output].Reset(); + sample[output].SetMeta(meta); + sample[output].Resize({0}, dtypes_[output]); + } + continue; + } + + // Reading Data + if (copy_read_data_) { + uint8_t* shared_tensor_data = nullptr; + for (auto& output : component.outputs) { + if (!shared_tensor_data) { + if (sample[output].shares_data()) { + sample[output].Reset(); + } + sample[output].Resize( + {static_cast(component.size / sample[output].type().size())}, + dtypes_[output]); + shared_tensor_data = reinterpret_cast(sample[output].raw_mutable_data()); + } else { + sample[output].ShareData( + shared_tensor_data, component.size, + {static_cast(component.size / sample[output].type().size())}, + sample[output].type()); + } + } + DALI_ENFORCE(current_wds_shard.Read(shared_tensor_data, component.size) == component.size, + "Error reading from a file " + paths_[current_sample.wds_shard_index]); + } else { + auto data = current_wds_shard.ReadFile(); + for (auto& output : component.outputs) { + sample[output].SetMeta(meta); + sample[output].ShareData( + data, component.size, + {static_cast(component.size / sample[output].type().size())}, + sample[output].type()); + } + } + } + + // Setting non-filled outputs + for (auto& empty_output : current_sample.empty_outputs) { + sample[empty_output].Reset(); + sample[empty_output].Resize({0}, dtypes_[empty_output]); + } + sample_index_++; +} + +Index WebdatasetLoader::SizeImpl() { + return samples_.size(); +} + +void WebdatasetLoader::PrepareMetadataImpl() { + if (!dont_use_mmap_) { + mmap_reserver_ = FileStream::MappingReserver(static_cast(paths_.size())); + } + copy_read_data_ = dont_use_mmap_ || !mmap_reserver_.CanShareMappedData(); + + // initializing all the readers + wds_shards_.reserve(paths_.size()); + for (auto& uri : paths_) { + wds_shards_.emplace_back(FileStream::Open(uri, read_ahead_, !dont_use_mmap_)); + } + + // preparing the map from extensions to outputs + std::unordered_map> ext_map; + for (size_t output_index = 0; output_index < ext_.size(); output_index++) { + for (auto& ext : ext_[output_index]) { + ext_map[ext].push_back(output_index); + } + } + + // collecting and filtering the index files + std::vector unfiltered_samples; + std::vector unfiltered_components; + bitmask was_output_set; + was_output_set.resize(ext_.size(), false); + output_indicies_.reserve(ext_.size()); + for (size_t wds_shard_index = 0; wds_shard_index < index_paths_.size(); wds_shard_index++) { + unfiltered_samples.resize(0); + unfiltered_components.resize(0); + detail::wds::ParseIndexFile(unfiltered_samples, unfiltered_components, + index_paths_[wds_shard_index]); + + for (auto& sample : unfiltered_samples) { + detail::wds::SampleDesc new_sample{ + detail::wds::VectorRange(components_, components_.size()), + detail::wds::VectorRange(empty_outputs_, empty_outputs_.size()), wds_shard_index, + sample.line_number}; + + size_t start_outputs_index = output_indicies_.size(); + + for (auto& component : sample.components) { + component.outputs = + detail::wds::VectorRange(output_indicies_, output_indicies_.size()); + for (auto& output : ext_map[component.ext]) { + if (!was_output_set[output]) { + DALI_ENFORCE( + component.size % dtypes_[output].size() == 0, + make_string("Error in index file at ", index_paths_[wds_shard_index], " line ", + sample.line_number, " - component size and dtype incompatible")); + output_indicies_.push_back(output); + component.outputs.num++; + was_output_set[output] = true; + } + } + if (component.outputs.num) { + components_.push_back(std::move(component)); + new_sample.components.num++; + } + } + + if (new_sample.components.num < ext_.size()) { + switch (missing_component_behavior_) { + case detail::wds::MissingExtBehavior::Empty: + for (size_t output = 0; output < ext_.size(); output++) { + if (!was_output_set[output]) { + empty_outputs_.push_back(output); + new_sample.empty_outputs.num++; + } + } + samples_.push_back(new_sample); + break; + case detail::wds::MissingExtBehavior::Skip: + components_.resize(new_sample.components.start); + output_indicies_.resize(start_outputs_index); + break; + case detail::wds::MissingExtBehavior::Raise: + DALI_FAIL(make_string("Underful sample detected at ", index_paths_[wds_shard_index], + " line ", sample.line_number)); + break; + default: + break; + } + } else { + samples_.push_back(new_sample); + } + was_output_set.fill(false); + } + } + sample_index_ = start_index(shard_id_, num_shards_, samples_.size()); +} + +void WebdatasetLoader::Reset(bool wrap_to_shard) { + sample_index_ = wrap_to_shard ? start_index(shard_id_, num_shards_, samples_.size()) : 0; +} + +} // namespace dali diff --git a/dali/operators/reader/loader/webdataset_loader.h b/dali/operators/reader/loader/webdataset_loader.h new file mode 100644 index 00000000000..dc84f99cf3f --- /dev/null +++ b/dali/operators/reader/loader/webdataset_loader.h @@ -0,0 +1,122 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_OPERATORS_READER_LOADER_WEBDATASET_LOADER_H_ +#define DALI_OPERATORS_READER_LOADER_WEBDATASET_LOADER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "dali/core/bitmask.h" +#include "dali/operators/reader/loader/loader.h" +#include "dali/operators/reader/loader/webdataset/tar_utils.h" +#include "dali/pipeline/data/tensor.h" + +namespace dali { +namespace detail { +namespace wds { + +constexpr char kCurrentIndexVersion[] = "v1.0"; +constexpr char kExtDelim = ';'; +const std::set kSupportedTypes = {DALI_UINT8, DALI_UINT16, DALI_UINT32, DALI_UINT64, + DALI_INT8, DALI_INT16, DALI_INT32, DALI_INT64, + DALI_FLOAT16, DALI_FLOAT, DALI_FLOAT64}; + +enum class MissingExtBehavior { + Empty, + Skip, + Raise, + Invalid +}; +MissingExtBehavior ParseMissingExtBehavior(std::string); + +template +class VectorRange { + private: + std::vector* data_ = nullptr; + + public: + size_t start = 0; + size_t num = 0; + VectorRange() = default; + explicit inline VectorRange(std::vector& data, size_t start_idx = 0, size_t count = 0) + : data_(&data), start(start_idx), num(count) {} + + inline T* begin() { + return data_->data() + start; + } + + inline T* end() { + return begin() + num; + } +}; + +struct ComponentDesc { + std::string ext; + size_t size = 0; + int64_t offset = 0; + VectorRange outputs; + + ComponentDesc() = default; + ComponentDesc(std::string new_ext, size_t new_size, int64_t new_offset) + : ext(std::move(new_ext)), size(new_size), offset(new_offset) {} +}; + +struct SampleDesc { + VectorRange components; + VectorRange empty_outputs; + size_t wds_shard_index; + int64_t line_number; +}; + +} // namespace wds +} // namespace detail + +class DLL_PUBLIC WebdatasetLoader : public Loader>> { + public: + explicit WebdatasetLoader(const OpSpec& spec); + ~WebdatasetLoader() override; + + void PrepareEmpty(std::vector>&) override; + void ReadSample(std::vector>&) override; + + protected: + Index SizeImpl() override; + void PrepareMetadataImpl() override; + void Reset(bool wrap_to_shard) override; + + std::vector paths_; + std::vector index_paths_; + std::vector> ext_; + std::vector dtypes_; + detail::wds::MissingExtBehavior missing_component_behavior_; + + private: + std::vector samples_; // data from the index files + std::vector components_; // data about the components held + // together for space optimization + std::vector empty_outputs_; // indices of empty outputs to fill in for space optimization + std::vector output_indicies_; // indices of outputs that a component corresponds to + + std::vector wds_shards_; + size_t sample_index_ = 0; + FileStream::MappingReserver mmap_reserver_; +}; + +} // namespace dali +#endif // DALI_OPERATORS_READER_LOADER_WEBDATASET_LOADER_H_ diff --git a/dali/operators/reader/webdataset_reader_op.cc b/dali/operators/reader/webdataset_reader_op.cc new file mode 100644 index 00000000000..28d394b4be6 --- /dev/null +++ b/dali/operators/reader/webdataset_reader_op.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dali/operators/reader/webdataset_reader_op.h" +#include +#include +#include + +namespace dali { + +bool WebdatasetReader::SetupImpl(std::vector& output_desc, const HostWorkspace& ws) { + DataReader>>::SetupImpl(output_desc, ws); + int num_outputs = ws.NumOutput(); + int num_samples = GetCurrBatchSize(); + + output_desc.resize(num_outputs); + for (int output_idx = 0; output_idx < num_outputs; output_idx++) { + output_desc[output_idx].shape = TensorListShape<>(num_samples, 1); + } + + for (int data_idx = 0; data_idx < num_samples; data_idx++) { + auto& sample = GetSample(data_idx); + for (int output_idx = 0; output_idx < num_outputs; output_idx++) { + output_desc[output_idx].shape.set_tensor_shape(data_idx, sample[output_idx].shape()); + output_desc[output_idx].type = sample[output_idx].type(); + } + } + return true; +} + +void WebdatasetReader::RunImpl(HostWorkspace& ws) { + int num_outputs = ws.NumOutput(); + int num_samples = GetCurrBatchSize(); + + for (int data_idx = 0; data_idx < num_samples; data_idx++) { + auto& sample = GetSample(data_idx); + for (int output_idx = 0; output_idx < num_outputs; output_idx++) { + ws.OutputRef(output_idx)[data_idx].SetMeta(sample[output_idx].GetMeta()); + std::memcpy(ws.OutputRef(output_idx)[data_idx].raw_mutable_data(), + sample[output_idx].raw_data(), sample[output_idx].nbytes()); + } + } +} + +DALI_SCHEMA(readers__Webdataset) + .DocStr((std::string) R"code(A reader for the webdataset format. + +The webdataset format is a way of providing efficient access to datasets stored in tar archives. + +Storing data in POSIX tar archives greatly speeds up I/O operations on mechanical storage devices +and on network file systems because it allows the operating system to reduce the number of I/O +operations and to read the data ahead. + +WebDataset fulfils a similar function to Tensorflow's TFRecord/tf.Example classes, but is much +easier to adopt because it does not actually require any data conversion. The data is stored in +exactly the same format inside tar files as it is on disk, and all preprocessing and data +augmentation code remains unchanged. + +The dataset consists of one or more tar archives, each of which is further split into samples. +A sample contains one or more components that correspond to the actual files contained within +the archive. The components that belong to a specific sample are aggregated by filename without +extension (for the specifics about the extensions please read the description of the ``ext`` parameter +below). Note that samples with their filename starting with a dot will not be loaded, as well as +entries that are not regular files. + +In addition to the tar archive with data, each archive should come with a corresponding index file. +The index file can be generated using a dedicated script:: + + ``/tools/wds2idx.py `` + +The format of the index file is: +)code" + detail::wds::kCurrentIndexVersion + + R"code( + ... +... + + +Based on https://github.com/webdataset/webdataset)code") + .NumInput(0) + .OutputFn([](const OpSpec& spec) { + return spec.HasArgument("ext") ? spec.GetRepeatedArgument("ext").size() : 0; + }) + .AddArg("paths", R"code(The list of (one or more) paths to the webdataset archives. + +Has to be the same length as the ``index_paths`` argument.)code", + DALI_STRING_VEC) + .AddArg("index_paths", + R"code(The list of the index files corresponding to the respective webdataset archives. + +Has to be the same length as the ``paths`` argument.)code", + DALI_STRING_VEC) + .AddArg("ext", R"code(The extension sets for each of the outputs produced. + +The number of extension sets determines the number of outputs of the reader. +The extensions of the components are counted as the text after the first dot in the name of the file +(excluding the samples starting with a dot). The different extension options should be separated +with a semicolon (';') and may contain dots. + +Example: "left.png;right.jpg")code", + DALI_STRING_VEC) + .AddOptionalArg( + "missing_component_behavior", + R"code(Specifies what to do in case there is not any file in a sample corresponding to a certain output. + +Possible behaviors: + - "empty" (default) - in that case the output that was not set will just contain an empty tensor + - "skip" - in that case the entire sample will just be skipped (no penalty to performance except + for reduced caching of the archive) + - "error" - in that case an exception will be raised and te execution stops)code", + "") + .AddOptionalArg("dtypes", R"code(Data types of the respective outputs. + +The default output data types are UINT8. However, if set, each output data type should be specified. +Moreover, the tar file should be constructed so that it will only output a sample with its byte size +divisible by the size of the data type.)code", + DALI_DATA_TYPE_VEC, + nullptr) // default is a vector of uint8 + .AddParent("LoaderBase"); + +DALI_REGISTER_OPERATOR(readers__Webdataset, WebdatasetReader, CPU); + +} // namespace dali diff --git a/dali/operators/reader/webdataset_reader_op.h b/dali/operators/reader/webdataset_reader_op.h new file mode 100644 index 00000000000..6de904924c6 --- /dev/null +++ b/dali/operators/reader/webdataset_reader_op.h @@ -0,0 +1,44 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_OPERATORS_READER_WEBDATASET_READER_OP_H_ +#define DALI_OPERATORS_READER_WEBDATASET_READER_OP_H_ + +#include +#include "dali/operators/reader/loader/webdataset_loader.h" +#include "dali/operators/reader/reader_op.h" +#include "dali/pipeline/data/tensor.h" + +namespace dali { + +class DLL_PUBLIC WebdatasetReader : public DataReader>> { + public: + explicit WebdatasetReader(const OpSpec& spec) + : DataReader>>(spec) { + loader_ = InitLoader(spec); + } + + bool SetupImpl(std::vector& output_desc, const HostWorkspace&) override; + void RunImpl(HostWorkspace& ws) override; + bool CanInferOutputs() const override { + return true; + } + + protected: + USE_READER_OPERATOR_MEMBERS(CPUBackend, vector>); +}; + +} // namespace dali + +#endif // DALI_OPERATORS_READER_WEBDATASET_READER_OP_H_ diff --git a/dali/test/python/test_dali_cpu_only.py b/dali/test/python/test_dali_cpu_only.py index 48e9947b71f..e724051ea57 100644 --- a/dali/test/python/test_dali_cpu_only.py +++ b/dali/test/python/test_dali_cpu_only.py @@ -26,6 +26,7 @@ import scipy.io.wavfile from PIL import Image, ImageEnhance from test_detection_pipeline import coco_anchors +from webdataset_base import generate_temp_index_file as generate_temp_wds_index import re import numpy as np @@ -45,6 +46,7 @@ caffe2_dir = os.path.join(data_root, 'db', 'c2lmdb') recordio_dir = os.path.join(data_root, 'db', 'recordio') tfrecord_dir = os.path.join(data_root, 'db', 'tfrecord') +webdataset_dir = os.path.join(data_root, 'db', 'webdataset') coco_dir = os.path.join(data_root, 'db', 'coco', 'images') coco_annotation = os.path.join(data_root, 'db', 'coco', 'instances.json') sequence_dir = os.path.join(data_root, 'db', 'sequence', 'frames') @@ -497,6 +499,17 @@ def test_tfrecord_reader_cpu(): for _ in range(3): pipe.run() +def test_webdataset_reader_cpu(): + webdataset = os.path.join(webdataset_dir, 'MNIST', 'devel-0.tar') + webdataset_idx = generate_temp_wds_index(webdataset) + check_no_input(fn.readers.webdataset, + paths=webdataset, + index_paths=webdataset_idx.name, + ext=["jpg", "cls"], + shard_id=0, + num_shards=1) + + def test_coco_reader_cpu(): check_no_input(fn.readers.coco, file_root=coco_dir, annotations_file=coco_annotation, shard_id=0, num_shards=1) @@ -982,6 +995,7 @@ def test_subscript_dim_check(): "readers.caffe2", "readers.coco", "readers.numpy", + "readers.webdataset", "coin_flip", "uniform", "random.uniform", diff --git a/dali/test/python/test_dali_variable_batch_size.py b/dali/test/python/test_dali_variable_batch_size.py index fcaf31f9f46..0629c96125c 100644 --- a/dali/test/python/test_dali_variable_batch_size.py +++ b/dali/test/python/test_dali_variable_batch_size.py @@ -1210,6 +1210,7 @@ def test_subscript_dim_check(): "readers.nemo_asr", # readers do do not support variable batch size yet "readers.video", # readers do do not support variable batch size yet "readers.video_resize", # readers do do not support variable batch size yet + "readers.webdataset", # readers do do not support variable batch size yet ] def test_coverage(): diff --git a/dali/test/python/test_operator_readers_webdataset_big.py b/dali/test/python/test_operator_readers_webdataset_big.py new file mode 100644 index 00000000000..e8dc640e45d --- /dev/null +++ b/dali/test/python/test_operator_readers_webdataset_big.py @@ -0,0 +1,100 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from webdataset_base import * + +def cross_check( + dont_use_mmap, + batch_size, + num_shards, + shard_id, + skip_cached_images, + pad_last_batch, + stick_to_shard, +): + num_multiplications = 4 + num_samples = 20 * num_multiplications + tar_file_paths = [ + os.path.join(get_dali_extra_path(), "db/webdataset/sample-tar/cross.tar") + ] * num_multiplications + index_files = [generate_temp_index_file(tar_file_path) for tar_file_path in tar_file_paths] + + extract_dirs = [generate_temp_extract(tar_file_path) for tar_file_path in tar_file_paths] + equivalent_files = sum( + ( + sorted( + glob(extract_dir.name + "/*"), + key=lambda s: (int(s[s.rfind("/") + 1 : s.find(".")]), s), + ) + for extract_dir in extract_dirs + ), + [], + ) + + compare_pipelines( + webdataset_raw_pipeline( + tar_file_paths, + [index_file.name for index_file in index_files], + ["a.a;a.b;a.a;a.b", "b.a;b.b;b.a;b.b"], + batch_size=batch_size, + device_id=0, + num_threads=10, + dont_use_mmap=dont_use_mmap, + num_shards=num_shards, + shard_id=shard_id, + prefetch_queue_depth=8, + skip_cached_images=skip_cached_images, + pad_last_batch=pad_last_batch, + stick_to_shard=stick_to_shard, + ), + file_reader_pipeline( + equivalent_files, + ["a.a", "b.a"], + batch_size=batch_size, + device_id=0, + num_threads=10, + dont_use_mmap=True, + num_shards=num_shards, + shard_id=shard_id, + skip_cached_images=skip_cached_images, + pad_last_batch=pad_last_batch, + stick_to_shard=stick_to_shard, + ), + batch_size, + math.ceil(num_samples / test_batch_size), + ) + + +def test_cross_check(): + scenarios = [ + ( + dont_use_mmap, + batch_size, + num_shards, + shard_id, + skip_cached_images, + pad_last_batch, + stick_to_shard, + ) + for dont_use_mmap in (False, True) + for stick_to_shard in (False, True) + for pad_last_batch in (False, True) + for skip_cached_images in (False, True) + for batch_size in (1, 8) if batch_size != 1 or not pad_last_batch + for num_shards in (1, 80) + for shard_id in {0, num_shards - 1} + ] + + for args in scenarios: + yield (cross_check,) + args diff --git a/dali/test/python/test_operator_readers_webdataset_corner.py b/dali/test/python/test_operator_readers_webdataset_corner.py new file mode 100644 index 00000000000..9ee082712ba --- /dev/null +++ b/dali/test/python/test_operator_readers_webdataset_corner.py @@ -0,0 +1,301 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from webdataset_base import * + + +def general_corner_case( + test_batch_size=test_batch_size, dtypes=None, missing_component_behavior="", **kwargs +): + num_samples = 1000 + tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar") + index_file = generate_temp_index_file(tar_file_path) + + extract_dir = generate_temp_extract(tar_file_path) + equivalent_files = sorted( + glob(extract_dir.name + "/*"), key=lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")]) + ) + + compare_pipelines( + webdataset_raw_pipeline( + tar_file_path, + index_file.name, + ["jpg", "cls"], + missing_component_behavior=missing_component_behavior, + dtypes=dtypes, + batch_size=test_batch_size, + device_id=0, + num_threads=1, + **kwargs + ), + file_reader_pipeline( + equivalent_files, + ["jpg", "cls"], + batch_size=test_batch_size, + device_id=0, + num_threads=1, + **kwargs + ), + test_batch_size, + math.ceil(num_samples / test_batch_size), + ) + +def test_mmap_dtype_incompatibility(): + assert_raises( + RuntimeError, + general_corner_case, + dtypes=[dali.types.INT8, dali.types.FLOAT64], + glob="component size and dtype incompatible", + ) + +def test_lazy_init(): + general_corner_case(lazy_init=True) + + +def test_read_ahead(): + general_corner_case(read_ahead=True) + +def test_single_sample(): + test_batch_size = 1 + num_samples = 1 + tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/sample-tar/single.tar") + index_file = generate_temp_index_file(tar_file_path) + + extract_dir = generate_temp_extract(tar_file_path) + equivalent_files = list(sorted(glob(extract_dir.name + "/*"))) + + compare_pipelines( + webdataset_raw_pipeline( + tar_file_path, + index_file.name, + ["txt"], + missing_component_behavior="skip", + batch_size=test_batch_size, + device_id=0, + num_threads=1, + ), + file_reader_pipeline( + equivalent_files, ["txt"], batch_size=test_batch_size, device_id=0, num_threads=1 + ), + test_batch_size, + math.ceil(num_samples / test_batch_size) * 10, + ) + wds_pipeline = webdataset_raw_pipeline( + tar_file_path, + index_file.name, + ["txt"], + batch_size=test_batch_size, + device_id=0, + num_threads=1, + ) + wds_pipeline.build() + assert_equal(list(wds_pipeline.epoch_size().values())[0], num_samples) + + +def test_single_sample_and_junk(): + test_batch_size = 1 + num_samples = 1 + tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/sample-tar/single_junk.tar") + index_file = generate_temp_index_file(tar_file_path) + + extract_dir = generate_temp_extract(tar_file_path) + equivalent_files = list(sorted(glob(extract_dir.name + "/*"))) + + compare_pipelines( + webdataset_raw_pipeline( + tar_file_path, + index_file.name, + ["txt"], + batch_size=test_batch_size, + device_id=0, + num_threads=1, + ), + file_reader_pipeline( + equivalent_files, ["txt"], batch_size=test_batch_size, device_id=0, num_threads=1 + ), + test_batch_size, + math.ceil(num_samples / test_batch_size) * 10, + ) + wds_pipeline = webdataset_raw_pipeline( + tar_file_path, + index_file.name, + ["txt"], + batch_size=test_batch_size, + device_id=0, + num_threads=1, + ) + wds_pipeline.build() + assert_equal(list(wds_pipeline.epoch_size().values())[0], num_samples) + + +def test_wide_sample(): + test_batch_size = 1 + num_samples = 1 + tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/sample-tar/wide.tar") + index_file = generate_temp_index_file(tar_file_path) + + extract_dir = generate_temp_extract(tar_file_path) + equivalent_files = list(sorted(glob(extract_dir.name + "/*"))) + + num_components = 1000 + compare_pipelines( + webdataset_raw_pipeline( + tar_file_path, + index_file.name, + [str(x) for x in range(num_components)], + batch_size=test_batch_size, + device_id=0, + num_threads=1, + ), + file_reader_pipeline( + equivalent_files, + [str(x) for x in range(num_components)], + batch_size=test_batch_size, + device_id=0, + num_threads=1, + ), + test_batch_size, + math.ceil(num_samples / test_batch_size) * 10, + ) + wds_pipeline = webdataset_raw_pipeline( + tar_file_path, + index_file.name, + ["txt"], + batch_size=test_batch_size, + device_id=0, + num_threads=1, + ) + wds_pipeline.build() + assert_equal(list(wds_pipeline.epoch_size().values())[0], num_samples) + + + + +def test_argument_errors(): + def paths_index_paths_error(): + webdataset_pipeline = webdataset_raw_pipeline( + [ + os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar"), + os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-1.tar"), + os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-2.tar"), + ], + [], + ["jpg", "cls"], + batch_size=1, + device_id=0, + num_threads=1, + ) + webdataset_pipeline.build() + + assert_raises( + RuntimeError, + paths_index_paths_error, + glob="Number of webdataset archives does not match the number of index files", + ) + + assert_raises( + RuntimeError, + general_corner_case, + missing_component_behavior="SomethingInvalid", + glob="Invalid value for missing_component_behavior", + ) + general_corner_case(missing_component_behavior="Skip") + + assert_raises( + RuntimeError, + general_corner_case, + dtypes=[dali.types.STRING, dali.types.STRING], + glob="Unsupported output dtype *. Supported types are", + ) + assert_raises( + RuntimeError, + general_corner_case, + dtypes=dali.types.INT8, + glob="Number of extensions does not match the number of provided types", + ) + + +def general_index_error( + index_file_contents, tar_file_path="db/webdataset/MNIST/devel-0.tar", ext="jpg" +): + index_file = tempfile.NamedTemporaryFile() + index_file.write(index_file_contents) + index_file.flush() + webdataset_pipeline = webdataset_raw_pipeline( + os.path.join(get_dali_extra_path(), tar_file_path), + index_file.name, + ext, + batch_size=1, + device_id=0, + num_threads=1, + ) + webdataset_pipeline.build() + webdataset_pipeline.run() + webdataset_pipeline.run() + + +def test_index_errors(): + assert_raises(RuntimeError, general_index_error, b"", glob="no version signature found") + assert_raises( + RuntimeError, + general_index_error, + b"v0.1", + glob="the version of the index file does not match the expected version (expected: ", + ) + assert_raises(RuntimeError, general_index_error, b"v1.0", glob="no sample count found") + assert_raises( + RuntimeError, general_index_error, b"v1.0 -1", glob="sample count must be positive" + ) + assert_raises( + RuntimeError, general_index_error, b"v1.0 1\n", glob="no extensions provided for the sample" + ) + assert_raises( + RuntimeError, + general_index_error, + b"v1.0 1\njpg", + glob="size or offset corresponding to the extension not found", + ) + assert_raises( + RuntimeError, + general_index_error, + b"v1.0 1\njpg 1 0", + glob="tar offset is not a multiple of tar block size", + ) + assert_raises( + RuntimeError, + general_index_error, + b"v1.0 1\njpg 1024 1", + "db/webdataset/sample-tar/empty.tar", + glob="offset is outside of the archive file" + ) + assert_raises( + RuntimeError, + general_index_error, + b"v1.0 1\njpg 0 1", + "db/webdataset/sample-tar/types.tar", + glob="component of a non-file type" + ) + assert_raises( + RuntimeError, + general_index_error, + b"v1.0 1\njpg 0 1", + glob="component extension does not match the archive entry extension" + ) + assert_raises( + RuntimeError, + general_index_error, + b"v1.0 1\ncls 0 1000", + ext="cls", + glob="component size does not match the archive entry size" + ) diff --git a/dali/test/python/test_operator_readers_webdataset_requirements.py b/dali/test/python/test_operator_readers_webdataset_requirements.py new file mode 100644 index 00000000000..cfec8c7bdcd --- /dev/null +++ b/dali/test/python/test_operator_readers_webdataset_requirements.py @@ -0,0 +1,247 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from webdataset_base import * + +def test_return_empty(): + global test_batch_size + num_samples = 1000 + tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/missing.tar") + index_file = generate_temp_index_file(tar_file_path) + + extract_dir = generate_temp_extract(tar_file_path) + equivalent_files = glob(extract_dir.name + "/*") + equivalent_files = sorted( + equivalent_files, key=(lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")])) + ) + + compare_pipelines( + webdataset_raw_pipeline( + tar_file_path, + index_file.name, + ["jpg", "txt"], + batch_size=test_batch_size, + device_id=0, + num_threads=1, + missing_component_behavior="empty", + ), + file_reader_pipeline( + equivalent_files, + ["jpg", []], + batch_size=test_batch_size, + device_id=0, + num_threads=1 + ), + test_batch_size, + math.ceil(num_samples / test_batch_size), + ) + + +def test_skip_sample(): + global test_batch_size + num_samples = 500 + tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/missing.tar") + index_file = generate_temp_index_file(tar_file_path) + + extract_dir = generate_temp_extract(tar_file_path) + equivalent_files = list( + filter( + lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")]) < 2500, + sorted( + glob(extract_dir.name + "/*"), key=lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")]) + ), + ) + ) + + compare_pipelines( + webdataset_raw_pipeline( + tar_file_path, + index_file.name, + ["jpg", "cls"], + missing_component_behavior="skip", + batch_size=test_batch_size, + device_id=0, + num_threads=1, + ), + file_reader_pipeline( + equivalent_files, ["jpg", "cls"], batch_size=test_batch_size, device_id=0, num_threads=1 + ), + test_batch_size, + math.ceil(num_samples / test_batch_size), + ) + wds_pipeline = webdataset_raw_pipeline( + tar_file_path, + index_file.name, + ["jpg", "cls"], + missing_component_behavior="skip", + batch_size=test_batch_size, + device_id=0, + num_threads=1, + ) + wds_pipeline.build() + assert_equal(list(wds_pipeline.epoch_size().values())[0], num_samples) + + +def test_raise_error_on_missing(): + global test_batch_size + num_samples = 1000 + tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/missing.tar") + index_file = generate_temp_index_file(tar_file_path) + wds_pipeline = webdataset_raw_pipeline( + tar_file_path, + index_file.name, + ["jpg", "cls"], + missing_component_behavior="error", + batch_size=test_batch_size, + device_id=0, + num_threads=1, + ) + assert_raises(RuntimeError, wds_pipeline.build, glob="Underful sample detected") + + +def test_different_components(): + global test_batch_size + num_samples = 1000 + tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/scrambled.tar") + index_file = generate_temp_index_file(tar_file_path) + + extract_dir = generate_temp_extract(tar_file_path) + equivalent_files = glob(extract_dir.name + "/*") + equivalent_files = sorted( + equivalent_files, key=(lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")])) + ) + + compare_pipelines( + webdataset_raw_pipeline( + tar_file_path, + index_file.name, + ["jpg", "txt;cls"], + batch_size=test_batch_size, + device_id=0, + num_threads=1, + ), + file_reader_pipeline( + equivalent_files, + ["jpg", {"txt", "cls"}], + batch_size=test_batch_size, + device_id=0, + num_threads=1 + ), + test_batch_size, + math.ceil(num_samples / test_batch_size), + ) + + +def test_dtypes(): + global test_batch_size + num_samples = 100 + tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/sample-tar/dtypes.tar") + index_file = generate_temp_index_file(tar_file_path) + + wds_pipeline = webdataset_raw_pipeline( + tar_file_path, + index_file.name, + ["float16", "int32", "float64"], + dtypes=[dali.types.FLOAT16, dali.types.INT32, dali.types.FLOAT64], + batch_size=test_batch_size, + device_id=0, + num_threads=1, + ) + wds_pipeline.build() + for sample_idx in range(num_samples): + if sample_idx % test_batch_size == 0: + f16, i32, f64 = wds_pipeline.run() + assert (f16.as_array()[sample_idx % test_batch_size] == [float(sample_idx)] * 10).all() + assert (i32.as_array()[sample_idx % test_batch_size] == [int(sample_idx)] * 10).all() + assert (f64.as_array()[sample_idx % test_batch_size] == [float(sample_idx)] * 10).all() + + +def test_wds_sharding(): + global test_batch_size + num_samples = 3000 + tar_file_paths = [ + os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar"), + os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-1.tar"), + os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-2.tar"), + ] + index_files = [generate_temp_index_file(tar_file_path) for tar_file_path in tar_file_paths] + + extract_dirs = [generate_temp_extract(tar_file_path) for tar_file_path in tar_file_paths] + equivalent_files = sum( + list( + sorted( + glob(extract_dir.name + "/*"), key=lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")]) + ) + for extract_dir in extract_dirs + ), + [], + ) + + compare_pipelines( + webdataset_raw_pipeline( + tar_file_paths, + [index_file.name for index_file in index_files], + ["jpg", "cls"], + batch_size=test_batch_size, + device_id=0, + num_threads=1, + ), + file_reader_pipeline( + equivalent_files, + ["jpg", "cls"], + batch_size=test_batch_size, + device_id=0, + num_threads=1, + ), + test_batch_size, + math.ceil(num_samples / test_batch_size), + ) + + +def test_sharding(): + global test_batch_size + num_samples = 1000 + tar_file_path = os.path.join(get_dali_extra_path(), "db/webdataset/MNIST/devel-0.tar") + index_file = generate_temp_index_file(tar_file_path) + + extract_dir = generate_temp_extract(tar_file_path) + equivalent_files = sorted( + glob(extract_dir.name + "/*"), key=lambda s: int(s[s.rfind("/") + 1 : s.rfind(".")]) + ) + + num_shards = 100 + for shard_id in range(num_shards): + compare_pipelines( + webdataset_raw_pipeline( + tar_file_path, + index_file.name, + ["jpg", "cls"], + num_shards=num_shards, + shard_id=shard_id, + batch_size=test_batch_size, + device_id=0, + num_threads=1, + ), + file_reader_pipeline( + equivalent_files, + ["jpg", "cls"], + num_shards=num_shards, + shard_id=shard_id, + batch_size=test_batch_size, + device_id=0, + num_threads=1, + ), + test_batch_size, + math.ceil(num_samples / num_shards / test_batch_size) * 2, + ) diff --git a/dali/test/python/test_pipeline.py b/dali/test/python/test_pipeline.py index db1c9957495..cc842189e31 100644 --- a/dali/test/python/test_pipeline.py +++ b/dali/test/python/test_pipeline.py @@ -30,6 +30,7 @@ from math import floor, ceil import sys import warnings +from webdataset_base import generate_temp_index_file as generate_temp_wds_index from test_utils import check_batch from test_utils import compare_pipelines @@ -49,6 +50,7 @@ coco_image_folder = os.path.join(test_data_root, 'db', 'coco', 'images') coco_annotation_file = os.path.join(test_data_root, 'db', 'coco', 'instances.json') test_data_video = os.path.join(test_data_root, 'db', 'optical_flow', 'sintel_trailer') +webdataset_db_folder = os.path.join(test_data_root, 'db', 'webdataset', 'MNIST') def test_tensor_multiple_uses(): batch_size = 128 @@ -874,6 +876,17 @@ def __init__(self, reader_type, batch_size, is_cached=False, is_cached_batch_cop skip_cached_images = skip_cached_images, features = {"image/encoded" : tfrec.FixedLenFeature((), tfrec.string, ""), "image/class/label": tfrec.FixedLenFeature([1], tfrec.int64, -1)}) + elif reader_type == "readers.Webdataset": + wds = [os.path.join(webdataset_db_folder, archive) + for archive in ['devel-1.tar', 'devel-2.tar', 'devel-0.tar']] + self.wds_index_files = [generate_temp_wds_index(archive) for archive in wds] + self.input = ops.readers.Webdataset(paths = wds, + index_paths = [idx.name for idx in self.wds_index_files], + ext = ["jpg", "cls"], + shard_id = 0, + num_shards=num_shards, + stick_to_shard=True, + skip_cached_images=skip_cached_images) if is_cached: self.decode = ops.decoders.Image(device = "mixed", output_type = types.RGB, @@ -899,21 +912,21 @@ def define_graph(self): def test_nvjpeg_cached_batch_copy_pipelines(): batch_size = 26 - for reader_type in {"readers.MXNet", "readers.Caffe", "readers.Caffe2", "readers.File", "readers.TFRecord"}: + for reader_type in {"readers.MXNet", "readers.Caffe", "readers.Caffe2", "readers.File", "readers.TFRecord", "readers.Webdataset"}: compare_pipelines(CachedPipeline(reader_type, batch_size, is_cached=True, is_cached_batch_copy=True), CachedPipeline(reader_type, batch_size, is_cached=True, is_cached_batch_copy=False), batch_size=batch_size, N_iterations=20) def test_nvjpeg_cached_pipelines(): batch_size = 26 - for reader_type in {"readers.MXNet", "readers.Caffe", "readers.Caffe2", "readers.File", "readers.TFRecord"}: + for reader_type in {"readers.MXNet", "readers.Caffe", "readers.Caffe2", "readers.File", "readers.TFRecord", "readers.Webdataset"}: compare_pipelines(CachedPipeline(reader_type, batch_size, is_cached=False), CachedPipeline(reader_type, batch_size, is_cached=True), batch_size=batch_size, N_iterations=20) def test_skip_cached_images(): batch_size = 1 - for reader_type in {"readers.MXNet", "readers.Caffe", "readers.Caffe2", "readers.File"}: + for reader_type in {"readers.MXNet", "readers.Caffe", "readers.Caffe2", "readers.File", "readers.Webdataset"}: compare_pipelines(CachedPipeline(reader_type, batch_size, is_cached=False), CachedPipeline(reader_type, batch_size, is_cached=True, skip_cached_images=True), batch_size=batch_size, N_iterations=100) diff --git a/dali/test/python/webdataset_base.py b/dali/test/python/webdataset_base.py new file mode 100644 index 00000000000..8f1114a6dfd --- /dev/null +++ b/dali/test/python/webdataset_base.py @@ -0,0 +1,125 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import nvidia.dali as dali +from nvidia.dali import pipeline_def +import nvidia.dali.fn.readers as readers +from test_utils import compare_pipelines, get_dali_extra_path +from nose_utils import assert_raises +from nose.tools import assert_equal +import tempfile +from subprocess import call +import os +from glob import glob +import tarfile +import math + +test_batch_size = 4 +wds2idx_script = "../../../tools/wds2idx.py" + + +@pipeline_def() +def webdataset_raw_pipeline( + paths, + index_paths, + ext, + missing_component_behavior="empty", + dtypes=None, + dont_use_mmap=False, + num_shards=1, + shard_id=0, + skip_cached_images=False, + pad_last_batch=False, + lazy_init=False, + read_ahead=False, + stick_to_shard=False, +): + out = readers.webdataset( + paths=paths, + index_paths=index_paths, + ext=ext, + missing_component_behavior=missing_component_behavior, + dtypes=dtypes, + dont_use_mmap=dont_use_mmap, + prefetch_queue_depth=1, + num_shards=num_shards, + shard_id=shard_id, + stick_to_shard=stick_to_shard, + skip_cached_images=skip_cached_images, + pad_last_batch=pad_last_batch, + lazy_init=lazy_init, + read_ahead=read_ahead, + ) + return out if not isinstance(out, list) else tuple(out) + + +def filter_ext(files, exts): + if isinstance(exts, str): + exts = {exts} + return list(filter(lambda s: any(map(lambda ext: s.endswith("." + ext), exts)), files)) + + +@pipeline_def() +def file_reader_pipeline( + files, + exts=None, + dont_use_mmap=False, + num_shards=1, + shard_id=0, + skip_cached_images=False, + pad_last_batch=False, + lazy_init=False, + read_ahead=False, + stick_to_shard=False, +): + if not isinstance(exts, list): + exts = [exts] + + return tuple( + readers.file( + files=filter_ext(files, ext), + dont_use_mmap=dont_use_mmap, + prefetch_queue_depth=1, + num_shards=num_shards, + shard_id=shard_id, + stick_to_shard=stick_to_shard, + skip_cached_images=skip_cached_images, + pad_last_batch=pad_last_batch, + lazy_init=lazy_init, + read_ahead=read_ahead, + )[0] + if type(ext) in {str, set} + else ext + for ext in exts + ) + + +def generate_temp_index_file(tar_file_path): + global wds2idx_script + temp_index_file = tempfile.NamedTemporaryFile() + assert_equal ( + call([wds2idx_script, tar_file_path, temp_index_file.name], stdout=open(os.devnull, "wb")) + , 0 + ) + return temp_index_file + + +def generate_temp_extract(tar_file_path): + temp_extract_dir = tempfile.TemporaryDirectory() + archive = tarfile.open(tar_file_path) + for member in archive: + if member.type != tarfile.REGTYPE: + continue + archive.extract(member, temp_extract_dir.name) + return temp_extract_dir diff --git a/dali/util/mmaped_file.cc b/dali/util/mmaped_file.cc index 9db6d43ed7e..5b5ec0af80b 100644 --- a/dali/util/mmaped_file.cc +++ b/dali/util/mmaped_file.cc @@ -171,7 +171,7 @@ inline uint8_t* ReadAheadHelper(std::shared_ptr &p, size_t &pos, } void MmapedFileStream::Seek(int64 pos) { - DALI_ENFORCE(pos >= 0 && pos < (int64)length_, "Invalid seek"); + DALI_ENFORCE(pos >= 0 && pos <= (int64)length_, "Invalid seek"); pos_ = pos; } diff --git a/tools/wds2idx.py b/tools/wds2idx.py new file mode 100755 index 00000000000..d60faebc75b --- /dev/null +++ b/tools/wds2idx.py @@ -0,0 +1,207 @@ +#!/usr/bin/python3 +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import time +import argparse +import subprocess +from shutil import which +import tarfile + + +class IndexCreator: + """Reads `Webdataset` data format, and creates index file + that enables random access. + + Example usage: + ---------- + >>> creator = IndexCreator('data/test.tar','data/test.idx') + >>> creator.create_index() + >>> creator.close() + >>> !ls data/ + test.tar test.idx + + Parameters + ---------- + uri : str + Path to the archive file. + idx_path : str + Path to the index file, that will be created/overwritten. + """ + + tar_block_size = 512 + index_file_version = "v1.0" + + def __init__(self, uri, idx_path): + self.uri = uri + self.idx_path = idx_path + self.fidx = open(self.idx_path, "w") + + def open(self): + """Opens the archive and index files and sets their read heads to 0.""" + if self.fidx.closed: + self.fidx = open(self.idx_path, "w") + else: + self.fidx.seek(0) + + def close(self): + """Closes the archive and index files.""" + if not self.fidx.closed: + self.fidx.close() + + def reset(self): + """Resets the archive and index files.""" + self.close() + self.open() + + @staticmethod + def split_name(filepath): + """Splits the webdataset into the basename and the extension""" + dot_pos = filepath.find(".", filepath.rfind("/") + 1) + return filepath[:dot_pos], filepath[dot_pos + 1 :] + + def _get_data_tar(self): + """Retreives the data about the offset, name and size of each component + using the gnu tar utility, while also filtering out non-file entries""" + + tar_blocks_proc = subprocess.Popen( + ["tar", "--list", "--block-num", "--file", self.uri], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + tar_types_sizes_proc = subprocess.Popen( + ["tar", "--verbose", "--list", "--file", self.uri], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + tar_blocks = tar_blocks_proc.communicate()[0].split(b"\n") # block : + tar_types_sizes = tar_types_sizes_proc.communicate()[0].split( + b"\n" + ) # ... + + # Extracting + for blocks_line, types_sizes_line in zip(tar_blocks, tar_types_sizes): + if not blocks_line or not types_sizes_line: + continue + + name = str(blocks_line[blocks_line.find(b":") + 2 :], "ascii") + entry_type = types_sizes_line[0:1] + + if entry_type != b"-": + continue + + offset = int(blocks_line[blocks_line.find(b"block") + 6 : blocks_line.find(b":")]) * 512 + size = types_sizes_line[: -len(name)] + size = size[: size.rfind(b"-") - 8] # "... 20yy-mm-...." + size = int(size[size.rfind(b" ") :]) + + yield offset, name, size + + def _get_data_tarfile(self): + """Retreives the data about the offset, name and size of each component + using the tarfile module, while also filtering out non-file entries + Intended as a fallback for the gnu tar version (since it is much slower)""" + + print( + "Warning: tar utility not found. Falling back to tarfile." + + " Processing will most likely take much longer", + file=sys.stderr, + ) + farchive = iter(tarfile.open(self.uri)) + farchive = filter(lambda member: member.type == tarfile.REGTYPE, farchive) + farchive = map(lambda member: (member.offset, member.name, member.size), farchive) + return farchive + + def create_index(self): + """Creates the index file from a tar archive""" + self.reset() + + pre_time = time.time() + counter = 0 + report_step = 100000 + + print(f"time: {time.time() - pre_time:.2f} count: {counter} stage: collect") + + # Aggregates extensions in samples + aggregated_data = [] + last_basename = None + + for offset, name, size in ( + self._get_data_tar() if which("tar") is not None else self._get_data_tarfile() + ): + if counter % report_step == 0: + cur_time = time.time() + print(f"time: {cur_time - pre_time:.2f} count: {counter} stage: collect") + counter += 1 + + basename, extension = IndexCreator.split_name(name) + + # check for the files starting with a dot (hidden files) + if not basename or basename.endswith("/"): + continue + + if last_basename != basename: + aggregated_data.append([(extension, offset, size)]) + last_basename = basename + else: + aggregated_data[-1].append((extension, offset, size)) + + if not aggregated_data: + raise ValueError("Webdataset Tar File empty") + + # Constructs the index file out of the aggregated extensions + self.fidx.write(f"{IndexCreator.index_file_version} {len(aggregated_data)}\n") + for bundle in aggregated_data: + if counter % report_step == 0: + cur_time = time.time() + print(f"time: {cur_time - pre_time:.2f} count: {counter} stage: index") + self.fidx.write(" ".join(map(lambda component: " ".join(map(str, component)), bundle))) + self.fidx.write("\n") + counter += 1 + + cur_time = time.time() + print(f"time: {cur_time - pre_time:.2f} count: {counter} stage: done") + + +def parse_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description="Creates a webdataset index file for the use with the fn.readers.webdataset from DALI.", + ) + parser.add_argument("archive", help="path to .tar file.") + parser.add_argument( + "index", + help="path to index file", + nargs="?", + ) + args = parser.parse_args() + if args.index is None: + args.index = args.archive[: args.archive.find(".", args.archive.rfind("/") + 2)] + ".idx" + args.archive = os.path.abspath(args.archive) + args.index = os.path.abspath(args.index) + return args + + +def main(): + args = parse_args() + creator = IndexCreator(args.archive, args.index) + creator.create_index() + creator.close() + + +if __name__ == "__main__": + main()