Skip to content

Commit

Permalink
[Security Solution] DetectionRulesClient: return RuleResponse fro…
Browse files Browse the repository at this point in the history
…m all methods (#186179)

**Partially addresses: #184364

## Summary

This PR is a follow-up to [PR
#185748](#185748) and it converts
the remaining `DetectionRulesClient` methods to return `RuleResponse`.

Changes in this PR:
- These methods now return `RuleResponse` instead of internal
`RuleAlertType` type:
  - `updateRule`
  - `patchRule`
  - `upgradePrebuiltRule`
  - `importRule`
  • Loading branch information
nikitaindik committed Jun 21, 2024
1 parent 385bb2b commit 55687dd
Show file tree
Hide file tree
Showing 26 changed files with 218 additions and 159 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import type { SecuritySolutionPluginRouter } from '../../../../../types';
import { buildRouteValidation } from '../../../../../utils/build_validation/route_validation';
import type { PromisePoolError } from '../../../../../utils/promise_pool';
import { buildSiemResponse } from '../../../routes/utils';
import { internalRuleToAPIResponse } from '../../../rule_management/normalization/rule_converters';
import { aggregatePrebuiltRuleErrors } from '../../logic/aggregate_prebuilt_rule_errors';
import { performTimelinesInstallation } from '../../logic/perform_timelines_installation';
import { createPrebuiltRuleAssetsClient } from '../../logic/rule_assets/prebuilt_rule_assets_client';
Expand Down Expand Up @@ -182,7 +181,7 @@ export const performRuleUpgradeRoute = (router: SecuritySolutionPluginRouter) =>
failed: ruleErrors.length,
},
results: {
updated: updatedRules.map(({ result }) => internalRuleToAPIResponse(result)),
updated: updatedRules.map(({ result }) => result),
skipped: skippedRules,
},
errors: allErrors,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ import {
typicalMlRulePayload,
} from '../../../../routes/__mocks__/request_responses';
import { serverMock, requestContextMock, requestMock } from '../../../../routes/__mocks__';
import {
getRulesSchemaMock,
getRulesMlSchemaMock,
} from '../../../../../../../common/api/detection_engine/model/rule_schema/rule_response_schema.mock';
import { bulkPatchRulesRoute } from './route';
import { getCreateRulesSchemaMock } from '../../../../../../../common/api/detection_engine/model/rule_schema/mocks';
import { getMlRuleParams, getQueryRuleParams } from '../../../../rule_schema/mocks';
Expand All @@ -34,7 +38,7 @@ describe('Bulk patch rules route', () => {

clients.rulesClient.find.mockResolvedValue(getFindResultWithSingleHit()); // rule exists
clients.rulesClient.update.mockResolvedValue(getRuleMock(getQueryRuleParams())); // update succeeds
clients.detectionRulesClient.patchRule.mockResolvedValue(getRuleMock(getQueryRuleParams()));
clients.detectionRulesClient.patchRule.mockResolvedValue(getRulesSchemaMock());

bulkPatchRulesRoute(server.router, logger);
});
Expand Down Expand Up @@ -72,14 +76,11 @@ describe('Bulk patch rules route', () => {
...getFindResultWithSingleHit(),
data: [getRuleMock(getMlRuleParams())],
});
clients.detectionRulesClient.patchRule.mockResolvedValueOnce(
getRuleMock(
getMlRuleParams({
anomalyThreshold,
machineLearningJobId: [machineLearningJobId],
})
)
);
clients.detectionRulesClient.patchRule.mockResolvedValueOnce({
...getRulesMlSchemaMock(),
anomaly_threshold: anomalyThreshold,
machine_learning_job_id: [machineLearningJobId],
});

const request = requestMock.create({
method: 'patch',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import {
import type { SecuritySolutionPluginRouter } from '../../../../../../types';
import { transformBulkError, buildSiemResponse } from '../../../../routes/utils';
import { getIdBulkError } from '../../../utils/utils';
import { transformValidateBulkError } from '../../../utils/validate';
import { readRules } from '../../../logic/detection_rules_client/read_rules';
import { getDeprecatedBulkEndpointHeader, logDeprecatedBulkEndpoint } from '../../deprecation';
import { validateRuleDefaultExceptionList } from '../../../logic/exceptions/validate_rule_default_exception_list';
Expand Down Expand Up @@ -86,11 +85,11 @@ export const bulkPatchRulesRoute = (router: SecuritySolutionPluginRouter, logger
ruleId: payloadRule.id,
});

const rule = await detectionRulesClient.patchRule({
const patchedRule = await detectionRulesClient.patchRule({
nextParams: payloadRule,
});

return transformValidateBulkError(rule.id, rule);
return patchedRule;
} catch (err) {
return transformBulkError(idOrRuleIdOrUnknown, err);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
typicalMlRulePayload,
} from '../../../../routes/__mocks__/request_responses';
import { serverMock, requestContextMock, requestMock } from '../../../../routes/__mocks__';
import { getRulesSchemaMock } from '../../../../../../../common/api/detection_engine/model/rule_schema/rule_response_schema.mock';
import { bulkUpdateRulesRoute } from './route';
import type { BulkError } from '../../../../routes/utils';
import { getCreateRulesSchemaMock } from '../../../../../../../common/api/detection_engine/model/rule_schema/mocks';
Expand All @@ -32,7 +33,7 @@ describe('Bulk update rules route', () => {

clients.rulesClient.find.mockResolvedValue(getFindResultWithSingleHit());
clients.rulesClient.update.mockResolvedValue(getRuleMock(getQueryRuleParams()));
clients.detectionRulesClient.updateRule.mockResolvedValue(getRuleMock(getQueryRuleParams()));
clients.detectionRulesClient.updateRule.mockResolvedValue(getRulesSchemaMock());
clients.appClient.getSignalsIndex.mockReturnValue('.siem-signals-test-index');

bulkUpdateRulesRoute(server.router, logger);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import {
import type { SecuritySolutionPluginRouter } from '../../../../../../types';
import { DETECTION_ENGINE_RULES_BULK_UPDATE } from '../../../../../../../common/constants';
import { getIdBulkError } from '../../../utils/utils';
import { transformValidateBulkError } from '../../../utils/validate';
import {
transformBulkError,
buildSiemResponse,
Expand Down Expand Up @@ -97,11 +96,11 @@ export const bulkUpdateRulesRoute = (router: SecuritySolutionPluginRouter, logge
ruleId: payloadRule.id,
});

const rule = await detectionRulesClient.updateRule({
const updatedRule = await detectionRulesClient.updateRule({
ruleUpdate: payloadRule,
});

return transformValidateBulkError(rule.id, rule);
return updatedRule;
} catch (err) {
return transformBulkError(idOrRuleIdOrUnknown, err);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
ruleIdsToNdJsonString,
rulesToNdJsonString,
} from '../../../../../../../common/api/detection_engine/rule_management/mocks';
import { getRulesSchemaMock } from '../../../../../../../common/api/detection_engine/model/rule_schema/rule_response_schema.mock';

import type { requestMock } from '../../../../routes/__mocks__';
import { createMockConfig, requestContextMock, serverMock } from '../../../../routes/__mocks__';
Expand Down Expand Up @@ -47,7 +48,8 @@ describe('Import rules route', () => {

clients.rulesClient.find.mockResolvedValue(getEmptyFindResult()); // no extant rules
clients.rulesClient.update.mockResolvedValue(getRuleMock(getQueryRuleParams()));
clients.detectionRulesClient.importRule.mockResolvedValue(getRuleMock(getQueryRuleParams()));
clients.detectionRulesClient.createCustomRule.mockResolvedValue(getRulesSchemaMock());
clients.detectionRulesClient.importRule.mockResolvedValue(getRulesSchemaMock());
clients.actionsClient.getAll.mockResolvedValue([]);
context.core.elasticsearch.client.asCurrentUser.search.mockResolvedValue(
elasticsearchClientMock.createSuccessTransportRequestPromise(getBasicEmptySearchResponse())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ import {

import { getMlRuleParams, getQueryRuleParams } from '../../../../rule_schema/mocks';

import {
getRulesSchemaMock,
getRulesMlSchemaMock,
} from '../../../../../../../common/api/detection_engine/model/rule_schema/rule_response_schema.mock';

import { patchRuleRoute } from './route';
import { HttpAuthzError } from '../../../../../machine_learning/validation';

Expand All @@ -34,7 +39,7 @@ describe('Patch rule route', () => {
clients.rulesClient.get.mockResolvedValue(getRuleMock(getQueryRuleParams())); // existing rule
clients.rulesClient.find.mockResolvedValue(getFindResultWithSingleHit()); // existing rule
clients.rulesClient.update.mockResolvedValue(getRuleMock(getQueryRuleParams())); // successful update
clients.detectionRulesClient.patchRule.mockResolvedValue(getRuleMock(getQueryRuleParams()));
clients.detectionRulesClient.patchRule.mockResolvedValue(getRulesSchemaMock());

patchRuleRoute(server.router);
});
Expand Down Expand Up @@ -99,14 +104,11 @@ describe('Patch rule route', () => {

const anomalyThreshold = 4;
const machineLearningJobId = 'some_job_id';
clients.detectionRulesClient.patchRule.mockResolvedValueOnce(
getRuleMock(
getMlRuleParams({
anomalyThreshold,
machineLearningJobId: [machineLearningJobId],
})
)
);
clients.detectionRulesClient.patchRule.mockResolvedValueOnce({
...getRulesMlSchemaMock(),
anomaly_threshold: anomalyThreshold,
machine_learning_job_id: [machineLearningJobId],
});

const request = requestMock.create({
method: 'patch',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import { readRules } from '../../../logic/detection_rules_client/read_rules';
import { checkDefaultRuleExceptionListReferences } from '../../../logic/exceptions/check_for_default_rule_exception_list';
import { validateRuleDefaultExceptionList } from '../../../logic/exceptions/validate_rule_default_exception_list';
import { getIdError } from '../../../utils/utils';
import { transformValidate } from '../../../utils/validate';

export const patchRuleRoute = (router: SecuritySolutionPluginRouter) => {
router.versioned
Expand Down Expand Up @@ -76,12 +75,12 @@ export const patchRuleRoute = (router: SecuritySolutionPluginRouter) => {
ruleId: params.id,
});

const rule = await detectionRulesClient.patchRule({
const patchedRule = await detectionRulesClient.patchRule({
nextParams: params,
});

return response.ok({
body: transformValidate(rule),
body: patchedRule,
});
} catch (err) {
const error = transformError(err);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
typicalMlRulePayload,
} from '../../../../routes/__mocks__/request_responses';
import { requestContextMock, serverMock, requestMock } from '../../../../routes/__mocks__';
import { getRulesSchemaMock } from '../../../../../../../common/api/detection_engine/model/rule_schema/rule_response_schema.mock';
import { DETECTION_ENGINE_RULES_URL } from '../../../../../../../common/constants';
import { updateRuleRoute } from './route';
import {
Expand All @@ -34,7 +35,7 @@ describe('Update rule route', () => {
clients.rulesClient.get.mockResolvedValue(getRuleMock(getQueryRuleParams())); // existing rule
clients.rulesClient.find.mockResolvedValue(getFindResultWithSingleHit()); // rule exists
clients.rulesClient.update.mockResolvedValue(getRuleMock(getQueryRuleParams())); // successful update
clients.detectionRulesClient.updateRule.mockResolvedValue(getRuleMock(getQueryRuleParams()));
clients.detectionRulesClient.updateRule.mockResolvedValue(getRulesSchemaMock());
clients.appClient.getSignalsIndex.mockReturnValue('.siem-signals-test-index');

updateRuleRoute(server.router);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import { readRules } from '../../../logic/detection_rules_client/read_rules';
import { checkDefaultRuleExceptionListReferences } from '../../../logic/exceptions/check_for_default_rule_exception_list';
import { validateRuleDefaultExceptionList } from '../../../logic/exceptions/validate_rule_default_exception_list';
import { getIdError } from '../../../utils/utils';
import { transformValidate, validateResponseActionsPermissions } from '../../../utils/validate';
import { validateResponseActionsPermissions } from '../../../utils/validate';

export const updateRuleRoute = (router: SecuritySolutionPluginRouter) => {
router.versioned
Expand Down Expand Up @@ -80,12 +80,12 @@ export const updateRuleRoute = (router: SecuritySolutionPluginRouter) => {
existingRule
);

const rule = await detectionRulesClient.updateRule({
const updatedRule = await detectionRulesClient.updateRule({
ruleUpdate: request.body,
});

return response.ok({
body: transformValidate(rule),
body: updatedRule,
});
} catch (err) {
const error = transformError(err);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ import { buildMlAuthz } from '../../../../machine_learning/authz';
import { throwAuthzError } from '../../../../machine_learning/validation';
import { createDetectionRulesClient } from './detection_rules_client';
import type { IDetectionRulesClient } from './detection_rules_client_interface';
import { RuleResponseValidationError } from './utils';
import type { RuleAlertType } from '../../../rule_schema';

jest.mock('../../../../machine_learning/authz');
jest.mock('../../../../machine_learning/validation');
Expand Down Expand Up @@ -70,20 +68,6 @@ describe('DetectionRulesClient.createCustomRule', () => {
expect(rulesClient.create).not.toHaveBeenCalled();
});

it('throws if RuleResponse validation fails', async () => {
const internalRuleMock: RuleAlertType = getRuleMock({
...getQueryRuleParams(),
/* Casting as 'query' suppress to TS error */
type: 'fake-non-existent-type' as 'query',
});

rulesClient.create.mockResolvedValueOnce(internalRuleMock);

await expect(
detectionRulesClient.createCustomRule({ params: getCreateMachineLearningRulesSchemaMock() })
).rejects.toThrow(RuleResponseValidationError);
});

it('calls the rulesClient with legacy ML params', async () => {
await detectionRulesClient.createCustomRule({
params: getCreateMachineLearningRulesSchemaMock(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ describe('DetectionRulesClient.importRule', () => {

beforeEach(() => {
rulesClient = rulesClientMock.create();
rulesClient.create.mockResolvedValue(getRuleMock(getQueryRuleParams()));
rulesClient.update.mockResolvedValue(getRuleMock(getQueryRuleParams()));
detectionRulesClient = createDetectionRulesClient(rulesClient, mlAuthz);
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import type { RulesClient } from '@kbn/alerting-plugin/server';
import type { MlAuthz } from '../../../../machine_learning/authz';

import type { RuleAlertType } from '../../../rule_schema';
import type { RuleResponse } from '../../../../../../common/api/detection_engine/model/rule_schema';
import type {
IDetectionRulesClient,
Expand Down Expand Up @@ -47,13 +46,13 @@ export const createDetectionRulesClient = (
});
},

async updateRule(args: UpdateRuleArgs): Promise<RuleAlertType> {
async updateRule(args: UpdateRuleArgs): Promise<RuleResponse> {
return withSecuritySpan('DetectionRulesClient.updateRule', async () => {
return updateRule(rulesClient, args, mlAuthz);
});
},

async patchRule(args: PatchRuleArgs): Promise<RuleAlertType> {
async patchRule(args: PatchRuleArgs): Promise<RuleResponse> {
return withSecuritySpan('DetectionRulesClient.patchRule', async () => {
return patchRule(rulesClient, args, mlAuthz);
});
Expand All @@ -65,13 +64,13 @@ export const createDetectionRulesClient = (
});
},

async upgradePrebuiltRule(args: UpgradePrebuiltRuleArgs): Promise<RuleAlertType> {
async upgradePrebuiltRule(args: UpgradePrebuiltRuleArgs): Promise<RuleResponse> {
return withSecuritySpan('DetectionRulesClient.upgradePrebuiltRule', async () => {
return upgradePrebuiltRule(rulesClient, args, mlAuthz);
});
},

async importRule(args: ImportRuleArgs): Promise<RuleAlertType> {
async importRule(args: ImportRuleArgs): Promise<RuleResponse> {
return withSecuritySpan('DetectionRulesClient.importRule', async () => {
return importRule(rulesClient, args, mlAuthz);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,12 @@ describe('DetectionRulesClient.upgradePrebuiltRule', () => {
ruleId: 'rule-id',
});
beforeEach(() => {
jest.resetAllMocks();
rulesClient.create.mockResolvedValue(getRuleMock(getQueryRuleParams()));
(readRules as jest.Mock).mockResolvedValue(installedRule);
});

it('deletes the old rule ', async () => {
it('deletes the old rule', async () => {
await detectionRulesClient.upgradePrebuiltRule({ ruleAsset });
expect(rulesClient.delete).toHaveBeenCalled();
});
Expand Down Expand Up @@ -153,6 +155,8 @@ describe('DetectionRulesClient.upgradePrebuiltRule', () => {
});

it('patches the existing rule with the new params from the rule asset', async () => {
rulesClient.update.mockResolvedValue(getRuleMock(getEqlRuleParams()));

await detectionRulesClient.upgradePrebuiltRule({ ruleAsset });
expect(rulesClient.update).toHaveBeenCalledWith(
expect.objectContaining({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@ import type {
RuleToImport,
RuleResponse,
} from '../../../../../../common/api/detection_engine';
import type { RuleAlertType } from '../../../rule_schema';
import type { PrebuiltRuleAsset } from '../../../prebuilt_rules';

export interface IDetectionRulesClient {
createCustomRule: (args: CreateCustomRuleArgs) => Promise<RuleResponse>;
createPrebuiltRule: (args: CreatePrebuiltRuleArgs) => Promise<RuleResponse>;
updateRule: (args: UpdateRuleArgs) => Promise<RuleAlertType>;
patchRule: (args: PatchRuleArgs) => Promise<RuleAlertType>;
updateRule: (args: UpdateRuleArgs) => Promise<RuleResponse>;
patchRule: (args: PatchRuleArgs) => Promise<RuleResponse>;
deleteRule: (args: DeleteRuleArgs) => Promise<void>;
upgradePrebuiltRule: (args: UpgradePrebuiltRuleArgs) => Promise<RuleAlertType>;
importRule: (args: ImportRuleArgs) => Promise<RuleAlertType>;
upgradePrebuiltRule: (args: UpgradePrebuiltRuleArgs) => Promise<RuleResponse>;
importRule: (args: ImportRuleArgs) => Promise<RuleResponse>;
}

export interface CreateCustomRuleArgs {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import type { CreateCustomRuleArgs } from '../detection_rules_client_interface';
import type { MlAuthz } from '../../../../../machine_learning/authz';
import type { RuleParams } from '../../../../rule_schema';
import { RuleResponse } from '../../../../../../../common/api/detection_engine/model/rule_schema';
import { convertCreateAPIToInternalSchema } from '../../../normalization/rule_converters';
import { transform } from '../../../utils/utils';
import {
convertCreateAPIToInternalSchema,
internalRuleToAPIResponse,
} from '../../../normalization/rule_converters';
import { validateMlAuth, RuleResponseValidationError } from '../utils';

export const createCustomRule = async (
Expand All @@ -29,7 +31,7 @@ export const createCustomRule = async (
});

/* Trying to convert the rule to a RuleResponse object */
const parseResult = RuleResponse.safeParse(transform(rule));
const parseResult = RuleResponse.safeParse(internalRuleToAPIResponse(rule));

if (!parseResult.success) {
throw new RuleResponseValidationError({
Expand Down
Loading

0 comments on commit 55687dd

Please sign in to comment.