Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions include/svs/core/data/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,22 @@ void populate_impl(
}
}

template <data::MemoryDataset Data> void populate(std::istream& is, Data& data) {
auto accessor = DefaultWriteAccessor();

size_t num_vectors = data.size();
size_t dims = data.dimensions();

auto max_lines = Dynamic;
auto nvectors = std::min(num_vectors, max_lines);

auto reader = lib::VectorReader<typename Data::element_type>(dims);
for (size_t i = 0; i < nvectors; ++i) {
reader.read(is);
accessor.set(data, i, reader.data());
}
}

// Intercept the native file to perform dispatch on the actual file type.
template <data::MemoryDataset Data, typename WriteAccessor>
void populate_impl(
Expand Down Expand Up @@ -120,6 +136,15 @@ void save(const Dataset& data, const File& file, const lib::UUID& uuid = lib::Ze
return save(data, accessor, file, uuid);
}

template <data::ImmutableMemoryDataset Dataset>
void save(const Dataset& data, std::ostream& os) {
auto accessor = DefaultReadAccessor();
auto writer = svs::io::v1::StreamWriter<void>(os);
for (size_t i = 0; i < data.size(); ++i) {
writer << accessor.get(data, i);
}
}

///
/// @brief Save the dataset as a "*vecs" file.
///
Expand Down Expand Up @@ -169,6 +194,14 @@ lib::lazy_result_t<F, size_t, size_t> load_dataset(const File& file, const F& la
return load_impl(detail::to_native(file), default_accessor, lazy);
}

template <lib::LazyInvocable<size_t, size_t> F>
lib::lazy_result_t<F, size_t, size_t>
load_dataset(std::istream& is, const F& lazy, size_t num_vectors, size_t dims) {
auto data = lazy(num_vectors, dims);
populate(is, data);
return data;
}

// Return whether or not a file is directly loadable via file-extension.
inline bool special_by_file_extension(std::string_view path) {
return (path.ends_with("svs") || path.ends_with("vecs") || path.ends_with("bin"));
Expand Down
87 changes: 79 additions & 8 deletions include/svs/core/data/simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,42 @@ class GenericSerializer {
}

template <data::ImmutableMemoryDataset Data>
static lib::SaveTable save(const Data& data, const lib::SaveContext& ctx) {
static lib::SaveTable save_table(const Data& data) {
using T = typename Data::element_type;
// UUID used to identify the file.
auto uuid = lib::UUID{};
auto filename = ctx.generate_name("data");
io::save(data, io::NativeFile(filename), uuid);
return lib::SaveTable(
auto table = lib::SaveTable(
serialization_schema,
save_version,
{
{"name", "uncompressed"},
{"binary_file", lib::save(filename.filename())},
{"dims", lib::save(data.dimensions())},
{"num_vectors", lib::save(data.size())},
{"uuid", uuid.str()},
{"eltype", lib::save(datatype_v<T>)},
}
);
return table;
}

template <data::ImmutableMemoryDataset Data, class FileName_t>
static lib::SaveTable
save_table(const Data& data, const FileName_t& filename, const lib::UUID& uuid) {
auto table = save_table(data);
table.insert("binary_file", filename);
table.insert("uuid", uuid.str());
return table;
}

template <data::ImmutableMemoryDataset Data>
static lib::SaveTable save(const Data& data, const lib::SaveContext& ctx) {
// UUID used to identify the file.
auto uuid = lib::UUID{};
auto filename = ctx.generate_name("data");
io::save(data, io::NativeFile(filename), uuid);
return save_table(data, lib::save(filename.filename()), uuid);
}

template <data::ImmutableMemoryDataset Data>
static void save(const Data& data, std::ostream& os) {
io::save(data, os);
}

template <typename T, lib::LazyInvocable<size_t, size_t> F>
Expand All @@ -116,6 +134,25 @@ class GenericSerializer {
}
return io::load_dataset(binaryfile.value(), lazy);
}

template <typename T, lib::LazyInvocable<size_t, size_t> F>
static lib::lazy_result_t<F, size_t, size_t>
load(const lib::ContextFreeLoadTable& table, std::istream& is, const F& lazy) {
auto datatype = lib::load_at<DataType>(table, "eltype");
if (datatype != datatype_v<T>) {
throw ANNEXCEPTION(
"Trying to load an uncompressed dataset with element types {} to a dataset "
"with element types {}.",
name(datatype),
name<datatype_v<T>>()
);
}

size_t num_vectors = lib::load_at<size_t>(table, "num_vectors");
size_t dims = lib::load_at<size_t>(table, "dims");

return io::load_dataset(is, lazy, num_vectors, dims);
}
};

struct Matcher {
Expand Down Expand Up @@ -405,6 +442,10 @@ class SimpleData {
return GenericSerializer::save(*this, ctx);
}

void save(std::ostream& os) const { return GenericSerializer::save(*this, os); }

lib::SaveTable save_table() const { return GenericSerializer::save_table(*this); }

static bool check_load_compatibility(std::string_view schema, lib::Version version) {
return GenericSerializer::check_compatibility(schema, version);
}
Expand All @@ -431,6 +472,20 @@ class SimpleData {
);
}

static SimpleData load(
const lib::ContextFreeLoadTable& table,
std::istream& is,
const allocator_type& allocator = {}
)
requires(!is_view)
{
return GenericSerializer::load<T>(
table, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) {
return SimpleData(n_elements, n_dimensions, allocator);
})
);
}

///
/// @brief Try to automatically load the dataset.
///
Expand Down Expand Up @@ -805,6 +860,10 @@ class SimpleData<T, Extent, Blocked<Alloc>> {
return GenericSerializer::save(*this, ctx);
}

void save(std::ostream& os) const { return GenericSerializer::save(*this, os); }

lib::SaveTable save_table() const { return GenericSerializer::save_table(*this); }

static bool check_load_compatibility(std::string_view schema, lib::Version version) {
return GenericSerializer::check_compatibility(schema, version);
}
Expand All @@ -818,6 +877,18 @@ class SimpleData<T, Extent, Blocked<Alloc>> {
);
}

static SimpleData load(
const lib::ContextFreeLoadTable& table,
std::istream& is,
const Blocked<Alloc>& allocator = {}
) {
return GenericSerializer::load<T>(
table, is, lib::Lazy([&allocator](size_t n_elements, size_t n_dimensions) {
return SimpleData(n_elements, n_dimensions, allocator);
})
);
}

static SimpleData
load(const std::filesystem::path& path, const Blocked<Alloc>& allocator = {}) {
if (detail::is_likely_reload(path)) {
Expand Down
78 changes: 50 additions & 28 deletions include/svs/core/io/native.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,28 +344,16 @@ struct Header {
static_assert(sizeof(Header) == header_size, "Mismatch in Native io::v1 header sizes!");
static_assert(std::is_trivially_copyable_v<Header>, "Header must be trivially copyable!");

template <typename T = void> class Writer {
// CRTP
template <typename T, class Derived> class Writer {
public:
Writer(
const std::string& path,
size_t dimension,
lib::UUID uuid = lib::UUID(lib::ZeroInitializer())
)
: dimension_{dimension}
, uuid_{uuid}
, stream_{lib::open_write(path, std::ofstream::out | std::ofstream::binary)} {
// Write a temporary header.
stream_.seekp(0, std::ofstream::beg);
lib::write_binary(stream_, Header());
}

size_t dimensions() const { return dimension_; }
void overwrite_num_vectors(size_t num_vectors) { vectors_written_ = num_vectors; }

// TODO: Error checking to make sure the length is correct.
template <typename U> Writer& append(U&& v) {
std::ostream& os = static_cast<Derived*>(this)->stream();
for (const auto& i : v) {
lib::write_binary(stream_, lib::io_convert<T>(i));
lib::write_binary(os, lib::io_convert<T>(i));
}
++vectors_written_;
return *this;
Expand All @@ -374,21 +362,45 @@ template <typename T = void> class Writer {
template <typename... Ts>
requires std::is_same_v<T, void>
Writer& append(std::tuple<Ts...>&& v) {
lib::foreach (v, [&](const auto& x) { lib::write_binary(stream_, x); });
std::ostream& os = static_cast<Derived*>(this)->stream();
lib::foreach (v, [&](const auto& x) { lib::write_binary(os, x); });
++vectors_written_;
return *this;
}

template <typename U> Writer& operator<<(U&& v) { return append(std::forward<U>(v)); }

protected:
size_t vectors_written_ = 0;
};

template <typename T = void> class FileWriter : public Writer<T, FileWriter<T>> {
public:
FileWriter(
const std::string& path,
size_t dimension,
lib::UUID uuid = lib::UUID(lib::ZeroInitializer())
)
: dimension_{dimension}
, uuid_{uuid}
, stream_{lib::open_write(path, std::ofstream::out | std::ofstream::binary)} {
// Write a temporary header.
stream_.seekp(0, std::ofstream::beg);
lib::write_binary(stream_, Header());
}

std::ostream& stream() { return stream_; }

size_t dimensions() const { return dimension_; }

void flush() { stream_.flush(); }

void writeheader(bool resume = true) {
auto position = stream_.tellp();
// Write to the header the number of vectors actually written.
stream_.seekp(0);
assert(stream_.good());
lib::write_binary(stream_, Header(vectors_written_, dimension_, uuid_));
lib::write_binary(stream_, Header(this->vectors_written_, dimension_, uuid_));
if (resume) {
stream_.seekp(position, std::ofstream::beg);
}
Expand All @@ -402,20 +414,30 @@ template <typename T = void> class Writer {
//
// We delete the copy constructor and copy assignment operators because
// `std::ofstream` isn't copyable anyways.
Writer(const Writer&) = delete;
Writer& operator=(const Writer&) = delete;
Writer(Writer&&) = delete;
Writer& operator=(Writer&&) = delete;
FileWriter(const FileWriter&) = delete;
FileWriter& operator=(const FileWriter&) = delete;
FileWriter(FileWriter&&) = delete;
FileWriter& operator=(FileWriter&&) = delete;

// Write the header for the file.
~Writer() noexcept { writeheader(); }
~FileWriter() noexcept { writeheader(); }

private:
size_t dimension_;
lib::UUID uuid_;
std::ofstream stream_;
size_t writes_this_vector_ = 0;
size_t vectors_written_ = 0;
};

template <typename T = void> class StreamWriter : public Writer<T, StreamWriter<T>> {
public:
StreamWriter(std::ostream& os)
: stream_{os} {}

std::ostream& stream() { return stream_; }

private:
std::ostream& stream_;
};

///
Expand Down Expand Up @@ -449,13 +471,13 @@ class NativeFile {
}

template <typename T>
Writer<T> writer(
FileWriter<T> writer(
lib::Type<T> SVS_UNUSED(type), size_t dimension, lib::UUID uuid = lib::ZeroUUID
) const {
return Writer<T>(path_, dimension, uuid);
return FileWriter<T>(path_, dimension, uuid);
}

Writer<> writer(size_t dimensions, lib::UUID uuid = lib::ZeroUUID) const {
FileWriter<> writer(size_t dimensions, lib::UUID uuid = lib::ZeroUUID) const {
return writer(lib::Type<void>(), dimensions, uuid);
}

Expand Down Expand Up @@ -715,7 +737,7 @@ class NativeFile {
public:
using compatible_file_types = lib::Types<vtest::NativeFile, v1::NativeFile>;

template <typename T> using Writer = v1::Writer<T>;
template <typename T> using Writer = v1::FileWriter<T>;

explicit NativeFile(std::filesystem::path path)
: path_{std::move(path)} {}
Expand Down
2 changes: 2 additions & 0 deletions include/svs/index/flat/flat.h
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,8 @@ class FlatIndex {
void save(const std::filesystem::path& data_directory) const {
lib::save_to_disk(data_, data_directory);
}

void save(std::ostream& os) const { lib::save_to_stream(data_, os); }
};

///
Expand Down
Loading
Loading