Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model Overview] Object Detection Metrics support #2025

Merged
merged 30 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
441781b
OD support to surface metrics
Advitya17 Mar 27, 2023
9f77608
endpoint schema ckpt
Advitya17 Mar 28, 2023
51d9268
flask endpoint ckpt
Advitya17 Mar 30, 2023
0cc09f2
flask endpoint invocation fix
Advitya17 Mar 31, 2023
9af0d79
od fn template rearrangement
Advitya17 Apr 1, 2023
2e0b893
ckpt
Advitya17 Apr 2, 2023
aa1f3f1
added endpoint return value
Advitya17 Apr 2, 2023
7e2fdc7
endpoint return update
Advitya17 Apr 2, 2023
d234899
ckpt
Advitya17 Apr 3, 2023
fa672bd
onchange functions ckpt
Advitya17 Apr 3, 2023
c0e0f08
backend state update
Advitya17 Apr 3, 2023
2cf8664
fn ckpt
Advitya17 Apr 3, 2023
7be8a2e
state function update
Advitya17 Apr 3, 2023
bc15797
state update fix
Advitya17 Apr 3, 2023
1b8d9c2
endpoint refactor ckpt
Advitya17 Apr 4, 2023
72163f5
code cleanup ckpt
Advitya17 Apr 4, 2023
ef931e7
endpoint condition update
Advitya17 Apr 5, 2023
8e32d85
filter change & state var fix
Advitya17 Apr 5, 2023
7f8001b
input saving ckpt
Advitya17 Apr 6, 2023
fb57341
merge conflict fix
Advitya17 Apr 6, 2023
50d0cc8
lint fixes
Advitya17 Apr 10, 2023
e813bdb
N/A support & telemetry reference update
Advitya17 Apr 10, 2023
cacb8e9
transparent color support for numbers
Advitya17 Apr 10, 2023
18f2aa7
lint fixes
Advitya17 Apr 10, 2023
452b981
state refactor
Advitya17 Apr 12, 2023
698ca35
type & style updates
Advitya17 Apr 12, 2023
7ebf0fa
lint fixes
Advitya17 Apr 13, 2023
ce4604f
removed console logs
Advitya17 Apr 13, 2023
35ba4e5
Merge branch 'main' of https://github.com/microsoft/responsible-ai-to…
Advitya17 Apr 13, 2023
1d7ab76
Merge branch 'main' into agemawat/od_metrics
Advitya17 Apr 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 6 additions & 5 deletions apps/widget/src/app/ModelAssessment.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,17 @@ export class ModelAssessment extends React.Component<IModelAssessmentProps> {
return callFlaskService(this.props.config, data, "/get_exp");
};
callBack.requestObjectDetectionMetrics = async (
trueY: number[][][],
predictedY: number[][][],
selectionIndexes: number[][],
aggregateMethod: string,
className: string,
iouThresh: number
iouThresh: number,
abortSignal: AbortSignal
): Promise<any[]> => {
return callFlaskService(
this.props.config,
[trueY, predictedY, aggregateMethod, className, iouThresh],
"/get_object_detection_metrics"
[selectionIndexes, aggregateMethod, className, iouThresh],
"/get_object_detection_metrics",
abortSignal
);
};
callBack.requestPredictions = async (data: any[]): Promise<any[]> => {
Expand Down
7 changes: 4 additions & 3 deletions libs/core-ui/src/lib/Context/ModelAssessmentContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ export interface IModelAssessmentContext {
| undefined;
requestObjectDetectionMetrics?:
| ((
trueY: number[][][],
predictedY: number[][][],
selectionIndexes: number[][],
aggregateMethod: string,
className: string,
iouThresh: number
iouThresh: number,
abortSignal: AbortSignal
) => Promise<any[]>)
| undefined;
requestSplinePlotDistribution?: (
Expand Down Expand Up @@ -173,6 +173,7 @@ export const defaultModelAssessmentContext: IModelAssessmentContext = {
modelMetadata: {} as IExplanationModelMetadata,
modelType: undefined,
requestExp: undefined,
requestObjectDetectionMetrics: undefined,
requestLocalFeatureExplanations: undefined,
requestPredictions: undefined,
selectedErrorCohort: {} as ErrorCohort,
Expand Down
2 changes: 2 additions & 0 deletions libs/core-ui/src/lib/util/JointDataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ export class JointDataset {
public static readonly PredictedYLabel = "PredictedY";
public static readonly ProbabilityYRoot = "ProbabilityClass";
public static readonly TrueYLabel = "TrueY";
public static readonly ObjectDetectionPredictedYLabel = "ObjectDetectionPredictedY";
public static readonly ObjectDetectionTrueYLabel = "ObjectDetectionTrueY";
public static readonly DitherLabel = "Dither";
public static readonly DitherLabel2 = "Dither2";
public static readonly ClassificationError = "ClassificationError";
Expand Down
2 changes: 2 additions & 0 deletions libs/core-ui/src/lib/util/JointDatasetUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ export interface IJointDatasetArgs {
metadata: IExplanationModelMetadata;
featureMetaData?: IFeatureMetaData;
targetColumn?: string | string[];
objectDetectionTrueY?: number[][][],
objectDetectionPredictedY?: number[][][]
}

export enum ColumnCategories {
Expand Down
19 changes: 5 additions & 14 deletions libs/core-ui/src/lib/util/ObjectDetectionStatisticsUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,50 +8,41 @@ import {
TotalCohortSamples
} from "../Interfaces/IStatistic";

import { JointDataset } from "./JointDataset";

export enum ObjectDetectionMetrics {
MeanAveragePrecision = "meanAveragePrecision",
AveragePrecision = "averagePrecision",
AverageRecall = "averageRecall"
}

export const generateObjectDetectionStats: (
jointDataset: JointDataset,
selectionIndexes: number[][]
) => ILabeledStatistic[][] = (
jointDataset: JointDataset,
selectionIndexes: number[][]
): ILabeledStatistic[][] => {
const numLabels = jointDataset.numLabels;

return selectionIndexes.map((selectionArray) => {
const count = selectionArray.length;

// TODO: replace placeholder values with flask endpoint calls to python backend.
const meanAveragePrecision = 42;
const averagePrecision = 42;
const averageRecall = 42;

return [
{
key: TotalCohortSamples,
label: localization.Interpret.Statistics.samples,
stat: count / numLabels // TODO: remove numLabels from here when using jointDataset elsewhere.
stat: count
},
{
key: ObjectDetectionMetrics.MeanAveragePrecision,
label: localization.Interpret.Statistics.meanAveragePrecision,
stat: meanAveragePrecision
stat: 0
Advitya17 marked this conversation as resolved.
Show resolved Hide resolved
},
{
key: ObjectDetectionMetrics.AveragePrecision,
label: localization.Interpret.Statistics.averagePrecision,
stat: averagePrecision
stat: 0
},
{
key: ObjectDetectionMetrics.AverageRecall,
label: localization.Interpret.Statistics.averageRecall,
stat: averageRecall
stat: 0
}
];
});
Expand Down
2 changes: 1 addition & 1 deletion libs/core-ui/src/lib/util/StatisticsUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ export const generateMetrics: (
});
}
if (modelType === ModelTypes.ObjectDetection) {
return generateObjectDetectionStats(jointDataset, selectionIndexes);
return generateObjectDetectionStats(selectionIndexes);
}
const outcomes = jointDataset.unwrap(JointDataset.ClassificationError);
return selectionIndexes.map((selectionArray) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ export interface IVisionExplanationDashboardProps {
abortSignal: AbortSignal
) => Promise<any[]>;
requestObjectDetectionMetrics?: (
trueY: number[][][],
predictedY: number[][][],
selectionIndexes: number[][],
aggregateMethod: string,
className: string,
iouThresh: number
iouThresh: number,
abortSignal: AbortSignal
) => Promise<any[]>;
selectedCohort: ErrorCohort;
setSelectedCohort: (cohort: ErrorCohort) => void;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ import {
TelemetryLevels,
TelemetryEventName,
DatasetTaskType,
ImageClassificationMetrics
ImageClassificationMetrics,
TotalCohortSamples
} from "@responsible-ai/core-ui";
import { localization } from "@responsible-ai/localization";
import React from "react";
Expand All @@ -52,6 +53,12 @@ import { getSelectableMetrics } from "./StatsTableUtils";

interface IModelOverviewProps {
telemetryHook?: (message: ITelemetryEvent) => void;
requestObjectDetectionMetrics?: (
selectionIndexes: number[][],
aggregateMethod: string,
className: string,
iouThresh: number
) => Promise<any[]>;
}

interface IModelOverviewState {
Expand All @@ -70,6 +77,9 @@ interface IModelOverviewState {
datasetBasedCohorts: ErrorCohort[];
featureBasedCohortLabeledStatistics: ILabeledStatistic[][];
featureBasedCohorts: ErrorCohort[];
aggregateMethod: string,
className: string,
iouThresh: number
}

const datasetCohortViewPivotKey = "datasetCohortView";
Expand All @@ -86,6 +96,7 @@ export class ModelOverview extends React.Component<

public constructor(props: IModelOverviewProps) {
super(props);

this.state = {
chartConfigurationIsVisible: false,
datasetBasedCohorts: [],
Expand All @@ -99,7 +110,10 @@ export class ModelOverview extends React.Component<
selectedFeatures: [],
selectedFeaturesContinuousFeatureBins: {},
selectedMetrics: [],
showHeatmapColors: true
showHeatmapColors: true,
Advitya17 marked this conversation as resolved.
Show resolved Hide resolved
aggregateMethod: localization.ModelAssessment.ModelOverview.metricTypes.macro,
className: "",
iouThresh: 70
};
}

Expand Down Expand Up @@ -333,6 +347,7 @@ export class ModelOverview extends React.Component<
<ObjectDetectionWidgets
classNames={classNames}
dataset={this.context.dataset}
modelOverview={this}
/>
)}
</Stack>
Expand Down Expand Up @@ -513,18 +528,96 @@ export class ModelOverview extends React.Component<
}

private updateDatasetCohortStats(): void {
let selectionIndexes : number[][] = this.context.errorCohorts.map((errorCohort) =>
errorCohort.cohort.unwrap(JointDataset.IndexLabel)
);
const datasetCohortMetricStats = generateMetrics(
this.context.jointDataset,
this.context.errorCohorts.map((errorCohort) =>
errorCohort.cohort.unwrap(JointDataset.IndexLabel)
),
selectionIndexes,
this.context.modelMetadata.modelType
);

this.setState({
datasetBasedCohorts: this.context.errorCohorts,
datasetCohortLabeledStatistics: datasetCohortMetricStats
});

this.updateObjectDetectionMetrics(selectionIndexes, true);
}

private updateObjectDetectionMetrics(
selectionIndexes: number[][],
isDatasetCohort: boolean): void {
if (this.context.requestObjectDetectionMetrics &&
selectionIndexes.length > 0 &&
this.state.aggregateMethod.length > 0 &&
this.state.className.length > 0 &&
this.state.iouThresh) {
console.log('entered endpoint')
Advitya17 marked this conversation as resolved.
Show resolved Hide resolved

this.context.requestObjectDetectionMetrics(
selectionIndexes,
this.state.aggregateMethod,
this.state.className,
this.state.iouThresh,
new AbortController().signal
).then(
(result) => {
// TODO: assert length of result and selectionIndexes are the same.
Advitya17 marked this conversation as resolved.
Show resolved Hide resolved
let updatedMetricStats : ILabeledStatistic[][] = [];

for (let cohortIndex = 0; cohortIndex < result.length; cohortIndex++) {
let count = selectionIndexes[cohortIndex].length;

let [meanAveragePrecision, averagePrecision, averageRecall] = result[cohortIndex];

let updatedCohortMetricStats = [
{
key: TotalCohortSamples,
label: localization.Interpret.Statistics.samples,
stat: count
},
{
key: ObjectDetectionMetrics.MeanAveragePrecision,
label: localization.Interpret.Statistics.meanAveragePrecision,
stat: meanAveragePrecision
},
{
key: ObjectDetectionMetrics.AveragePrecision,
label: localization.Interpret.Statistics.averagePrecision,
stat: averagePrecision
},
{
key: ObjectDetectionMetrics.AverageRecall,
label: localization.Interpret.Statistics.averageRecall,
stat: averageRecall
}
];

updatedMetricStats.push(updatedCohortMetricStats);

}

isDatasetCohort
? this.updateDatasetCohortState(updatedMetricStats)
: this.updateFeatureCohortState(updatedMetricStats);

}
)
}
else {
console.log(this.context.requestObjectDetectionMetrics);
Advitya17 marked this conversation as resolved.
Show resolved Hide resolved
console.log(selectionIndexes.length)
console.log(this.state.aggregateMethod);
console.log(this.state.className);
console.log(this.state.iouThresh);
}
}

private updateDatasetCohortState(cohortMetricStats: ILabeledStatistic[][]): void {
this.setState({
datasetCohortLabeledStatistics: cohortMetricStats
});
}

private async updateFeatureCohortStats(): Promise<void> {
Expand All @@ -537,18 +630,28 @@ export class ModelOverview extends React.Component<
this.state.selectedFeaturesContinuousFeatureBins
);

let selectionIndexes: number[][] = featureBasedCohorts.map((errorCohort) =>
errorCohort.cohort.unwrap(JointDataset.IndexLabel)
);

const featureCohortMetricStats = generateMetrics(
this.context.jointDataset,
featureBasedCohorts.map((errorCohort) =>
errorCohort.cohort.unwrap(JointDataset.IndexLabel)
),
selectionIndexes,
this.context.modelMetadata.modelType
);

this.setState({
featureBasedCohortLabeledStatistics: featureCohortMetricStats,
featureBasedCohorts
});

this.updateObjectDetectionMetrics(selectionIndexes, false);
}

private updateFeatureCohortState(cohortMetricStats: ILabeledStatistic[][]): void {
this.setState({
featureBasedCohortLabeledStatistics: cohortMetricStats
});
}

private ifCohortIndexesEquals(a: number[], b: number[]): boolean {
Expand Down