setStatementTimeout(Duration timeout) {
+ requireNonNull(timeout, "timeout must not be null");
+
+ // TODO: implement me
+ return Mono.empty();
+ }
+
boolean isSessionAutoCommit() {
return (context.getServerStatuses() & ServerStatuses.AUTO_COMMIT) != 0;
}
diff --git a/src/main/java/dev/miku/r2dbc/mysql/MySqlResult.java b/src/main/java/dev/miku/r2dbc/mysql/MySqlResult.java
index 8ada542a..72bc4b1e 100644
--- a/src/main/java/dev/miku/r2dbc/mysql/MySqlResult.java
+++ b/src/main/java/dev/miku/r2dbc/mysql/MySqlResult.java
@@ -19,168 +19,316 @@
import dev.miku.r2dbc.mysql.codec.Codecs;
import dev.miku.r2dbc.mysql.message.FieldValue;
import dev.miku.r2dbc.mysql.message.server.DefinitionMetadataMessage;
-import dev.miku.r2dbc.mysql.message.server.EofMessage;
+import dev.miku.r2dbc.mysql.message.server.ErrorMessage;
import dev.miku.r2dbc.mysql.message.server.OkMessage;
import dev.miku.r2dbc.mysql.message.server.RowMessage;
import dev.miku.r2dbc.mysql.message.server.ServerMessage;
import dev.miku.r2dbc.mysql.message.server.SyntheticMetadataMessage;
import dev.miku.r2dbc.mysql.util.NettyBufferUtils;
import dev.miku.r2dbc.mysql.util.OperatorUtils;
+import io.netty.util.AbstractReferenceCounted;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
+import io.r2dbc.spi.R2dbcException;
+import io.r2dbc.spi.Readable;
import io.r2dbc.spi.Result;
import io.r2dbc.spi.Row;
import io.r2dbc.spi.RowMetadata;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
-import reactor.core.publisher.MonoProcessor;
import reactor.core.publisher.SynchronousSink;
import reactor.util.annotation.Nullable;
-import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
+import java.util.function.Predicate;
import static dev.miku.r2dbc.mysql.util.AssertUtils.requireNonNull;
/**
* An implementation of {@link Result} representing the results of a query against the MySQL database.
+ *
+ * A {@link Segment} provided by this implementation may be both {@link UpdateCount} and {@link RowSegment},
+ * see also {@link MySqlOkSegment}. It's based on a {@link OkMessage}, when the {@code generatedKeyName} is
+ * not {@code null}.
*/
public final class MySqlResult implements Result {
- private static final Function ROWS_UPDATED = message ->
- (int) message.getAffectedRows();
-
private static final Consumer RELEASE = ReferenceCounted::release;
- private final boolean isBinary;
-
- private final Codecs codecs;
-
- private final ConnectionContext context;
-
- @Nullable
- private final String generatedKeyName;
-
- private final AtomicReference> messages;
+ private static final BiConsumer> ROWS_UPDATED = (segment, sink) -> {
+ if (segment instanceof UpdateCount) {
+ sink.next((int) ((UpdateCount) segment).value());
+ } else if (segment instanceof Message) {
+ sink.error(((Message) segment).exception());
+ } else if (segment instanceof ReferenceCounted) {
+ ReferenceCountUtil.safeRelease(segment);
+ }
+ };
- private final MonoProcessor okProcessor = MonoProcessor.create();
+ private static final BiFunction SUM = Integer::sum;
- private MySqlRowMetadata rowMetadata;
+ private final Flux segments;
- MySqlResult(boolean isBinary, Codecs codecs, ConnectionContext context, @Nullable String generatedKeyName,
- Flux messages) {
- this.isBinary = isBinary;
- this.codecs = requireNonNull(codecs, "codecs must not be null");
- this.context = requireNonNull(context, "context must not be null");
- this.generatedKeyName = generatedKeyName;
- this.messages = new AtomicReference<>(requireNonNull(messages, "messages must not be null"));
+ private MySqlResult(Flux segments) {
+ this.segments = segments;
}
@Override
public Mono getRowsUpdated() {
- return affects().map(ROWS_UPDATED);
+ return segments.handle(ROWS_UPDATED).reduce(SUM);
}
@Override
- public Publisher map(BiFunction f) {
+ public Flux map(BiFunction f) {
requireNonNull(f, "mapping function must not be null");
- if (generatedKeyName == null) {
- return results().handle((message, sink) -> handleResult(message, sink, f));
- }
+ return segments.handle((segment, sink) -> {
+ if (segment instanceof RowSegment) {
+ Row row = ((RowSegment) segment).row();
- return affects().map(message -> {
- InsertSyntheticRow row = new InsertSyntheticRow(codecs, generatedKeyName,
- message.getLastInsertId());
- return f.apply(row, row);
+ try {
+ sink.next(f.apply(row, row.getMetadata()));
+ } finally {
+ ReferenceCountUtil.safeRelease(segment);
+ }
+ } else if (segment instanceof Message) {
+ sink.error(((Message) segment).exception());
+ } else if (segment instanceof ReferenceCounted) {
+ ReferenceCountUtil.safeRelease(segment);
+ }
});
}
- private Mono affects() {
- return this.okProcessor.doOnSubscribe(s -> {
- Flux messages = this.messages.getAndSet(null);
-
- if (messages == null) {
- // Has subscribed, `okProcessor` will be set or cancel.
- return;
- }
+ @Override
+ public Flux map(Function super Readable, ? extends T> f) {
+ requireNonNull(f, "mapping function must not be null");
- messages.subscribe(message -> {
- if (message instanceof OkMessage) {
- // No need check terminal because of OkMessage no need release.
- this.okProcessor.onNext(((OkMessage) message));
- } else if (message instanceof EofMessage) {
- // Metadata EOF message will be not receive in here.
- // EOF message, means it is SELECT statement.
- this.okProcessor.onComplete();
- } else {
- ReferenceCountUtil.safeRelease(message);
+ return segments.handle((segment, sink) -> {
+ if (segment instanceof RowSegment) {
+ try {
+ sink.next(f.apply(((RowSegment) segment).row()));
+ } finally {
+ ReferenceCountUtil.safeRelease(segment);
}
- }, this.okProcessor::onError, this.okProcessor::onComplete);
+ } else if (segment instanceof Message) {
+ sink.error(((Message) segment).exception());
+ } else if (segment instanceof ReferenceCounted) {
+ ReferenceCountUtil.safeRelease(segment);
+ }
});
}
- private Flux results() {
- return Flux.defer(() -> {
- Flux messages = this.messages.getAndSet(null);
+ @Override
+ public MySqlResult filter(Predicate filter) {
+ requireNonNull(filter, "filter must not be null");
- if (messages == null) {
- return Flux.error(new IllegalStateException("Source has been released"));
+ return new MySqlResult(segments.filter(segment -> {
+ if (filter.test(segment)) {
+ return true;
}
- // Result mode, no need ok message.
- this.okProcessor.onComplete();
+ if (segment instanceof ReferenceCounted) {
+ ReferenceCountUtil.safeRelease(segment);
+ }
- return OperatorUtils.discardOnCancel(messages).doOnDiscard(ReferenceCounted.class, RELEASE);
- });
+ return false;
+ }));
}
- private void handleResult(ServerMessage message, SynchronousSink sink,
- BiFunction f) {
- if (message instanceof SyntheticMetadataMessage) {
- DefinitionMetadataMessage[] metadataMessages = ((SyntheticMetadataMessage) message).unwrap();
- if (metadataMessages.length == 0) {
- return;
+ @Override
+ public Flux flatMap(Function> f) {
+ requireNonNull(f, "mapping function must not be null");
+
+ return segments.flatMap(segment -> {
+ Publisher extends T> ret = f.apply(segment);
+
+ if (ret == null) {
+ return Mono.error(new IllegalStateException("The mapper returned a null Publisher"));
}
- this.rowMetadata = MySqlRowMetadata.create(metadataMessages);
- } else if (message instanceof RowMessage) {
- processRow((RowMessage) message, sink, f);
- } else {
- ReferenceCountUtil.safeRelease(message);
+
+ // doAfterTerminate to not release resources before they had a chance to get emitted.
+ if (ret instanceof Mono) {
+ @SuppressWarnings("unchecked")
+ Mono mono = (Mono) ret;
+ return mono.doAfterTerminate(() -> ReferenceCountUtil.release(segment));
+ }
+
+ return Flux.from(ret).doAfterTerminate(() -> ReferenceCountUtil.release(segment));
+ });
+ }
+
+ static MySqlResult toResult(boolean binary, Codecs codecs, ConnectionContext context,
+ @Nullable String generatedKeyName, Flux messages) {
+ requireNonNull(codecs, "codecs must not be null");
+ requireNonNull(context, "context must not be null");
+ requireNonNull(messages, "messages must not be null");
+
+ return new MySqlResult(OperatorUtils.discardOnCancel(messages)
+ .doOnDiscard(ReferenceCounted.class, RELEASE)
+ .handle(new MySqlSegments(binary, codecs, context, generatedKeyName)));
+ }
+
+ private static final class MySqlMessage implements Message {
+
+ private final ErrorMessage message;
+
+ private MySqlMessage(ErrorMessage message) {
+ this.message = message;
+ }
+
+ @Override
+ public R2dbcException exception() {
+ return message.toException();
+ }
+
+ @Override
+ public int errorCode() {
+ return message.getCode();
+ }
+
+ @Override
+ public String sqlState() {
+ return message.getSqlState();
+ }
+
+ @Override
+ public String message() {
+ return message.getMessage();
}
}
- private void processRow(RowMessage message, SynchronousSink sink,
- BiFunction f) {
- MySqlRowMetadata rowMetadata = this.rowMetadata;
+ private static final class MySqlRowSegment extends AbstractReferenceCounted implements RowSegment {
+
+ private final MySqlRow row;
- if (rowMetadata == null) {
- ReferenceCountUtil.safeRelease(message);
- sink.error(new IllegalStateException("No MySqlRowMetadata available"));
- return;
+ private final FieldValue[] fields;
+
+ private MySqlRowSegment(FieldValue[] fields, MySqlRowMetadata metadata, Codecs codecs, boolean binary,
+ ConnectionContext context) {
+ this.row = new MySqlRow(fields, metadata, codecs, binary, context);
+ this.fields = fields;
}
- FieldValue[] fields;
- T t;
+ @Override
+ public Row row() {
+ return row;
+ }
- try {
- fields = message.decode(isBinary, rowMetadata.unwrap());
- } finally {
- // Release row messages' reader.
- ReferenceCountUtil.safeRelease(message);
+ @Override
+ public ReferenceCounted touch(Object hint) {
+ if (this.fields.length == 0) {
+ return this;
+ }
+
+ for (FieldValue field : this.fields) {
+ field.touch(hint);
+ }
+
+ return this;
}
- try {
- // Can NOT just sink.next(f.apply(...)) because of finally release
- t = f.apply(new MySqlRow(fields, rowMetadata, codecs, isBinary, context), rowMetadata);
- } finally {
- // Release decoded field values.
+ @Override
+ protected void deallocate() {
NettyBufferUtils.releaseAll(fields);
}
+ }
+
+ private static class MySqlUpdateCount implements UpdateCount {
+
+ protected final OkMessage message;
+
+ private MySqlUpdateCount(OkMessage message) {
+ this.message = message;
+ }
+
+ @Override
+ public long value() {
+ return message.getAffectedRows();
+ }
+ }
+
+ private static final class MySqlOkSegment extends MySqlUpdateCount implements RowSegment {
+
+ private final Codecs codecs;
+
+ private final String keyName;
+
+ private MySqlOkSegment(OkMessage message, Codecs codecs, String keyName) {
+ super(message);
+
+ this.codecs = codecs;
+ this.keyName = keyName;
+ }
+
+ @Override
+ public Row row() {
+ return new InsertSyntheticRow(codecs, keyName, message.getLastInsertId());
+ }
+ }
+
+ private static final class MySqlSegments implements BiConsumer> {
- sink.next(t);
+ private final boolean binary;
+
+ private final Codecs codecs;
+
+ private final ConnectionContext context;
+
+ @Nullable
+ private final String generatedKeyName;
+
+ private MySqlRowMetadata rowMetadata;
+
+ private MySqlSegments(boolean binary, Codecs codecs, ConnectionContext context,
+ @Nullable String generatedKeyName) {
+ this.binary = binary;
+ this.codecs = codecs;
+ this.context = context;
+ this.generatedKeyName = generatedKeyName;
+ }
+
+ @Override
+ public void accept(ServerMessage message, SynchronousSink sink) {
+ if (message instanceof RowMessage) {
+ MySqlRowMetadata metadata = this.rowMetadata;
+
+ if (metadata == null) {
+ ReferenceCountUtil.safeRelease(message);
+ sink.error(new IllegalStateException("No MySqlRowMetadata available"));
+ return;
+ }
+
+ FieldValue[] fields;
+
+ try {
+ fields = ((RowMessage) message).decode(binary, metadata.unwrap());
+ } finally {
+ ReferenceCountUtil.safeRelease(message);
+ }
+
+ sink.next(new MySqlRowSegment(fields, metadata, codecs, binary, context));
+ } else if (message instanceof SyntheticMetadataMessage) {
+ DefinitionMetadataMessage[] metadataMessages = ((SyntheticMetadataMessage) message).unwrap();
+
+ if (metadataMessages.length == 0) {
+ return;
+ }
+
+ this.rowMetadata = MySqlRowMetadata.create(metadataMessages);
+ } else if (message instanceof OkMessage) {
+ Segment segment = generatedKeyName == null ? new MySqlUpdateCount((OkMessage) message) :
+ new MySqlOkSegment((OkMessage) message, codecs, generatedKeyName);
+
+ sink.next(segment);
+ } else if (message instanceof ErrorMessage) {
+ sink.next(new MySqlMessage((ErrorMessage) message));
+ } else {
+ ReferenceCountUtil.safeRelease(message);
+ }
+ }
}
}
diff --git a/src/main/java/dev/miku/r2dbc/mysql/MySqlRow.java b/src/main/java/dev/miku/r2dbc/mysql/MySqlRow.java
index 4468b8ac..70e7f3a0 100644
--- a/src/main/java/dev/miku/r2dbc/mysql/MySqlRow.java
+++ b/src/main/java/dev/miku/r2dbc/mysql/MySqlRow.java
@@ -19,6 +19,7 @@
import dev.miku.r2dbc.mysql.codec.Codecs;
import dev.miku.r2dbc.mysql.message.FieldValue;
import io.r2dbc.spi.Row;
+import io.r2dbc.spi.RowMetadata;
import reactor.util.annotation.Nullable;
import java.lang.reflect.ParameterizedType;
@@ -101,4 +102,12 @@ public T get(String name, ParameterizedType type) {
MySqlColumnDescriptor info = rowMetadata.getColumnMetadata(name);
return codecs.decode(fields[info.getIndex()], info, type, binary, context);
}
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public RowMetadata getMetadata() {
+ return rowMetadata;
+ }
}
diff --git a/src/main/java/dev/miku/r2dbc/mysql/MySqlRowMetadata.java b/src/main/java/dev/miku/r2dbc/mysql/MySqlRowMetadata.java
index e445ce00..72d96f5f 100644
--- a/src/main/java/dev/miku/r2dbc/mysql/MySqlRowMetadata.java
+++ b/src/main/java/dev/miku/r2dbc/mysql/MySqlRowMetadata.java
@@ -21,7 +21,6 @@
import io.r2dbc.spi.RowMetadata;
import java.util.Arrays;
-import java.util.Comparator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Set;
@@ -35,18 +34,10 @@
*/
final class MySqlRowMetadata implements RowMetadata {
- private static final Comparator NAME_COMPARATOR = (left, right) ->
- MySqlNames.compare(left.getName(), right.getName());
-
private final MySqlColumnDescriptor[] originMetadata;
private final MySqlColumnDescriptor[] sortedMetadata;
- /**
- * Copied column names from {@link #sortedMetadata}.
- */
- private final String[] sortedNames;
-
private final ColumnNameSet nameSet;
private MySqlRowMetadata(MySqlColumnDescriptor[] metadata) {
@@ -60,21 +51,19 @@ private MySqlRowMetadata(MySqlColumnDescriptor[] metadata) {
this.originMetadata = metadata;
this.sortedMetadata = metadata;
- this.sortedNames = new String[] { name };
this.nameSet = ColumnNameSet.of(name);
break;
default:
MySqlColumnDescriptor[] sortedMetadata = new MySqlColumnDescriptor[size];
System.arraycopy(metadata, 0, sortedMetadata, 0, size);
- Arrays.sort(sortedMetadata, NAME_COMPARATOR);
+ Arrays.sort(sortedMetadata, ColumnNameSet.NAME_COMPARATOR);
String[] originNames = getNames(metadata);
String[] sortedNames = getNames(sortedMetadata);
this.originMetadata = metadata;
this.sortedMetadata = sortedMetadata;
- this.sortedNames = sortedNames;
this.nameSet = ColumnNameSet.of(originNames, sortedNames);
break;
@@ -94,7 +83,7 @@ public MySqlColumnDescriptor getColumnMetadata(int index) {
public MySqlColumnDescriptor getColumnMetadata(String name) {
requireNonNull(name, "name must not be null");
- int index = MySqlNames.nameSearch(this.sortedNames, name);
+ int index = nameSet.findIndex(name);
if (index < 0) {
throw new NoSuchElementException("Column name '" + name + "' does not exist");
@@ -103,11 +92,19 @@ public MySqlColumnDescriptor getColumnMetadata(String name) {
return sortedMetadata[index];
}
+ @Override
+ public boolean contains(String name) {
+ requireNonNull(name, "name must not be null");
+
+ return nameSet.contains(name);
+ }
+
@Override
public List getColumnMetadatas() {
return InternalArrays.asImmutableList(originMetadata);
}
+ @SuppressWarnings("deprecation")
@Override
public Set getColumnNames() {
return nameSet;
@@ -116,7 +113,7 @@ public Set getColumnNames() {
@Override
public String toString() {
return "MySqlRowMetadata{metadata=" + Arrays.toString(originMetadata) + ", sortedNames=" +
- Arrays.toString(sortedNames) + '}';
+ Arrays.toString(nameSet.getSortedNames()) + '}';
}
MySqlColumnDescriptor[] unwrap() {
diff --git a/src/main/java/dev/miku/r2dbc/mysql/MySqlSyntheticBatch.java b/src/main/java/dev/miku/r2dbc/mysql/MySqlSyntheticBatch.java
index b8e2a36a..36812a23 100644
--- a/src/main/java/dev/miku/r2dbc/mysql/MySqlSyntheticBatch.java
+++ b/src/main/java/dev/miku/r2dbc/mysql/MySqlSyntheticBatch.java
@@ -54,7 +54,7 @@ public MySqlBatch add(String sql) {
@Override
public Flux execute() {
return QueryFlow.execute(client, statements)
- .map(messages -> new MySqlResult(false, codecs, context, null, messages));
+ .map(messages -> MySqlResult.toResult(false, codecs, context, null, messages));
}
@Override
diff --git a/src/main/java/dev/miku/r2dbc/mysql/OptionMapper.java b/src/main/java/dev/miku/r2dbc/mysql/OptionMapper.java
index e28d557e..9d35f9d5 100644
--- a/src/main/java/dev/miku/r2dbc/mysql/OptionMapper.java
+++ b/src/main/java/dev/miku/r2dbc/mysql/OptionMapper.java
@@ -42,16 +42,18 @@ SourceSpec from(Option> option) {
return new SourceSpec(options, option);
}
+ @SuppressWarnings("unchecked")
void consume(Option option, Consumer consumer) {
- T t = options.getValue(option);
+ Object t = options.getValue(option);
if (t != null) {
- consumer.accept(t);
+ consumer.accept((T) t);
}
}
+ @SuppressWarnings("unchecked")
void requiredConsume(Option option, Consumer consumer) {
- consumer.accept(options.getRequiredValue(option));
+ consumer.accept((T) options.getRequiredValue(option));
}
}
diff --git a/src/main/java/dev/miku/r2dbc/mysql/ParametrizedStatementSupport.java b/src/main/java/dev/miku/r2dbc/mysql/ParametrizedStatementSupport.java
index 729c8e87..f8c3c0e1 100644
--- a/src/main/java/dev/miku/r2dbc/mysql/ParametrizedStatementSupport.java
+++ b/src/main/java/dev/miku/r2dbc/mysql/ParametrizedStatementSupport.java
@@ -23,6 +23,7 @@
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
+import java.util.NoSuchElementException;
import java.util.Spliterator;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
@@ -134,7 +135,7 @@ private ParameterIndex getIndexes(String name) {
ParameterIndex index = query.getNamedIndexes().get(name);
if (index == null) {
- throw new IllegalArgumentException(String.format("No such parameter with name '%s'", name));
+ throw new NoSuchElementException("No such parameter with name: " + name);
}
return index;
diff --git a/src/main/java/dev/miku/r2dbc/mysql/PrepareParametrizedStatement.java b/src/main/java/dev/miku/r2dbc/mysql/PrepareParametrizedStatement.java
index 08b73270..1e32d3df 100644
--- a/src/main/java/dev/miku/r2dbc/mysql/PrepareParametrizedStatement.java
+++ b/src/main/java/dev/miku/r2dbc/mysql/PrepareParametrizedStatement.java
@@ -43,7 +43,7 @@ final class PrepareParametrizedStatement extends ParametrizedStatementSupport {
@Override
public Flux execute(List bindings) {
return QueryFlow.execute(client, query.getFormattedSql(), bindings, fetchSize, prepareCache)
- .map(messages -> new MySqlResult(true, codecs, context, generatedKeyName, messages));
+ .map(messages -> MySqlResult.toResult(true, codecs, context, generatedKeyName, messages));
}
@Override
diff --git a/src/main/java/dev/miku/r2dbc/mysql/PrepareSimpleStatement.java b/src/main/java/dev/miku/r2dbc/mysql/PrepareSimpleStatement.java
index b91722c8..b5dfc6a0 100644
--- a/src/main/java/dev/miku/r2dbc/mysql/PrepareSimpleStatement.java
+++ b/src/main/java/dev/miku/r2dbc/mysql/PrepareSimpleStatement.java
@@ -46,7 +46,7 @@ final class PrepareSimpleStatement extends SimpleStatementSupport {
@Override
public Flux execute() {
return QueryFlow.execute(client, sql, BINDINGS, fetchSize, prepareCache)
- .map(messages -> new MySqlResult(true, codecs, context, generatedKeyName, messages));
+ .map(messages -> MySqlResult.toResult(true, codecs, context, generatedKeyName, messages));
}
@Override
diff --git a/src/main/java/dev/miku/r2dbc/mysql/QueryFlow.java b/src/main/java/dev/miku/r2dbc/mysql/QueryFlow.java
index 146303c1..c4f332d7 100644
--- a/src/main/java/dev/miku/r2dbc/mysql/QueryFlow.java
+++ b/src/main/java/dev/miku/r2dbc/mysql/QueryFlow.java
@@ -48,8 +48,8 @@
import dev.miku.r2dbc.mysql.message.server.SyntheticSslResponseMessage;
import dev.miku.r2dbc.mysql.util.InternalArrays;
import io.netty.util.ReferenceCountUtil;
+import io.netty.util.ReferenceCounted;
import io.r2dbc.spi.IsolationLevel;
-import io.r2dbc.spi.R2dbcException;
import io.r2dbc.spi.R2dbcPermissionDeniedException;
import io.r2dbc.spi.TransactionDefinition;
import org.slf4j.Logger;
@@ -82,32 +82,19 @@ final class QueryFlow {
// Metadata EOF message will be not receive in here.
private static final Predicate RESULT_DONE = message -> message instanceof CompleteMessage;
- private static final Consumer