Skip to content

Commit

Permalink
take absolute value of error calculation for regression scenario (#1301)
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Mar 31, 2022
1 parent a97c552 commit 93a8ba3
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 2 deletions.
3 changes: 3 additions & 0 deletions apps/widget-e2e/src/describer/modelAssessment/Constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ export enum Locators {
WhatIfSetValueButton = "#CounterfactualPanel button:contains('Set Value')",
DECRotatedVerticalBox = "#DatasetExplorerChart div[class*='rotatedVerticalBox']", // DEC- Data explorer chart
DECHorizontalAxis = "#DatasetExplorerChart div[class*='horizontalAxis']",
DECHorizontalAxisButton = "#DatasetExplorerChart div[class*='horizontalAxis'] button",
DECChoiceFieldGroup = "#AxisConfigPanel div[class*='ms-ChoiceFieldGroup']",
DECCloseButton = "#AxisConfigPanel button.ms-Panel-closeButton",
DECAxisPanel = "#AxisConfigPanel div.ms-Panel-main",
Expand All @@ -78,6 +79,8 @@ export enum Locators {
DEIndividualDatapoints = "#ChartTypeSelection label:contains('Individual datapoints')",
DEAggregatePlots = "#ChartTypeSelection label:contains('Aggregate plots')",
DEYAxisPoints = "#DatasetExplorerChart g[class^='cartesianlayer'] g[class^='ytick'] text",
DEPoints = "#DatasetExplorerChart .highcharts-scatter-series > path.highcharts-point",
DEPointTooltip = ".highcharts-tooltip",
MSCRotatedVerticalBox = "#OverallMetricChart div[class*='rotatedVerticalBox']", // MSC- Model statistics chart
MSCHorizontalAxis = "#OverallMetricChart div[class*='horizontalAxis']",
CausalAnalysisHeader = "#ModelAssessmentDashboard #causalAnalysisHeader",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export interface IModelAssessmentData {
featureNames?: string[];
cohortDefaultName?: string;
isMulticlass?: boolean;
isRegression?: boolean;
}

export interface IErrorAnalysisData {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

import { ScatterHighchart } from "../../../util/ScatterHighchart";
import { Locators } from "../Constants";
import { IModelAssessmentData } from "../IModelAssessmentData";

import { describeAxisConfigDialog } from "./describeAxisConfigDialog";
Expand All @@ -20,6 +21,30 @@ export function describeIndividualDatapoints(
).click();
props.chart = new ScatterHighchart("#DatasetExplorerChart");
});

if (dataShape.isRegression) {
it("Should have clickable individual datapoints that are positive for regression error", () => {
cy.get(Locators.DEIndividualDatapoints).click();
cy.get(`${Locators.DECohortDropdown} span`).should(
"contain",
dataShape.cohortDefaultName
);
axisSelection("Error");

cy.get(Locators.DEPoints).each((point) => {
cy.wrap(point).trigger("mouseover", { force: true });
cy.get("#DatasetExplorerChart")
.find(Locators.DEPointTooltip)
.then((tooltip) => {
cy.wrap(tooltip).should("contain", "Regression error");
cy.wrap(tooltip).should("not.have.value", "Regression error: -");
});
});
axisSelection("Index");
cy.get(Locators.DEAggregatePlots).click();
});
}

describe.skip("Dataset explorer Chart", () => {
it("should have color label", () => {
cy.get('#DatasetExplorerChart label:contains("Color value")').should(
Expand Down Expand Up @@ -52,3 +77,16 @@ export function describeIndividualDatapoints(
}
});
}

export function axisSelection(label: string): void {
cy.get(Locators.DECHorizontalAxisButton)
.click()
.get(
`#AxisConfigPanel div[class*='ms-ChoiceFieldGroup'] label:contains('${label}')`
)
.click()
.get("#AxisConfigPanel")
.find("button")
.contains("Select")
.click();
}
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ const modelAssessmentDatasets = {
"age",
"s6"
],
isRegression: true,
modelStatisticsData: {
defaultXAxis: "Error",
defaultXAxisPanelValue: "Error",
Expand Down
5 changes: 3 additions & 2 deletions libs/core-ui/src/lib/util/JointDataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,9 @@ export class JointDataset {
modelType: ModelTypes
): void {
if (modelType === ModelTypes.Regression) {
row[JointDataset.RegressionError] =
row[JointDataset.TrueYLabel] - row[JointDataset.PredictedYLabel];
row[JointDataset.RegressionError] = Math.abs(
row[JointDataset.TrueYLabel] - row[JointDataset.PredictedYLabel]
);
return;
}
if (modelType === ModelTypes.Binary) {
Expand Down

0 comments on commit 93a8ba3

Please sign in to comment.