Skip to content

Commit

Permalink
[api] Refactor PublisherBytesSupplier.java
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Oct 31, 2023
1 parent 3927867 commit 4863366
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,10 @@

import ai.djl.ndarray.BytesSupplier;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

/**
Expand All @@ -29,16 +26,14 @@
*/
public class PublisherBytesSupplier implements BytesSupplier {

private final List<byte[]> allData;
private final AtomicBoolean completed;
private Consumer<byte[]> subscriber;
private final AtomicInteger dataPushed;
private CountDownLatch latch;
private CompletableFuture<Void> future;

/** Constructs a {@link PublisherBytesSupplier}. */
public PublisherBytesSupplier() {
allData = new ArrayList<>();
completed = new AtomicBoolean();
dataPushed = new AtomicInteger();
latch = new CountDownLatch(1);
future = new CompletableFuture<>();
}

/**
Expand All @@ -48,83 +43,42 @@ public PublisherBytesSupplier() {
* @param lastChunk true if this is the last chunk
*/
public void appendContent(byte[] data, boolean lastChunk) {
synchronized (allData) {
allData.add(data);
if (subscriber == null) {
try {
if (!latch.await(2, TimeUnit.MINUTES)) {
throw new IllegalStateException("Wait for subscriber timeout.");
}
} catch (InterruptedException e) {
throw new IllegalStateException("Append content interrupted.", e);
}
}
subscriber.accept(data);
if (lastChunk) {
completed.set(true);
subscriber.accept(null);
future.complete(null);
}
pushData();
}

/**
* Adds the subscriber to the {@link BytesSupplier} to get notified about additional data.
*
* @param subscriber a consumer function that will receive bytes when new daata is added and
* null when completed
* @return a {@code CompletableFuture} object
*/
public void subscribe(Consumer<byte[]> subscriber) {
public CompletableFuture<Void> subscribe(Consumer<byte[]> subscriber) {
if (this.subscriber != null) {
throw new IllegalStateException(
"The PublisherBytesSupplier only allows a single Subscriber");
}
this.subscriber = subscriber;
pushData();
}

private void pushData() {
if (subscriber == null) {
return;
}

int dataAvailable;
synchronized (allData) {
dataAvailable = allData.size();
}

int sent = dataPushed.getAndSet(dataAvailable);
if (sent < dataAvailable) {
synchronized (this) {
for (; sent < dataAvailable; sent++) {
subscriber.accept(allData.get(sent));
}
if (completed.get()) {
subscriber.accept(null);
}
}
}
}

/** Waits until completed before passing thread (BLOCKS THREAD!). */
@SuppressWarnings("PMD.EmptyControlStatement")
public void waitToRead() {
// Block until complete!!!
while (!completed.get()) {
// Do nothing
}
}

/** {@inheritDoc} */
@Override
public byte[] getAsBytes() {
if (!completed.get()) {
throw new IllegalStateException(
"PublisherByteSupplier must be completely filled before reading.");
}

try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
for (byte[] data : allData) {
bos.write(data);
}
return bos.toByteArray();
} catch (IOException e) {
throw new AssertionError("Failed to read BytesSupplier", e);
}
latch.countDown();
return future;
}

/** {@inheritDoc} */
@Override
public ByteBuffer toByteBuffer() {
return ByteBuffer.wrap(getAsBytes());
throw new UnsupportedOperationException("Not supported.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,38 @@
import org.testng.Assert;
import org.testng.annotations.Test;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;

public class PublisherBytesSupplierTest {

@Test
public void test() {
public void test() throws ExecutionException, InterruptedException {
AtomicInteger contentCount = new AtomicInteger();
PublisherBytesSupplier supplier = new PublisherBytesSupplier();

// Add to supplier without subscriber
supplier.appendContent(new byte[] {1}, false);
Assert.assertEquals(contentCount.get(), 0);
new Thread(
() -> {
// Add to supplier without subscriber
supplier.appendContent(new byte[] {1}, false);
// Add to supplier with subscriber
supplier.appendContent(new byte[] {1}, true);
})
.start();

// Subscribing with data should trigger subscriptions
supplier.subscribe(
d -> {
if (d == null) {
// Do nothing on completion
return;
}
contentCount.getAndIncrement();
});
Assert.assertEquals(contentCount.get(), 1);
CompletableFuture<Void> future =
supplier.subscribe(
d -> {
if (d == null) {
// Do nothing on completion
return;
}
contentCount.getAndIncrement();
});

// Add to supplier with subscriber
supplier.appendContent(new byte[] {1}, true);
future.get();
Assert.assertEquals(contentCount.get(), 2);
}
}

0 comments on commit 4863366

Please sign in to comment.