Skip to content

Commit

Permalink
Add sample method to data_pipeline (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
najielhachem committed Sep 29, 2023
1 parent 5b47c8d commit af213ff
Show file tree
Hide file tree
Showing 9 changed files with 427 additions and 0 deletions.
20 changes: 20 additions & 0 deletions fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,26 @@ def_data_pipeline(py::module_ &data_module)
return data_pipeline::round_robin(std::move(pipelines));
},
py::arg("pipelines"))
.def_static(
"sample",
[](
std::vector<std::reference_wrapper<data_pipeline>> &refs,
std::optional<std::vector<float>> weights)
{
std::vector<data_pipeline> pipelines{};

pipelines.reserve(refs.size());

std::transform(
refs.begin(), refs.end(), std::back_inserter(pipelines), [](auto &r) {
return std::move(r.get());
});

return data_pipeline::sample(
std::move(pipelines), std::move(weights));
},
py::arg("pipelines"),
py::arg("weights") = std::nullopt)
.def_static(
"constant",
[](data example, std::optional<std::string> key)
Expand Down
1 change: 1 addition & 0 deletions fairseq2n/src/fairseq2n/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ target_sources(fairseq2n
data/py.cc
data/record_reader.cc
data/round_robin_data_source.cc
data/sample_data_source.cc
data/shard_data_source.cc
data/shuffle_data_source.cc
data/skip_data_source.cc
Expand Down
34 changes: 34 additions & 0 deletions fairseq2n/src/fairseq2n/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "fairseq2n/data/prefetch_data_source.h"
#include "fairseq2n/data/take_data_source.h"
#include "fairseq2n/data/round_robin_data_source.h"
#include "fairseq2n/data/sample_data_source.h"
#include "fairseq2n/data/shard_data_source.h"
#include "fairseq2n/data/shuffle_data_source.h"
#include "fairseq2n/data/skip_data_source.h"
Expand Down Expand Up @@ -224,6 +225,39 @@ data_pipeline::round_robin(std::vector<data_pipeline> pipelines)
return data_pipeline_builder{std::move(factory)};
}

data_pipeline_builder
data_pipeline::sample(
std::vector<data_pipeline> pipelines,
std::optional<std::vector<float32>> weights)
{
if (pipelines.empty())
throw_<std::invalid_argument>(
"`pipelines` does not contain any elements. Can not sample from empty set.");

bool is_broken = std::any_of(
pipelines.begin(), pipelines.end(), [](const data_pipeline &pipeline)
{
return pipeline.is_broken();
});
if (is_broken)
throw_<std::invalid_argument>(
"At least one of the specified data pipelines is broken and cannot be used in sample.");

if (!weights)
weights = std::vector<float32>(pipelines.size(), 1.0F / static_cast<float32>(pipelines.size()));
else if (weights.value().size() != pipelines.size())
throw_<std::invalid_argument>(
"The number of `pipelines` and the number of `weights` must be equal, but are {} and {} instead.", pipelines.size(), weights.value().size());

auto tmp = std::make_shared<std::vector<data_pipeline>>(std::move(pipelines));

auto factory = [tmp, weights=std::move(weights.value())]() mutable {
return std::make_unique<sample_data_source>(std::move(*tmp), std::move(weights));
};

return data_pipeline_builder{std::move(factory)};
}

data_pipeline_builder
data_pipeline::constant(data example, std::optional<std::string> key)
{
Expand Down
5 changes: 5 additions & 0 deletions fairseq2n/src/fairseq2n/data/data_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ class FAIRSEQ2_API data_pipeline {
static data_pipeline_builder
round_robin(std::vector<data_pipeline> pipelines);

static data_pipeline_builder
sample(
std::vector<data_pipeline> pipelines,
std::optional<std::vector<float>> weights = {});

static data_pipeline_builder
constant(data example, std::optional<std::string> key = {});

Expand Down
72 changes: 72 additions & 0 deletions fairseq2n/src/fairseq2n/data/sample_data_source.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include "fairseq2n/data/sample_data_source.h"

#include <ATen/CPUGeneratorImpl.h>
#include <ATen/Context.h>
#include <ATen/Functions.h>

#include <stdexcept>

#include "fairseq2n/utils/tensor.h"

namespace fairseq2n::detail {

sample_data_source::sample_data_source(std::vector<data_pipeline> &&pipelines, std::vector<float32> &&weights)
: pipelines_(std::move(pipelines))
{
weights_ = make_tensor_from_vector(weights, { static_cast<std::int64_t>(pipelines_.size()) });
generator_ = at::globalContext().defaultGenerator(c10::DeviceType::CPU);
}

std::optional<data>
sample_data_source::next()
{
if (eod_)
return std::nullopt;

std::optional<data> output = pipelines_[next_index()].next();
if (!output)
eod_ = true;

return output;
}

void
sample_data_source::reset()
{
eod_ = false;
for (data_pipeline &p : pipelines_)
p.reset();
}

void
sample_data_source::record_position(tape &t) const
{
t.record(eod_);
for (const data_pipeline &p : pipelines_)
p.record_position(t);
}

void
sample_data_source::reload_position(tape &t)
{
eod_ = t.read<bool>();
for (data_pipeline &p : pipelines_)
p.reload_position(t);
}

std::size_t
sample_data_source::next_index()
{
auto result = at::multinomial(weights_, 1, false, generator_)
.item<std::int64_t>();

return static_cast<std::size_t>(result);
}

} // namespace fairseq2::detail
50 changes: 50 additions & 0 deletions fairseq2n/src/fairseq2n/data/sample_data_source.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <memory>
#include <vector>

#include <ATen/Generator.h>
#include <ATen/Tensor.h>

#include "fairseq2n/data/data_pipeline.h"
#include "fairseq2n/data/data_source.h"

namespace fairseq2n::detail {

/// @brief sample from a list of datasources
class sample_data_source final : public data_source {
public:
explicit
sample_data_source(std::vector<data_pipeline> &&pipelines, std::vector<float32> &&weights);

std::optional<data>
next() override;

void
reset() override;

void
record_position(tape &t) const override;

void
reload_position(tape &t) override;

private:
std::size_t
next_index();

private:
std::vector<data_pipeline> pipelines_;
bool eod_ = false;

at::Generator generator_;
at::Tensor weights_;
};

} // namespace fairseq2::detail
31 changes: 31 additions & 0 deletions fairseq2n/src/fairseq2n/utils/tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <memory>
#include <vector>

#include <ATen/Tensor.h>

namespace fairseq2n::detail {

template <typename T>
inline at::Tensor
make_tensor_from_vector(
const std::vector<T> &src,
const std::initializer_list<std::int64_t> &shape) noexcept
{
auto storage = std::make_shared<std::vector<T>>(src);

return at::from_blob(
storage->data(),
c10::ArrayRef<std::int64_t>(shape),
[storage](void*) mutable { storage.reset(); }
);
}

} // namespace fairseq2::detail
13 changes: 13 additions & 0 deletions src/fairseq2/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,19 @@ def round_robin(pipelines: Sequence["DataPipeline"]) -> "DataPipelineBuilder":
The data pipelines to round robin.
"""

@staticmethod
def sample(
pipelines: Sequence["DataPipeline"],
weights: Optional[Sequence[float]] = None,
) -> "DataPipelineBuilder":
"""Extract examples from ``pipelines`` by sampling based on ``weights``.
:param data_pipelines:
The data pipelines to sample from.
:param weights:
Desired distribution of pipelines. If None, use uniform distribution.
"""

@staticmethod
def constant(example: Any, key: Optional[str] = None) -> "DataPipelineBuilder":
...
Expand Down
Loading

0 comments on commit af213ff

Please sign in to comment.