Skip to content

Commit

Permalink
[ML] Truncate categorization fields (#89827) (#89961)
Browse files Browse the repository at this point in the history
Truncate the raw categorization field passed to the backend at 1001 characters.
  • Loading branch information
edsavage committed Sep 9, 2022
1 parent 02e0c8f commit 635de5e
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ public class AnalysisConfig implements ToXContentObject, Writeable {
public static final String ML_CATEGORY_FIELD = "mlcategory";
public static final Set<String> AUTO_CREATED_FIELDS = new HashSet<>(Collections.singletonList(ML_CATEGORY_FIELD));

// Since the C++ backend truncates the categorization field at length 1000 (see model::CCategoryExamplesCollector::MAX_EXAMPLE_LENGTH),
// adding an ellipsis on truncation, it makes no sense to send potentially very long strings to it. For the backend logic still to work
// we need to send more than that, hence we truncate at length 1001.
//
// Also, because we do the tokenization on the Java side now the tokens will still be sent correctly (separately) to the C++ backend
// even if they extend beyond the length of a truncated example.
public static final int MAX_CATEGORIZATION_FIELD_LENGTH = 1001;

// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ConstructingObjectParser<AnalysisConfig.Builder, Void> LENIENT_PARSER = createParser(true);
public static final ConstructingObjectParser<AnalysisConfig.Builder, Void> STRICT_PARSER = createParser(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
*/
package org.elasticsearch.xpack.ml.integration;

import org.apache.logging.log4j.LogManager;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
Expand All @@ -32,7 +31,6 @@
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.CategorizerStats;
import org.elasticsearch.xpack.core.ml.job.results.CategoryDefinition;
import org.elasticsearch.xpack.core.ml.job.results.Result;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.junit.After;
import org.junit.Before;

Expand Down Expand Up @@ -342,58 +340,6 @@ public void testCategorizationStatePersistedOnSwitchToRealtime() throws Exceptio
);
}

public void testCategorizationPerformance() {
// To compare Java/C++ tokenization performance:
// 1. Change false to true in this assumption
// 2. Run the test several times
// 3. Change MachineLearning.CATEGORIZATION_TOKENIZATION_IN_JAVA to false
// 4. Run the test several more times
// 5. Check the timings that get logged
// 6. Revert the changes to this assumption and MachineLearning.CATEGORIZATION_TOKENIZATION_IN_JAVA
assumeTrue("This is time consuming to run on every build - it should be run manually when comparing Java/C++ tokenization", false);

int testBatchSize = 1000;
int testNumBatches = 1000;
String[] possibleMessages = new String[] {
"<sol13m-9402.1.p2ps: Info: Tue Apr 06 19:00:16 2010> Source LOTS on 33080:817 has shut down.<END>",
"<lnl00m-8601.1.p2ps: Alert: Tue Apr 06 18:57:24 2010> P2PS failed to connect to the hrm server. "
+ "Reason: Failed to connect to hrm server - No ACK from SIPC<END>",
"<sol00m-8607.1.p2ps: Debug: Tue Apr 06 18:56:43 2010> Did not receive an image data for IDN_SELECTFEED:7630.T on 493. "
+ "Recalling item. <END>",
"<lnl13m-8602.1.p2ps.rrcpTransport.0.sinkSide.rrcp.transmissionBus: Warning: Tue Apr 06 18:36:32 2010> "
+ "RRCP STATUS MSG: RRCP_REBOOT: node 33191 has rebooted<END>",
"<sol00m-8608.1.p2ps: Info: Tue Apr 06 18:30:02 2010> Source PRISM_VOBr on 33069:757 has shut down.<END>",
"<lnl06m-9402.1.p2ps: Info: Thu Mar 25 18:30:01 2010> Service PRISM_VOB has shut down.<END>" };

String jobId = "categorization-performance";
Job.Builder job = newJobBuilder(jobId, Collections.emptyList(), false);
putJob(job);
openJob(job.getId());

long startTime = System.currentTimeMillis();

for (int batchNum = 0; batchNum < testNumBatches; ++batchNum) {
StringBuilder json = new StringBuilder(testBatchSize * 100);
for (int docNum = 0; docNum < testBatchSize; ++docNum) {
json.append(
String.format(Locale.ROOT, "{\"time\":1000000,\"msg\":\"%s\"}\n", possibleMessages[docNum % possibleMessages.length])
);
}
postData(jobId, json.toString());
}
flushJob(jobId, false);

long duration = System.currentTimeMillis() - startTime;
LogManager.getLogger(CategorizationIT.class)
.info(
"Performance test with tokenization in "
+ (MachineLearning.CATEGORIZATION_TOKENIZATION_IN_JAVA ? "Java" : "C++")
+ " took "
+ duration
+ "ms"
);
}

public void testStopOnWarn() throws IOException {

long testTime = System.currentTimeMillis();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,9 +473,6 @@ public class MachineLearning extends Plugin

private static final long DEFAULT_MODEL_CIRCUIT_BREAKER_LIMIT = (long) ((0.50) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes());
private static final double DEFAULT_MODEL_CIRCUIT_BREAKER_OVERHEAD = 1.0D;
// This is for performance testing. It's not exposed to the end user.
// Recompile if you want to compare performance with C++ tokenization.
public static final boolean CATEGORIZATION_TOKENIZATION_IN_JAVA = true;

public static final LicensedFeature.Persistent ML_ANOMALY_JOBS_FEATURE = LicensedFeature.persistent(
MachineLearningField.ML_FEATURE_FAMILY,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.TimingStats;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
import org.elasticsearch.xpack.ml.job.persistence.StateStreamer;
import org.elasticsearch.xpack.ml.job.process.CountingInputStream;
Expand Down Expand Up @@ -88,8 +87,7 @@ public class AutodetectCommunicator implements Closeable {
this.onFinishHandler = onFinishHandler;
this.xContentRegistry = xContentRegistry;
this.autodetectWorkerExecutor = autodetectWorkerExecutor;
this.includeTokensField = MachineLearning.CATEGORIZATION_TOKENIZATION_IN_JAVA
&& job.getAnalysisConfig().getCategorizationFieldName() != null;
this.includeTokensField = job.getAnalysisConfig().getCategorizationFieldName() != null;
}

public void restoreState(ModelSnapshot modelSnapshot) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ protected final Map<String, Integer> outputFieldIndexes() {
}
}
// field for categorization tokens
if (MachineLearning.CATEGORIZATION_TOKENIZATION_IN_JAVA && job.getAnalysisConfig().getCategorizationFieldName() != null) {
if (job.getAnalysisConfig().getCategorizationFieldName() != null) {
fieldIndexes.put(LengthEncodedWriter.PRETOKENISED_TOKEN_FIELD, index++);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ public AutodetectProcess createAutodetectProcess(
true
);
createNativeProcess(job, params, processPipes, filesToDelete);
boolean includeTokensField = MachineLearning.CATEGORIZATION_TOKENIZATION_IN_JAVA
&& job.getAnalysisConfig().getCategorizationFieldName() != null;
boolean includeTokensField = job.getAnalysisConfig().getCategorizationFieldName() != null;
// The extra 1 is the control field
int numberOfFields = job.allInputFields().size() + (includeTokensField ? 1 : 0) + 1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ public abstract class AbstractDataToProcessWriter implements DataToProcessWriter
private long latestEpochMs;
private long latestEpochMsThisUpload;

private Set<String> termFields;

protected AbstractDataToProcessWriter(
boolean includeControlField,
boolean includeTokensField,
Expand All @@ -74,6 +76,7 @@ protected AbstractDataToProcessWriter(
this.logger = Objects.requireNonNull(logger);
this.latencySeconds = analysisConfig.getLatency() == null ? 0 : analysisConfig.getLatency().seconds();
this.bucketSpanMs = analysisConfig.getBucketSpan().getMillis();
this.termFields = analysisConfig.termFields();

Date date = dataCountsReporter.getLatestRecordTime();
latestEpochMsThisUpload = 0;
Expand All @@ -90,6 +93,13 @@ protected AbstractDataToProcessWriter(
}
}

public String maybeTruncateCatgeorizationField(String categorizationField) {
if (termFields.contains(analysisConfig.getCategorizationFieldName()) == false) {
return categorizationField.substring(0, Math.min(categorizationField.length(), AnalysisConfig.MAX_CATEGORIZATION_FIELD_LENGTH));
}
return categorizationField;
}

/**
* Set up the field index mappings. This must be called before
* {@linkplain DataToProcessWriter#write(InputStream, CategorizationAnalyzer, XContentType, BiConsumer)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,17 @@ private void writeJson(CategorizationAnalyzer categorizationAnalyzer, XContentPa

for (InputOutputMap inOut : inputOutputMap) {
String field = input[inOut.inputIndex];
record[inOut.outputIndex] = (field == null) ? "" : field;
field = (field == null) ? "" : field;
if (categorizationFieldIndex != null && inOut.inputIndex == categorizationFieldIndex) {
field = maybeTruncateCatgeorizationField(field);
}
record[inOut.outputIndex] = field;
}

if (categorizationAnalyzer != null && categorizationFieldIndex != null) {
tokenizeForCategorization(categorizationAnalyzer, input[categorizationFieldIndex], record);
}

transformTimeAndWrite(record, inputFieldCount);

inputFieldCount = recordReader.read(input, gotFields);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,81 @@ public void testTokenizeForCategorization() throws IOException {
);
}
}

public void testMaybeTruncateCategorizationField() {
{
DataDescription.Builder dd = new DataDescription.Builder();
dd.setTimeField("time_field");

Detector.Builder detector = new Detector.Builder("count", "");
detector.setByFieldName("mlcategory");
AnalysisConfig.Builder builder = new AnalysisConfig.Builder(Collections.singletonList(detector.build()));
builder.setCategorizationFieldName("message");
AnalysisConfig ac = builder.build();

boolean includeTokensFields = randomBoolean();
AbstractDataToProcessWriter writer = new JsonDataToProcessWriter(
true,
includeTokensFields,
autodetectProcess,
dd.build(),
ac,
dataCountsReporter,
NamedXContentRegistry.EMPTY
);

String truncatedField = writer.maybeTruncateCatgeorizationField(randomAlphaOfLengthBetween(1002, 2000));
assertEquals(AnalysisConfig.MAX_CATEGORIZATION_FIELD_LENGTH, truncatedField.length());
}
{
DataDescription.Builder dd = new DataDescription.Builder();
dd.setTimeField("time_field");

Detector.Builder detector = new Detector.Builder("count", "");
detector.setByFieldName("mlcategory");
AnalysisConfig.Builder builder = new AnalysisConfig.Builder(Collections.singletonList(detector.build()));
builder.setCategorizationFieldName("message");
AnalysisConfig ac = builder.build();

boolean includeTokensFields = randomBoolean();
AbstractDataToProcessWriter writer = new JsonDataToProcessWriter(
true,
includeTokensFields,
autodetectProcess,
dd.build(),
ac,
dataCountsReporter,
NamedXContentRegistry.EMPTY
);

String categorizationField = randomAlphaOfLengthBetween(1, 1000);
String truncatedField = writer.maybeTruncateCatgeorizationField(categorizationField);
assertEquals(categorizationField.length(), truncatedField.length());
}
{
DataDescription.Builder dd = new DataDescription.Builder();
dd.setTimeField("time_field");

Detector.Builder detector = new Detector.Builder("count", "");
detector.setByFieldName("mlcategory");
detector.setPartitionFieldName("message");
AnalysisConfig.Builder builder = new AnalysisConfig.Builder(Collections.singletonList(detector.build()));
builder.setCategorizationFieldName("message");
AnalysisConfig ac = builder.build();

boolean includeTokensFields = randomBoolean();
AbstractDataToProcessWriter writer = new JsonDataToProcessWriter(
true,
includeTokensFields,
autodetectProcess,
dd.build(),
ac,
dataCountsReporter,
NamedXContentRegistry.EMPTY
);

String truncatedField = writer.maybeTruncateCatgeorizationField(randomAlphaOfLengthBetween(1002, 2000));
assertFalse(AnalysisConfig.MAX_CATEGORIZATION_FIELD_LENGTH == truncatedField.length());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.elasticsearch.xpack.core.ml.job.config.CategorizationAnalyzerConfig;
import org.elasticsearch.xpack.core.ml.job.config.DataDescription;
import org.elasticsearch.xpack.core.ml.job.config.Detector;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzerTests;
import org.elasticsearch.xpack.ml.job.process.DataCountsReporter;
Expand Down Expand Up @@ -135,15 +134,10 @@ public void testWrite_GivenTimeFormatIsEpochAndCategorization() throws Exception

List<String[]> expectedRecords = new ArrayList<>();
// The "." field is the control field; "..." is the pre-tokenized tokens field
if (MachineLearning.CATEGORIZATION_TOKENIZATION_IN_JAVA) {
expectedRecords.add(new String[] { "time", "message", "...", "." });
expectedRecords.add(new String[] { "1", "Node 1 started", "Node,started", "" });
expectedRecords.add(new String[] { "2", "Node 2 started", "Node,started", "" });
} else {
expectedRecords.add(new String[] { "time", "message", "." });
expectedRecords.add(new String[] { "1", "Node 1 started", "" });
expectedRecords.add(new String[] { "2", "Node 2 started", "" });
}
expectedRecords.add(new String[] { "time", "message", "...", "." });
expectedRecords.add(new String[] { "1", "Node 1 started", "Node,started", "" });
expectedRecords.add(new String[] { "2", "Node 2 started", "Node,started", "" });

assertWrittenRecordsEqualTo(expectedRecords);

verify(dataCountsReporter).finishReporting();
Expand Down Expand Up @@ -411,8 +405,7 @@ private static InputStream createInputStream(String input) {
}

private JsonDataToProcessWriter createWriter() {
boolean includeTokensField = MachineLearning.CATEGORIZATION_TOKENIZATION_IN_JAVA
&& analysisConfig.getCategorizationFieldName() != null;
boolean includeTokensField = analysisConfig.getCategorizationFieldName() != null;
return new JsonDataToProcessWriter(
true,
includeTokensField,
Expand Down

0 comments on commit 635de5e

Please sign in to comment.