diff --git a/apps/widget/src/app/ModelAssessment.tsx b/apps/widget/src/app/ModelAssessment.tsx index 3b1c9720dc..493a0273af 100644 --- a/apps/widget/src/app/ModelAssessment.tsx +++ b/apps/widget/src/app/ModelAssessment.tsx @@ -56,16 +56,17 @@ export class ModelAssessment extends React.Component { return callFlaskService(this.props.config, data, "/get_exp"); }; callBack.requestObjectDetectionMetrics = async ( - trueY: number[][][], - predictedY: number[][][], + selectionIndexes: number[][], aggregateMethod: string, className: string, - iouThresh: number + iouThresh: number, + abortSignal: AbortSignal ): Promise => { return callFlaskService( this.props.config, - [trueY, predictedY, aggregateMethod, className, iouThresh], - "/get_object_detection_metrics" + [selectionIndexes, aggregateMethod, className, iouThresh], + "/get_object_detection_metrics", + abortSignal ); }; callBack.requestPredictions = async (data: any[]): Promise => { diff --git a/libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx b/libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx index d632112b58..23e1a9de90 100644 --- a/libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx +++ b/libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx @@ -138,11 +138,11 @@ export interface IModelAssessmentContext { | undefined; requestObjectDetectionMetrics?: | (( - trueY: number[][][], - predictedY: number[][][], + selectionIndexes: number[][], aggregateMethod: string, className: string, - iouThresh: number + iouThresh: number, + abortSignal: AbortSignal ) => Promise) | undefined; requestSplinePlotDistribution?: ( @@ -174,6 +174,7 @@ export const defaultModelAssessmentContext: IModelAssessmentContext = { modelType: undefined, requestExp: undefined, requestLocalFeatureExplanations: undefined, + requestObjectDetectionMetrics: undefined, requestPredictions: undefined, selectedErrorCohort: {} as ErrorCohort, setAsCategorical: () => undefined, diff --git a/libs/core-ui/src/lib/util/JointDataset.ts b/libs/core-ui/src/lib/util/JointDataset.ts index de4ad2a64d..35b76c4692 100644 --- a/libs/core-ui/src/lib/util/JointDataset.ts +++ b/libs/core-ui/src/lib/util/JointDataset.ts @@ -42,6 +42,9 @@ export class JointDataset { public static readonly PredictedYLabel = "PredictedY"; public static readonly ProbabilityYRoot = "ProbabilityClass"; public static readonly TrueYLabel = "TrueY"; + public static readonly ObjectDetectionPredictedYLabel = + "ObjectDetectionPredictedY"; + public static readonly ObjectDetectionTrueYLabel = "ObjectDetectionTrueY"; public static readonly DitherLabel = "Dither"; public static readonly DitherLabel2 = "Dither2"; public static readonly ClassificationError = "ClassificationError"; diff --git a/libs/core-ui/src/lib/util/JointDatasetUtils.ts b/libs/core-ui/src/lib/util/JointDatasetUtils.ts index 2b69c7307c..e59a4e2abd 100644 --- a/libs/core-ui/src/lib/util/JointDatasetUtils.ts +++ b/libs/core-ui/src/lib/util/JointDatasetUtils.ts @@ -23,6 +23,8 @@ export interface IJointDatasetArgs { metadata: IExplanationModelMetadata; featureMetaData?: IFeatureMetaData; targetColumn?: string | string[]; + objectDetectionTrueY?: number[][][]; + objectDetectionPredictedY?: number[][][]; } export enum ColumnCategories { diff --git a/libs/core-ui/src/lib/util/ObjectDetectionStatisticsUtils.ts b/libs/core-ui/src/lib/util/ObjectDetectionStatisticsUtils.ts index 88b3dd283a..0a2d02fb96 100644 --- a/libs/core-ui/src/lib/util/ObjectDetectionStatisticsUtils.ts +++ b/libs/core-ui/src/lib/util/ObjectDetectionStatisticsUtils.ts @@ -8,8 +8,6 @@ import { TotalCohortSamples } from "../Interfaces/IStatistic"; -import { JointDataset } from "./JointDataset"; - export enum ObjectDetectionMetrics { MeanAveragePrecision = "meanAveragePrecision", AveragePrecision = "averagePrecision", @@ -17,41 +15,33 @@ export enum ObjectDetectionMetrics { } export const generateObjectDetectionStats: ( - jointDataset: JointDataset, selectionIndexes: number[][] ) => ILabeledStatistic[][] = ( - jointDataset: JointDataset, selectionIndexes: number[][] ): ILabeledStatistic[][] => { - const numLabels = jointDataset.numLabels; return selectionIndexes.map((selectionArray) => { const count = selectionArray.length; - // TODO: replace placeholder values with flask endpoint calls to python backend. - const meanAveragePrecision = 42; - const averagePrecision = 42; - const averageRecall = 42; - return [ { key: TotalCohortSamples, label: localization.Interpret.Statistics.samples, - stat: count / numLabels // TODO: remove numLabels from here when using jointDataset elsewhere. + stat: count }, { key: ObjectDetectionMetrics.MeanAveragePrecision, label: localization.Interpret.Statistics.meanAveragePrecision, - stat: meanAveragePrecision + stat: Number.NaN }, { key: ObjectDetectionMetrics.AveragePrecision, label: localization.Interpret.Statistics.averagePrecision, - stat: averagePrecision + stat: Number.NaN }, { key: ObjectDetectionMetrics.AverageRecall, label: localization.Interpret.Statistics.averageRecall, - stat: averageRecall + stat: Number.NaN } ]; }); diff --git a/libs/core-ui/src/lib/util/StatisticsUtils.ts b/libs/core-ui/src/lib/util/StatisticsUtils.ts index 3273a0ecc5..af331d2308 100644 --- a/libs/core-ui/src/lib/util/StatisticsUtils.ts +++ b/libs/core-ui/src/lib/util/StatisticsUtils.ts @@ -279,7 +279,7 @@ export const generateMetrics: ( }); } if (modelType === ModelTypes.ObjectDetection) { - return generateObjectDetectionStats(jointDataset, selectionIndexes); + return generateObjectDetectionStats(selectionIndexes); } const outcomes = jointDataset.unwrap(JointDataset.ClassificationError); return selectionIndexes.map((selectionArray) => { diff --git a/libs/interpret-vision/src/lib/VisionExplanationDashboard/Interfaces/IVisionExplanationDashboardProps.ts b/libs/interpret-vision/src/lib/VisionExplanationDashboard/Interfaces/IVisionExplanationDashboardProps.ts index 3cb254bb45..d20dece04d 100644 --- a/libs/interpret-vision/src/lib/VisionExplanationDashboard/Interfaces/IVisionExplanationDashboardProps.ts +++ b/libs/interpret-vision/src/lib/VisionExplanationDashboard/Interfaces/IVisionExplanationDashboardProps.ts @@ -17,11 +17,11 @@ export interface IVisionExplanationDashboardProps { abortSignal: AbortSignal ) => Promise; requestObjectDetectionMetrics?: ( - trueY: number[][][], - predictedY: number[][][], + selectionIndexes: number[][], aggregateMethod: string, className: string, - iouThresh: number + iouThresh: number, + abortSignal: AbortSignal ) => Promise; selectedCohort: ErrorCohort; setSelectedCohort: (cohort: ErrorCohort) => void; diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx index bab28125a0..4b49c12944 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ModelOverview.tsx @@ -33,7 +33,8 @@ import { TelemetryLevels, TelemetryEventName, DatasetTaskType, - ImageClassificationMetrics + ImageClassificationMetrics, + TotalCohortSamples } from "@responsible-ai/core-ui"; import { localization } from "@responsible-ai/localization"; import React from "react"; @@ -52,6 +53,12 @@ import { getSelectableMetrics } from "./StatsTableUtils"; interface IModelOverviewProps { telemetryHook?: (message: ITelemetryEvent) => void; + requestObjectDetectionMetrics?: ( + selectionIndexes: number[][], + aggregateMethod: string, + className: string, + iouThresh: number + ) => Promise; } interface IModelOverviewState { @@ -62,14 +69,17 @@ interface IModelOverviewState { selectedFeatureBasedCohorts?: number[]; chartConfigurationIsVisible: boolean; datasetCohortViewIsVisible: boolean; + aggregateMethod: string; datasetCohortChartIsVisible: boolean; featureConfigurationIsVisible: boolean; metricConfigurationIsVisible: boolean; showHeatmapColors: boolean; datasetCohortLabeledStatistics: ILabeledStatistic[][]; datasetBasedCohorts: ErrorCohort[]; + className: string; featureBasedCohortLabeledStatistics: ILabeledStatistic[][]; featureBasedCohorts: ErrorCohort[]; + iouThresh: number; } const datasetCohortViewPivotKey = "datasetCohortView"; @@ -86,8 +96,12 @@ export class ModelOverview extends React.Component< public constructor(props: IModelOverviewProps) { super(props); + this.state = { + aggregateMethod: + localization.ModelAssessment.ModelOverview.metricTypes.macro, chartConfigurationIsVisible: false, + className: "", datasetBasedCohorts: [], datasetCohortChartIsVisible: true, datasetCohortLabeledStatistics: [], @@ -95,6 +109,7 @@ export class ModelOverview extends React.Component< featureBasedCohortLabeledStatistics: [], featureBasedCohorts: [], featureConfigurationIsVisible: false, + iouThresh: 70, metricConfigurationIsVisible: false, selectedFeatures: [], selectedFeaturesContinuousFeatureBins: {}, @@ -333,6 +348,11 @@ export class ModelOverview extends React.Component< )} @@ -512,12 +532,55 @@ export class ModelOverview extends React.Component< ); } - private updateDatasetCohortStats(): void { + private setAggregateMethod = (value: string): void => { + this.setState({ aggregateMethod: value }, () => { + if (this.state.datasetCohortChartIsVisible) { + this.updateDatasetCohortStats(); + } else { + this.updateFeatureCohortStats(); + } + }); + + this.logButtonClick( + TelemetryEventName.ModelOverviewMetricsSelectionUpdated + ); + }; + + private setClassName = (value: string): void => { + this.setState({ className: value }, () => { + if (this.state.datasetCohortChartIsVisible) { + this.updateDatasetCohortStats(); + } else { + this.updateFeatureCohortStats(); + } + }); + + this.logButtonClick( + TelemetryEventName.ModelOverviewMetricsSelectionUpdated + ); + }; + + private setIoUThreshold = (value: number): void => { + this.setState({ iouThresh: value }, () => { + if (this.state.datasetCohortChartIsVisible) { + this.updateDatasetCohortStats(); + } else { + this.updateFeatureCohortStats(); + } + }); + + this.logButtonClick( + TelemetryEventName.ModelOverviewMetricsSelectionUpdated + ); + }; + + private updateDatasetCohortStats = (): void => { + const selectionIndexes: number[][] = this.context.errorCohorts.map( + (errorCohort) => errorCohort.cohort.unwrap(JointDataset.IndexLabel) + ); const datasetCohortMetricStats = generateMetrics( this.context.jointDataset, - this.context.errorCohorts.map((errorCohort) => - errorCohort.cohort.unwrap(JointDataset.IndexLabel) - ), + selectionIndexes, this.context.modelMetadata.modelType ); @@ -525,9 +588,81 @@ export class ModelOverview extends React.Component< datasetBasedCohorts: this.context.errorCohorts, datasetCohortLabeledStatistics: datasetCohortMetricStats }); + + this.updateObjectDetectionMetrics(selectionIndexes, true); + }; + + private updateObjectDetectionMetrics( + selectionIndexes: number[][], + isDatasetCohort: boolean + ): void { + if ( + this.context.requestObjectDetectionMetrics && + selectionIndexes.length > 0 && + this.state.aggregateMethod.length > 0 && + this.state.className.length > 0 && + this.state.iouThresh + ) { + this.context + .requestObjectDetectionMetrics( + selectionIndexes, + this.state.aggregateMethod, + this.state.className, + this.state.iouThresh, + new AbortController().signal + ) + .then((result) => { + // Assumption: the lengths of `result` and `selectionIndexes` are the same. + const updatedMetricStats: ILabeledStatistic[][] = []; + + for (const [ + cohortIndex, + [meanAveragePrecision, averagePrecision, averageRecall] + ] of result.entries()) { + const count = selectionIndexes[cohortIndex].length; + + const updatedCohortMetricStats = [ + { + key: TotalCohortSamples, + label: localization.Interpret.Statistics.samples, + stat: count + }, + { + key: ObjectDetectionMetrics.MeanAveragePrecision, + label: localization.Interpret.Statistics.meanAveragePrecision, + stat: meanAveragePrecision + }, + { + key: ObjectDetectionMetrics.AveragePrecision, + label: localization.Interpret.Statistics.averagePrecision, + stat: averagePrecision + }, + { + key: ObjectDetectionMetrics.AverageRecall, + label: localization.Interpret.Statistics.averageRecall, + stat: averageRecall + } + ]; + + updatedMetricStats.push(updatedCohortMetricStats); + } + + isDatasetCohort + ? this.updateDatasetCohortState(updatedMetricStats) + : this.updateFeatureCohortState(updatedMetricStats); + }); + } + } + + private updateDatasetCohortState( + cohortMetricStats: ILabeledStatistic[][] + ): void { + this.setState({ + datasetCohortLabeledStatistics: cohortMetricStats + }); } - private async updateFeatureCohortStats(): Promise { + private updateFeatureCohortStats = async (): Promise => { // generate table contents for selected feature cohorts const featureBasedCohorts = generateOverlappingFeatureBasedCohorts( this.context.baseErrorCohort, @@ -537,11 +672,13 @@ export class ModelOverview extends React.Component< this.state.selectedFeaturesContinuousFeatureBins ); + const selectionIndexes: number[][] = featureBasedCohorts.map( + (errorCohort) => errorCohort.cohort.unwrap(JointDataset.IndexLabel) + ); + const featureCohortMetricStats = generateMetrics( this.context.jointDataset, - featureBasedCohorts.map((errorCohort) => - errorCohort.cohort.unwrap(JointDataset.IndexLabel) - ), + selectionIndexes, this.context.modelMetadata.modelType ); @@ -549,6 +686,16 @@ export class ModelOverview extends React.Component< featureBasedCohortLabeledStatistics: featureCohortMetricStats, featureBasedCohorts }); + + this.updateObjectDetectionMetrics(selectionIndexes, false); + }; + + private updateFeatureCohortState( + cohortMetricStats: ILabeledStatistic[][] + ): void { + this.setState({ + featureBasedCohortLabeledStatistics: cohortMetricStats + }); } private ifCohortIndexesEquals(a: number[], b: number[]): boolean { diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ObjectDetectionModelOverview.tsx b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ObjectDetectionModelOverview.tsx index 74b4a708ad..c4d11ee62a 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ObjectDetectionModelOverview.tsx +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/ObjectDetectionModelOverview.tsx @@ -3,12 +3,17 @@ import { ComboBox, + IComboBox, IComboBoxOption, IProcessedStyleSet, Slider, Stack } from "@fluentui/react"; -import { FluentUIStyles, IDataset } from "@responsible-ai/core-ui"; +import { + FluentUIStyles, + IDataset, + ITelemetryEvent +} from "@responsible-ai/core-ui"; import { localization } from "@responsible-ai/localization"; import React from "react"; @@ -44,6 +49,12 @@ export function getSelectableClassNames(dataset: IDataset): IComboBoxOption[] { export interface IObjectDetectionWidgetsProps { classNames: IProcessedStyleSet; dataset: IDataset; + setAggregateMethod: (value: string) => void; + setClassName: (value: string) => void; + setIoUThreshold: (value: number) => void; + updateDatasetCohortStats: () => void; + updateFeatureCohortStats: () => Promise; + telemetryHook?: (message: ITelemetryEvent) => void; } export class ObjectDetectionWidgets extends React.PureComponent { @@ -53,8 +64,9 @@ export class ObjectDetectionWidgets extends React.PureComponent @@ -68,6 +80,7 @@ export class ObjectDetectionWidgets extends React.PureComponent @@ -77,11 +90,37 @@ export class ObjectDetectionWidgets extends React.PureComponent `IoU=${value}%`} showValue /> ); } + + private onAggregateMethodChange = ( + _: React.FormEvent, + item?: IComboBoxOption + ): void => { + if (item) { + this.props.setAggregateMethod(item.text.toString()); + } + }; + + private onClassNameChange = ( + _: React.FormEvent, + item?: IComboBoxOption + ): void => { + if (item) { + this.props.setClassName(item.text.toString()); + } + }; + + private onIoUThresholdChange = (_: React.MouseEvent, value: number): void => { + if (value) { + this.props.setIoUThreshold(value); + } + }; } diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts index 21b2644699..ab191bcd52 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/ModelOverview/StatsTableUtils.ts @@ -140,7 +140,9 @@ export function generateCohortsStatsTable( // only 1 unique value in the set, set color to 0 colorValue = 0; } + const colorConfig = { color: "transparent" }; items.push({ + ...colorConfig, colorValue, value: Number(labeledStat.stat.toFixed(3)), x: metricIndex + 1, diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsViewProps.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsViewProps.ts index 7c8b37a8e6..6372f982a0 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsViewProps.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Controls/TabsView/TabsViewProps.ts @@ -45,11 +45,11 @@ export interface ITabsViewProps { abortSignal: AbortSignal ) => Promise; requestObjectDetectionMetrics?: ( - trueY: number[][][], - predictedY: number[][][], + selectionIndexes: number[][], aggregateMethod: string, className: string, - iouThresh: number + iouThresh: number, + abortSignal: AbortSignal ) => Promise; requestPredictions?: ( request: any[], diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboardProps.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboardProps.ts index 4b2797611f..68c1432b2f 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboardProps.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/ModelAssessmentDashboardProps.ts @@ -116,11 +116,11 @@ export interface IModelAssessmentDashboardProps abortSignal: AbortSignal ) => Promise; requestObjectDetectionMetrics?: ( - trueY: number[][][], - predictedY: number[][][], + selectionIndexes: number[][], aggregateMethod: string, className: string, - iouThresh: number + iouThresh: number, + abortSignal: AbortSignal ) => Promise; requestBubblePlotData?: ( filter: unknown[], diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/utils/getModelTypeFromProps.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/utils/getModelTypeFromProps.ts index 5313f68cd0..cb18385e79 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/utils/getModelTypeFromProps.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/utils/getModelTypeFromProps.ts @@ -47,5 +47,8 @@ export function getModelTypeFromProps( if (taskType === DatasetTaskType.MultilabelTextClassification) { return ModelTypes.TextMultilabel; } + if (taskType === DatasetTaskType.ObjectDetection) { + return ModelTypes.ObjectDetection; + } return modelType; } diff --git a/raiwidgets/raiwidgets/responsibleai_dashboard.py b/raiwidgets/raiwidgets/responsibleai_dashboard.py index 3c3d7e6a5d..bdad46826f 100644 --- a/raiwidgets/raiwidgets/responsibleai_dashboard.py +++ b/raiwidgets/raiwidgets/responsibleai_dashboard.py @@ -88,7 +88,7 @@ def get_object_detection_metrics(): data = request.get_json(force=True) return jsonify(self.input.get_object_detection_metrics(data)) self.add_url_rule( - get_exp, + get_object_detection_metrics, '/get_object_detection_metrics', methods=["POST"] ) diff --git a/raiwidgets/raiwidgets/responsibleai_dashboard_input.py b/raiwidgets/raiwidgets/responsibleai_dashboard_input.py index 0ca545dc84..5c3fc1a577 100644 --- a/raiwidgets/raiwidgets/responsibleai_dashboard_input.py +++ b/raiwidgets/raiwidgets/responsibleai_dashboard_input.py @@ -326,14 +326,12 @@ def get_object_detection_metrics(self, post_data): :rtype: Dict[str, List] """ try: - true_y = post_data[0] - predicted_y = post_data[1] - aggregate_method = post_data[2] - class_name = post_data[3] - iou_thresh = post_data[4] + selection_indexes = post_data[0] + aggregate_method = post_data[1] + class_name = post_data[2] + iou_thresh = post_data[3] exp = self._analysis.compute_object_detection_metrics( - true_y, - predicted_y, + selection_indexes, aggregate_method, class_name, iou_thresh