Skip to content

Commit

Permalink
[ML] frequent items filter (#91137)
Browse files Browse the repository at this point in the history
add a filter to the frequent items agg that filters documents from the analysis while still calculating support on the full set

A filter is specified top-level in frequent_items:

"frequent_items": {
  "filter": {
    "term": {
      "host.name.keyword": "i-12345"
    }
   },
...

The above filters documents that don't match, however still counts the docs when calculating support. That's in contrast to
specifying a query at the top, in which case you find the same item sets, but don't know the importance given the full
document set.
  • Loading branch information
Hendrik Muhs committed Nov 3, 2022
1 parent 01f77da commit 14b2d2d
Show file tree
Hide file tree
Showing 13 changed files with 249 additions and 34 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/91137.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 91137
summary: Add a filter parameter to frequent items
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ A `frequent_items` aggregation looks like this in isolation:
|`minimum_set_size` | (integer) The <<frequent-items-minimum-set-size,minimum size>> of one item set. | Optional | `1`
|`minimum_support` | (integer) The <<frequent-items-minimum-support,minimum support>> of one item set. | Optional | `0.1`
|`size` | (integer) The number of top item sets to return. | Optional | `10`
|`filter` | (object) Query that filters documents from the analysis | Optional | `match_all`
|===


Expand Down Expand Up @@ -102,6 +103,18 @@ parameter has a significant effect on the required memory and the runtime of the
aggregation.


[discrete]
[[frequent-items-filter]]
==== Filter

A query to filter documents to use as part of the analysis. Documents that
don't match the filter are ignored when generating the item sets, however still
count when calculating the support of an item set.

Use the filter if you want to narrow the item set analysis to fields of interest.
Use a top-level query to filter the data set.


[discrete]
[[frequent-items-example]]
==== Examples
Expand All @@ -123,7 +136,7 @@ example.

[source,console]
-------------------------------------------------
POST /kibana_sample_data_ecommerce /_async_search
POST /kibana_sample_data_ecommerce/_async_search
{
"size": 0,
"aggs": {
Expand Down Expand Up @@ -224,7 +237,45 @@ from New York. Finally, the item set with the third highest support is


[discrete]
==== Analizing numeric values by using a runtime field
==== Aggregation with two analyzed fields and a filter

We take the first example, but want to narrow the item sets to places in Europe.
For that we add a filter:

[source,console]
-------------------------------------------------
POST /kibana_sample_data_ecommerce/_async_search
{
"size": 0,
"aggs": {
"my_agg": {
"frequent_items": {
"minimum_set_size": 3,
"fields": [
{ "field": "category.keyword" },
{ "field": "geoip.city_name" }
],
"size": 3,
"filter": {
"term": {
"geoip.continent_name": "Europe"
}
}
}
}
}
}
-------------------------------------------------
// TEST[skip:setup kibana sample data]

The result will only show item sets that created from documents matching the
filter, namely purchases in Europe. Using `filter` the calculated `support` still
takes all purchases into acount. That's different to specifying a query at the
top-level, in which case `support` gets calculated only from purchases in Europe.


[discrete]
==== Analyzing numeric values by using a runtime field

The frequent items aggregation enables you to bucket numeric values by using
<<runtime,runtime fields>>. The next example demonstrates how to use a script to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ public HashBasedTransactionStore map(Stream<Tuple<Field, List<Object>>> keyValue
return transactionStore;
}

@Override
public HashBasedTransactionStore mapFiltered(HashBasedTransactionStore transactionStore) {
transactionStore.addFilteredTransaction();
return transactionStore;
}

@Override
protected ImmutableTransactionStore mapFinalize(HashBasedTransactionStore transactionStore) {

Expand All @@ -197,6 +203,7 @@ protected ImmutableTransactionStore mapFinalize(HashBasedTransactionStore transa
profilingInfoMap.put("ram_bytes_transactionstore_after_map", transactionStore.ramBytesUsed());
profilingInfoMap.put("total_items_after_map", transactionStore.getTotalItemCount());
profilingInfoMap.put("total_transactions_after_map", transactionStore.getTotalTransactionCount());
profilingInfoMap.put("filtered_transactions_after_map", transactionStore.getFilteredTransactionCount());
profilingInfoMap.put("unique_items_after_map", transactionStore.getUniqueItemsCount());
profilingInfoMap.put("unique_transactions_after_map", transactionStore.getUniqueTransactionCount());
}
Expand Down Expand Up @@ -283,6 +290,7 @@ public EclatResult reduceFinalize(HashBasedTransactionStore transactionStore, Li
profilingInfoReduce.put("ram_bytes_transactionstore_after_reduce", transactionStore.ramBytesUsed());
profilingInfoReduce.put("total_items_after_reduce", transactionStore.getTotalItemCount());
profilingInfoReduce.put("total_transactions_after_reduce", transactionStore.getTotalTransactionCount());
profilingInfoReduce.put("filtered_transactions_after_reduce", transactionStore.getFilteredTransactionCount());
profilingInfoReduce.put("unique_items_after_reduce", transactionStore.getUniqueItemsCount());
profilingInfoReduce.put("unique_transactions_after_reduce", transactionStore.getUniqueTransactionCount());
}
Expand All @@ -293,6 +301,7 @@ public EclatResult reduceFinalize(HashBasedTransactionStore transactionStore, Li
profilingInfoReduce.put("ram_bytes_transactionstore_after_prune", transactionStore.ramBytesUsed());
profilingInfoReduce.put("total_items_after_prune", transactionStore.getTotalItemCount());
profilingInfoReduce.put("total_transactions_after_prune", transactionStore.getTotalTransactionCount());
profilingInfoReduce.put("filtered_transactions_after_prune", transactionStore.getFilteredTransactionCount());
profilingInfoReduce.put("unique_items_after_prune", transactionStore.getUniqueItemsCount());
profilingInfoReduce.put("unique_transactions_after_prune", transactionStore.getUniqueTransactionCount());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.AggregationBuilder;
Expand All @@ -21,6 +23,7 @@
import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ContextParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.ml.aggs.frequentitemsets.mr.ItemSetMapReduceValueSource;
Expand Down Expand Up @@ -50,22 +53,29 @@ public final class FrequentItemSetsAggregationBuilder extends AbstractAggregatio
double minimumSupport = args[1] == null ? DEFAULT_MINIMUM_SUPPORT : (double) args[1];
int minimumSetSize = args[2] == null ? DEFAULT_MINIMUM_SET_SIZE : (int) args[2];
int size = args[3] == null ? DEFAULT_SIZE : (int) args[3];
QueryBuilder filter = (QueryBuilder) args[4];

return new FrequentItemSetsAggregationBuilder(context, fields, minimumSupport, minimumSetSize, size);
return new FrequentItemSetsAggregationBuilder(context, fields, minimumSupport, minimumSetSize, size, filter);
}
);

static {
ContextParser<Void, MultiValuesSourceFieldConfig.Builder> metricParser = MultiValuesSourceFieldConfig.parserBuilder(
ContextParser<Void, MultiValuesSourceFieldConfig.Builder> fieldsParser = MultiValuesSourceFieldConfig.parserBuilder(
false, // scriptable
false, // timezone aware
false, // filtered
false, // filtered (not defined per field, but for all fields below)
false // format
);
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, n) -> metricParser.parse(p, null).build(), FIELDS);
PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, n) -> fieldsParser.parse(p, null).build(), FIELDS);
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), MINIMUM_SUPPORT);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MINIMUM_SET_SIZE);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), Aggregation.CommonFields.SIZE);
PARSER.declareField(
ConstructingObjectParser.optionalConstructorArg(),
(p, context) -> AbstractQueryBuilder.parseTopLevelQuery(p),
MultiValuesSourceFieldConfig.FILTER,
ObjectParser.ValueType.OBJECT
);
}

static final ValuesSourceRegistry.RegistryKey<ItemSetMapReduceValueSource.ValueSourceSupplier> REGISTRY_KEY =
Expand All @@ -92,13 +102,15 @@ public static void registerAggregators(ValuesSourceRegistry.Builder registry) {
private final double minimumSupport;
private final int minimumSetSize;
private final int size;
private final QueryBuilder filter;

public FrequentItemSetsAggregationBuilder(
String name,
List<MultiValuesSourceFieldConfig> fields,
double minimumSupport,
int minimumSetSize,
int size
int size,
QueryBuilder filter
) {
super(name);
this.fields = fields;
Expand All @@ -118,6 +130,7 @@ public FrequentItemSetsAggregationBuilder(
throw new IllegalArgumentException("[size] must be greater than 0. Found [" + size + "] in [" + name + "]");
}
this.size = size;
this.filter = filter;
}

public FrequentItemSetsAggregationBuilder(StreamInput in) throws IOException {
Expand All @@ -126,6 +139,11 @@ public FrequentItemSetsAggregationBuilder(StreamInput in) throws IOException {
this.minimumSupport = in.readDouble();
this.minimumSetSize = in.readVInt();
this.size = in.readVInt();
if (in.getVersion().onOrAfter(Version.V_8_6_0)) {
this.filter = in.readOptionalNamedWriteable(QueryBuilder.class);
} else {
this.filter = null;
}
}

@Override
Expand All @@ -135,7 +153,7 @@ public boolean supportsSampling() {

@Override
protected AggregationBuilder shallowCopy(Builder factoriesBuilder, Map<String, Object> metadata) {
return new FrequentItemSetsAggregationBuilder(name, fields, minimumSupport, minimumSetSize, size);
return new FrequentItemSetsAggregationBuilder(name, fields, minimumSupport, minimumSetSize, size, filter);
}

@Override
Expand All @@ -149,6 +167,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeDouble(minimumSupport);
out.writeVInt(minimumSetSize);
out.writeVInt(size);
if (out.getVersion().onOrAfter(Version.V_8_6_0)) {
out.writeOptionalNamedWriteable(filter);
}
}

@Override
Expand All @@ -164,7 +185,8 @@ protected AggregatorFactory doBuild(AggregationContext context, AggregatorFactor
fields,
minimumSupport,
minimumSetSize,
size
size,
filter
);
}

Expand All @@ -179,6 +201,9 @@ protected XContentBuilder internalXContent(XContentBuilder builder, Params param
builder.field(MINIMUM_SUPPORT.getPreferredName(), minimumSupport);
builder.field(MINIMUM_SET_SIZE.getPreferredName(), minimumSetSize);
builder.field(Aggregation.CommonFields.SIZE.getPreferredName(), size);
if (filter != null) {
builder.field(MultiValuesSourceFieldConfig.FILTER.getPreferredName(), filter);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.ml.aggs.frequentitemsets;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.aggregations.AggregationExecutionException;
import org.elasticsearch.search.aggregations.Aggregator;
Expand Down Expand Up @@ -59,6 +60,7 @@ public class FrequentItemSetsAggregatorFactory extends AggregatorFactory {
private final double minimumSupport;
private final int minimumSetSize;
private final int size;
private final QueryBuilder filter;

public FrequentItemSetsAggregatorFactory(
String name,
Expand All @@ -69,13 +71,15 @@ public FrequentItemSetsAggregatorFactory(
List<MultiValuesSourceFieldConfig> fields,
double minimumSupport,
int minimumSetSize,
int size
int size,
QueryBuilder filter
) throws IOException {
super(name, context, parent, subFactoriesBuilder, metadata);
this.fields = fields;
this.minimumSupport = minimumSupport;
this.minimumSetSize = minimumSetSize;
this.size = size;
this.filter = filter;
}

@Override
Expand Down Expand Up @@ -109,7 +113,8 @@ protected Aggregator createInternal(Aggregator parent, CardinalityUpperBound car
parent,
metadata,
new EclatMapReducer(FrequentItemSetsAggregationBuilder.NAME, minimumSupport, minimumSetSize, size, context.profiling()),
configs
configs,
filter
) {
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ public final class HashBasedTransactionStore extends TransactionStore {
private BytesRefHash transactions;
private LongArray transactionCounts;
private long totalTransactionCount;
private long filteredTransactionCount;

public HashBasedTransactionStore(BigArrays bigArrays) {
super(bigArrays);
Expand Down Expand Up @@ -209,6 +210,14 @@ public void add(Stream<Tuple<Field, List<Object>>> keyValues) {
transactionCounts.increment(id, 1);
}

/**
* Report a filtered transaction to the store.
*/
public void addFilteredTransaction() {
++filteredTransactionCount;
++totalTransactionCount;
}

@Override
public long getTotalItemCount() {
return totalItemCount;
Expand All @@ -219,6 +228,11 @@ public long getTotalTransactionCount() {
return totalTransactionCount;
}

@Override
public long getFilteredTransactionCount() {
return filteredTransactionCount;
}

@Override
public BytesRefArray getItems() {
return items.getBytesRefs();
Expand Down Expand Up @@ -292,6 +306,7 @@ public void merge(TransactionStore other) throws IOException {

totalItemCount += other.getTotalItemCount();
totalTransactionCount += other.getTotalTransactionCount();
filteredTransactionCount += other.getFilteredTransactionCount();
}

/**
Expand Down Expand Up @@ -445,7 +460,8 @@ public ImmutableTransactionStore createImmutableTransactionStore() {
totalItemCount,
transactions.takeBytesRefsOwnership(),
transactionCounts,
totalTransactionCount
totalTransactionCount,
filteredTransactionCount
);

items = null;
Expand Down

0 comments on commit 14b2d2d

Please sign in to comment.