Skip to content

Commit

Permalink
Webdataset reader operator implementation (NVIDIA#3306)
Browse files Browse the repository at this point in the history
Implementation of nvidia.dali.fn.readers.webdataset
  • Loading branch information
barci2 authored and cyyever committed Jan 23, 2022
1 parent 89bd76a commit 5ce5320
Show file tree
Hide file tree
Showing 19 changed files with 1,723 additions and 36 deletions.
2 changes: 1 addition & 1 deletion DALI_EXTRA_VERSION
@@ -1 +1 @@
c6bfd2987d0d180e756232f69137467c5468b193
0e51e444c4a0970446129db3431044b61d342a6f
4 changes: 4 additions & 0 deletions dali/operators/reader/CMakeLists.txt
Expand Up @@ -35,6 +35,10 @@ if (BUILD_NVDEC)
list(APPEND DALI_OPERATOR_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/video_reader_resize_op.cc")
endif()

if (BUILD_LIBTAR)
list(APPEND DALI_OPERATOR_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/webdataset_reader_op.cc")
endif()

if (BUILD_LMDB)
list(APPEND DALI_OPERATOR_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/caffe_reader_op.cc")
list(APPEND DALI_OPERATOR_SRCS "${CMAKE_CURRENT_SOURCE_DIR}/caffe2_reader_op.cc")
Expand Down
5 changes: 5 additions & 0 deletions dali/operators/reader/loader/CMakeLists.txt
Expand Up @@ -54,5 +54,10 @@ if (BUILD_LIBSND)
"${CMAKE_CURRENT_SOURCE_DIR}/nemo_asr_loader_test.cc")
endif()

if (BUILD_LIBTAR)
set(DALI_OPERATOR_SRCS ${DALI_OPERATOR_SRCS}
"${CMAKE_CURRENT_SOURCE_DIR}/webdataset_loader.cc")
endif()

set(DALI_OPERATOR_SRCS ${DALI_OPERATOR_SRCS} PARENT_SCOPE)
set(DALI_OPERATOR_TEST_SRCS ${DALI_OPERATOR_TEST_SRCS} PARENT_SCOPE)
29 changes: 10 additions & 19 deletions dali/operators/reader/loader/webdataset/tar_utils.cc
Expand Up @@ -13,11 +13,10 @@
// limitations under the License.

#include "dali/operators/reader/loader/webdataset/tar_utils.h"
#include <libtar.h>
#include <cstring>
#include <algorithm>
#include <cstdarg>
#include <cstdlib>
#include <cstring>
#include <list>
#include <string>
#include <tuple>
Expand All @@ -31,7 +30,6 @@ namespace detail {

namespace {

constexpr size_t kBlockSize = T_BLOCKSIZE;
static_assert(is_pow2(kBlockSize),
"The implementation assumes that the block size is a power of 2");

Expand Down Expand Up @@ -96,8 +94,7 @@ static tartype_t kTarArchiveType = {LibtarOpenTarArchive, [](int) -> int { retur
[](int, const void*, size_t) -> ssize_t { return 0; }};

TarArchive::TarArchive(std::unique_ptr<FileStream> stream)
: stream_(std::move(stream)),
instance_handle_(Register(this)) {
: stream_(std::move(stream)), instance_handle_(Register(this)) {
tar_open(ToTarHandle(&handle_), "", &kTarArchiveType, 0, instance_handle_, TAR_GNU);
stream_->Seek(0);
eof_ = stream_->Size() == 0;
Expand Down Expand Up @@ -142,14 +139,8 @@ bool TarArchive::NextFile() {
}

const int64_t offset = stream_->Tell() + RoundToBlockSize(filesize_) - readoffset_;
assert(offset >= 0);
if (static_cast<size_t>(offset) >= stream_->Size()) {
SetEof();
return false;
}

stream_->Seek(stream_->Tell() + RoundToBlockSize(filesize_) - readoffset_);
current_header_ = stream_->Tell();
current_header_ = offset;
stream_->Seek(offset);
ParseHeader();
return !eof_;
}
Expand All @@ -159,14 +150,14 @@ bool TarArchive::EndOfArchive() const {
}

void TarArchive::SeekArchive(int64_t offset) {
assert(offset % T_BLOCKSIZE == 0);
readoffset_ = 0;
if (static_cast<size_t>(offset) >= stream_->Size()) {
SetEof();
if (offset == current_header_) {
return;
}
assert(offset % T_BLOCKSIZE == 0);
eof_ = false;
readoffset_ = 0;
stream_->Seek(offset);
current_header_ = stream_->Tell();
current_header_ = offset;
ParseHeader();
}

Expand Down Expand Up @@ -215,7 +206,6 @@ inline void TarArchive::SetEof() {
filename_ = "";
filesize_ = 0;
filetype_ = ENTRY_NONE;
current_header_ = stream_ ? stream_->Size() : 0;
}

inline void TarArchive::ParseHeader() {
Expand Down Expand Up @@ -258,6 +248,7 @@ void TarArchive::Close() {
handle_ = nullptr;
}
readoffset_ = 0;
current_header_ = 0;
SetEof();
stream_.reset();
if (instance_handle_ >= 0) {
Expand Down
9 changes: 8 additions & 1 deletion dali/operators/reader/loader/webdataset/tar_utils.h
Expand Up @@ -15,6 +15,7 @@
#ifndef DALI_OPERATORS_READER_LOADER_WEBDATASET_TAR_UTILS_H_
#define DALI_OPERATORS_READER_LOADER_WEBDATASET_TAR_UTILS_H_

#include <libtar.h>
#include <memory>
#include <mutex>
#include <string>
Expand All @@ -24,6 +25,8 @@

namespace dali {
namespace detail {
constexpr size_t kBlockSize = T_BLOCKSIZE;

/**
* @brief Used to access .tar archives through the given FileStream
*/
Expand Down Expand Up @@ -103,6 +106,11 @@ class DLL_PUBLIC TarArchive {
*/
bool EndOfFile() const;

/**
* @brief Frees the underlying file stream
*/
void Close();

private:
std::unique_ptr<FileStream> stream_;
int instance_handle_ = -1;
Expand All @@ -120,7 +128,6 @@ class DLL_PUBLIC TarArchive {
void SetEof();

void ParseHeader();
void Close(); // resets objects to default values
};

} // namespace detail
Expand Down
23 changes: 12 additions & 11 deletions dali/operators/reader/loader/webdataset/tar_utils_test.cc
Expand Up @@ -25,14 +25,15 @@
#include "dali/operators/reader/loader/filesystem.h"
#include "dali/util/file.h"
#include "dali/core/util.h"
#include "dali/test/dali_test_config.h"

namespace dali {
namespace detail {

TEST(LibTarUtilsTestSimple, Interface) {
std::string filepath(dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"),
std::string filepath(dali::filesystem::join_path(testing::dali_extra_path(),
"db/webdataset/MNIST/devel-2.tar"));
std::string dummy_filepath(dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"),
std::string dummy_filepath(dali::filesystem::join_path(testing::dali_extra_path(),
"db/webdataset/MNIST/devel-1.tar"));

TarArchive dummy_archive(FileStream::Open(dummy_filepath, false, false));
Expand Down Expand Up @@ -66,7 +67,7 @@ TEST(LibTarUtilsTestSimple, Interface) {
}

TEST(LibTarUtilsTestSimple, LongNameIndexing) {
std::string filepath(dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"),
std::string filepath(dali::filesystem::join_path(testing::dali_extra_path(),
"db/webdataset/sample-tar/gnu.tar"));
TarArchive archive(FileStream::Open(filepath, false, false));
std::string name_prefix(128, '#');
Expand All @@ -78,7 +79,7 @@ TEST(LibTarUtilsTestSimple, LongNameIndexing) {
}

TEST(LibTarUtilsTestSimple, Types) {
std::string filepath(dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"),
std::string filepath(dali::filesystem::join_path(testing::dali_extra_path(),
"db/webdataset/sample-tar/types.tar"));
std::vector<TarArchive::EntryType> types = {
TarArchive::ENTRY_BLOCKDEV, TarArchive::ENTRY_CHARDEV, TarArchive::ENTRY_DIR,
Expand All @@ -99,7 +100,7 @@ TEST(LibTarUtilsTestSimple, Types) {
}

TEST(LibTarUtilsTestSimple, Offset) {
std::string filepath(dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"),
std::string filepath(dali::filesystem::join_path(testing::dali_extra_path(),
"db/webdataset/sample-tar/types.tar"));

TarArchive archive(FileStream::Open(filepath, false, true));
Expand Down Expand Up @@ -195,28 +196,28 @@ auto SimpleTarTestsValues() {
vector<SimpleTarTestsData> values;

SimpleTarTestsData filepaths[] = {
{ dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"),
{ dali::filesystem::join_path(testing::dali_extra_path(),
"db/webdataset/MNIST/devel-0.tar"),
false,
false,
2000,
3000,
{".cls", ".jpg"} },
{ dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"),
{ dali::filesystem::join_path(testing::dali_extra_path(),
"db/webdataset/sample-tar/empty.tar"),
false,
false,
0,
0,
{} },
{ dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"),
{ dali::filesystem::join_path(testing::dali_extra_path(),
"db/webdataset/sample-tar/v7.tar"),
false,
false,
0,
1000,
{""} },
{ dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"),
{ dali::filesystem::join_path(testing::dali_extra_path(),
"db/webdataset/sample-tar/oldgnu.tar"),
false,
false,
Expand All @@ -233,7 +234,7 @@ auto SimpleTarTestsValues() {
}
}
}
return testing::ValuesIn(values.begin(), values.end());
return ::testing::ValuesIn(values.begin(), values.end());
}

INSTANTIATE_TEST_SUITE_P(LibTarUtilsTestParametrized, SimpleTarTests, SimpleTarTestsValues());
Expand All @@ -246,7 +247,7 @@ class MultiTarTests : public ::testing::TestWithParam<bool> {
const std::pair<int, int> ranges[kMultithreadedSamples] = {{2000, 3000}, {0, 1000}, {1000, 2000}};
void SetUp() final {
std::string filepath_prefix(
dali::filesystem::join_path(std::getenv("DALI_EXTRA_PATH"), "db/webdataset/MNIST/devel-"));
dali::filesystem::join_path(testing::dali_extra_path(), "db/webdataset/MNIST/devel-"));

std::string filepaths[kMultithreadedSamples] = {
filepath_prefix + "0.tar", filepath_prefix + "1.tar", filepath_prefix + "2.tar"};
Expand Down

0 comments on commit 5ce5320

Please sign in to comment.