Skip to content

Commit

Permalink
[IO] Return bytes written in Stream::Write (#686)
Browse files Browse the repository at this point in the history
This commit updates the `Stream::Write` method to return the number of
bytes written, analogous to the current behavior of `Stream::Read`.
This information is necessary to correctly model partial writes to a
buffered stream.  (e.g. Writing to a OS pipe when the pipe's buffer is
smaller than the data to be written.)
  • Loading branch information
Lunderberg committed May 22, 2024
1 parent 2e19f7f commit 3031e4a
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 31 deletions.
11 changes: 6 additions & 5 deletions include/dmlc/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,17 @@ class Stream { // NOLINT(*)
/*!
* \brief reads data from a stream
* \param ptr pointer to a memory buffer
* \param size block size
* \return the size of data read
* \param size The maximum number of bytes to read
* \return The number of bytes read from the stream
*/
virtual size_t Read(void *ptr, size_t size) = 0;
virtual size_t Read(void* ptr, size_t size) = 0;
/*!
* \brief writes data to a stream
* \param ptr pointer to a memory buffer
* \param size block size
* \param size The maximum number of bytes to write
* \return The number of bytes written
*/
virtual void Write(const void *ptr, size_t size) = 0;
virtual size_t Write(const void* ptr, size_t size) = 0;
/*! \brief virtual destructor */
virtual ~Stream(void) {}
/*!
Expand Down
22 changes: 12 additions & 10 deletions include/dmlc/memory_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,24 @@ struct MemoryFixedSizeStream : public SeekStream {
buffer_size_(buffer_size) {
curr_ptr_ = 0;
}
virtual size_t Read(void *ptr, size_t size) {
virtual size_t Read(void *ptr, size_t size) override {
CHECK(curr_ptr_ + size <= buffer_size_);
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
virtual void Write(const void *ptr, size_t size) {
if (size == 0) return;
virtual size_t Write(const void *ptr, size_t size) override {
if (size == 0) return 0;
CHECK(curr_ptr_ + size <= buffer_size_);
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
curr_ptr_ += size;
return size;
}
virtual void Seek(size_t pos) {
virtual void Seek(size_t pos) override {
curr_ptr_ = static_cast<size_t>(pos);
}
virtual size_t Tell(void) {
virtual size_t Tell(void) override {
return curr_ptr_;
}

Expand All @@ -73,25 +74,26 @@ struct MemoryStringStream : public dmlc::SeekStream {
: p_buffer_(p_buffer) {
curr_ptr_ = 0;
}
virtual size_t Read(void *ptr, size_t size) {
virtual size_t Read(void *ptr, size_t size) override {
CHECK(curr_ptr_ <= p_buffer_->length());
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
virtual void Write(const void *ptr, size_t size) {
if (size == 0) return;
virtual size_t Write(const void *ptr, size_t size) override {
if (size == 0) return 0;
if (curr_ptr_ + size > p_buffer_->length()) {
p_buffer_->resize(curr_ptr_+size);
}
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
curr_ptr_ += size;
return size;
}
virtual void Seek(size_t pos) {
virtual void Seek(size_t pos) override {
curr_ptr_ = static_cast<size_t>(pos);
}
virtual size_t Tell(void) {
virtual size_t Tell(void) override {
return curr_ptr_;
}

Expand Down
9 changes: 5 additions & 4 deletions src/io/hdfs_filesys.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class HDFSStream : public SeekStream {
}
}

virtual size_t Read(void *ptr, size_t size) {
virtual size_t Read(void *ptr, size_t size) override {
char *buf = static_cast<char*>(ptr);
size_t nleft = size;
size_t nmax = static_cast<size_t>(std::numeric_limits<tSize>::max());
Expand All @@ -48,7 +48,7 @@ class HDFSStream : public SeekStream {
return size - nleft;
}

virtual void Write(const void *ptr, size_t size) {
virtual size_t Write(const void *ptr, size_t size) override {
const char *buf = reinterpret_cast<const char*>(ptr);
size_t nleft = size;
// When using builtin-java classes to write, the maximum write size
Expand All @@ -70,14 +70,15 @@ class HDFSStream : public SeekStream {
LOG(FATAL) << "HDFSStream.hdfsWrite Error:" << strerror(errsv);
}
}
return size - nleft;
}
virtual void Seek(size_t pos) {
virtual void Seek(size_t pos) override {
if (hdfsSeek(fs_, fp_, pos) != 0) {
int errsv = errno;
LOG(FATAL) << "HDFSStream.hdfsSeek Error:" << strerror(errsv);
}
}
virtual size_t Tell(void) {
virtual size_t Tell(void) override {
tOffset offset = hdfsTell(fs_, fp_);
if (offset == -1) {
int errsv = errno;
Expand Down
9 changes: 5 additions & 4 deletions src/io/local_filesys.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,22 @@ class FileStream : public SeekStream {
virtual ~FileStream(void) {
this->Close();
}
virtual size_t Read(void *ptr, size_t size) {
virtual size_t Read(void *ptr, size_t size) override {
return std::fread(ptr, 1, size, fp_);
}
virtual void Write(const void *ptr, size_t size) {
virtual size_t Write(const void *ptr, size_t size) override {
CHECK(std::fwrite(ptr, 1, size, fp_) == size)
<< "FileStream.Write incomplete";
return 0;
}
virtual void Seek(size_t pos) {
virtual void Seek(size_t pos) override {
#ifndef _MSC_VER
CHECK(!std::fseek(fp_, static_cast<long>(pos), SEEK_SET)); // NOLINT(*)
#else // _MSC_VER
CHECK(!_fseeki64(fp_, pos, SEEK_SET));
#endif // _MSC_VER
}
virtual size_t Tell(void) {
virtual size_t Tell(void) override {
#ifndef _MSC_VER
return std::ftell(fp_);
#else // _MSC_VER
Expand Down
16 changes: 9 additions & 7 deletions src/io/s3_filesys.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,23 +424,24 @@ class CURLReadStreamBase : public SeekStream {
virtual ~CURLReadStreamBase() {
this->Cleanup();
}
virtual size_t Tell(void) {
virtual size_t Tell(void) override {
return curr_bytes_;
}
virtual bool AtEnd(void) const {
return at_end_;
}
virtual void Write(const void *ptr, size_t size) {
virtual size_t Write(const void *ptr, size_t size) override {
LOG(FATAL) << "CURL.ReadStream cannot be used for write";
return 0;
}
// lazy seek function
virtual void Seek(size_t pos) {
virtual void Seek(size_t pos) override {
if (curr_bytes_ != pos) {
this->Cleanup();
curr_bytes_ = pos;
}
}
virtual size_t Read(void *ptr, size_t size);
virtual size_t Read(void *ptr, size_t size) override ;

protected:
CURLReadStreamBase()
Expand Down Expand Up @@ -790,11 +791,11 @@ class WriteStream : public Stream {
ecurl_ = curl_easy_init();
this->Init();
}
virtual size_t Read(void *ptr, size_t size) {
virtual size_t Read(void *ptr, size_t size) override {
LOG(FATAL) << "S3.WriteStream cannot be used for read";
return 0;
}
virtual void Write(const void *ptr, size_t size);
virtual size_t Write(const void *ptr, size_t size) override;
// destructor
virtual ~WriteStream() {
this->Close();
Expand Down Expand Up @@ -863,13 +864,14 @@ class WriteStream : public Stream {
void Finish(void);
};

void WriteStream::Write(const void *ptr, size_t size) {
size_t WriteStream::Write(const void *ptr, size_t size) {
size_t rlen = buffer_.length();
buffer_.resize(rlen + size);
std::memcpy(BeginPtr(buffer_) + rlen, ptr, size);
if (buffer_.length() >= max_buffer_size_) {
this->Upload();
}
return size;
}

void WriteStream::Run(const std::string &method,
Expand Down
3 changes: 2 additions & 1 deletion src/io/single_file_split.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ class SingleFileSplit : public InputSplit {
CHECK(part_index == 0 && num_parts == 1);
this->BeforeFirst();
}
virtual void Write(const void * /*ptr*/, size_t /*size*/) {
virtual size_t Write(const void * /*ptr*/, size_t /*size*/) {
LOG(FATAL) << "InputSplit do not support write";
return 0;
}
virtual bool NextRecord(Blob *out_rec) {
if (chunk_begin_ == chunk_end_) {
Expand Down

0 comments on commit 3031e4a

Please sign in to comment.