From 4da12d7b3738a7360df877223669cf1f2f924cb0 Mon Sep 17 00:00:00 2001 From: Quynh Nguyen Date: Thu, 9 Feb 2023 15:29:40 -0600 Subject: [PATCH 1/3] Add fix, add test --- .../decision_path_classification.tsx | 11 +- .../use_classification_path_data.test.tsx | 325 ++++++++++++------ .../use_classification_path_data.tsx | 4 + 3 files changed, 241 insertions(+), 99 deletions(-) diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/feature_importance/decision_path_classification.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/feature_importance/decision_path_classification.tsx index d10755b32d7a75..45f19ad7b76f7c 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/feature_importance/decision_path_classification.tsx +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/feature_importance/decision_path_classification.tsx @@ -21,6 +21,7 @@ import type { } from '../../../../../../../common/types/feature_importance'; import { DecisionPathChart } from './decision_path_chart'; import { MissingDecisionPathCallout } from './missing_decision_path_callout'; +import { TopClass } from '../../../../../../../common/types/feature_importance'; interface ClassificationDecisionPathProps { predictedValue: string | boolean; @@ -42,12 +43,20 @@ export const ClassificationDecisionPath: FC = ( const [currentClass, setCurrentClass] = useState( getStringBasedClassName(topClasses[0].class_name) ); + const selectedClass = topClasses.find( + (t) => getStringBasedClassName(t.class_name) === getStringBasedClassName(currentClass) + ) as TopClass; + const predictedProbabilityForCurrentClass = selectedClass + ? selectedClass.class_probability + : undefined; + const { decisionPathData } = useDecisionPathData({ baseline, featureImportance, predictedValue: currentClass, - predictedProbability, + predictedProbability: predictedProbabilityForCurrentClass, }); + const options = useMemo(() => { const predictionValueStr = getStringBasedClassName(predictedValue); diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/feature_importance/use_classification_path_data.test.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/feature_importance/use_classification_path_data.test.tsx index 70c62294cae009..53ae0daff084a3 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/feature_importance/use_classification_path_data.test.tsx +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/feature_importance/use_classification_path_data.test.tsx @@ -10,9 +10,10 @@ import { buildRegressionDecisionPathData, } from './use_classification_path_data'; import type { FeatureImportance } from '../../../../../../../common/types/feature_importance'; +import { roundToDecimalPlace } from '../../../../../formatters/round_to_decimal_place'; describe('buildClassificationDecisionPathData()', () => { - test('should return correct prediction probability for binary classification', () => { + test('returns correct prediction probability for binary classification', () => { const expectedResults = [ { className: 'yes', probability: 0.28564605871278403 }, { className: 'no', probability: 1 - 0.28564605871278403 }, @@ -71,139 +72,170 @@ describe('buildClassificationDecisionPathData()', () => { expect(result).toHaveLength(featureNames.length); expect(featureNames).toContain(result![0][0]); expect(result![0]).toHaveLength(3); + // Top shown result should equal expected probability expect(result![0][2]).toEqual(probability); + // Make sure probability (result[0]) is always less than 1 + expect(result?.every((r) => r[2] <= 1)).toEqual(true); } }); - test('should return correct prediction probability for multiclass classification', () => { - const expectedResults = [{ className: 1, probability: 0.3551929251919077 }]; + test('returns correct prediction probability & accounts for "other" residual probability for binary classification (boolean)', () => { + const expectedResults = [ + { + class_score: 0.1940750725280285, + class_probability: 0.9034630008985833, + // boolean class name should be converted to string 'True'/'False' + class_name: false, + }, + { + class_score: 0.09653699910141661, + class_probability: 0.09653699910141661, + class_name: true, + }, + ]; const baselinesData = { classes: [ { - class_name: 0, - baseline: 0.1845274610161167, + class_name: false, + baseline: 2.418789842558993, }, { - class_name: 1, - baseline: 0.1331813646384272, - }, - { - class_name: 2, - baseline: 0.1603600353308416, + class_name: true, + baseline: -2.418789842558993, }, ], }; const featureImportanceData: FeatureImportance[] = [ { - feature_name: 'AvgTicketPrice', + feature_name: 'DestWeather', classes: [ - { importance: 0.34413545865934353, class_name: 0 }, - { importance: 0.4781222770431657, class_name: 1 }, - { importance: 0.31847802693610877, class_name: 2 }, + { + importance: 0.5555510565764721, + // string class names 'true'/'false' should be converted to string 'True'/'False' + class_name: 'false', + }, + { + importance: -0.5555510565764721, + class_name: 'true', + }, ], }, { - feature_name: 'Cancelled', + feature_name: 'OriginWeather', classes: [ - { importance: 0.0002822015809810556, class_name: 0 }, - { importance: -0.0033337017702255597, class_name: 1 }, - { importance: 0.0020744732163668696, class_name: 2 }, + { + importance: 0.31139248413258486, + class_name: 'false', + }, + { + importance: -0.31139248413258486, + class_name: 'true', + }, ], }, { - feature_name: 'DistanceKilometers', + feature_name: 'OriginAirportID', classes: [ - { importance: 0.028472232240294063, class_name: 0 }, - { importance: 0.04119838646840895, class_name: 1 }, - { importance: 0.0662663363977551, class_name: 2 }, + { + importance: 0.2895740692218651, + class_name: 'false', + }, + { + importance: -0.2895740692218651, + class_name: 'true', + }, + ], + }, + { + feature_name: 'DestAirportID', + classes: [ + { + importance: 0.1297619730881764, + class_name: 'false', + }, + { + importance: -0.1297619730881764, + class_name: 'true', + }, + ], + }, + { + feature_name: 'hour_of_day', + classes: [ + { + importance: -0.10596307272294636, + class_name: 'false', + }, + { + importance: 0.10596307272294636, + class_name: 'true', + }, ], }, ]; const featureNames = featureImportanceData.map((d) => d.feature_name); - for (const { className, probability } of expectedResults) { + for (const { class_name: className, class_probability: probability } of expectedResults) { const result = buildClassificationDecisionPathData({ baselines: baselinesData.classes, featureImportance: featureImportanceData, currentClass: className, + predictedProbability: probability, }); + expect(result).toBeDefined(); - expect(result).toHaveLength(featureNames.length); + // Should add an 'other' field + expect(result).toHaveLength(featureNames.length + 1); expect(featureNames).toContain(result![0][0]); expect(result![0]).toHaveLength(3); + // Top shown result should equal expected probability expect(result![0][2]).toEqual(probability); + // Make sure probability (result[0]) is always less than 1 + expect(result?.every((r) => r[2] <= 1)).toEqual(true); } }); -}); -describe('buildRegressionDecisionPathData()', () => { - test('should return correct decision path', () => { - const predictedValue = 0.008000000000000005; - const baseline = 0.01570748450465414; - const featureImportanceData: FeatureImportance[] = [ - { feature_name: 'g1', importance: -0.01171550599313763 }, - { feature_name: 'tau4', importance: -0.01190799086101345 }, - ]; - const expectedFeatures = [ - ...featureImportanceData.map((d) => d.feature_name), - 'other', - 'baseline', - ]; - const result = buildRegressionDecisionPathData({ - baseline, - featureImportance: featureImportanceData, - predictedValue: 0.008, - }); - expect(result).toBeDefined(); - expect(result).toHaveLength(expectedFeatures.length); - expect(result![0]).toHaveLength(3); - expect(result![0][2]).toEqual(predictedValue); - }); - - test('buildClassificationDecisionPathData() should return correct prediction probability for binary classification', () => { - const expectedResults = [ - { className: 'yes', probability: 0.28564605871278403 }, - { className: 'no', probability: 1 - 0.28564605871278403 }, - ]; + test('returns correct prediction probability for multiclass classification', () => { + const expectedResults = [{ className: 1, probability: 0.3551929251919077 }]; const baselinesData = { classes: [ { - class_name: 'no', - baseline: 3.228256450715653, + class_name: 0, + baseline: 0.1845274610161167, }, { - class_name: 'yes', - baseline: -3.228256450715653, + class_name: 1, + baseline: 0.1331813646384272, + }, + { + class_name: 2, + baseline: 0.1603600353308416, }, ], }; const featureImportanceData: FeatureImportance[] = [ { - feature_name: 'duration', - classes: [ - { importance: 2.9932577725789455, class_name: 'yes' }, - { importance: -2.9932577725789455, class_name: 'no' }, - ], - }, - { - feature_name: 'job', + feature_name: 'AvgTicketPrice', classes: [ - { importance: -0.8023759403354496, class_name: 'yes' }, - { importance: 0.8023759403354496, class_name: 'no' }, + { importance: 0.34413545865934353, class_name: 0 }, + { importance: 0.4781222770431657, class_name: 1 }, + { importance: 0.31847802693610877, class_name: 2 }, ], }, { - feature_name: 'poutcome', + feature_name: 'Cancelled', classes: [ - { importance: 0.43319318839128396, class_name: 'yes' }, - { importance: -0.43319318839128396, class_name: 'no' }, + { importance: 0.0002822015809810556, class_name: 0 }, + { importance: -0.0033337017702255597, class_name: 1 }, + { importance: 0.0020744732163668696, class_name: 2 }, ], }, { - feature_name: 'housing', + feature_name: 'DistanceKilometers', classes: [ - { importance: -0.3124436380550531, class_name: 'yes' }, - { importance: 0.3124436380550531, class_name: 'no' }, + { importance: 0.028472232240294063, class_name: 0 }, + { importance: 0.04119838646840895, class_name: 1 }, + { importance: 0.0662663363977551, class_name: 2 }, ], }, ]; @@ -219,67 +251,164 @@ describe('buildRegressionDecisionPathData()', () => { expect(result).toHaveLength(featureNames.length); expect(featureNames).toContain(result![0][0]); expect(result![0]).toHaveLength(3); + // Top shown result should equal expected probability expect(result![0][2]).toEqual(probability); + // Make sure probability (result[0]) is always less than 1 + expect(result?.every((r) => r[2] <= 1)).toEqual(true); } }); - test('buildClassificationDecisionPathData() should return correct prediction probability for multiclass classification', () => { - const expectedResults = [{ className: 1, probability: 0.3551929251919077 }]; + test('returns correct prediction probability for multiclass classification with "other"', () => { + const expectedResults = [ + { + class_score: 0.2653792729907741, + class_probability: 0.995901728296372, + class_name: 'Iris-setosa', + }, + { + class_score: 0.002499393297421585, + class_probability: 0.002499393297421585, + class_name: 'Iris-versicolor', + }, + { + class_score: 0.0015399995493349922, + class_probability: 0.0015988784062062893, + class_name: 'Iris-virginica', + }, + ]; const baselinesData = { classes: [ { - class_name: 0, - baseline: 0.1845274610161167, + class_name: 'Iris-setosa', + baseline: -0.25145851617108084, }, { - class_name: 1, - baseline: 0.1331813646384272, + class_name: 'Iris-versicolor', + baseline: 0.46014588263093625, }, { - class_name: 2, - baseline: 0.1603600353308416, + class_name: 'Iris-virginica', + baseline: -0.20868736645984168, }, ], }; const featureImportanceData: FeatureImportance[] = [ { - feature_name: 'AvgTicketPrice', + feature_name: 'petal_length', classes: [ - { importance: 0.34413545865934353, class_name: 0 }, - { importance: 0.4781222770431657, class_name: 1 }, - { importance: 0.31847802693610877, class_name: 2 }, + { + importance: 2.4826228835057464, + class_name: 'Iris-setosa', + }, + { + importance: -0.5861671310095675, + class_name: 'Iris-versicolor', + }, + { + importance: -1.8964557524961734, + class_name: 'Iris-virginica', + }, ], }, { - feature_name: 'Cancelled', + feature_name: 'petal_width', classes: [ - { importance: 0.0002822015809810556, class_name: 0 }, - { importance: -0.0033337017702255597, class_name: 1 }, - { importance: 0.0020744732163668696, class_name: 2 }, + { + importance: 1.4568820749127243, + class_name: 'Iris-setosa', + }, + { + importance: -0.9431104132306853, + class_name: 'Iris-versicolor', + }, + { + importance: -0.5137716616820365, + class_name: 'Iris-virginica', + }, ], }, { - feature_name: 'DistanceKilometers', + feature_name: 'sepal_width', classes: [ - { importance: 0.028472232240294063, class_name: 0 }, - { importance: 0.04119838646840895, class_name: 1 }, - { importance: 0.0662663363977551, class_name: 2 }, + { + importance: 0.3508206289936615, + class_name: 'Iris-setosa', + }, + { + importance: 0.023074695691663594, + class_name: 'Iris-versicolor', + }, + { + importance: -0.3738953246853245, + class_name: 'Iris-virginica', + }, + ], + }, + { + feature_name: 'sepal_length', + classes: [ + { + importance: -0.027900272907686156, + class_name: 'Iris-setosa', + }, + { + importance: 0.13376776004064217, + class_name: 'Iris-versicolor', + }, + { + importance: -0.1058674871329558, + class_name: 'Iris-virginica', + }, ], }, ]; const featureNames = featureImportanceData.map((d) => d.feature_name); - for (const { className, probability } of expectedResults) { + for (const { + class_name: className, + class_probability: classPredictedProbability, + } of expectedResults) { const result = buildClassificationDecisionPathData({ baselines: baselinesData.classes, featureImportance: featureImportanceData, currentClass: className, + predictedProbability: classPredictedProbability, }); expect(result).toBeDefined(); - expect(result).toHaveLength(featureNames.length); + // Result accounts for 'other' or residual importance + expect(result).toHaveLength(featureNames.length + 1); expect(featureNames).toContain(result![0][0]); expect(result![0]).toHaveLength(3); - expect(result![0][2]).toEqual(probability); + expect(roundToDecimalPlace(result![0][2], 3)).toEqual( + roundToDecimalPlace(classPredictedProbability, 3) + ); + // Make sure probability (result[0]) is always less than 1 + expect(result?.every((r) => r[2] <= 1)).toEqual(true); } }); }); +describe('buildRegressionDecisionPathData()', () => { + test('returns correct decision path', () => { + const predictedValue = 0.008000000000000005; + const baseline = 0.01570748450465414; + const featureImportanceData: FeatureImportance[] = [ + { feature_name: 'g1', importance: -0.01171550599313763 }, + { feature_name: 'tau4', importance: -0.01190799086101345 }, + ]; + const expectedFeatures = [ + ...featureImportanceData.map((d) => d.feature_name), + 'other', + 'baseline', + ]; + + const result = buildRegressionDecisionPathData({ + baseline, + featureImportance: featureImportanceData, + predictedValue: 0.008, + }); + expect(result).toBeDefined(); + expect(result).toHaveLength(expectedFeatures.length); + expect(result![0]).toHaveLength(3); + expect(result![0][2]).toEqual(predictedValue); + }); +}); diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/feature_importance/use_classification_path_data.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/feature_importance/use_classification_path_data.tsx index ad9f0b3d0bb712..65954578db6ed1 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/feature_importance/use_classification_path_data.tsx +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/feature_importance/use_classification_path_data.tsx @@ -58,6 +58,10 @@ export const getStringBasedClassName = (v: string | boolean | undefined | number if (typeof v === 'boolean') { return v ? 'True' : 'False'; } + + if (v === 'true') return 'True'; + if (v === 'false') return 'False'; + if (typeof v === 'number') { return v.toString(); } From 548f768ff9e143b1826bec60122c44c059c31b8e Mon Sep 17 00:00:00 2001 From: Quynh Nguyen Date: Fri, 10 Feb 2023 11:28:03 -0600 Subject: [PATCH 2/3] Fix row index causing popover to not show up on page >=2 --- .../ml/public/application/components/data_grid/data_grid.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugins/ml/public/application/components/data_grid/data_grid.tsx b/x-pack/plugins/ml/public/application/components/data_grid/data_grid.tsx index b2fceb58edfa46..95a7955ec9a395 100644 --- a/x-pack/plugins/ml/public/application/components/data_grid/data_grid.tsx +++ b/x-pack/plugins/ml/public/application/components/data_grid/data_grid.tsx @@ -124,7 +124,7 @@ export const DataGrid: FC = memo( analysisType === ANALYSIS_CONFIG_TYPE.OUTLIER_DETECTION ) { if (schema === 'featureImportance') { - const row = data[rowIndex]; + const row = data[rowIndex - pagination.pageIndex * pagination.pageSize]; if (!row) return
; // if resultsField for some reason is not available then use ml const mlResultsField = resultsField ?? DEFAULT_RESULTS_FIELD; From c3f9233e1a8d57a8dc9ec97868ea910a30dd57c7 Mon Sep 17 00:00:00 2001 From: Quynh Nguyen Date: Fri, 10 Feb 2023 12:04:56 -0600 Subject: [PATCH 3/3] Add tests to check for feature importance when pagination is changed --- .../data_frame_analytics/results_view_content.ts | 7 +++++++ .../functional/services/ml/common_data_grid.ts | 15 +++++++++++++++ .../services/ml/data_frame_analytics_results.ts | 4 ++++ 3 files changed, 26 insertions(+) diff --git a/x-pack/test/functional/apps/ml/data_frame_analytics/results_view_content.ts b/x-pack/test/functional/apps/ml/data_frame_analytics/results_view_content.ts index f3e6eac011b730..308ec43593c0ce 100644 --- a/x-pack/test/functional/apps/ml/data_frame_analytics/results_view_content.ts +++ b/x-pack/test/functional/apps/ml/data_frame_analytics/results_view_content.ts @@ -294,6 +294,13 @@ export default function ({ getService }: FtrProviderContext) { await ml.dataFrameAnalyticsResults.assertFeatureImportancePopoverContent(); }); + it('should display the feature importance decision path after changing page', async () => { + await ml.dataFrameAnalyticsResults.selectResultsTablePage(3); + await ml.dataFrameAnalyticsResults.assertResultsTableNotEmpty(); + await ml.dataFrameAnalyticsResults.openFeatureImportancePopover(); + await ml.dataFrameAnalyticsResults.assertFeatureImportancePopoverContent(); + }); + it('should display the histogram charts', async () => { await ml.testExecution.logTestStep( 'displays the histogram charts when option is enabled' diff --git a/x-pack/test/functional/services/ml/common_data_grid.ts b/x-pack/test/functional/services/ml/common_data_grid.ts index f118af7090b437..9950d6b8f72056 100644 --- a/x-pack/test/functional/services/ml/common_data_grid.ts +++ b/x-pack/test/functional/services/ml/common_data_grid.ts @@ -219,5 +219,20 @@ export function MachineLearningCommonDataGridProvider({ getService }: FtrProvide await browser.pressKeys(browser.keys.ESCAPE); }); }, + + async assertActivePage(tableSubj: string, expectedPage: number) { + const table = await testSubjects.find(tableSubj); + const pagination = await table.findByClassName('euiPagination__list'); + const activePage = await pagination.findByCssSelector( + '.euiPaginationButton[aria-current] .euiButtonEmpty__text' + ); + const text = await activePage.getVisibleText(); + expect(text).to.eql(expectedPage); + }, + + async selectPage(tableSubj: string, page: number) { + await testSubjects.click(`${tableSubj} > pagination-button-${page - 1}`); + await this.assertActivePage(tableSubj, page); + }, }; } diff --git a/x-pack/test/functional/services/ml/data_frame_analytics_results.ts b/x-pack/test/functional/services/ml/data_frame_analytics_results.ts index a4378636673def..0fc99e1e032a1f 100644 --- a/x-pack/test/functional/services/ml/data_frame_analytics_results.ts +++ b/x-pack/test/functional/services/ml/data_frame_analytics_results.ts @@ -57,6 +57,10 @@ export function MachineLearningDataFrameAnalyticsResultsProvider( await testSubjects.existOrFail('mlExplorationDataGrid loaded', { timeout: 5000 }); }, + async selectResultsTablePage(page: number) { + await commonDataGrid.selectPage('mlExplorationDataGrid loaded', page); + }, + async assertResultsTableTrainingFiltersExist() { await testSubjects.existOrFail('mlDFAnalyticsExplorationQueryBarFilterButtons', { timeout: 5000,