Skip to content

Commit

Permalink
Merge branch 'main' into gaugup/SerializationAPIs
Browse files Browse the repository at this point in the history
  • Loading branch information
gaugup committed Apr 1, 2022
2 parents 59138b3 + d3f2bac commit 8be1826
Show file tree
Hide file tree
Showing 11 changed files with 192 additions and 10 deletions.
3 changes: 3 additions & 0 deletions apps/widget-e2e/src/describer/modelAssessment/Constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ export enum Locators {
WhatIfSetValueButton = "#CounterfactualPanel button:contains('Set Value')",
DECRotatedVerticalBox = "#DatasetExplorerChart div[class*='rotatedVerticalBox']", // DEC- Data explorer chart
DECHorizontalAxis = "#DatasetExplorerChart div[class*='horizontalAxis']",
DECHorizontalAxisButton = "#DatasetExplorerChart div[class*='horizontalAxis'] button",
DECChoiceFieldGroup = "#AxisConfigPanel div[class*='ms-ChoiceFieldGroup']",
DECCloseButton = "#AxisConfigPanel button.ms-Panel-closeButton",
DECAxisPanel = "#AxisConfigPanel div.ms-Panel-main",
Expand All @@ -78,6 +79,8 @@ export enum Locators {
DEIndividualDatapoints = "#ChartTypeSelection label:contains('Individual datapoints')",
DEAggregatePlots = "#ChartTypeSelection label:contains('Aggregate plots')",
DEYAxisPoints = "#DatasetExplorerChart g[class^='cartesianlayer'] g[class^='ytick'] text",
DEPoints = "#DatasetExplorerChart .highcharts-scatter-series > path.highcharts-point",
DEPointTooltip = ".highcharts-tooltip",
MSCRotatedVerticalBox = "#OverallMetricChart div[class*='rotatedVerticalBox']", // MSC- Model statistics chart
MSCHorizontalAxis = "#OverallMetricChart div[class*='horizontalAxis']",
CausalAnalysisHeader = "#ModelAssessmentDashboard #causalAnalysisHeader",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export interface IModelAssessmentData {
featureNames?: string[];
cohortDefaultName?: string;
isMulticlass?: boolean;
isRegression?: boolean;
}

export interface IErrorAnalysisData {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

import { ScatterHighchart } from "../../../util/ScatterHighchart";
import { Locators } from "../Constants";
import { IModelAssessmentData } from "../IModelAssessmentData";

import { describeAxisConfigDialog } from "./describeAxisConfigDialog";
Expand All @@ -20,6 +21,30 @@ export function describeIndividualDatapoints(
).click();
props.chart = new ScatterHighchart("#DatasetExplorerChart");
});

if (dataShape.isRegression) {
it("Should have clickable individual datapoints that are positive for regression error", () => {
cy.get(Locators.DEIndividualDatapoints).click();
cy.get(`${Locators.DECohortDropdown} span`).should(
"contain",
dataShape.cohortDefaultName
);
axisSelection("Error");

cy.get(Locators.DEPoints).each((point) => {
cy.wrap(point).trigger("mouseover", { force: true });
cy.get("#DatasetExplorerChart")
.find(Locators.DEPointTooltip)
.then((tooltip) => {
cy.wrap(tooltip).should("contain", "Regression error");
cy.wrap(tooltip).should("not.have.value", "Regression error: -");
});
});
axisSelection("Index");
cy.get(Locators.DEAggregatePlots).click();
});
}

describe.skip("Dataset explorer Chart", () => {
it("should have color label", () => {
cy.get('#DatasetExplorerChart label:contains("Color value")').should(
Expand Down Expand Up @@ -52,3 +77,16 @@ export function describeIndividualDatapoints(
}
});
}

export function axisSelection(label: string): void {
cy.get(Locators.DECHorizontalAxisButton)
.click()
.get(
`#AxisConfigPanel div[class*='ms-ChoiceFieldGroup'] label:contains('${label}')`
)
.click()
.get("#AxisConfigPanel")
.find("button")
.contains("Select")
.click();
}
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ const modelAssessmentDatasets = {
"age",
"s6"
],
isRegression: true,
modelStatisticsData: {
defaultXAxis: "Error",
defaultXAxisPanelValue: "Error",
Expand Down
1 change: 1 addition & 0 deletions libs/core-ui/src/lib/Interfaces/ICounterfactualData.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

export interface ICounterfactualData {
cfs_list: Array<Array<Array<string | number>>>;
errorMessage?: string;
feature_names: string[];
feature_names_including_target: string[];
summary_importance?: number[];
Expand Down
3 changes: 2 additions & 1 deletion libs/core-ui/src/lib/components/LabelWithCallout.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import { labelWithCalloutStyles } from "./LabelWithCallout.styles";
export interface ILabelWithCalloutProps {
label: string;
calloutTitle: string | undefined;
renderOnNewLayer?: boolean;
type?: "label" | "button";
}
interface ILabelWithCalloutState {
Expand Down Expand Up @@ -62,7 +63,7 @@ export class LabelWithCallout extends React.Component<
)}
{this.state.showCallout && (
<FabricCallout
doNotLayer
doNotLayer={!this.props.renderOnNewLayer}
target={`#${id}`}
setInitialFocus
onDismiss={this.toggleCallout}
Expand Down
5 changes: 3 additions & 2 deletions libs/core-ui/src/lib/util/JointDataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,9 @@ export class JointDataset {
modelType: ModelTypes
): void {
if (modelType === ModelTypes.Regression) {
row[JointDataset.RegressionError] =
row[JointDataset.TrueYLabel] - row[JointDataset.PredictedYLabel];
row[JointDataset.RegressionError] = Math.abs(
row[JointDataset.TrueYLabel] - row[JointDataset.PredictedYLabel]
);
return;
}
if (modelType === ModelTypes.Binary) {
Expand Down
11 changes: 10 additions & 1 deletion libs/counterfactuals/src/lib/CounterfactualPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import {
TooltipDelay,
DirectionalHint,
IconButton,
ITooltipProps
ITooltipProps,
MessageBar,
MessageBarType
} from "office-ui-fabric-react";
import React from "react";

Expand Down Expand Up @@ -81,6 +83,13 @@ export class CounterfactualPanel extends React.Component<
onRenderFooterContent={this.renderClose}
>
<Stack tokens={{ childrenGap: "m1" }}>
<Stack.Item>
{this.props.data?.errorMessage && (
<MessageBar messageBarType={MessageBarType.error}>
{this.props.data.errorMessage}
</MessageBar>
)}
</Stack.Item>
<Stack.Item className={classes.counterfactualList}>
<CounterfactualList
selectedIndex={this.props.selectedIndex}
Expand Down
6 changes: 4 additions & 2 deletions libs/localization/src/lib/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -1300,8 +1300,10 @@
"CorrectPredictions": "Correct predictions",
"GlobalExplanation": "Aggregate feature importance",
"IncorrectPredictions": "Incorrect predictions",
"IndividualFeature": "Select a datapoint by clicking on a datapoint (or multiple datapoints) in the table to view their local feature importance values (local explanation) and individual conditional expectation (ICE) plot below. For datasets with more than 5000 datapoints, the view is a random subsample to enable easy exploration.",
"LocalExplanation": "Individual feature importance"
"IndividualFeature": "Select a datapoint by clicking on a datapoint (up to 5 datapoints) in the table to view their local feature importance values (local explanation) and individual conditional expectation (ICE) plot below. For datasets with more than 5000 datapoints, the view is a random subsample to enable easy exploration.",
"LocalExplanation": "Individual feature importance",
"SelectionCounter": "{0}/{1} datapoints selected",
"SelectionLimit": "Up to 5 datapoints can be selected at this time."
},
"MainMenu": {
"DashboardSettings": "Dashboard configuration",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import {
IStyle,
mergeStyleSets,
IProcessedStyleSet,
getTheme
} from "office-ui-fabric-react";

export interface IFeatureImportanceStyles {
chevronButton: IStyle;
header: IStyle;
headerCount: IStyle;
headerTitle: IStyle;
selectionCounter: IStyle;
}

export const individualFeatureImportanceViewStyles: () => IProcessedStyleSet<IFeatureImportanceStyles> =
() => {
const theme = getTheme();
return mergeStyleSets({
chevronButton: {
marginLeft: 48,
paddingTop: 6,
width: 36
},
header: {
margin: `8px 0`,
padding: 8,
// Overlay the sizer bars
position: "relative",
zIndex: 100
},
headerCount: [
"headerCount",
theme.fonts.medium,
{
paddingTop: 4
}
],
headerTitle: [
theme.fonts.medium,
{
paddingTop: 4
}
],
selectionCounter: {
paddingTop: 12
}
});
};
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import {
FabricStyles,
constructRows,
constructCols,
ModelTypes
ModelTypes,
LabelWithCallout
} from "@responsible-ai/core-ui";
import { IGlobalSeries, LocalImportancePlots } from "@responsible-ai/interpret";
import { localization } from "@responsible-ai/localization";
Expand All @@ -34,10 +35,14 @@ import {
TooltipHost,
IColumn,
IGroup,
Text
Text,
IDetailsGroupDividerProps,
Icon
} from "office-ui-fabric-react";
import React from "react";

import { individualFeatureImportanceViewStyles } from "./IndividualFeatureImportanceView.styles";

export interface IIndividualFeatureImportanceProps {
features: string[];
jointDataset: JointDataset;
Expand All @@ -59,6 +64,8 @@ export interface IIndividualFeatureImportanceTableState {
export interface IIndividualFeatureImportanceState
extends IIndividualFeatureImportanceTableState {
featureImportances: IGlobalSeries[];
indexToUnselect?: number;
selectedIndices: number[];
sortArray: number[];
sortingSeriesIndex?: number;
}
Expand All @@ -70,9 +77,22 @@ export class IndividualFeatureImportanceView extends React.Component<
public static contextType = ModelAssessmentContext;
public context: React.ContextType<typeof ModelAssessmentContext> =
defaultModelAssessmentContext;
private readonly maxSelectable = 5;

private selection: Selection = new Selection({
onSelectionChanged: (): void => {
const c = this.selection.getSelectedCount();
const indices = this.selection.getSelectedIndices();
if (c === this.maxSelectable) {
this.setState({ selectedIndices: indices });
}
if (c > this.maxSelectable) {
for (const index of indices) {
if (!this.state.selectedIndices.includes(index)) {
this.setState({ indexToUnselect: index });
}
}
}
this.updateViewedFeatureImportances();
}
});
Expand All @@ -84,6 +104,8 @@ export class IndividualFeatureImportanceView extends React.Component<

this.state = {
featureImportances: [],
indexToUnselect: undefined,
selectedIndices: [],
sortArray: [],
...tableState
};
Expand All @@ -95,6 +117,10 @@ export class IndividualFeatureImportanceView extends React.Component<
if (this.props.selectedCohort !== prevProps.selectedCohort) {
this.setState(this.updateItems());
}
if (this.state.indexToUnselect) {
this.selection.toggleIndexSelected(this.state.indexToUnselect);
this.setState({ indexToUnselect: undefined });
}
}

public render(): React.ReactNode {
Expand Down Expand Up @@ -132,6 +158,7 @@ export class IndividualFeatureImportanceView extends React.Component<
text: meta.abbridgedLabel
};
});
const classNames = individualFeatureImportanceViewStyles();

return (
<Stack tokens={{ padding: "l1" }}>
Expand All @@ -140,6 +167,22 @@ export class IndividualFeatureImportanceView extends React.Component<
{localization.ModelAssessment.FeatureImportances.IndividualFeature}
</Text>
</Stack.Item>
<Stack.Item className={classNames.selectionCounter}>
<LabelWithCallout
label={localization.formatString(
localization.ModelAssessment.FeatureImportances.SelectionCounter,
this.selection.count,
this.maxSelectable
)}
calloutTitle={undefined}
renderOnNewLayer
type="label"
>
<Text block>
{localization.ModelAssessment.FeatureImportances.SelectionLimit}
</Text>
</LabelWithCallout>
</Stack.Item>
<Stack.Item className="tabularDataView">
<div style={{ height: "500px", position: "relative" }}>
<Fabric>
Expand All @@ -155,10 +198,12 @@ export class IndividualFeatureImportanceView extends React.Component<
onRenderDetailsHeader={this.onRenderDetailsHeader}
selectionPreservedOnEmptyClick
ariaLabelForSelectionColumn="Toggle selection"
ariaLabelForSelectAllCheckbox="Toggle selection for all items"
checkButtonAriaLabel="Row checkbox"
// checkButtonGroupAriaLabel="Group checkbox"
groupProps={{ showEmptyGroups: true }}
groupProps={{
onRenderHeader: this._onRenderGroupHeader,
showEmptyGroups: true
}}
selectionMode={SelectionMode.multiple}
selection={this.selection}
/>
Expand Down Expand Up @@ -329,4 +374,32 @@ export class IndividualFeatureImportanceView extends React.Component<
</div>
);
};

private _onRenderGroupHeader = (props?: IDetailsGroupDividerProps) => {
const classNames = individualFeatureImportanceViewStyles();
const iconName = props?.group?.isCollapsed
? "ChevronRightMed"
: "ChevronDownMed";
return (
<Stack className={classNames.header} horizontal>
<Icon
ariaLabel="expand collapse group"
className={classNames.chevronButton}
iconName={iconName}
onClick={this._onToggleCollapse(props)}
/>
<span className={classNames.headerTitle}>{props?.group!.name}</span>
&nbsp;
<span className={classNames.headerCount}>
{`(${props?.group!.count})`}
</span>
</Stack>
);
};

private _onToggleCollapse = (props?: IDetailsGroupDividerProps) => {
return () => {
props!.onToggleCollapse!(props!.group!);
};
};
}

0 comments on commit 8be1826

Please sign in to comment.