Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rules to correlations for correlation engine #423

Merged
merged 2 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ public void onResponse(MultiSearchResponse items) {
categoryToQueriesMap.put(query.getCategory(), correlationQueries);
}
}
searchFindingsByTimestamp(detectorType, categoryToQueriesMap);
searchFindingsByTimestamp(detectorType, categoryToQueriesMap,
filteredCorrelationRules.stream().map(CorrelationRule::getId).collect(Collectors.toList()));
}

@Override
Expand All @@ -194,15 +195,15 @@ public void onFailure(Exception e) {
}
});
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding());
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), List.of());
}
}

/**
* this method searches for parent findings given the log category & correlation time window & collects all related docs
* for them.
*/
private void searchFindingsByTimestamp(String detectorType, Map<String, List<CorrelationQuery>> categoryToQueriesMap) {
private void searchFindingsByTimestamp(String detectorType, Map<String, List<CorrelationQuery>> categoryToQueriesMap, List<String> correlationRules) {
long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli();
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<Pair<String, List<CorrelationQuery>>> categoryToQueriesPairs = new ArrayList<>();
Expand Down Expand Up @@ -255,7 +256,7 @@ public void onResponse(MultiSearchResponse items) {
relatedDocIds));
++idx;
}
searchDocsWithFilterKeys(detectorType, relatedDocsMap);
searchDocsWithFilterKeys(detectorType, relatedDocsMap, correlationRules);
}

@Override
Expand All @@ -264,14 +265,14 @@ public void onFailure(Exception e) {
}
});
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding());
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
}
}

/**
* Given the related docs from parent findings, this method filters only those related docs which match parent join criteria.
*/
private void searchDocsWithFilterKeys(String detectorType, Map<String, DocSearchCriteria> relatedDocsMap) {
private void searchDocsWithFilterKeys(String detectorType, Map<String, DocSearchCriteria> relatedDocsMap, List<String> correlationRules) {
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<String> categories = new ArrayList<>();

Expand Down Expand Up @@ -318,7 +319,7 @@ public void onResponse(MultiSearchResponse items) {
filteredRelatedDocIds.put(categories.get(idx), docIds);
++idx;
}
getCorrelatedFindings(detectorType, filteredRelatedDocIds);
getCorrelatedFindings(detectorType, filteredRelatedDocIds, correlationRules);
}

@Override
Expand All @@ -327,15 +328,15 @@ public void onFailure(Exception e) {
}
});
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding());
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
}
}

/**
* Given the filtered related docs of the parent findings, this method gets the actual filtered parent findings for
* the finding to be correlated.
*/
private void getCorrelatedFindings(String detectorType, Map<String, List<String>> filteredRelatedDocIds) {
private void getCorrelatedFindings(String detectorType, Map<String, List<String>> filteredRelatedDocIds, List<String> correlationRules) {
long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli();
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<String> categories = new ArrayList<>();
Expand Down Expand Up @@ -390,7 +391,7 @@ public void onResponse(MultiSearchResponse items) {
}
++idx;
}
correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings);
correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRules);
}

@Override
Expand All @@ -399,7 +400,7 @@ public void onFailure(Exception e) {
}
});
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding());
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public VectorEmbeddingsEngine(Client client, TimeValue indexTimeout, long corrTi
this.correlateFindingAction = correlateFindingAction;
}

public void insertCorrelatedFindings(String detectorType, Finding finding, String logType, List<String> correlatedFindings, float timestampFeature) {
public void insertCorrelatedFindings(String detectorType, Finding finding, String logType, List<String> correlatedFindings, float timestampFeature, List<String> correlationRules) {
long findingTimestamp = finding.getTimestamp().toEpochMilli();
MatchQueryBuilder queryBuilder = QueryBuilders.matchQuery(
"root", true
Expand Down Expand Up @@ -172,6 +172,7 @@ public void onResponse(MultiSearchResponse items) {
corrBuilder.field("corr_vector", corrVector);
corrBuilder.field("recordType", "finding-finding");
corrBuilder.field("scoreTimestamp", 0L);
corrBuilder.field("corrRules", correlationRules);
corrBuilder.endObject();

IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_INDEX)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class CorrelatedFinding implements Writeable, ToXContentObject {

Expand All @@ -24,24 +26,29 @@ public class CorrelatedFinding implements Writeable, ToXContentObject {

private String logType2;

private List<String> correlationRules;

protected static final String FINDING1_FIELD = "finding1";
protected static final String LOGTYPE1_FIELD = "logType1";
protected static final String FINDING2_FIELD = "finding2";
protected static final String LOGTYPE2_FIELD = "logType2";
protected static final String RULES_FIELD = "rules";

public CorrelatedFinding(String finding1, String logType1, String finding2, String logType2) {
public CorrelatedFinding(String finding1, String logType1, String finding2, String logType2, List<String> correlationRules) {
this.finding1 = finding1;
this.logType1 = logType1;
this.finding2 = finding2;
this.logType2 = logType2;
this.correlationRules = correlationRules;
}

public CorrelatedFinding(StreamInput sin) throws IOException {
this(
sin.readString(),
sin.readString(),
sin.readString(),
sin.readString()
sin.readString(),
sin.readStringList()
);
}

Expand All @@ -51,6 +58,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(logType1);
out.writeString(finding2);
out.writeString(logType2);
out.writeStringCollection(correlationRules);
}

@Override
Expand All @@ -59,7 +67,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
.field(FINDING1_FIELD, finding1)
.field(LOGTYPE1_FIELD, logType1)
.field(FINDING2_FIELD, finding2)
.field(LOGTYPE2_FIELD, logType2);
.field(LOGTYPE2_FIELD, logType2)
.field(RULES_FIELD, correlationRules);
return builder.endObject();
}

Expand All @@ -68,6 +77,7 @@ public static CorrelatedFinding parse(XContentParser xcp) throws IOException {
String logType1 = null;
String finding2 = null;
String logType2 = null;
List<String> correlationRules = new ArrayList<>();

XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -87,11 +97,17 @@ public static CorrelatedFinding parse(XContentParser xcp) throws IOException {
case LOGTYPE2_FIELD:
logType2 = xcp.text();
break;
case RULES_FIELD:
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_ARRAY) {
correlationRules.add(xcp.text());
}
break;
default:
xcp.skipChildren();
}
}
return new CorrelatedFinding(finding1, logType1, finding2, logType2);
return new CorrelatedFinding(finding1, logType1, finding2, logType2, correlationRules);
}

public static CorrelatedFinding readFrom(StreamInput sin) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,37 @@
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class FindingWithScore implements Writeable, ToXContentObject {

protected static final String FINDING = "finding";
protected static final String DETECTOR_TYPE = "detector_type";
protected static final String SCORE = "score";
protected static final String RULES = "rules";

private String finding;

private String detectorType;

private Double score;

public FindingWithScore(String finding, String detectorType, Double score) {
private List<String> rules;

public FindingWithScore(String finding, String detectorType, Double score, List<String> rules) {
this.finding = finding;
this.detectorType = detectorType;
this.score = score;
this.rules = rules;
}

public FindingWithScore(StreamInput sin) throws IOException {
this(
sin.readString(),
sin.readString(),
sin.readDouble()
sin.readDouble(),
sin.readStringList()
);
}

Expand All @@ -45,6 +52,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(finding);
out.writeString(detectorType);
out.writeDouble(score);
out.writeStringCollection(rules);
}

@Override
Expand All @@ -53,6 +61,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
.field(FINDING, finding)
.field(DETECTOR_TYPE, detectorType)
.field(SCORE, score)
.field(RULES, rules)
.endObject();
return builder;
}
Expand All @@ -61,6 +70,7 @@ public static FindingWithScore parse(XContentParser xcp) throws IOException {
String finding = null;
String detectorType = null;
Double score = null;
List<String> rules = new ArrayList<>();

XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -77,9 +87,17 @@ public static FindingWithScore parse(XContentParser xcp) throws IOException {
case SCORE:
score = xcp.doubleValue();
break;
case RULES:
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_ARRAY) {
rules.add(xcp.text());
}
break;
default:
xcp.skipChildren();
}
}
return new FindingWithScore(finding, detectorType, score);
return new FindingWithScore(finding, detectorType, score, rules);
}

public static FindingWithScore readFrom(StreamInput sin) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ public TransportCorrelateFindingAction(TransportService transportService,
@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<SubscribeFindingsResponse> actionListener) {
try {
log.info("hit here1");
PublishFindingsRequest transformedRequest = transformRequest(request);

if (!this.correlationIndices.correlationIndexExists()) {
Expand Down Expand Up @@ -157,6 +158,7 @@ public void onFailure(Exception e) {
log.error(ex);
}
} else {
log.info("hit here2");
AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, actionListener);
correlateFindingAction.start();
}
Expand Down Expand Up @@ -184,9 +186,11 @@ public class AsyncCorrelateFindingAction {

this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, this);
this.vectorEmbeddingsEngine = new VectorEmbeddingsEngine(client, indexTimeout, corrTimeWindow, this);
log.info("hit here5");
}

void start() {
log.info("hit here4");
TransportCorrelateFindingAction.this.threadPool.getThreadContext().stashContext();
String monitorId = request.getMonitorId();
Finding finding = request.getFinding();
Expand Down Expand Up @@ -246,8 +250,9 @@ public void onFailure(Exception e) {
}
}

public void initCorrelationIndex(String detectorType, Map<String, List<String>> correlatedFindings) {
public void initCorrelationIndex(String detectorType, Map<String, List<String>> correlatedFindings, List<String> correlationRules) {
try {
log.info("hit here6");
if (!IndexUtils.correlationIndexUpdated) {
IndexUtils.updateIndexMapping(
CorrelationIndices.CORRELATION_INDEX,
Expand All @@ -257,7 +262,7 @@ public void initCorrelationIndex(String detectorType, Map<String, List<String>>
public void onResponse(AcknowledgedResponse response) {
if (response.isAcknowledged()) {
IndexUtils.correlationIndexUpdated();
getTimestampFeature(detectorType, correlatedFindings, null);
getTimestampFeature(detectorType, correlatedFindings, null, correlationRules);
} else {
onFailures(new OpenSearchStatusException("Failed to create correlation Index", RestStatus.INTERNAL_SERVER_ERROR));
}
Expand All @@ -270,14 +275,15 @@ public void onFailure(Exception e) {
}
);
} else {
getTimestampFeature(detectorType, correlatedFindings, null);
getTimestampFeature(detectorType, correlatedFindings, null, correlationRules);
}
} catch (IOException ex) {
onFailures(ex);
}
}

public void getTimestampFeature(String detectorType, Map<String, List<String>> correlatedFindings, Finding orphanFinding) {
public void getTimestampFeature(String detectorType, Map<String, List<String>> correlatedFindings, Finding orphanFinding, List<String> correlationRules) {
log.info("hit here7");
long findingTimestamp = this.request.getFinding().getTimestamp().toEpochMilli();
BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery()
.mustNot(QueryBuilders.termQuery("scoreTimestamp", 0L));
Expand Down Expand Up @@ -318,7 +324,7 @@ public void onResponse(IndexResponse response) {
}
for (Map.Entry<String, List<String>> correlatedFinding : correlatedFindings.entrySet()) {
vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(),
Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue());
Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), correlationRules);
}
} else {
vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue());
Expand All @@ -341,7 +347,7 @@ public void onFailure(Exception e) {
}
for (Map.Entry<String, List<String>> correlatedFinding : correlatedFindings.entrySet()) {
vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(),
timestampFeature);
timestampFeature, correlationRules);
}
} else {
vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, timestampFeature);
Expand Down
Loading