Skip to content

Commit

Permalink
Add AWS SDK support to all file readers (FileReader, NumpyReader, Web…
Browse files Browse the repository at this point in the history
…datasetReader...) (NVIDIA#5415)

Adds support for s3:// urls to the readers (FileReader, NumpyReader, WebdatasetReader...)
s3 url can be used in file_root of FileReader as well (autodiscovery of files)

Signed-off-by: Joaquin Anton <janton@nvidia.com>
  • Loading branch information
jantonguirao committed Apr 24, 2024
1 parent 3ff4bfe commit 521a59d
Show file tree
Hide file tree
Showing 29 changed files with 817 additions and 88 deletions.
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ cmake_dependent_option(BUILD_CFITSIO "Build with cfitsio support" ON
"NOT BUILD_DALI_NODEPS" OFF)
cmake_dependent_option(BUILD_NVIMAGECODEC "Build with support for nvimagecodec library" ON
"NOT BUILD_DALI_NODEPS" OFF)
cmake_dependent_option(BUILD_AWSSDK "Build with support for AWS SKD library" ON
"NOT BUILD_DALI_NODEPS" OFF)
set(NVIMGCODEC_DEFAULT_INSTALL_PATH "/opt/nvidia/nvimgcodec_cuda${CUDA_VERSION_MAJOR}" CACHE STRING
"Path of the nvimagecodec installation")

Expand Down Expand Up @@ -309,6 +311,7 @@ propagate_option(BUILD_NVCOMP)
propagate_option(BUILD_NVML)
propagate_option(BUILD_CUFILE)
propagate_option(BUILD_NVIMAGECODEC)
propagate_option(BUILD_AWSSDK)
propagate_option(LINK_DRIVER)
propagate_option(WITH_DYNAMIC_NVJPEG)
propagate_option(WITH_DYNAMIC_CUFFT)
Expand Down
2 changes: 1 addition & 1 deletion DALI_DEPS_VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
77d583ea6cf4dd0a97db2c155044cd23140c4c99
f0fe9cb92317a788b7dd9c01d73f79bb6aace349
22 changes: 22 additions & 0 deletions cmake/Dependencies.common.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,25 @@ if(BUILD_NVIMAGECODEC)
endif()
endif()
endif()


##################################################################
# AWS SDK
##################################################################
if(BUILD_AWSSDK)
find_path(AWSSDK_INCLUDE_DIR aws/core/Aws.h)
find_library(AWS_CPP_SDK_CORE_LIB NAMES aws-cpp-sdk-core)
find_library(AWS_CPP_SDK_S3_LIB NAMES aws-cpp-sdk-s3)
if ("${AWSSDK_INCLUDE_DIR}" STREQUAL "AWSSDK_INCLUDE_DIR-NOTFOUND" OR
"${AWS_CPP_SDK_CORE_LIB}" STREQUAL "AWS_CPP_SDK_CORE_LIB-NOTFOUND" OR
"${AWS_CPP_SDK_S3_LIB}" STREQUAL "AWS_CPP_SDK_S3_LIB-NOTFOUND")
message(WARNING "AWS SDK not found. Disabling AWS SDK support.")
set(BUILD_AWSSDK OFF)
else()
set(AWSSDK_LIBRARIES "")
list(APPEND AWSSDK_LIBRARIES ${AWS_CPP_SDK_S3_LIB})
list(APPEND AWSSDK_LIBRARIES ${AWS_CPP_SDK_CORE_LIB})
message(STATUS "AWSSDK_INCLUDE_DIR=${AWSSDK_INCLUDE_DIR}")
message(STATUS "AWSSDK_LIBRARIES=${AWSSDK_LIBRARIES}")
endif()
endif()
5 changes: 5 additions & 0 deletions dali/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ if (BUILD_CUFILE)
target_link_libraries(dali PRIVATE dynlink_cufile)
endif()

if (BUILD_AWSSDK)
target_include_directories(dali PRIVATE ${AWSSDK_INCLUDE_DIR})
target_link_libraries(dali PRIVATE ${AWSSDK_LIBRARIES})
endif()

# Build test suite
################################################
if (BUILD_DALI_PIPELINE AND BUILD_TEST)
Expand Down
67 changes: 51 additions & 16 deletions dali/operators/reader/file_reader_op.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,26 +33,61 @@ class FileReader : public DataReader<CPUBackend, ImageLabelWrapper, ImageLabelWr
this->SetInitialSnapshot();
}

void RunImpl(SampleWorkspace &ws) override {
const int idx = ws.data_idx();

const auto& image_label = GetSample(idx);
bool CanInferOutputs() const override {
return true;
}

// copy from raw_data -> outputs directly
auto &image_output = ws.Output<CPUBackend>(0);
auto &label_output = ws.Output<CPUBackend>(1);
bool SetupImpl(std::vector<OutputDesc>& output_desc, const Workspace& ws) override {
// If necessary start prefetching thread and wait for a consumable batch
DataReader<CPUBackend, ImageLabelWrapper, ImageLabelWrapper, true>::SetupImpl(output_desc, ws);
auto samples = GetCurrBatch();
int batch_size = samples.size();

Index image_size = image_label.image.size();
output_desc.resize(2);
output_desc[0].shape.resize(batch_size, 1);
output_desc[0].type = DALI_UINT8;
output_desc[1].shape = uniform_list_shape<1>(batch_size, {1});
output_desc[1].type = DALI_INT32;

image_output.Resize({image_size}, DALI_UINT8);
label_output.Resize({1}, DALI_INT32);
TensorListShape<1> out_shape(batch_size);
for (int sample_idx = 0; sample_idx < batch_size; ++sample_idx) {
auto& sample = *samples[sample_idx];
int64_t image_size =
sample.file_stream != nullptr ? sample.file_stream->Size() : sample.image.size();
output_desc[0].shape.tensor_shape_span(sample_idx)[0] = image_size;
output_desc[1].shape.tensor_shape_span(sample_idx)[0] = 1;
}
return true;
}

std::memcpy(image_output.raw_mutable_data(),
image_label.image.raw_data(),
image_size);
image_output.SetSourceInfo(image_label.image.GetSourceInfo());
void RunImpl(Workspace &ws) override {
auto &image_output = ws.Output<CPUBackend>(0);
auto &label_output = ws.Output<CPUBackend>(1);
auto samples = GetCurrBatch();
int batch_size = samples.size();

label_output.mutable_data<int>()[0] = image_label.label;
auto &thread_pool = ws.GetThreadPool();
for (int sample_idx = 0; sample_idx < batch_size; ++sample_idx) {
thread_pool.AddWork([&, sample_idx](int tid) {
auto &sample = *samples[sample_idx];
if (sample.file_stream != nullptr) {
sample.file_stream->SeekRead(0, SEEK_SET);
int64_t sz = sample.file_stream->Size();
int64_t read_nbytes =
sample.file_stream->Read(image_output.raw_mutable_tensor(sample_idx), sz);
sample.file_stream->Close();
sample.file_stream.reset();
DALI_ENFORCE(read_nbytes == sz,
make_string("Failed to read file: ", sample.file_stream->path()));
} else {
std::memcpy(image_output.raw_mutable_tensor(sample_idx), sample.image.raw_data(),
sample.image.size());
}
image_output.SetSourceInfo(sample_idx, sample.image.GetSourceInfo());
label_output.mutable_tensor<int>(sample_idx)[0] = sample.label;
}, image_output.shape().tensor_size(sample_idx));
}
thread_pool.RunAll();
}

protected:
Expand Down
5 changes: 5 additions & 0 deletions dali/operators/reader/loader/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,10 @@ if (BUILD_LIBTAR)
"${CMAKE_CURRENT_SOURCE_DIR}/webdataset_loader.cc")
endif()

if (BUILD_AWSSDK)
set(DALI_OPERATOR_SRCS ${DALI_OPERATOR_SRCS}
"${CMAKE_CURRENT_SOURCE_DIR}/discover_files_s3.cc")
endif()

set(DALI_OPERATOR_SRCS ${DALI_OPERATOR_SRCS} PARENT_SCOPE)
set(DALI_OPERATOR_TEST_SRCS ${DALI_OPERATOR_TEST_SRCS} PARENT_SCOPE)
7 changes: 7 additions & 0 deletions dali/operators/reader/loader/discover_files.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
#include "dali/core/error_handling.h"
#include "dali/operators/reader/loader/filesystem.h"
#include "dali/operators/reader/loader/utils.h"
#if AWSSDK_ENABLED
#include "dali/operators/reader/loader/discover_files_s3.h"
#endif

namespace dali {

Expand Down Expand Up @@ -111,7 +114,11 @@ std::vector<FileLabelEntry> discover_files(const std::string &file_root,
const FileDiscoveryOptions &opts) {
bool is_s3 = starts_with(file_root, "s3://");
if (is_s3) {
#if AWSSDK_ENABLED
return s3_discover_files(file_root, opts);
#else
DALI_FAIL("This version of DALI was not built with AWS S3 storage support.");
#endif
}

std::vector<std::string> subdirs;
Expand Down
96 changes: 96 additions & 0 deletions dali/operators/reader/loader/discover_files_s3.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "dali/operators/reader/loader/discover_files_s3.h"
#include <fnmatch.h>
#include <filesystem>
#include <string>
#include <unordered_map>
#include <vector>
#include "dali/operators/reader/loader/discover_files.h"
#include "dali/util/s3_client_manager.h"
#include "dali/util/s3_filesystem.h"

namespace dali {

// We are using std::filesystem to analyze URI relative paths, which wouldn't be OK in non-UNIX
// based systems
#ifndef __unix__
#error This code works only in UNIX-based systems
#endif

std::vector<FileLabelEntry> s3_discover_files(const std::string &file_root,
const FileDiscoveryOptions &opts) {
assert(starts_with(file_root, "s3://"));
auto s3_object_location = s3_filesystem::parse_uri(file_root);
std::filesystem::path parent_object_key(s3_object_location.object);
auto count_elems = [](const std::filesystem::path &p) {
size_t k = 0;
for (auto &elem : p)
k++;
return k;
};
std::vector<FileLabelEntry> entries;
// in case that files are not visited in lexicographical order, we remember previously assigned
// labels
std::unordered_map<std::string, int> labels;
int next_label = 0; // next free-label to be assigned
s3_filesystem::list_objects_f(
S3ClientManager::Instance().client(), s3_object_location,
[&](const std::string &object_key, size_t object_size) {
auto p = std::filesystem::relative(object_key, parent_object_key);
auto path_elems = count_elems(p);
assert(path_elems >= 2);
if (path_elems > 2)
return; // we only look at one subdir level
const auto& subdir = p.begin()->native();
const auto& fname = (++p.begin())->native();
bool subdir_ok = opts.dir_filters.empty();
bool fname_ok = opts.file_filters.empty();
for (auto &filter : opts.dir_filters) {
if (fnmatch(filter.c_str(), subdir.c_str(),
opts.case_sensitive_filter ? 0 : FNM_CASEFOLD) == 0) {
subdir_ok |= true;
break;
}
}

for (auto &filter : opts.file_filters) {
if (fnmatch(filter.c_str(), fname.c_str(),
opts.case_sensitive_filter ? 0 : FNM_CASEFOLD) == 0) {
fname_ok |= true;
break;
}
}

if (!subdir_ok || !fname_ok)
return;

if (opts.label_from_subdir) {
int curr_label = -1;
auto it = labels.find(subdir);
if (it == labels.end()) {
curr_label = labels[subdir] = next_label++;
} else {
curr_label = it->second;
}
entries.push_back({p, curr_label, object_size});
} else {
entries.push_back({p, std::nullopt, object_size});
}
});
return entries;
}

} // namespace dali
29 changes: 29 additions & 0 deletions dali/operators/reader/loader/discover_files_s3.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_OPERATORS_READER_LOADER_DISCOVER_FILES_S3_H_
#define DALI_OPERATORS_READER_LOADER_DISCOVER_FILES_S3_H_

#include <string>
#include <vector>
#include "dali/operators/reader/loader/discover_files.h"

namespace dali {

std::vector<FileLabelEntry> s3_discover_files(const std::string &file_root,
const FileDiscoveryOptions &opts);

} // namespace dali

#endif // DALI_OPERATORS_READER_LOADER_DISCOVER_FILES_S3_H_
45 changes: 31 additions & 14 deletions dali/operators/reader/loader/file_label_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "dali/operators/reader/loader/filesystem.h"
#include "dali/operators/reader/loader/utils.h"
#include "dali/util/file.h"
#include "dali/util/uri.h"
#include "dali/core/call_at_exit.h"

namespace dali {

Expand All @@ -35,6 +37,9 @@ void FileLabelLoaderBase<checkpointing_supported>::ReadSample(ImageLabelWrapper
// handle wrap-around
MoveToNextShard(current_index_);

// should be cleared by now
assert(image_label.file_stream == nullptr);

// copy the label
image_label.label = entry.label.value();
DALIMeta meta;
Expand All @@ -50,28 +55,40 @@ void FileLabelLoaderBase<checkpointing_supported>::ReadSample(ImageLabelWrapper
return;
}

auto uri = filesystem::join_path(file_root_, entry.filename);
auto current_image = FileStream::Open(uri, {read_ahead_, !copy_read_data_, false}, entry.size);
Index image_size = current_image->Size();
auto path = filesystem::join_path(file_root_, entry.filename);
auto uri = URI::Parse(path);
FileStream::Options opts;
bool local_file = !uri.valid() || uri.scheme() == "file";
opts.read_ahead = read_ahead_;
opts.use_mmap = local_file && !copy_read_data_;
opts.use_odirect = false;
auto current_file = FileStream::Open(path, opts, entry.size);
auto current_file_cleanup = AtScopeExit([&current_file] {
if (current_file)
current_file->Close();
});
Index file_size = current_file->Size();

if (copy_read_data_) {
if (copy_read_data_ || !current_file->CanMemoryMap()) {
if (image_label.image.shares_data()) {
image_label.image.Reset();
}
image_label.image.Resize({image_size}, DALI_UINT8);
// copy the image
Index ret = current_image->Read(image_label.image.mutable_data<uint8_t>(), image_size);
DALI_ENFORCE(ret == image_size, make_string("Failed to read file: ", entry.filename));
if (local_file) {
// if local file, read right away
image_label.image.Resize({file_size}, DALI_UINT8);
int64_t read_nbytes =
current_file->Read(image_label.image.mutable_data<uint8_t>(), file_size);
DALI_ENFORCE(read_nbytes == file_size, make_string("Failed to read file: ", entry.filename));
} else {
// if URI, defer reading
image_label.file_stream = std::move(current_file);
}
} else {
auto p = current_image->Get(image_size);
auto p = current_file->Get(file_size);
DALI_ENFORCE(p != nullptr, make_string("Failed to read file: ", entry.filename));
// Wrap the raw data in the Tensor object.
image_label.image.ShareData(p, image_size, false, {image_size}, DALI_UINT8, CPU_ONLY_DEVICE_ID);
image_label.image.ShareData(p, file_size, false, {file_size}, DALI_UINT8, CPU_ONLY_DEVICE_ID);
}

// close the file handle
current_image->Close();

image_label.image.SetMeta(meta);
}

Expand Down
Loading

0 comments on commit 521a59d

Please sign in to comment.