Skip to content

Commit

Permalink
x-pack batch 2 instanceof pattern matching replacement (#81936) (#81973)
Browse files Browse the repository at this point in the history
  • Loading branch information
astefan committed Dec 21, 2021
1 parent 2bbb4c2 commit 5fe8fdd
Show file tree
Hide file tree
Showing 23 changed files with 47 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,8 @@ private static Attribute handleSpecialFields(UnresolvedAttribute u, Attribute na
FieldAttribute fa = (FieldAttribute) named;

// incompatible mappings
if (fa.field() instanceof InvalidMappedField) {
named = u.withUnresolvedMessage(
"Cannot use field [" + fa.name() + "] due to ambiguities being " + ((InvalidMappedField) fa.field()).errorMessage()
);
if (fa.field()instanceof InvalidMappedField field) {
named = u.withUnresolvedMessage("Cannot use field [" + fa.name() + "] due to ambiguities being " + field.errorMessage());
}
// unsupported types
else if (DataTypes.isUnsupported(fa.dataType())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ Collection<Failure> verify(LogicalPlan plan, Function<String, Collection<String>

Set<Failure> localFailures = new LinkedHashSet<>();

if (p instanceof Unresolvable) {
localFailures.add(fail(p, ((Unresolvable) p).unresolvedMessage()));
if (p instanceof Unresolvable unresolvable) {
localFailures.add(fail(p, unresolvable.unresolvedMessage()));
} else {
p.forEachExpression(e -> {
// everything is fine, skip expression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,15 @@ public Executable assemble(
break;
}
// optional field
} else if (extractor instanceof ComputingExtractor) {
keyFields.add(((ComputingExtractor) extractor).hitName());
} else if (extractor instanceof ComputingExtractor computingExtractor) {
keyFields.add(computingExtractor.hitName());
}
}

PhysicalPlan query = plans.get(i);
// search query
if (query instanceof EsQueryExec) {
SearchSourceBuilder source = ((EsQueryExec) query).source(session, false);
if (query instanceof EsQueryExec esQueryExec) {
SearchSourceBuilder source = esQueryExec.source(session, false);
QueryRequest original = () -> source;
BoxedQueryRequest boxedRequest = new BoxedQueryRequest(original, timestampName, keyFields, optionalKeys);
Criterion<BoxedQueryRequest> criterion = new Criterion<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ public static HitExtractor createExtractor(FieldExtraction ref, EqlConfiguration
return new FieldHitExtractor(f.name(), f.getDataType(), cfg.zoneId(), f.hitName(), false);
}

if (ref instanceof ComputedRef) {
Pipe proc = ((ComputedRef) ref).processor();
if (ref instanceof ComputedRef computedRef) {
Pipe proc = computedRef.processor();
// collect hitNames
Set<String> hitNames = new LinkedHashSet<>();
proc = proc.transformDown(ReferenceInput.class, l -> {
Expand Down Expand Up @@ -175,8 +175,8 @@ public static SearchSourceBuilder addFilter(QueryBuilder filter, SearchSourceBui
BoolQueryBuilder bool = null;
QueryBuilder query = source.query();

if (query instanceof BoolQueryBuilder) {
bool = (BoolQueryBuilder) query;
if (query instanceof BoolQueryBuilder boolQueryBuilder) {
bool = boolQueryBuilder;
if (filter != null && bool.filter().contains(filter) == false) {
bool.filter(filter);
}
Expand All @@ -202,8 +202,8 @@ public static SearchSourceBuilder replaceFilter(
BoolQueryBuilder bool = null;
QueryBuilder query = source.query();

if (query instanceof BoolQueryBuilder) {
bool = (BoolQueryBuilder) query;
if (query instanceof BoolQueryBuilder boolQueryBuilder) {
bool = boolQueryBuilder;
if (oldFilters != null) {
bool.filter().removeAll(oldFilters);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ private static void sorting(QueryContainer container, SearchSourceBuilder source
Attribute attr = as.attribute();

// sorting only works on not-analyzed fields - look for a multi-field replacement
if (attr instanceof FieldAttribute) {
FieldAttribute fa = ((FieldAttribute) attr).exactAttribute();
if (attr instanceof FieldAttribute fieldAttribute) {
FieldAttribute fa = fieldAttribute.exactAttribute();

sortBuilder = fieldSort(fa.name()).missing(as.missing().searchOrder(as.direction()))
.unmappedType(fa.dataType().esType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,8 @@ private SequenceKey key(Object[] keys) {
} else {
for (int i = 0; i < keys.length; i++) {
Object o = keys[i];
if (o instanceof String) {
keys[i] = cache((String) o);
if (o instanceof String s) {
keys[i] = cache(s);
}
}
key = new SequenceKey(keys);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ public FieldExtraction fieldExtraction(Expression expression) {
}

private FieldExtraction createFieldExtractionFor(Expression expression) {
if (expression instanceof FieldAttribute) {
FieldAttribute fa = ((FieldAttribute) expression).exactAttribute();
if (expression instanceof FieldAttribute fieldAttribute) {
FieldAttribute fa = fieldAttribute.exactAttribute();
if (fa.isNested()) {
throw new UnsupportedOperationException("Nested not yet supported");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,7 @@ public String describeCredentials(Collection<? extends Credential> credentials)
return "<null>";
}
byte[] encoded;
if (c instanceof X509Credential) {
X509Credential x = (X509Credential) c;
if (c instanceof X509Credential x) {
try {
encoded = x.getEntityCertificate().getEncoded();
} catch (CertificateEncodingException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,7 @@ protected static Long parseTerm(Object value) {
if (lv >= 0) {
return lv;
}
} else if (value instanceof BigInteger) {
BigInteger bigIntegerValue = (BigInteger) value;
} else if (value instanceof BigInteger bigIntegerValue) {
if (bigIntegerValue.compareTo(BigInteger.ZERO) >= 0 && bigIntegerValue.compareTo(BIGINTEGER_2_64_MINUS_ONE) <= 0) {
return bigIntegerValue.longValue();
}
Expand Down Expand Up @@ -596,8 +595,7 @@ private static long parseUnsignedLong(Object value) {
throw new IllegalArgumentException("Value \"" + value + "\" has a decimal part");
}
return parseUnsignedLong(v.longValue());
} else if (value instanceof BigInteger) {
BigInteger bigIntegerValue = (BigInteger) value;
} else if (value instanceof BigInteger bigIntegerValue) {
if (bigIntegerValue.compareTo(BIGINTEGER_2_64_MINUS_ONE) > 0 || bigIntegerValue.compareTo(BigInteger.ZERO) < 0) {
throw new IllegalArgumentException("Value [" + bigIntegerValue + "] is out of range for unsigned long");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,12 @@ public void test() throws IOException {

// We should have got here if and only if no ML endpoints were called
for (ExecutableSection section : testCandidate.getTestSection().getExecutableSections()) {
if (section instanceof DoSection) {
String apiName = ((DoSection) section).getApiCallSection().getApi();
if (section instanceof DoSection doSection) {
String apiName = doSection.getApiCallSection().getApi();

if (apiName.startsWith("ml.")) {
fail("call to ml endpoint [" + apiName + "] should have failed because of missing role");
} else if (apiName.startsWith("search")) {
DoSection doSection = (DoSection) section;
List<Map<String, Object>> bodies = doSection.getApiCallSection().getBodies();
boolean containsInferenceAgg = false;
for (Map<String, Object> body : bodies) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ public void test() throws IOException {

// We should have got here if and only if the only ML endpoints in the test were in the allowed list
for (ExecutableSection section : testCandidate.getTestSection().getExecutableSections()) {
if (section instanceof DoSection) {
String apiName = ((DoSection) section).getApiCallSection().getApi();
if (section instanceof DoSection doSection) {
String apiName = doSection.getApiCallSection().getApi();

if (apiName.startsWith("ml.") && isAllowed(apiName) == false) {
fail("call to ml endpoint [" + apiName + "] should have failed because of missing role");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,7 @@ protected void validate(ValidationContext context) {
);
}

if (inferenceConfig instanceof ClassificationConfigUpdate) {
ClassificationConfigUpdate classUpdate = (ClassificationConfigUpdate) inferenceConfig;
if (inferenceConfig instanceof ClassificationConfigUpdate classUpdate) {

// error if the top classes result field is set and not equal to the only acceptable value
String topClassesField = classUpdate.getTopClassesResultsField();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,9 @@ private void processAggs(long docCount, List<Aggregation> aggregations) throws I
for (Aggregation agg : aggregations) {
if (agg instanceof MultiBucketsAggregation) {
bucketAggregations.add((MultiBucketsAggregation) agg);
} else if (agg instanceof SingleBucketAggregation) {
} else if (agg instanceof SingleBucketAggregation singleBucketAggregation) {
// Skip a level down for single bucket aggs, if they have a sub-agg that is not
// a bucketed agg we should treat it like a leaf in this bucket
SingleBucketAggregation singleBucketAggregation = (SingleBucketAggregation) agg;
for (Aggregation subAgg : singleBucketAggregation.getAggregations()) {
if (subAgg instanceof MultiBucketsAggregation || subAgg instanceof SingleBucketAggregation) {
singleBucketAggregations.add(singleBucketAggregation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ private static Map<String, Object> getNextLevel(Map<String, Object> source, Stri
if (nextLevel instanceof Map<?, ?>) {
return (Map<String, Object>) source.get(key);
}
if (nextLevel instanceof List<?>) {
List<?> asList = (List<?>) nextLevel;
if (nextLevel instanceof List<?> asList) {
if (asList.isEmpty() == false) {
Object firstElement = asList.get(0);
if (firstElement instanceof Map<?, ?>) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {

@Override
public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
if (config instanceof NerConfig) {
NerConfig nerConfig = (NerConfig) config;
if (config instanceof NerConfig nerConfig) {
return new NerResultProcessor(iobMap, nerConfig.getResultsField(), ignoreCase);
}
return new NerResultProcessor(iobMap, resultsField, ignoreCase);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {

@Override
public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
if (config instanceof TextClassificationConfig) {
TextClassificationConfig textClassificationConfig = (TextClassificationConfig) config;
if (config instanceof TextClassificationConfig textClassificationConfig) {
return (tokenization, pytorchResult) -> processResult(
tokenization,
pytorchResult,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ public void validateInputs(List<String> inputs) {
@Override
public NlpTask.RequestBuilder getRequestBuilder(NlpConfig nlpConfig) {
final String[] labelsValue;
if (nlpConfig instanceof ZeroShotClassificationConfig) {
ZeroShotClassificationConfig zeroShotConfig = (ZeroShotClassificationConfig) nlpConfig;
if (nlpConfig instanceof ZeroShotClassificationConfig zeroShotConfig) {
labelsValue = zeroShotConfig.getLabels().toArray(new String[0]);
} else {
labelsValue = this.labels;
Expand All @@ -86,8 +85,7 @@ public NlpTask.ResultProcessor getResultProcessor(NlpConfig nlpConfig) {
final String[] labelsValue;
final boolean isMultiLabelValue;
final String resultsFieldValue;
if (nlpConfig instanceof ZeroShotClassificationConfig) {
ZeroShotClassificationConfig zeroShotConfig = (ZeroShotClassificationConfig) nlpConfig;
if (nlpConfig instanceof ZeroShotClassificationConfig zeroShotConfig) {
labelsValue = zeroShotConfig.getLabels().toArray(new String[0]);
isMultiLabelValue = zeroShotConfig.isMultiLabel();
resultsFieldValue = zeroShotConfig.getResultsField();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,16 @@ protected void assertFromXContent(InternalInferenceAggregation agg, ParsedAggreg
ParsedInference parsed = ((ParsedInference) parsedAggregation);

InferenceResults result = agg.getInferenceResult();
if (result instanceof WarningInferenceResults) {
WarningInferenceResults warning = (WarningInferenceResults) result;
if (result instanceof WarningInferenceResults warning) {
assertEquals(warning.getWarning(), parsed.getWarning());
} else if (result instanceof RegressionInferenceResults) {
RegressionInferenceResults regression = (RegressionInferenceResults) result;
} else if (result instanceof RegressionInferenceResults regression) {
assertEquals(regression.value(), parsed.getValue());
List<RegressionFeatureImportance> featureImportance = regression.getFeatureImportance();
if (featureImportance.isEmpty()) {
featureImportance = null;
}
assertEquals(featureImportance, parsed.getFeatureImportance());
} else if (result instanceof ClassificationInferenceResults) {
ClassificationInferenceResults classification = (ClassificationInferenceResults) result;
} else if (result instanceof ClassificationInferenceResults classification) {
assertEquals(classification.predictedValue(), parsed.getValue());

List<ClassificationFeatureImportance> featureImportance = classification.getFeatureImportance();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ public void setup() throws Exception {
when(threadPool.executor(anyString())).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
when(client.threadPool()).thenReturn(threadPool);
doAnswer(invocationOnMock -> {
if (invocationOnMock.getArguments()[0] instanceof ActionType<?>) {
ActionType<?> v = (ActionType<?>) invocationOnMock.getArguments()[0];
if (invocationOnMock.getArguments()[0]instanceof ActionType<?> v) {
ActionListener<?> l = (ActionListener<?>) invocationOnMock.getArguments()[2];
ParameterizedType parameterizedType = (ParameterizedType) v.getClass().getGenericSuperclass();
Type t = parameterizedType.getActualTypeArguments()[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ public static <T extends BaseNodeResponse> void ensureNoTimeouts(TimeValue colle
public static void ensureNoTimeouts(TimeValue collectionTimeout, BaseTasksResponse response) {
HashSet<String> timedOutNodeIds = null;
for (ElasticsearchException nodeFailure : response.getNodeFailures()) {
if (nodeFailure instanceof FailedNodeException) {
FailedNodeException failedNodeException = (FailedNodeException) nodeFailure;
if (nodeFailure instanceof FailedNodeException failedNodeException) {
if (isTimeoutFailure(failedNodeException)) {
if (timedOutNodeIds == null) {
timedOutNodeIds = new HashSet<>();
Expand All @@ -73,8 +72,7 @@ public static void ensureNoTimeouts(TimeValue collectionTimeout, BroadcastRespon
HashSet<String> timedOutNodeIds = null;
for (DefaultShardOperationFailedException shardFailure : response.getShardFailures()) {
final Throwable shardFailureCause = shardFailure.getCause();
if (shardFailureCause instanceof FailedNodeException) {
FailedNodeException failedNodeException = (FailedNodeException) shardFailureCause;
if (shardFailureCause instanceof FailedNodeException failedNodeException) {
if (isTimeoutFailure(failedNodeException)) {
if (timedOutNodeIds == null) {
timedOutNodeIds = new HashSet<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ public static TypeResolution isIP(Expression e, String operationName, ParamOrdin
}

public static TypeResolution isExact(Expression e, String message) {
if (e instanceof FieldAttribute) {
EsField.Exact exact = ((FieldAttribute) e).getExactInfo();
if (e instanceof FieldAttribute fa) {
EsField.Exact exact = fa.getExactInfo();
if (exact.hasExact() == false) {
return new TypeResolution(format(null, message, e.dataType().typeName(), exact.errorMsg()));
}
Expand All @@ -83,8 +83,8 @@ public static TypeResolution isExact(Expression e, String message) {
}

public static TypeResolution isExact(Expression e, String operationName, ParamOrdinal paramOrd) {
if (e instanceof FieldAttribute) {
EsField.Exact exact = ((FieldAttribute) e).getExactInfo();
if (e instanceof FieldAttribute fa) {
EsField.Exact exact = fa.getExactInfo();
if (exact.hasExact() == false) {
return new TypeResolution(
format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ public Object process(Object input) {
return null;
}

if (input instanceof Number) {
return operation.apply((Number) input);
if (input instanceof Number number) {
return operation.apply(number);
}
throw new QlIllegalArgumentException("A number is required; received {}", input);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class TransportTaskHelper {
static void doProcessTasks(String id, Consumer<RollupJobTask> operation, TaskManager taskManager) {
RollupJobTask matchingTask = null;
for (Task task : taskManager.getTasks().values()) {
if (task instanceof RollupJobTask && ((RollupJobTask) task).getConfig().getId().equals(id)) {
if (task instanceof RollupJobTask rollupJobTask && rollupJobTask.getConfig().getId().equals(id)) {
if (matchingTask != null) {
throw new IllegalArgumentException(
"Found more than one matching task for rollup job [" + id + "] when " + "there should only be one."
Expand Down

0 comments on commit 5fe8fdd

Please sign in to comment.