diff --git a/build.gradle b/build.gradle index 4c980b3..4353cda 100644 --- a/build.gradle +++ b/build.gradle @@ -35,6 +35,7 @@ dependencies { testImplementation 'run.halo.app:api' testImplementation 'org.springframework.boot:spring-boot-starter-test' + testImplementation 'io.projectreactor:reactor-test' } test { diff --git a/src/main/java/run/halo/s3os/S3OsAttachmentHandler.java b/src/main/java/run/halo/s3os/S3OsAttachmentHandler.java index b68d4c9..8317b34 100644 --- a/src/main/java/run/halo/s3os/S3OsAttachmentHandler.java +++ b/src/main/java/run/halo/s3os/S3OsAttachmentHandler.java @@ -14,7 +14,10 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.pf4j.Extension; +import org.reactivestreams.Publisher; import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.MediaType; import org.springframework.http.MediaTypeFactory; import org.springframework.lang.Nullable; @@ -22,8 +25,10 @@ import org.springframework.web.server.ServerWebInputException; import org.springframework.web.util.UriUtils; import reactor.core.Exceptions; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; +import reactor.util.context.Context; import reactor.util.retry.Retry; import run.halo.app.core.extension.attachment.Attachment; import run.halo.app.core.extension.attachment.Attachment.AttachmentSpec; @@ -232,10 +237,52 @@ private S3Presigner buildS3Presigner(S3OsProperties properties) { .build(); } + Flux reshape(Publisher content, int bufferSize) { + var dataBufferFactory = DefaultDataBufferFactory.sharedInstance; + return Flux.create(sink -> { + var byteBuffer = ByteBuffer.allocate(bufferSize); + Flux.from(content) + .doOnNext(dataBuffer -> { + var count = dataBuffer.readableByteCount(); + for (var i = 0; i < count; i++) { + byteBuffer.put(dataBuffer.read()); + // Emit the buffer when buffer + if (!byteBuffer.hasRemaining()) { + sink.next(deepCopy(byteBuffer)); + byteBuffer.clear(); + } + } + }) + .doOnComplete(() -> { + // Emit the last part of buffer. + if (byteBuffer.position() > 0) { + sink.next(deepCopy(byteBuffer)); + } + }) + .subscribe(DataBufferUtils::release, sink::error, sink::complete, + Context.of(sink.contextView())); + }) + .map(dataBufferFactory::wrap) + .cast(DataBuffer.class) + .doOnDiscard(DataBuffer.class, DataBufferUtils::release); + } + + ByteBuffer deepCopy(ByteBuffer src) { + src.flip(); + var dest = ByteBuffer.allocate(src.limit()); + dest.put(src); + src.rewind(); + dest.flip(); + return dest; + } + Mono upload(UploadContext uploadContext, S3OsProperties properties) { return Mono.using(() -> buildS3Client(properties), client -> { var uploadState = new UploadState(properties, uploadContext.file().filename()); + + var content = uploadContext.file().content(); + return checkFileExistsAndRename(uploadState, client) // init multipart upload .flatMap(state -> Mono.fromCallable(() -> client.createMultipartUpload( @@ -243,12 +290,12 @@ Mono upload(UploadContext uploadContext, S3OsProperties properties .bucket(properties.getBucket()) .contentType(state.contentType) .key(state.objectKey) - .build())).subscribeOn(Schedulers.boundedElastic())) - .flatMapMany((response) -> { + .build()))) + .doOnNext((response) -> { checkResult(response, "createMultipartUpload"); uploadState.uploadId = response.uploadId(); - return uploadContext.file().content(); }) + .thenMany(reshape(content, MULTIPART_MIN_PART_SIZE)) // buffer to part .windowUntil((buffer) -> { uploadState.buffered += buffer.readableByteCount(); diff --git a/src/test/java/run/halo/s3os/S3OsAttachmentHandlerTest.java b/src/test/java/run/halo/s3os/S3OsAttachmentHandlerTest.java index b31c261..48018b6 100644 --- a/src/test/java/run/halo/s3os/S3OsAttachmentHandlerTest.java +++ b/src/test/java/run/halo/s3os/S3OsAttachmentHandlerTest.java @@ -1,12 +1,19 @@ package run.halo.s3os; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import java.util.List; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; import run.halo.app.core.extension.attachment.Policy; class S3OsAttachmentHandlerTest { @@ -33,4 +40,57 @@ void acceptHandlingWhenPolicyTemplateIsExpected() { // policy is null assertFalse(handler.shouldHandle(null)); } + + @Test + void reshapeDataBufferWithSmallerBufferSize() { + var handler = new S3OsAttachmentHandler(); + var factory = DefaultDataBufferFactory.sharedInstance; + var content = Flux.fromIterable(List.of(factory.wrap("halo".getBytes()))); + + StepVerifier.create(handler.reshape(content, 2)) + .assertNext(dataBuffer -> { + var str = dataBuffer.toString(UTF_8); + assertEquals("ha", str); + }) + .assertNext(dataBuffer -> { + var str = dataBuffer.toString(UTF_8); + assertEquals("lo", str); + }) + .verifyComplete(); + } + + @Test + void reshapeDataBufferWithBiggerBufferSize() { + var handler = new S3OsAttachmentHandler(); + var factory = DefaultDataBufferFactory.sharedInstance; + var content = Flux.fromIterable(List.of(factory.wrap("halo".getBytes()))); + + StepVerifier.create(handler.reshape(content, 10)) + .assertNext(dataBuffer -> { + var str = dataBuffer.toString(UTF_8); + assertEquals("halo", str); + }) + .verifyComplete(); + } + + @Test + void reshapeDataBuffersWithBiggerBufferSize() { + var handler = new S3OsAttachmentHandler(); + var factory = DefaultDataBufferFactory.sharedInstance; + var content = Flux.fromIterable(List.of( + factory.wrap("ha".getBytes()), + factory.wrap("lo".getBytes()) + )); + + StepVerifier.create(handler.reshape(content, 3)) + .assertNext(dataBuffer -> { + var str = dataBuffer.toString(UTF_8); + assertEquals("hal", str); + }) + .assertNext(dataBuffer -> { + var str = dataBuffer.toString(UTF_8); + assertEquals("o", str); + }) + .verifyComplete(); + } }