diff --git a/ReactAndroid/src/main/java/com/facebook/react/modules/network/ProgressRequestBody.java b/ReactAndroid/src/main/java/com/facebook/react/modules/network/ProgressRequestBody.java index 9676071aa714b7..605343024477ab 100644 --- a/ReactAndroid/src/main/java/com/facebook/react/modules/network/ProgressRequestBody.java +++ b/ReactAndroid/src/main/java/com/facebook/react/modules/network/ProgressRequestBody.java @@ -9,60 +9,75 @@ package com.facebook.react.modules.network; +import com.facebook.common.internal.CountingOutputStream; + import java.io.IOException; + import okhttp3.MediaType; import okhttp3.RequestBody; import okio.BufferedSink; -import okio.Buffer; -import okio.Sink; -import okio.ForwardingSink; import okio.Okio; +import okio.Sink; public class ProgressRequestBody extends RequestBody { private final RequestBody mRequestBody; private final ProgressListener mProgressListener; private BufferedSink mBufferedSink; + private long mContentLength = 0L; public ProgressRequestBody(RequestBody requestBody, ProgressListener progressListener) { - mRequestBody = requestBody; - mProgressListener = progressListener; + mRequestBody = requestBody; + mProgressListener = progressListener; } @Override public MediaType contentType() { - return mRequestBody.contentType(); + return mRequestBody.contentType(); } @Override public long contentLength() throws IOException { - return mRequestBody.contentLength(); + if (mContentLength == 0) { + mContentLength = mRequestBody.contentLength(); + } + return mContentLength; } @Override public void writeTo(BufferedSink sink) throws IOException { - if (mBufferedSink == null) { - mBufferedSink = Okio.buffer(sink(sink)); - } - mRequestBody.writeTo(mBufferedSink); - mBufferedSink.flush(); + if (mBufferedSink == null) { + mBufferedSink = Okio.buffer(outputStreamSink(sink)); + } + + // contentLength changes for input streams, since we're using inputStream.available(), + // so get the length before writing to the sink + contentLength(); + + mRequestBody.writeTo(mBufferedSink); + mBufferedSink.flush(); } - private Sink sink(Sink sink) { - return new ForwardingSink(sink) { - long bytesWritten = 0L; - long contentLength = 0L; + private Sink outputStreamSink(BufferedSink sink) { + return Okio.sink(new CountingOutputStream(sink.outputStream()) { + @Override + public void write(byte[] buffer, int off, int len) throws IOException { + super.write(buffer, off, len); + sendProgressUpdate(); + } - @Override - public void write(Buffer source, long byteCount) throws IOException { - super.write(source, byteCount); - if (contentLength == 0) { - contentLength = contentLength(); - } - bytesWritten += byteCount; - mProgressListener.onProgress( - bytesWritten, contentLength, bytesWritten == contentLength); - } - }; + @Override + public void write(int buffer) throws IOException { + super.write(buffer); + sendProgressUpdate(); + } + + private void sendProgressUpdate() throws IOException { + long bytesWritten = getCount(); + long contentLength = contentLength(); + mProgressListener.onProgress( + bytesWritten, contentLength, bytesWritten == contentLength); + } + }); } }