Skip to content

Commit

Permalink
[Refactor]1.Add DatasetCohort 2. Move FeatureFlights to core (#2003)
Browse files Browse the repository at this point in the history
* move flight to core and add DatasetCohort

Signed-off-by: RubyZ10 <zhenzhu@microsoft.com>

* fix import

Signed-off-by: RubyZ10 <zhenzhu@microsoft.com>

* fix probY

Signed-off-by: RubyZ10 <zhenzhu@microsoft.com>

---------

Signed-off-by: RubyZ10 <zhenzhu@microsoft.com>
  • Loading branch information
RubyZ10 committed Mar 14, 2023
1 parent bbb12fd commit a373cd0
Show file tree
Hide file tree
Showing 14 changed files with 399 additions and 49 deletions.
3 changes: 1 addition & 2 deletions apps/dashboard/src/app/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
// Licensed under the MIT License.

import { ITheme } from "@fluentui/react";
import { generateRoute } from "@responsible-ai/core-ui";
import { generateRoute, parseFeatureFlights } from "@responsible-ai/core-ui";
import { Language } from "@responsible-ai/localization";
import { parseFeatureFlights } from "@responsible-ai/model-assessment";
import _ from "lodash";
import React from "react";
import { Redirect, generatePath } from "react-router-dom";
Expand Down
4 changes: 2 additions & 2 deletions apps/dashboard/src/app/AppHeader.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ import {
ICommandBarItemProps,
IContextualMenuItem
} from "@fluentui/react";
import { Language } from "@responsible-ai/localization";
import {
featureFlights,
featureFlightSeparator,
parseFeatureFlights
} from "@responsible-ai/model-assessment";
} from "@responsible-ai/core-ui";
import { Language } from "@responsible-ai/localization";
import React from "react";

import { applications, IApplications } from "./applications";
Expand Down
6 changes: 3 additions & 3 deletions apps/widget/src/app/ModelAssessment.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ import {
IHighchartBoxData,
IHighchartBubbleSDKClusterData,
ICounterfactualData,
ILocalExplanations
ILocalExplanations,
parseFeatureFlights
} from "@responsible-ai/core-ui";
import {
ModelAssessmentDashboard,
IModelAssessmentData,
IModelAssessmentDashboardProps,
parseFeatureFlights
IModelAssessmentDashboardProps
} from "@responsible-ai/model-assessment";
import React from "react";

Expand Down
1 change: 1 addition & 0 deletions libs/core-ui/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ export * from "./lib/components/MissingParametersPlaceholder";
export * from "./lib/components/LabelWithCallout";
export * from "./lib/components/NoData";
export * from "./lib/components/SVGToolTip";
export * from "./lib/FeatureFlights";
export * from "./lib/Interfaces/ComparisonTypes";
export * from "./lib/Interfaces/ExplanationInterfaces";
export * from "./lib/Interfaces/IExplanationContext";
Expand Down
8 changes: 0 additions & 8 deletions libs/core-ui/src/lib/Cohort/Cohort.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,6 @@ import { compare } from "../util/compare";
import { JointDataset } from "../util/JointDataset";
import { ModelExplanationUtils } from "../util/ModelExplanationUtils";

export enum CohortSource {
None = "None",
TreeMap = "Tree map",
HeatMap = "Heat map",
ManuallyCreated = "Manually created",
Prebuilt = "Prebuilt"
}

export class Cohort {
private static _cohortIndex = 0;

Expand Down
3 changes: 2 additions & 1 deletion libs/core-ui/src/lib/Cohort/CohortBasedComponent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import { IDataset } from "../Interfaces/IDataset";
import { IFilter, ICompositeFilter } from "../Interfaces/IFilter";
import { JointDataset } from "../util/JointDataset";

import { Cohort, CohortSource } from "./Cohort";
import { Cohort } from "./Cohort";
import { MetricCohortStats } from "./CohortStats";
import { CohortSource } from "./Constants";
import { ErrorCohort } from "./ErrorCohort";

export interface ICohortBasedComponentProps {
Expand Down
8 changes: 8 additions & 0 deletions libs/core-ui/src/lib/Cohort/Constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,11 @@ export class Metrics {
public static RecallScore = "Recall score";
public static SampleSize = "Sample size";
}

export enum CohortSource {
None = "None",
TreeMap = "Tree map",
HeatMap = "Heat map",
ManuallyCreated = "Manually created",
Prebuilt = "Prebuilt"
}
4 changes: 2 additions & 2 deletions libs/core-ui/src/lib/Cohort/ErrorCohort.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import { getBasicFilterString } from "../util/getBasicFilterString";
import { getCompositeFilterString } from "../util/getCompositeFilterString";
import { JointDataset } from "../util/JointDataset";

import { Cohort, CohortSource } from "./Cohort";
import { Cohort } from "./Cohort";
import { MetricCohortStats, ErrorCohortStats } from "./CohortStats";
import { Metrics } from "./Constants";
import { CohortSource, Metrics } from "./Constants";

export class ErrorCohort {
public cohortStats: MetricCohortStats;
Expand Down
236 changes: 236 additions & 0 deletions libs/core-ui/src/lib/DatasetCohort.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

import {
ICategoricalRange,
INumericRange,
RangeTypes
} from "@responsible-ai/mlchartlib";

import { ErrorCohortStats, MetricCohortStats } from "./Cohort/CohortStats";
import { CohortSource, Metrics } from "./Cohort/Constants";
import { DatasetCohortColumns } from "./DatasetCohortColumns";
import { IDataset } from "./Interfaces/IDataset";
import { ModelTypes } from "./Interfaces/IExplanationContext";
import { FilterMethods, IFilter } from "./Interfaces/IFilter";
import { getPropertyValues } from "./util/datasetUtils/getPropertyValues";
import { IsBinary, IsMulticlass } from "./util/ExplanationUtils";
import { MulticlassClassificationEnum } from "./util/JointDatasetUtils";

export class DatasetCohort {
public selectedIndexes: number[] = [];
public cohortStats: MetricCohortStats;
public constructor(
public name: string,
public dataset: IDataset,
public filters: IFilter[] = [],
public modelTypes?: ModelTypes,
private featureRanges?: {
[key: string]: INumericRange | ICategoricalRange;
},
public source: CohortSource = CohortSource.None,
public isTemporary: boolean = false,
cohortStats: MetricCohortStats | undefined = undefined
) {
this.name = name;
this.selectedIndexes = this.applyFilters();
if (cohortStats) {
this.cohortStats = cohortStats;
} else {
this.cohortStats = this.updateStatsFromData();
}
}

private applyFilters(): number[] {
const indexes = [];
const dataDict = this.getDataDict(this.modelTypes);
for (const [index, row] of dataDict.entries()) {
if (this.filterRow(row, this.filters)) {
indexes.push(index);
}
}
return indexes;
}

private getDataDict(
modelType?: ModelTypes
): Array<{ [key: string]: unknown }> {
const dataDict = Array.from({ length: this.dataset.features.length }).map(
(_, index) => {
const dict = {};
dict[DatasetCohortColumns.Index] = index;
return dict;
}
);
this.dataset.features.forEach((row, index) => {
row.forEach((val, colIndex) => {
const featureName = this.dataset.feature_names[colIndex];
dataDict[index][featureName] = val;
});
});
this.dataset.true_y.forEach((val, index) => {
if (Array.isArray(val)) {
val.forEach((subVal, subIndex) => {
dataDict[index][DatasetCohortColumns.TrueY + subIndex.toString()] =
subVal;
});
} else {
dataDict[index][DatasetCohortColumns.TrueY] = val;
}
});
this.dataset.predicted_y?.forEach((val, index) => {
if (Array.isArray(val)) {
val.forEach((subVal, subIndex) => {
dataDict[index][
DatasetCohortColumns.PredictedY + subIndex.toString()
] = subVal;
});
} else {
dataDict[index][DatasetCohortColumns.PredictedY] = val;
}
});
// set up errors
if (modelType === ModelTypes.Regression) {
for (const [index, row] of dataDict.entries()) {
dataDict[index][DatasetCohortColumns.RegressionError] = Math.abs(
row[DatasetCohortColumns.TrueY] - row[DatasetCohortColumns.PredictedY]
);
}
} else if (modelType && IsBinary(modelType)) {
// sum pred and 2*true to map to ints 0 - 3,
// 0: TN
// 1: FP
// 2: FN
// 3: TP
for (const [index, row] of dataDict.entries()) {
dataDict[index][DatasetCohortColumns.ClassificationError] =
2 * row[DatasetCohortColumns.TrueY] +
row[DatasetCohortColumns.PredictedY];
}
} else if (modelType && IsMulticlass(modelType)) {
for (const [index, row] of dataDict.entries()) {
dataDict[index][DatasetCohortColumns.ClassificationError] =
row[DatasetCohortColumns.TrueY] !==
row[DatasetCohortColumns.PredictedY]
? MulticlassClassificationEnum.Misclassified
: MulticlassClassificationEnum.Correct;
}
}
return dataDict;
}

private filterRow(
row: { [key: string]: unknown },
filters: IFilter[]
): boolean {
return filters.every((filter) => {
const rowVal = row[filter.column];
switch (filter.method) {
case FilterMethods.Equal:
return rowVal === filter.arg[0];
case FilterMethods.GreaterThan:
return typeof rowVal == "number" && rowVal > filter.arg[0];
case FilterMethods.GreaterThanEqualTo:
return typeof rowVal == "number" && rowVal >= filter.arg[0];
case FilterMethods.LessThan:
return typeof rowVal == "number" && rowVal < filter.arg[0];
case FilterMethods.LessThanEqualTo:
return typeof rowVal == "number" && rowVal <= filter.arg[0];
case FilterMethods.Includes:
return this.includesValue(filter, rowVal);
case FilterMethods.Excludes:
return !this.includesValue(filter, rowVal);
case FilterMethods.InTheRangeOf:
return (
typeof rowVal == "number" &&
rowVal >= filter.arg[0] &&
rowVal <= filter.arg[1]
);
default:
return false;
}
});
}

private includesValue(filter: IFilter, val: unknown): boolean {
if (
this.featureRanges &&
this.featureRanges[filter.column]?.rangeType === RangeTypes.Categorical
) {
if (
filter.column === DatasetCohortColumns.PredictedY ||
filter.column === DatasetCohortColumns.TrueY ||
filter.column === DatasetCohortColumns.ClassificationError
) {
return filter.arg.includes(val as number);
}
const uniqueValues = (
this.featureRanges[filter.column] as ICategoricalRange
).uniqueValues;
const index = uniqueValues.findIndex((item) => item === val);
return filter.arg.includes(index);
}
return false;
}

private updateStatsFromData(): ErrorCohortStats {
let totalAll = 0;
let totalCohort = 0;
let totalCorrect = 0;
let totalCohortCorrect = 0;
let totalIncorrect = 0;
let totalCohortIncorrect = 0;
let errorRate = 0;

const trueYsCohort = getPropertyValues(
this.selectedIndexes,
DatasetCohortColumns.TrueY,
this.dataset
);
const predYsCohort = getPropertyValues(
this.selectedIndexes,
DatasetCohortColumns.PredictedY,
this.dataset
);
const indexes = [...new Array(this.dataset.features.length).keys()];
const trueYs = getPropertyValues(
indexes,
DatasetCohortColumns.TrueY,
this.dataset
);
const predYs = getPropertyValues(
indexes,
DatasetCohortColumns.PredictedY,
this.dataset
);

if (trueYsCohort && trueYs && predYsCohort && predYs) {
totalCohort = trueYsCohort.length;
totalAll = trueYs.length;

for (let i = 0; i < totalAll; i++) {
totalCorrect += trueYs[i] === predYs[i] ? 1 : 0;
}
totalIncorrect = totalAll - totalCorrect;

for (let i = 0; i < totalCohort; i++) {
totalCohortCorrect += trueYsCohort[i] === predYsCohort[i] ? 1 : 0;
}
totalCohortIncorrect = totalCohort - totalCohortCorrect;
}
// Calculate error rate
if (totalCohort === 0) {
errorRate = 0;
} else {
errorRate = (totalCohortIncorrect / totalCohort) * 100;
}
return new ErrorCohortStats(
totalCohortIncorrect,
totalCohort,
totalIncorrect,
totalAll,
errorRate,
Metrics.ErrorRate
);
}
}
12 changes: 12 additions & 0 deletions libs/core-ui/src/lib/DatasetCohortColumns.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

export enum DatasetCohortColumns {
Index = "Index",
Dataset = "Data",
PredictedY = "Predicted Y",
TrueY = "True Y",
ClassificationError = "Classification outcome",
RegressionError = "Regression error",
ProbabilityY = "Probability Y"
}
Loading

0 comments on commit a373cd0

Please sign in to comment.