Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pre-built cohort into adult census notebook #1243

Merged
merged 23 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
cba4923
[WIP] Add pre-built cohort into adult census notebook
gaugup Feb 24, 2022
df4fe0b
erroranalysis version bump in raiwidgets to 0.1.31 (#1245)
imatiach-msft Feb 24, 2022
1a37fcf
Make cohrtData empty list in case no pre-bdefined cohorts are injecte…
gaugup Feb 26, 2022
f020a00
Simplify the train pipeline responsibleaidashboard-census-classificat…
gaugup Mar 8, 2022
0c9ee31
Add regression test for pre-defined cohorts in raiwidgets (#1249)
gaugup Feb 28, 2022
6e837f1
color (#1248)
zhb000 Mar 1, 2022
8872e0a
Add feature importance box & bar chart (#1241)
zhb000 Mar 2, 2022
1b2bc5e
PreBuilt cohorts UX changes (#1242)
gaugup Mar 2, 2022
b47627e
Make _cohort.py module a public module (#1253)
gaugup Mar 3, 2022
9d8f0b6
fix notebook build failures due to pywinpty dependency release failin…
imatiach-msft Mar 4, 2022
ea553ee
Add supported models and data types to README.md responsibleai (#1259)
gaugup Mar 4, 2022
e2e289c
make getting-started notebook a markdown file showing APIs (#1223)
imatiach-msft Mar 4, 2022
26b2453
refactor tabs out of RAI dashboard into a separate component (#1256)
imatiach-msft Mar 4, 2022
ac1563d
Add individual causal scatter chart (#1258)
zhb000 Mar 5, 2022
422db55
minor fix to url for responsibleai package in setup.py (#1260)
imatiach-msft Mar 7, 2022
5610f64
Merge branch 'main' into gaugup/NB
gaugup Mar 8, 2022
39c68a6
Fix UX e2e tests and address code review comments
gaugup Mar 8, 2022
4249ae1
Merge branch 'main' into gaugup/NB
gaugup Mar 8, 2022
5e03481
Fix eslint
gaugup Mar 9, 2022
fb8f38d
Merge branch 'main' into gaugup/NB
gaugup Apr 20, 2022
60cdf37
Address review comments
gaugup Apr 20, 2022
05a4f3e
Reset the number of samples in test dataset
gaugup Apr 20, 2022
652587b
Merge branch 'main' into gaugup/NB
gaugup Apr 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ const modelAssessmentDatasets = {
"capital-loss"
],
modelStatisticsData: {
cohortDropDownValues: [
"All data",
"Cohort Age and Hours-Per-Week",
"Cohort Marital-Status",
"Cohort Index",
"Cohort Predicted Y",
"Cohort True Y"
],
defaultXAxis: "Probability : <=50K",
defaultXAxisPanelValue: "Prediction probabilities",
defaultYAxis: "Cohort",
Expand Down Expand Up @@ -115,6 +123,7 @@ const modelAssessmentDatasets = {
"s6"
],
modelStatisticsData: {
cohortDropDownValues: ["All data"],
defaultXAxis: "Error",
defaultXAxisPanelValue: "Error",
defaultYAxis: "Cohort",
Expand Down Expand Up @@ -180,6 +189,7 @@ const modelAssessmentDatasets = {
],
isRegression: true,
modelStatisticsData: {
cohortDropDownValues: ["All data"],
defaultXAxis: "Error",
defaultXAxisPanelValue: "Error",
defaultYAxis: "Cohort",
Expand Down Expand Up @@ -272,6 +282,7 @@ const modelAssessmentDatasets = {
"YrSold"
],
modelStatisticsData: {
cohortDropDownValues: ["All data"],
defaultXAxis: "Probability : Less than median",
defaultXAxisPanelValue: "Prediction probabilities",
defaultYAxis: "Cohort",
Expand Down Expand Up @@ -364,6 +375,7 @@ const modelAssessmentDatasets = {
"YrSold"
],
modelStatisticsData: {
cohortDropDownValues: ["All data"],
hasModelStatisticsComponent: false,
hasSideBar: false
},
Expand Down Expand Up @@ -416,6 +428,7 @@ const modelAssessmentDatasets = {
],
isMulticlass: true,
modelStatisticsData: {
cohortDropDownValues: ["All data"],
defaultXAxis: "Predicted Y",
defaultXAxisPanelValue: "Prediction probabilities",
defaultYAxis: "Cohort",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ export function describeModelPerformanceSideBar(
});

it("Side bar should be updated with updated values", () => {
cy.get(Locators.MSSideBarCards).should("have.length", 1);
cy.get(Locators.MSSideBarCards).should(
"have.length",
dataShape.modelStatisticsData?.cohortDropDownValues
? dataShape.modelStatisticsData?.cohortDropDownValues.length
: 0
);
cy.get(`${Locators.MSCRotatedVerticalBox} button`)
.click()
.get(
Expand Down Expand Up @@ -50,7 +55,12 @@ export function describeModelPerformanceSideBar(
cy.get(`${Locators.MSCRotatedVerticalBox}`).contains(
dataShape.modelStatisticsData?.defaultYAxis || "Cohort"
);
cy.get(Locators.MSSideBarCards).should("have.length", 1);
cy.get(Locators.MSSideBarCards).should(
"have.length",
dataShape.modelStatisticsData?.cohortDropDownValues
? dataShape.modelStatisticsData?.cohortDropDownValues.length
: 0
);
});

it("Should have dropdown to select cohort when y axis is changed to different value than cohort", () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,80 @@
"rai_insights.compute()"
]
},
{
"cell_type": "markdown",
"id": "b84c6c0d",
"metadata": {},
"source": [
"Compose some cohorts which can be injected into the `ResponsibleAIDashboard`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0994b7d6",
"metadata": {},
"outputs": [],
"source": [
"from raiwidgets.cohort import Cohort, CohortFilter, CohortFilterMethods\n",
"\n",
"# Cohort on age and hours-per-week features in the dataset\n",
"cohort_filter_age = CohortFilter(\n",
" method=CohortFilterMethods.METHOD_LESS,\n",
" arg=[65],\n",
" column='age')\n",
"cohort_filter_hours_per_week = CohortFilter(\n",
" method=CohortFilterMethods.METHOD_GREATER,\n",
" arg=[40],\n",
" column='hours-per-week')\n",
"\n",
"user_cohort_age_and_hours_per_week = Cohort(name='Cohort Age and Hours-Per-Week')\n",
"user_cohort_age_and_hours_per_week.add_cohort_filter(cohort_filter_age)\n",
"user_cohort_age_and_hours_per_week.add_cohort_filter(cohort_filter_hours_per_week)\n",
"\n",
"# Cohort on marital-status feature in the dataset\n",
"cohort_filter_marital_status = CohortFilter(\n",
" method=CohortFilterMethods.METHOD_INCLUDES,\n",
" arg=[\"Never-married\", \"Divorced\"],\n",
" column='marital-status')\n",
"\n",
"user_cohort_marital_status = Cohort(name='Cohort Marital-Status')\n",
"user_cohort_marital_status.add_cohort_filter(cohort_filter_marital_status)\n",
"\n",
"# Cohort on index of the row in the dataset\n",
"cohort_filter_index = CohortFilter(\n",
" method=CohortFilterMethods.METHOD_LESS,\n",
" arg=[20],\n",
" column='Index')\n",
"\n",
"user_cohort_index = Cohort(name='Cohort Index')\n",
"user_cohort_index.add_cohort_filter(cohort_filter_index)\n",
"\n",
"# Cohort on predicted target value\n",
"cohort_filter_predicted_y = CohortFilter(\n",
" method=CohortFilterMethods.METHOD_INCLUDES,\n",
" arg=['>50K'],\n",
" column='Predicted Y')\n",
"\n",
"user_cohort_predicted_y = Cohort(name='Cohort Predicted Y')\n",
"user_cohort_predicted_y.add_cohort_filter(cohort_filter_predicted_y)\n",
"\n",
"# Cohort on predicted target value\n",
"cohort_filter_true_y = CohortFilter(\n",
" method=CohortFilterMethods.METHOD_INCLUDES,\n",
" arg=['>50K'],\n",
" column='True Y')\n",
"\n",
"user_cohort_true_y = Cohort(name='Cohort True Y')\n",
"user_cohort_true_y.add_cohort_filter(cohort_filter_true_y)\n",
"\n",
"cohort_list = [user_cohort_age_and_hours_per_week,\n",
" user_cohort_marital_status,\n",
" user_cohort_index,\n",
" user_cohort_predicted_y,\n",
" user_cohort_true_y]"
]
},
{
"cell_type": "markdown",
"id": "elder-fleet",
Expand All @@ -267,7 +341,7 @@
"metadata": {},
"outputs": [],
"source": [
"ResponsibleAIDashboard(rai_insights)"
"ResponsibleAIDashboard(rai_insights, cohort_list=cohort_list)"
]
},
{
Expand Down Expand Up @@ -510,7 +584,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
"version": "3.6.12"
}
},
"nbformat": 4,
Expand Down