Skip to content

Commit

Permalink
EQL: Fix aggressive/incorrect until policy in sequences (#65156)
Browse files Browse the repository at this point in the history
The current until implementation in sequences is too optimistic, leading
to an aggressive match that discards correct data leading to invalid
results.
This commit addresses this issue and also unifies the until usage inside
TumblingWindow.
Further more it packs together the UntilGroup with SequenceGroup to
minimize memory usage and improve clean-up.
  • Loading branch information
costin committed Nov 18, 2020
1 parent 4a89bee commit de2724e
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 179 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.StringJoiner;

import static java.util.stream.Collectors.toList;

Expand Down Expand Up @@ -154,14 +155,29 @@ private RestHighLevelClient highLevelClient() {

protected void assertEvents(List<Event> events) {
assertNotNull(events);
logger.info("Events {}", events);
logger.debug("Events {}", new Object() {
public String toString() {
return eventsToString(events);
}
});

long[] expected = eventIds;
long[] actual = extractIds(events);
assertArrayEquals(LoggerMessageFormat.format(null, "unexpected result for spec[{}] [{}] -> {} vs {}", name, query, Arrays.toString(
expected), Arrays.toString(actual)),
expected, actual);
}

private String eventsToString(List<Event> events) {
StringJoiner sj = new StringJoiner(",", "[", "]");
for (Event event : events) {
sj.add(event.id() + "|" + event.index());
sj.add(event.sourceAsMap().toString());
sj.add("\n");
}
return sj.toString();
}

@SuppressWarnings("unchecked")
private long[] extractIds(List<Event> events) {
final int len = events.size();
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -9,116 +9,128 @@
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.xpack.eql.execution.search.Ordinal;

import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Map.Entry;

/**
* Dedicated collection for mapping a key to a list of sequences
* The list represents the sequence for each stage (based on its index) and is fixed in size
*/
class KeyToSequences {

/**
* Utility class holding the sequencegroup/until tuple that also handles
* lazy initialization.
*/
private class SequenceEntry {

private final SequenceGroup[] groups;
// created lazily
private UntilGroup until;

SequenceEntry(int stages) {
this.groups = new SequenceGroup[stages];
}

void add(int stage, Sequence sequence) {
// create the group on demand
if (groups[stage] == null) {
groups[stage] = new SequenceGroup();
}
groups[stage].add(sequence);
}

public void remove(int stage) {
groups[stage] = null;
}

void until(Ordinal ordinal) {
if (until == null) {
until = new UntilGroup();
}
until.add(ordinal);
}
}

private final int listSize;
/** for each key, associate the frame per state (determined by index) */
private final Map<SequenceKey, SequenceGroup[]> keyToSequences;
private final Map<SequenceKey, UntilGroup> keyToUntil;
private final Map<SequenceKey, SequenceEntry> keyToSequences;

KeyToSequences(int listSize) {
this.listSize = listSize;
this.keyToSequences = new LinkedHashMap<>();
this.keyToUntil = new LinkedHashMap<>();
}

private SequenceGroup[] groups(SequenceKey key) {
return keyToSequences.computeIfAbsent(key, k -> new SequenceGroup[listSize]);
}

SequenceGroup groupIfPresent(int stage, SequenceKey key) {
SequenceGroup[] groups = keyToSequences.get(key);
return groups == null ? null : groups[stage];
SequenceEntry sequenceEntry = keyToSequences.get(key);
return sequenceEntry == null ? null : sequenceEntry.groups[stage];
}

UntilGroup untilIfPresent(SequenceKey key) {
return keyToUntil.get(key);
SequenceEntry sequenceEntry = keyToSequences.get(key);
return sequenceEntry == null ? null : sequenceEntry.until;
}

void add(int stage, Sequence sequence) {
SequenceKey key = sequence.key();
SequenceGroup[] groups = groups(key);
// create the group on demand
if (groups[stage] == null) {
groups[stage] = new SequenceGroup(key);
}
groups[stage].add(sequence);
SequenceEntry info = keyToSequences.computeIfAbsent(key, k -> new SequenceEntry(listSize));
info.add(stage, sequence);
}

void until(Iterable<KeyAndOrdinal> until) {
for (KeyAndOrdinal keyAndOrdinal : until) {
// ignore unknown keys
SequenceKey key = keyAndOrdinal.key();
if (keyToSequences.containsKey(key)) {
UntilGroup group = keyToUntil.computeIfAbsent(key, UntilGroup::new);
group.add(keyAndOrdinal);
SequenceEntry sequenceEntry = keyToSequences.get(key);
if (sequenceEntry != null) {
sequenceEntry.until(keyAndOrdinal.ordinal);
}
}
}

void remove(int stage, SequenceGroup group) {
SequenceKey key = group.key();
SequenceGroup[] groups = keyToSequences.get(key);
groups[stage] = null;
// clean-up the key if all groups are empty
boolean shouldRemoveKey = true;
for (SequenceGroup gp : groups) {
if (gp != null && gp.isEmpty() == false) {
shouldRemoveKey = false;
break;
}
}
if (shouldRemoveKey) {
keyToSequences.remove(key);
}
}

void dropUntil() {
// clean-up all candidates that occur before until
for (Entry<SequenceKey, UntilGroup> entry : keyToUntil.entrySet()) {
SequenceGroup[] groups = keyToSequences.get(entry.getKey());
if (groups != null) {
for (Ordinal o : entry.getValue()) {
for (SequenceGroup group : groups) {
if (group != null) {
group.trimBefore(o);
}
}
}
}
}

keyToUntil.clear();
void remove(int stage, SequenceKey key) {
SequenceEntry info = keyToSequences.get(key);
info.remove(stage);
}

/**
* Remove all matches expect the latest.
*/
void trimToTail() {
for (SequenceGroup[] groups : keyToSequences.values()) {
for (SequenceGroup group : groups) {
for (Iterator<SequenceEntry> it = keyToSequences.values().iterator(); it.hasNext(); ) {
SequenceEntry seqs = it.next();
// first remove the sequences
// and remember the last item from the first
// initialized stage to be used with until
Sequence firstTail = null;
for (SequenceGroup group : seqs.groups) {
if (group != null) {
group.trimToLast();
Sequence sequence = group.trimToLast();
if (firstTail == null) {
firstTail = sequence;
}
}
}
// there are no sequences on any stage for this key, drop it
if (firstTail == null) {
it.remove();
} else {
// drop any possible UNTIL that occurs before the last tail
UntilGroup until = seqs.until;
if (until != null) {
until.trimBefore(firstTail.ordinal());
}
}
}
}

public void clear() {
keyToSequences.clear();
keyToUntil.clear();
}

@Override
public String toString() {
return LoggerMessageFormat.format(null, "Keys=[{}], Until=[{}]", keyToSequences.size(), keyToUntil.size());
return LoggerMessageFormat.format(null, "Keys=[{}]", keyToSequences.size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
*/
abstract class OrdinalGroup<E> implements Iterable<Ordinal> {

private final SequenceKey key;
private final Function<E, Ordinal> extractor;

// NB: since the size varies significantly, use a LinkedList
Expand All @@ -32,15 +31,10 @@ abstract class OrdinalGroup<E> implements Iterable<Ordinal> {

private Ordinal start, stop;

protected OrdinalGroup(SequenceKey key, Function<E, Ordinal> extractor) {
this.key = key;
protected OrdinalGroup(Function<E, Ordinal> extractor) {
this.extractor = extractor;
}

SequenceKey key() {
return key;
}

void add(E element) {
Ordinal ordinal = extractor.apply(element);
if (start == null || start.compareTo(ordinal) > 0) {
Expand Down Expand Up @@ -82,14 +76,15 @@ E before(Ordinal ordinal) {
return match != null ? match.v1() : null;
}

void trimToLast() {
E trimToLast() {
E last = elements.peekLast();
if (last != null) {
elements.clear();
start = null;
stop = null;
add(last);
}
return last;
}

private Tuple<E, Integer> findBefore(Ordinal ordinal) {
Expand Down Expand Up @@ -132,7 +127,7 @@ public Ordinal next() {

@Override
public int hashCode() {
return key.hashCode();
return elements.hashCode();
}

@Override
Expand All @@ -146,12 +141,11 @@ public boolean equals(Object obj) {
}

OrdinalGroup<?> other = (OrdinalGroup<?>) obj;
return Objects.equals(key, other.key)
&& Objects.equals(elements, other.elements);
return Objects.equals(elements, other.elements);
}

@Override
public String toString() {
return format(null, "[{}][{}-{}]({} seqs)", key, start, stop, elements.size());
return format(null, "[{}-{}]({} seqs)", start, stop, elements.size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

public class SequenceGroup extends OrdinalGroup<Sequence> {

SequenceGroup(SequenceKey key) {
super(key, Sequence::ordinal);
SequenceGroup() {
super(Sequence::ordinal);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,21 @@ public class SequenceMatcher {
static class Stats {
long seen = 0;
long ignored = 0;
long until = 0;
long rejectionMaxspan = 0;
long rejectionUntil = 0;

@Override
public String toString() {
return LoggerMessageFormat.format(null, "Stats: Seen [{}]/Ignored [{}]/Until [{}]/Rejected {Maxspan [{}]/Until [{}]}",
return LoggerMessageFormat.format(null, "Stats: Seen [{}]/Ignored [{}]/Rejected {Maxspan [{}]/Until [{}]}",
seen,
ignored,
until,
rejectionMaxspan,
rejectionUntil);
}

public void clear() {
seen = 0;
ignored = 0;
until = 0;
rejectionMaxspan = 0;
rejectionUntil = 0;
}
Expand Down Expand Up @@ -160,7 +157,7 @@ private void match(int stage, SequenceKey key, Ordinal ordinal, HitReference hit

// remove the group early (as the key space is large)
if (group.isEmpty()) {
keyToSequences.remove(previousStage, group);
keyToSequences.remove(previousStage, key);
stageToKeys.remove(previousStage, key);
}

Expand All @@ -177,10 +174,10 @@ private void match(int stage, SequenceKey key, Ordinal ordinal, HitReference hit
// until
UntilGroup until = keyToSequences.untilIfPresent(key);
if (until != null) {
KeyAndOrdinal nearestUntil = until.before(ordinal);
Ordinal nearestUntil = until.before(ordinal);
if (nearestUntil != null) {
// check if until matches
if (nearestUntil.ordinal().between(sequence.ordinal(), ordinal)) {
if (nearestUntil.between(sequence.ordinal(), ordinal)) {
stats.rejectionUntil++;
return;
}
Expand Down Expand Up @@ -247,10 +244,6 @@ List<Sequence> completed() {
return limit != null ? limit.view(asList) : asList;
}

void dropUntil() {
keyToSequences.dropUntil();
}

void until(Iterable<KeyAndOrdinal> markers) {
keyToSequences.until(markers);
}
Expand Down

0 comments on commit de2724e

Please sign in to comment.