Skip to content

Commit

Permalink
add rules to correlations for correlation engine (#423)
Browse files Browse the repository at this point in the history
Signed-off-by: Subhobrata Dey <sbcd90@gmail.com>
(cherry picked from commit 98663af)
  • Loading branch information
sbcd90 authored and github-actions[bot] committed May 3, 2023
1 parent 6c273d6 commit 940a12a
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

public class CorrelatedFindingAction extends ActionType<CorrelatedFindingResponse> {
public static final CorrelatedFindingAction INSTANCE = new CorrelatedFindingAction();
public static final String NAME = "cluster:admin/opensearch/securityanalytics/findings/correlated";
public static final String NAME = "cluster:admin/opensearch/securityanalytics/correlations/findings";

public CorrelatedFindingAction() {
super(NAME, CorrelatedFindingResponse::new);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ public void onResponse(MultiSearchResponse items) {
categoryToQueriesMap.put(query.getCategory(), correlationQueries);
}
}
searchFindingsByTimestamp(detectorType, categoryToQueriesMap);
searchFindingsByTimestamp(detectorType, categoryToQueriesMap,
filteredCorrelationRules.stream().map(CorrelationRule::getId).collect(Collectors.toList()));
}

@Override
Expand All @@ -194,15 +195,15 @@ public void onFailure(Exception e) {
}
});
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding());
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), List.of());
}
}

/**
* this method searches for parent findings given the log category & correlation time window & collects all related docs
* for them.
*/
private void searchFindingsByTimestamp(String detectorType, Map<String, List<CorrelationQuery>> categoryToQueriesMap) {
private void searchFindingsByTimestamp(String detectorType, Map<String, List<CorrelationQuery>> categoryToQueriesMap, List<String> correlationRules) {
long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli();
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<Pair<String, List<CorrelationQuery>>> categoryToQueriesPairs = new ArrayList<>();
Expand Down Expand Up @@ -255,7 +256,7 @@ public void onResponse(MultiSearchResponse items) {
relatedDocIds));
++idx;
}
searchDocsWithFilterKeys(detectorType, relatedDocsMap);
searchDocsWithFilterKeys(detectorType, relatedDocsMap, correlationRules);
}

@Override
Expand All @@ -264,14 +265,14 @@ public void onFailure(Exception e) {
}
});
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding());
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
}
}

/**
* Given the related docs from parent findings, this method filters only those related docs which match parent join criteria.
*/
private void searchDocsWithFilterKeys(String detectorType, Map<String, DocSearchCriteria> relatedDocsMap) {
private void searchDocsWithFilterKeys(String detectorType, Map<String, DocSearchCriteria> relatedDocsMap, List<String> correlationRules) {
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<String> categories = new ArrayList<>();

Expand Down Expand Up @@ -318,7 +319,7 @@ public void onResponse(MultiSearchResponse items) {
filteredRelatedDocIds.put(categories.get(idx), docIds);
++idx;
}
getCorrelatedFindings(detectorType, filteredRelatedDocIds);
getCorrelatedFindings(detectorType, filteredRelatedDocIds, correlationRules);
}

@Override
Expand All @@ -327,15 +328,15 @@ public void onFailure(Exception e) {
}
});
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding());
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
}
}

/**
* Given the filtered related docs of the parent findings, this method gets the actual filtered parent findings for
* the finding to be correlated.
*/
private void getCorrelatedFindings(String detectorType, Map<String, List<String>> filteredRelatedDocIds) {
private void getCorrelatedFindings(String detectorType, Map<String, List<String>> filteredRelatedDocIds, List<String> correlationRules) {
long findingTimestamp = request.getFinding().getTimestamp().toEpochMilli();
MultiSearchRequest mSearchRequest = new MultiSearchRequest();
List<String> categories = new ArrayList<>();
Expand Down Expand Up @@ -390,7 +391,7 @@ public void onResponse(MultiSearchResponse items) {
}
++idx;
}
correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings);
correlateFindingAction.initCorrelationIndex(detectorType, correlatedFindings, correlationRules);
}

@Override
Expand All @@ -399,7 +400,7 @@ public void onFailure(Exception e) {
}
});
} else {
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding());
correlateFindingAction.getTimestampFeature(detectorType, null, request.getFinding(), correlationRules);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public VectorEmbeddingsEngine(Client client, TimeValue indexTimeout, long corrTi
this.correlateFindingAction = correlateFindingAction;
}

public void insertCorrelatedFindings(String detectorType, Finding finding, String logType, List<String> correlatedFindings, float timestampFeature) {
public void insertCorrelatedFindings(String detectorType, Finding finding, String logType, List<String> correlatedFindings, float timestampFeature, List<String> correlationRules) {
long findingTimestamp = finding.getTimestamp().toEpochMilli();
MatchQueryBuilder queryBuilder = QueryBuilders.matchQuery(
"root", true
Expand Down Expand Up @@ -172,6 +172,7 @@ public void onResponse(MultiSearchResponse items) {
corrBuilder.field("corr_vector", corrVector);
corrBuilder.field("recordType", "finding-finding");
corrBuilder.field("scoreTimestamp", 0L);
corrBuilder.field("corrRules", correlationRules);
corrBuilder.endObject();

IndexRequest indexRequest = new IndexRequest(CorrelationIndices.CORRELATION_INDEX)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class CorrelatedFinding implements Writeable, ToXContentObject {

Expand All @@ -24,24 +26,29 @@ public class CorrelatedFinding implements Writeable, ToXContentObject {

private String logType2;

private List<String> correlationRules;

protected static final String FINDING1_FIELD = "finding1";
protected static final String LOGTYPE1_FIELD = "logType1";
protected static final String FINDING2_FIELD = "finding2";
protected static final String LOGTYPE2_FIELD = "logType2";
protected static final String RULES_FIELD = "rules";

public CorrelatedFinding(String finding1, String logType1, String finding2, String logType2) {
public CorrelatedFinding(String finding1, String logType1, String finding2, String logType2, List<String> correlationRules) {
this.finding1 = finding1;
this.logType1 = logType1;
this.finding2 = finding2;
this.logType2 = logType2;
this.correlationRules = correlationRules;
}

public CorrelatedFinding(StreamInput sin) throws IOException {
this(
sin.readString(),
sin.readString(),
sin.readString(),
sin.readString()
sin.readString(),
sin.readStringList()
);
}

Expand All @@ -51,6 +58,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(logType1);
out.writeString(finding2);
out.writeString(logType2);
out.writeStringCollection(correlationRules);
}

@Override
Expand All @@ -59,7 +67,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
.field(FINDING1_FIELD, finding1)
.field(LOGTYPE1_FIELD, logType1)
.field(FINDING2_FIELD, finding2)
.field(LOGTYPE2_FIELD, logType2);
.field(LOGTYPE2_FIELD, logType2)
.field(RULES_FIELD, correlationRules);
return builder.endObject();
}

Expand All @@ -68,6 +77,7 @@ public static CorrelatedFinding parse(XContentParser xcp) throws IOException {
String logType1 = null;
String finding2 = null;
String logType2 = null;
List<String> correlationRules = new ArrayList<>();

XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -87,11 +97,17 @@ public static CorrelatedFinding parse(XContentParser xcp) throws IOException {
case LOGTYPE2_FIELD:
logType2 = xcp.text();
break;
case RULES_FIELD:
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_ARRAY) {
correlationRules.add(xcp.text());
}
break;
default:
xcp.skipChildren();
}
}
return new CorrelatedFinding(finding1, logType1, finding2, logType2);
return new CorrelatedFinding(finding1, logType1, finding2, logType2, correlationRules);
}

public static CorrelatedFinding readFrom(StreamInput sin) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,37 @@
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class FindingWithScore implements Writeable, ToXContentObject {

protected static final String FINDING = "finding";
protected static final String DETECTOR_TYPE = "detector_type";
protected static final String SCORE = "score";
protected static final String RULES = "rules";

private String finding;

private String detectorType;

private Double score;

public FindingWithScore(String finding, String detectorType, Double score) {
private List<String> rules;

public FindingWithScore(String finding, String detectorType, Double score, List<String> rules) {
this.finding = finding;
this.detectorType = detectorType;
this.score = score;
this.rules = rules;
}

public FindingWithScore(StreamInput sin) throws IOException {
this(
sin.readString(),
sin.readString(),
sin.readDouble()
sin.readDouble(),
sin.readStringList()
);
}

Expand All @@ -45,6 +52,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(finding);
out.writeString(detectorType);
out.writeDouble(score);
out.writeStringCollection(rules);
}

@Override
Expand All @@ -53,6 +61,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
.field(FINDING, finding)
.field(DETECTOR_TYPE, detectorType)
.field(SCORE, score)
.field(RULES, rules)
.endObject();
return builder;
}
Expand All @@ -61,6 +70,7 @@ public static FindingWithScore parse(XContentParser xcp) throws IOException {
String finding = null;
String detectorType = null;
Double score = null;
List<String> rules = new ArrayList<>();

XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -77,9 +87,17 @@ public static FindingWithScore parse(XContentParser xcp) throws IOException {
case SCORE:
score = xcp.doubleValue();
break;
case RULES:
XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp);
while (xcp.nextToken() != XContentParser.Token.END_ARRAY) {
rules.add(xcp.text());
}
break;
default:
xcp.skipChildren();
}
}
return new FindingWithScore(finding, detectorType, score);
return new FindingWithScore(finding, detectorType, score, rules);
}

public static FindingWithScore readFrom(StreamInput sin) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ public TransportCorrelateFindingAction(TransportService transportService,
@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<SubscribeFindingsResponse> actionListener) {
try {
log.info("hit here1");
PublishFindingsRequest transformedRequest = transformRequest(request);

if (!this.correlationIndices.correlationIndexExists()) {
Expand Down Expand Up @@ -157,6 +158,7 @@ public void onFailure(Exception e) {
log.error(ex);
}
} else {
log.info("hit here2");
AsyncCorrelateFindingAction correlateFindingAction = new AsyncCorrelateFindingAction(task, transformedRequest, actionListener);
correlateFindingAction.start();
}
Expand Down Expand Up @@ -184,9 +186,11 @@ public class AsyncCorrelateFindingAction {

this.joinEngine = new JoinEngine(client, request, xContentRegistry, corrTimeWindow, this);
this.vectorEmbeddingsEngine = new VectorEmbeddingsEngine(client, indexTimeout, corrTimeWindow, this);
log.info("hit here5");
}

void start() {
log.info("hit here4");
TransportCorrelateFindingAction.this.threadPool.getThreadContext().stashContext();
String monitorId = request.getMonitorId();
Finding finding = request.getFinding();
Expand Down Expand Up @@ -246,8 +250,9 @@ public void onFailure(Exception e) {
}
}

public void initCorrelationIndex(String detectorType, Map<String, List<String>> correlatedFindings) {
public void initCorrelationIndex(String detectorType, Map<String, List<String>> correlatedFindings, List<String> correlationRules) {
try {
log.info("hit here6");
if (!IndexUtils.correlationIndexUpdated) {
IndexUtils.updateIndexMapping(
CorrelationIndices.CORRELATION_INDEX,
Expand All @@ -257,7 +262,7 @@ public void initCorrelationIndex(String detectorType, Map<String, List<String>>
public void onResponse(AcknowledgedResponse response) {
if (response.isAcknowledged()) {
IndexUtils.correlationIndexUpdated();
getTimestampFeature(detectorType, correlatedFindings, null);
getTimestampFeature(detectorType, correlatedFindings, null, correlationRules);
} else {
onFailures(new OpenSearchStatusException("Failed to create correlation Index", RestStatus.INTERNAL_SERVER_ERROR));
}
Expand All @@ -270,14 +275,15 @@ public void onFailure(Exception e) {
}
);
} else {
getTimestampFeature(detectorType, correlatedFindings, null);
getTimestampFeature(detectorType, correlatedFindings, null, correlationRules);
}
} catch (IOException ex) {
onFailures(ex);
}
}

public void getTimestampFeature(String detectorType, Map<String, List<String>> correlatedFindings, Finding orphanFinding) {
public void getTimestampFeature(String detectorType, Map<String, List<String>> correlatedFindings, Finding orphanFinding, List<String> correlationRules) {
log.info("hit here7");
long findingTimestamp = this.request.getFinding().getTimestamp().toEpochMilli();
BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery()
.mustNot(QueryBuilders.termQuery("scoreTimestamp", 0L));
Expand Down Expand Up @@ -318,7 +324,7 @@ public void onResponse(IndexResponse response) {
}
for (Map.Entry<String, List<String>> correlatedFinding : correlatedFindings.entrySet()) {
vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(),
Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue());
Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue(), correlationRules);
}
} else {
vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, Long.valueOf(CorrelationIndices.FIXED_HISTORICAL_INTERVAL / 1000L).floatValue());
Expand All @@ -341,7 +347,7 @@ public void onFailure(Exception e) {
}
for (Map.Entry<String, List<String>> correlatedFinding : correlatedFindings.entrySet()) {
vectorEmbeddingsEngine.insertCorrelatedFindings(detectorType, request.getFinding(), correlatedFinding.getKey(), correlatedFinding.getValue(),
timestampFeature);
timestampFeature, correlationRules);
}
} else {
vectorEmbeddingsEngine.insertOrphanFindings(detectorType, orphanFinding, timestampFeature);
Expand Down
Loading

0 comments on commit 940a12a

Please sign in to comment.