Skip to content

Commit

Permalink
Refactor PathTrie to tidy it up (#107542)
Browse files Browse the repository at this point in the history
  • Loading branch information
thecoop committed Apr 17, 2024
1 parent 3df8afb commit eb6af0e
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 149 deletions.
138 changes: 56 additions & 82 deletions server/src/main/java/org/elasticsearch/common/path/PathTrie.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@

import org.elasticsearch.common.collect.Iterators;

import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;
import java.util.function.Supplier;
import java.util.function.UnaryOperator;
import java.util.stream.Stream;

import static java.util.Collections.emptyMap;
import static java.util.Collections.unmodifiableMap;
Expand Down Expand Up @@ -50,52 +52,43 @@ enum TrieMatchingMode {
TrieMatchingMode.WILDCARD_ROOT_NODES_ALLOWED
);

public interface Decoder {
String decode(String value);
}

private final Decoder decoder;
private final UnaryOperator<String> decoder;
private final TrieNode root;
private T rootValue;

private static final String SEPARATOR = "/";
private static final String WILDCARD = "*";

public PathTrie(Decoder decoder) {
public PathTrie(UnaryOperator<String> decoder) {
this.decoder = decoder;
root = new TrieNode(SEPARATOR, null, WILDCARD);
root = new TrieNode(SEPARATOR, null);
}

public class TrieNode {
private transient String key;
private transient T value;
private final String wildcard;

private transient String namedWildcard;

private class TrieNode {
private T value;
private String namedWildcard;
private Map<String, TrieNode> children;

private TrieNode(String key, T value, String wildcard) {
this.key = key;
this.wildcard = wildcard;
private TrieNode(String key, T value) {
this.value = value;
this.children = emptyMap();
if (isNamedWildcard(key)) {
namedWildcard = key.substring(key.indexOf('{') + 1, key.indexOf('}'));
updateNamedWildcard(key);
} else {
namedWildcard = null;
}
}

private void updateKeyWithNamedWildcard(String key) {
this.key = key;
String newNamedWildcard = key.substring(key.indexOf('{') + 1, key.indexOf('}'));
if (namedWildcard != null && newNamedWildcard.equals(namedWildcard) == false) {
throw new IllegalArgumentException(
"Trying to use conflicting wildcard names for same path: " + namedWildcard + " and " + newNamedWildcard
);
private void updateNamedWildcard(String key) {
String newNamedWildcard = key.substring(1, key.length() - 1);
if (newNamedWildcard.equals(namedWildcard) == false) {
if (namedWildcard != null) {
throw new IllegalArgumentException(
"Trying to use conflicting wildcard names for same path: " + namedWildcard + " and " + newNamedWildcard
);
}
namedWildcard = newNamedWildcard;
}
namedWildcard = newNamedWildcard;
}

private void addInnerChild(String key, TrieNode child) {
Expand All @@ -110,16 +103,17 @@ private synchronized void insert(String[] path, int index, T value) {
String token = path[index];
String key = token;
if (isNamedWildcard(token)) {
key = wildcard;
key = WILDCARD;
}

TrieNode node = children.get(key);
if (node == null) {
T nodeValue = index == path.length - 1 ? value : null;
node = new TrieNode(token, nodeValue, wildcard);
node = new TrieNode(token, nodeValue);
addInnerChild(key, node);
} else {
if (isNamedWildcard(token)) {
node.updateKeyWithNamedWildcard(token);
node.updateNamedWildcard(token);
}
/*
* If the target node already exists, but is without a value,
Expand All @@ -139,22 +133,23 @@ private synchronized void insert(String[] path, int index, T value) {
node.insert(path, index + 1, value);
}

private synchronized void insertOrUpdate(String[] path, int index, T value, BiFunction<T, T, T> updater) {
private synchronized void insertOrUpdate(String[] path, int index, T value, BinaryOperator<T> updater) {
if (index >= path.length) return;

String token = path[index];
String key = token;
if (isNamedWildcard(token)) {
key = wildcard;
key = WILDCARD;
}

TrieNode node = children.get(key);
if (node == null) {
T nodeValue = index == path.length - 1 ? value : null;
node = new TrieNode(token, nodeValue, wildcard);
node = new TrieNode(token, nodeValue);
addInnerChild(key, node);
} else {
if (isNamedWildcard(token)) {
node.updateKeyWithNamedWildcard(token);
node.updateNamedWildcard(token);
}
/*
* If the target node already exists, but is without a value,
Expand All @@ -173,7 +168,7 @@ private synchronized void insertOrUpdate(String[] path, int index, T value, BiFu
}

private static boolean isNamedWildcard(String key) {
return key.indexOf('{') != -1 && key.indexOf('}') != -1;
return key.charAt(0) == '{' && key.charAt(key.length() - 1) == '}';
}

private String namedWildcard() {
Expand All @@ -184,7 +179,7 @@ private boolean isNamedWildcard() {
return namedWildcard != null;
}

public T retrieve(String[] path, int index, Map<String, String> params, TrieMatchingMode trieMatchingMode) {
private T retrieve(String[] path, int index, Map<String, String> params, TrieMatchingMode trieMatchingMode) {
if (index >= path.length) return null;

String token = path[index];
Expand All @@ -193,7 +188,7 @@ public T retrieve(String[] path, int index, Map<String, String> params, TrieMatc

if (node == null) {
if (trieMatchingMode == TrieMatchingMode.WILDCARD_NODES_ALLOWED) {
node = children.get(wildcard);
node = children.get(WILDCARD);
if (node == null) {
return null;
}
Expand All @@ -202,7 +197,7 @@ public T retrieve(String[] path, int index, Map<String, String> params, TrieMatc
/*
* Allow root node wildcard matches.
*/
node = children.get(wildcard);
node = children.get(WILDCARD);
if (node == null) {
return null;
}
Expand All @@ -211,7 +206,7 @@ public T retrieve(String[] path, int index, Map<String, String> params, TrieMatc
/*
* Allow leaf node wildcard matches.
*/
node = children.get(wildcard);
node = children.get(WILDCARD);
if (node == null) {
return null;
}
Expand All @@ -220,68 +215,64 @@ public T retrieve(String[] path, int index, Map<String, String> params, TrieMatc
return null;
}
} else {
TrieNode wildcardNode;
if (index + 1 == path.length
&& node.value == null
&& children.get(wildcard) != null
&& EXPLICIT_OR_ROOT_WILDCARD.contains(trieMatchingMode) == false) {
&& EXPLICIT_OR_ROOT_WILDCARD.contains(trieMatchingMode) == false
&& (wildcardNode = children.get(WILDCARD)) != null) {
/*
* If we are at the end of the path, the current node does not have a value but
* there is a child wildcard node, use the child wildcard node.
*/
node = children.get(wildcard);
node = wildcardNode;
usedWildcard = true;
} else if (index == 1
&& node.value == null
&& children.get(wildcard) != null
&& trieMatchingMode == TrieMatchingMode.WILDCARD_ROOT_NODES_ALLOWED) {
&& trieMatchingMode == TrieMatchingMode.WILDCARD_ROOT_NODES_ALLOWED
&& (wildcardNode = children.get(WILDCARD)) != null) {
/*
* If we are at the root, and root wildcards are allowed, use the child wildcard
* node.
*/
node = children.get(wildcard);
node = wildcardNode;
usedWildcard = true;
} else {
usedWildcard = token.equals(wildcard);
usedWildcard = token.equals(WILDCARD);
}
}

put(params, node, token);
recordWildcardParam(params, node, token);

if (index == (path.length - 1)) {
return node.value;
}

T nodeValue = node.retrieve(path, index + 1, params, trieMatchingMode);
if (nodeValue == null && usedWildcard == false && trieMatchingMode != TrieMatchingMode.EXPLICIT_NODES_ONLY) {
node = children.get(wildcard);
node = children.get(WILDCARD);
if (node != null) {
put(params, node, token);
recordWildcardParam(params, node, token);
nodeValue = node.retrieve(path, index + 1, params, trieMatchingMode);
}
}

return nodeValue;
}

private void put(Map<String, String> params, TrieNode node, String value) {
private void recordWildcardParam(Map<String, String> params, TrieNode node, String value) {
if (params != null && node.isNamedWildcard()) {
params.put(node.namedWildcard(), decoder.decode(value));
params.put(node.namedWildcard(), decoder.apply(value));
}
}

Iterator<T> allNodeValues() {
private Iterator<T> allNodeValues() {
final Iterator<T> childrenIterator = Iterators.flatMap(children.values().iterator(), TrieNode::allNodeValues);
if (value == null) {
return childrenIterator;
} else {
return Iterators.concat(Iterators.single(value), childrenIterator);
}
}

@Override
public String toString() {
return key;
}
}

public void insert(String path, T value) {
Expand All @@ -308,7 +299,7 @@ public void insert(String path, T value) {
* </pre>
* allowing the value to be updated if desired.
*/
public void insertOrUpdate(String path, T value, BiFunction<T, T, T> updater) {
public void insertOrUpdate(String path, T value, BinaryOperator<T> updater) {
String[] strings = path.split(SEPARATOR);
if (strings.length == 0) {
if (rootValue != null) {
Expand All @@ -334,8 +325,8 @@ public T retrieve(String path, Map<String, String> params) {
return retrieve(path, params, TrieMatchingMode.WILDCARD_NODES_ALLOWED);
}

public T retrieve(String path, Map<String, String> params, TrieMatchingMode trieMatchingMode) {
if (path.length() == 0) {
T retrieve(String path, Map<String, String> params, TrieMatchingMode trieMatchingMode) {
if (path.isEmpty()) {
return rootValue;
}
String[] strings = path.split(SEPARATOR);
Expand All @@ -353,29 +344,12 @@ public T retrieve(String path, Map<String, String> params, TrieMatchingMode trie
}

/**
* Returns an iterator of the objects stored in the {@code PathTrie}, using
* Returns a stream of the objects stored in the {@code PathTrie}, using
* all possible {@code TrieMatchingMode} modes. The {@code paramSupplier}
* is called between each invocation of {@code next()} to supply a new map
* of parameters.
* is called for each mode to supply a new map of parameters.
*/
public Iterator<T> retrieveAll(String path, Supplier<Map<String, String>> paramSupplier) {
return new Iterator<>() {

private int mode;

@Override
public boolean hasNext() {
return mode < TrieMatchingMode.values().length;
}

@Override
public T next() {
if (hasNext() == false) {
throw new NoSuchElementException("called next() without validating hasNext()! no more modes available");
}
return retrieve(path, paramSupplier.get(), TrieMatchingMode.values()[mode++]);
}
};
public Stream<T> retrieveAll(String path, Supplier<Map<String, String>> paramSupplier) {
return Arrays.stream(TrieMatchingMode.values()).map(m -> retrieve(path, paramSupplier.get(), m));
}

public Iterator<T> allNodeValues() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ Iterator<MethodHandlers> getAllHandlers(@Nullable Map<String, String> requestPar
// we use rawPath since we don't want to decode it while processing the path resolution
// so we can handle things like:
// my_index/my_type/http%3A%2F%2Fwww.google.com
return handlers.retrieveAll(rawPath, paramsSupplier);
return handlers.retrieveAll(rawPath, paramsSupplier).iterator();
}

/**
Expand Down
4 changes: 2 additions & 2 deletions server/src/main/java/org/elasticsearch/rest/RestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
package org.elasticsearch.rest;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.path.PathTrie;
import org.elasticsearch.core.Booleans;

import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Map;
import java.util.Optional;
import java.util.function.UnaryOperator;
import java.util.regex.Pattern;

import static org.elasticsearch.rest.RestRequest.PATH_RESTRICTED;
Expand All @@ -28,7 +28,7 @@ public class RestUtils {
*/
private static final boolean DECODE_PLUS_AS_SPACE = Booleans.parseBoolean(System.getProperty("es.rest.url_plus_as_space", "false"));

public static final PathTrie.Decoder REST_DECODER = RestUtils::decodeComponent;
public static final UnaryOperator<String> REST_DECODER = RestUtils::decodeComponent;

public static void decodeQueryString(String s, int fromIndex, Map<String, String> params) {
if (fromIndex < 0) {
Expand Down

0 comments on commit eb6af0e

Please sign in to comment.