Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 250 additions & 36 deletions test/cpp/api/dataloader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,19 @@ TEST(DataTest, ChunkDataSetWithInvalidInitParameter) {

auto initialization_function =
[&](size_t preloader_count, size_t batch_size, size_t cache_size) {
std::shared_ptr<datasets::ChunkSelector> chunk_selector =

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question I had since the first iteration was that, whether it is a good idea to make it a shared pointer. This gives the user access to change chunk_selector while our chunkDataset is still using it. For example, the user may unintentionally call reset while the epoch is not fully exhausted yet.

Maybe we can consider make a copy of the chunk_selector inside chunkdataset?

std::make_shared<datasets::SequentialChunkSelector>(
data_reader.chunk_count());

datasets::SharedBatchDataset<datasets::ChunkDataset<
DummyChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
DummyChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>(
data_reader,
sampler,
sampler,
chunk_selector,
datasets::ChunkDatasetOptions(
preloader_count, batch_size, cache_size));
};
Expand Down Expand Up @@ -831,6 +833,179 @@ TEST(DataTest, CanUseCustomTypeAsIndexType) {
}
}

TEST(DataTest, RandomChunkSelectorSingleReplicaSingleThread) {
size_t chunk_count = 10;
datasets::RandomChunkSelector cs(chunk_count);
ASSERT_EQ(cs.local_chunk_count(), chunk_count);

std::vector<size_t> res;
torch::optional<size_t> idx;
while ((idx = cs.next()).has_value()) {
res.push_back(*idx);
}

ASSERT_EQ(res.size(), chunk_count);

std::sort(res.begin(), res.end());
for (size_t i = 0; i < res.size(); ++i) {
ASSERT_EQ(res[i], i);
}
}

TEST(DataTest, RandomChunkSelectorMultiReplicaSingleThread) {
size_t chunk_count = 10;
size_t num_replicas = 3;

auto test_function = [&](bool allow_duplicates,
size_t local_chunk_count,
std::vector<size_t>& output) {
std::vector<std::unique_ptr<datasets::RandomChunkSelector>> selectors;

for (size_t i = 0; i < num_replicas; ++i) {
selectors.emplace_back(torch::make_unique<datasets::RandomChunkSelector>(
chunk_count, num_replicas, i, allow_duplicates));
}
// local_chunk_count does not depend on the rank. So only checking one.
ASSERT_EQ((*selectors[0]).local_chunk_count(), local_chunk_count);

std::vector<size_t> res;
for (size_t i = 0; i < num_replicas; ++i) {
(*selectors[i]).reset();
torch::optional<size_t> idx;
while ((idx = (*selectors[i]).next()).has_value()) {
res.push_back(*idx);
}
ASSERT_EQ(res.size(), local_chunk_count * (i + 1));
}
std::sort(res.begin(), res.end());
ASSERT_EQ(res, output);
};

size_t local_chunk_count =
static_cast<size_t>(std::ceil(chunk_count * 1.0 / num_replicas));
std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
test_function(true, local_chunk_count, output1);

local_chunk_count =
static_cast<size_t>(std::floor(chunk_count * 1.0 / num_replicas));
std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
test_function(false, local_chunk_count, output2);
}

TEST(DataTest, RandomChunkSelectorMultiReplicaMultiThread) {
size_t chunk_count = 10;

datasets::RandomChunkSelector cs(chunk_count);
std::vector<size_t> res;
std::shared_ptr<std::mutex> guard_ptr = std::make_shared<std::mutex>();

auto loader = [&] {
torch::optional<size_t> idx;
while ((idx = cs.next()).has_value()) {
{
std::lock_guard<std::mutex> lock(*guard_ptr);
res.push_back(*idx);
}
}
};
std::thread t1(loader);
std::thread t2(loader);
t1.join();
t2.join();

std::sort(res.begin(), res.end());
for (size_t i = 0; i < res.size(); ++i) {
ASSERT_EQ(res[i], i);
}
}

TEST(DataTest, SequentialChunkSelectorSingleReplicaSingleThread) {
size_t chunk_count = 10;
datasets::SequentialChunkSelector cs(chunk_count);
ASSERT_EQ(cs.local_chunk_count(), chunk_count);

std::vector<size_t> res;
torch::optional<size_t> idx;
while ((idx = cs.next()).has_value()) {
res.push_back(*idx);
}

ASSERT_EQ(res.size(), chunk_count);

std::sort(res.begin(), res.end());
for (size_t i = 0; i < res.size(); ++i) {
ASSERT_EQ(res[i], i);
}
}

TEST(DataTest, SequentialChunkSelectorMultiReplicaSingleThread) {
size_t chunk_count = 10;
size_t num_replicas = 3;

auto test_function = [&](bool allow_duplicates,
size_t local_chunk_count,
std::vector<size_t>& output) {
std::vector<std::unique_ptr<datasets::SequentialChunkSelector>> selectors;

for (size_t i = 0; i < num_replicas; ++i) {
selectors.emplace_back(
torch::make_unique<datasets::SequentialChunkSelector>(
chunk_count, num_replicas, i, allow_duplicates));
}
// local_chunk_count does not depend on the rank. So only checking one.
ASSERT_EQ((*selectors[0]).local_chunk_count(), local_chunk_count);

std::vector<size_t> res;
for (size_t i = 0; i < num_replicas; ++i) {
(*selectors[i]).reset();
torch::optional<size_t> idx;
while ((idx = (*selectors[i]).next()).has_value()) {
res.push_back(*idx);
}
ASSERT_EQ(res.size(), local_chunk_count * (i + 1));
}
std::sort(res.begin(), res.end());
ASSERT_EQ(res, output);
};

size_t local_chunk_count =
static_cast<size_t>(std::ceil(chunk_count * 1.0 / num_replicas));
std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
test_function(true, local_chunk_count, output1);

local_chunk_count =
static_cast<size_t>(std::floor(chunk_count * 1.0 / num_replicas));
std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
test_function(false, local_chunk_count, output2);
}

TEST(DataTest, SequentialChunkSelectorMultiReplicaMultiThread) {
size_t chunk_count = 10;

datasets::SequentialChunkSelector cs(chunk_count);
std::vector<size_t> res;
std::shared_ptr<std::mutex> guard_ptr = std::make_shared<std::mutex>();

auto loader = [&] {
torch::optional<size_t> idx;
while ((idx = cs.next()).has_value()) {
{
std::lock_guard<std::mutex> lock(*guard_ptr);
res.push_back(*idx);
}
}
};
std::thread t1(loader);
std::thread t2(loader);
t1.join();
t2.join();

std::sort(res.begin(), res.end());
for (size_t i = 0; i < res.size(); ++i) {
ASSERT_EQ(res[i], i);
}
}

TEST(DataLoaderTest, DataLoaderOptionsDefaultAsExpected) {
DataLoaderOptions partial_options;
FullDataLoaderOptions full_options(partial_options);
Expand Down Expand Up @@ -1445,17 +1620,19 @@ TEST(DataLoaderTest, ChunkDataSetGetBatch) {
for (auto prefetch_count : prefetch_counts) {
for (auto batch_size : batch_sizes) {
for (auto dataloader_worker_count : dataloader_worker_counts) {
std::shared_ptr<datasets::ChunkSelector> chunk_selector =
std::make_shared<datasets::SequentialChunkSelector>(
data_reader.chunk_count());

datasets::SharedBatchDataset<datasets::ChunkDataset<
DummyChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
DummyChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>(
data_reader,
sampler,
sampler,
chunk_selector,
datasets::ChunkDatasetOptions(prefetch_count, batch_size));

auto data_loader = torch::data::make_data_loader(
Expand Down Expand Up @@ -1499,23 +1676,22 @@ TEST(DataLoaderTest, ChunkDataSetWithBatchSizeMismatch) {

DummyChunkDataReader data_reader;
samplers::SequentialSampler sampler(0);
std::shared_ptr<datasets::ChunkSelector> chunk_selector =
std::make_shared<datasets::SequentialChunkSelector>(
data_reader.chunk_count());

datasets::SharedBatchDataset<datasets::ChunkDataset<
DummyChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>
datasets::SharedBatchDataset<
datasets::ChunkDataset<DummyChunkDataReader, samplers::SequentialSampler>>
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
DummyChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>(
data_reader,
sampler,
sampler,
chunk_selector,
datasets::ChunkDatasetOptions(prefetch_count, batch_size));

auto data_loader = torch::data::make_data_loader(
dataset,
DataLoaderOptions(requested_batch_size).workers(0));
dataset, DataLoaderOptions(requested_batch_size).workers(0));

std::string exception_msg =
"The requested batch size does not match with the initialized batch "
Expand Down Expand Up @@ -1546,18 +1722,19 @@ TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) {
const size_t batch_size = 5;
DummyEmptyChunkDataReader data_reader;
samplers::SequentialSampler sampler(0);
std::shared_ptr<datasets::ChunkSelector> chunk_selector =
std::make_shared<datasets::SequentialChunkSelector>(
data_reader.chunk_count());

datasets::SharedBatchDataset<datasets::ChunkDataset<
DummyEmptyChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
DummyEmptyChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>(
data_reader,
sampler,
sampler,
chunk_selector,
datasets::ChunkDatasetOptions(prefetch_count, batch_size));

auto data_loader = torch::data::make_data_loader(
Expand All @@ -1569,39 +1746,42 @@ TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) {
}
}

TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) {
struct D : public datasets::ChunkDataReader<std::vector<int>> {
public:
using BatchType = std::vector<int>;

BatchType read_chunk(size_t chunk_index) override {
BatchType batch_data(10, 0);
return batch_data;
}
struct DummyTwoChunkReader
: public datasets::ChunkDataReader<std::vector<int>> {
public:
using BatchType = std::vector<int>;

size_t chunk_count() override {
return 2;
};
BatchType read_chunk(size_t chunk_index) override {
BatchType batch_data(10, 0);
return batch_data;
}

void reset() override{};
size_t chunk_count() override {
return 2;
};

void reset() override{};
};

TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) {
const size_t batch_sizes[] = {17, 30};
D data_reader;
DummyTwoChunkReader data_reader;
samplers::SequentialSampler sampler(0);

for (auto batch_size : batch_sizes) {
std::shared_ptr<datasets::ChunkSelector> chunk_selector =
std::make_shared<datasets::SequentialChunkSelector>(
data_reader.chunk_count());

datasets::SharedBatchDataset<datasets::ChunkDataset<
D,
samplers::SequentialSampler,
DummyTwoChunkReader,
samplers::SequentialSampler>>
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
D,
samplers::SequentialSampler,
DummyTwoChunkReader,
samplers::SequentialSampler>>(
data_reader,
sampler,
sampler,
chunk_selector,
datasets::ChunkDatasetOptions(1, batch_size));

auto data_loader = torch::data::make_data_loader(
Expand All @@ -1620,3 +1800,37 @@ TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) {
}
}
}

TEST(DataLoaderTest, ChunkDataSetEnumerationWithMultipleEpochs) {
DummyTwoChunkReader data_reader;
samplers::SequentialSampler sampler(0);
size_t batch_size = 17;

std::shared_ptr<datasets::ChunkSelector> chunk_selector =
std::make_shared<datasets::SequentialChunkSelector>(
data_reader.chunk_count());

datasets::SharedBatchDataset<
datasets::ChunkDataset<DummyTwoChunkReader, samplers::SequentialSampler>>
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
DummyTwoChunkReader,
samplers::SequentialSampler>>(
data_reader,
sampler,
chunk_selector,
datasets::ChunkDatasetOptions(1, batch_size));

auto data_loader = torch::data::make_data_loader(
dataset, DataLoaderOptions(batch_size).workers(0));

for (size_t epoch = 0; epoch < 3; ++epoch)
chunk_selector->set_epoch(
epoch); // setting epoch number for this enumeration.
{
for (auto iterator = data_loader->begin(); iterator != data_loader->end();
++iterator) {
std::vector<int> batch = *iterator;
ASSERT_TRUE(batch.size() == 17 || batch.size() == 3);
}
}
}
Loading