-
Notifications
You must be signed in to change notification settings - Fork 349
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Refactor]1.Add DatasetCohort 2. Move FeatureFlights to core (#2003)
* move flight to core and add DatasetCohort Signed-off-by: RubyZ10 <zhenzhu@microsoft.com> * fix import Signed-off-by: RubyZ10 <zhenzhu@microsoft.com> * fix probY Signed-off-by: RubyZ10 <zhenzhu@microsoft.com> --------- Signed-off-by: RubyZ10 <zhenzhu@microsoft.com>
- Loading branch information
Showing
14 changed files
with
399 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT License. | ||
|
||
import { | ||
ICategoricalRange, | ||
INumericRange, | ||
RangeTypes | ||
} from "@responsible-ai/mlchartlib"; | ||
|
||
import { ErrorCohortStats, MetricCohortStats } from "./Cohort/CohortStats"; | ||
import { CohortSource, Metrics } from "./Cohort/Constants"; | ||
import { DatasetCohortColumns } from "./DatasetCohortColumns"; | ||
import { IDataset } from "./Interfaces/IDataset"; | ||
import { ModelTypes } from "./Interfaces/IExplanationContext"; | ||
import { FilterMethods, IFilter } from "./Interfaces/IFilter"; | ||
import { getPropertyValues } from "./util/datasetUtils/getPropertyValues"; | ||
import { IsBinary, IsMulticlass } from "./util/ExplanationUtils"; | ||
import { MulticlassClassificationEnum } from "./util/JointDatasetUtils"; | ||
|
||
export class DatasetCohort { | ||
public selectedIndexes: number[] = []; | ||
public cohortStats: MetricCohortStats; | ||
public constructor( | ||
public name: string, | ||
public dataset: IDataset, | ||
public filters: IFilter[] = [], | ||
public modelTypes?: ModelTypes, | ||
private featureRanges?: { | ||
[key: string]: INumericRange | ICategoricalRange; | ||
}, | ||
public source: CohortSource = CohortSource.None, | ||
public isTemporary: boolean = false, | ||
cohortStats: MetricCohortStats | undefined = undefined | ||
) { | ||
this.name = name; | ||
this.selectedIndexes = this.applyFilters(); | ||
if (cohortStats) { | ||
this.cohortStats = cohortStats; | ||
} else { | ||
this.cohortStats = this.updateStatsFromData(); | ||
} | ||
} | ||
|
||
private applyFilters(): number[] { | ||
const indexes = []; | ||
const dataDict = this.getDataDict(this.modelTypes); | ||
for (const [index, row] of dataDict.entries()) { | ||
if (this.filterRow(row, this.filters)) { | ||
indexes.push(index); | ||
} | ||
} | ||
return indexes; | ||
} | ||
|
||
private getDataDict( | ||
modelType?: ModelTypes | ||
): Array<{ [key: string]: unknown }> { | ||
const dataDict = Array.from({ length: this.dataset.features.length }).map( | ||
(_, index) => { | ||
const dict = {}; | ||
dict[DatasetCohortColumns.Index] = index; | ||
return dict; | ||
} | ||
); | ||
this.dataset.features.forEach((row, index) => { | ||
row.forEach((val, colIndex) => { | ||
const featureName = this.dataset.feature_names[colIndex]; | ||
dataDict[index][featureName] = val; | ||
}); | ||
}); | ||
this.dataset.true_y.forEach((val, index) => { | ||
if (Array.isArray(val)) { | ||
val.forEach((subVal, subIndex) => { | ||
dataDict[index][DatasetCohortColumns.TrueY + subIndex.toString()] = | ||
subVal; | ||
}); | ||
} else { | ||
dataDict[index][DatasetCohortColumns.TrueY] = val; | ||
} | ||
}); | ||
this.dataset.predicted_y?.forEach((val, index) => { | ||
if (Array.isArray(val)) { | ||
val.forEach((subVal, subIndex) => { | ||
dataDict[index][ | ||
DatasetCohortColumns.PredictedY + subIndex.toString() | ||
] = subVal; | ||
}); | ||
} else { | ||
dataDict[index][DatasetCohortColumns.PredictedY] = val; | ||
} | ||
}); | ||
// set up errors | ||
if (modelType === ModelTypes.Regression) { | ||
for (const [index, row] of dataDict.entries()) { | ||
dataDict[index][DatasetCohortColumns.RegressionError] = Math.abs( | ||
row[DatasetCohortColumns.TrueY] - row[DatasetCohortColumns.PredictedY] | ||
); | ||
} | ||
} else if (modelType && IsBinary(modelType)) { | ||
// sum pred and 2*true to map to ints 0 - 3, | ||
// 0: TN | ||
// 1: FP | ||
// 2: FN | ||
// 3: TP | ||
for (const [index, row] of dataDict.entries()) { | ||
dataDict[index][DatasetCohortColumns.ClassificationError] = | ||
2 * row[DatasetCohortColumns.TrueY] + | ||
row[DatasetCohortColumns.PredictedY]; | ||
} | ||
} else if (modelType && IsMulticlass(modelType)) { | ||
for (const [index, row] of dataDict.entries()) { | ||
dataDict[index][DatasetCohortColumns.ClassificationError] = | ||
row[DatasetCohortColumns.TrueY] !== | ||
row[DatasetCohortColumns.PredictedY] | ||
? MulticlassClassificationEnum.Misclassified | ||
: MulticlassClassificationEnum.Correct; | ||
} | ||
} | ||
return dataDict; | ||
} | ||
|
||
private filterRow( | ||
row: { [key: string]: unknown }, | ||
filters: IFilter[] | ||
): boolean { | ||
return filters.every((filter) => { | ||
const rowVal = row[filter.column]; | ||
switch (filter.method) { | ||
case FilterMethods.Equal: | ||
return rowVal === filter.arg[0]; | ||
case FilterMethods.GreaterThan: | ||
return typeof rowVal == "number" && rowVal > filter.arg[0]; | ||
case FilterMethods.GreaterThanEqualTo: | ||
return typeof rowVal == "number" && rowVal >= filter.arg[0]; | ||
case FilterMethods.LessThan: | ||
return typeof rowVal == "number" && rowVal < filter.arg[0]; | ||
case FilterMethods.LessThanEqualTo: | ||
return typeof rowVal == "number" && rowVal <= filter.arg[0]; | ||
case FilterMethods.Includes: | ||
return this.includesValue(filter, rowVal); | ||
case FilterMethods.Excludes: | ||
return !this.includesValue(filter, rowVal); | ||
case FilterMethods.InTheRangeOf: | ||
return ( | ||
typeof rowVal == "number" && | ||
rowVal >= filter.arg[0] && | ||
rowVal <= filter.arg[1] | ||
); | ||
default: | ||
return false; | ||
} | ||
}); | ||
} | ||
|
||
private includesValue(filter: IFilter, val: unknown): boolean { | ||
if ( | ||
this.featureRanges && | ||
this.featureRanges[filter.column]?.rangeType === RangeTypes.Categorical | ||
) { | ||
if ( | ||
filter.column === DatasetCohortColumns.PredictedY || | ||
filter.column === DatasetCohortColumns.TrueY || | ||
filter.column === DatasetCohortColumns.ClassificationError | ||
) { | ||
return filter.arg.includes(val as number); | ||
} | ||
const uniqueValues = ( | ||
this.featureRanges[filter.column] as ICategoricalRange | ||
).uniqueValues; | ||
const index = uniqueValues.findIndex((item) => item === val); | ||
return filter.arg.includes(index); | ||
} | ||
return false; | ||
} | ||
|
||
private updateStatsFromData(): ErrorCohortStats { | ||
let totalAll = 0; | ||
let totalCohort = 0; | ||
let totalCorrect = 0; | ||
let totalCohortCorrect = 0; | ||
let totalIncorrect = 0; | ||
let totalCohortIncorrect = 0; | ||
let errorRate = 0; | ||
|
||
const trueYsCohort = getPropertyValues( | ||
this.selectedIndexes, | ||
DatasetCohortColumns.TrueY, | ||
this.dataset | ||
); | ||
const predYsCohort = getPropertyValues( | ||
this.selectedIndexes, | ||
DatasetCohortColumns.PredictedY, | ||
this.dataset | ||
); | ||
const indexes = [...new Array(this.dataset.features.length).keys()]; | ||
const trueYs = getPropertyValues( | ||
indexes, | ||
DatasetCohortColumns.TrueY, | ||
this.dataset | ||
); | ||
const predYs = getPropertyValues( | ||
indexes, | ||
DatasetCohortColumns.PredictedY, | ||
this.dataset | ||
); | ||
|
||
if (trueYsCohort && trueYs && predYsCohort && predYs) { | ||
totalCohort = trueYsCohort.length; | ||
totalAll = trueYs.length; | ||
|
||
for (let i = 0; i < totalAll; i++) { | ||
totalCorrect += trueYs[i] === predYs[i] ? 1 : 0; | ||
} | ||
totalIncorrect = totalAll - totalCorrect; | ||
|
||
for (let i = 0; i < totalCohort; i++) { | ||
totalCohortCorrect += trueYsCohort[i] === predYsCohort[i] ? 1 : 0; | ||
} | ||
totalCohortIncorrect = totalCohort - totalCohortCorrect; | ||
} | ||
// Calculate error rate | ||
if (totalCohort === 0) { | ||
errorRate = 0; | ||
} else { | ||
errorRate = (totalCohortIncorrect / totalCohort) * 100; | ||
} | ||
return new ErrorCohortStats( | ||
totalCohortIncorrect, | ||
totalCohort, | ||
totalIncorrect, | ||
totalAll, | ||
errorRate, | ||
Metrics.ErrorRate | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the MIT License. | ||
|
||
export enum DatasetCohortColumns { | ||
Index = "Index", | ||
Dataset = "Data", | ||
PredictedY = "Predicted Y", | ||
TrueY = "True Y", | ||
ClassificationError = "Classification outcome", | ||
RegressionError = "Regression error", | ||
ProbabilityY = "Probability Y" | ||
} |
Oops, something went wrong.