Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model comparison plot #666

Closed
romanlutz opened this issue Dec 31, 2020 · 17 comments
Closed

Add model comparison plot #666

romanlutz opened this issue Dec 31, 2020 · 17 comments
Assignees
Labels
enhancement New feature or request help wanted

Comments

@romanlutz
Copy link
Member

romanlutz commented Dec 31, 2020

Is your feature request related to a problem? Please describe.

With the introduction of matplotlib-based plotting functions in #561 we are adding ways to plot metrics for all groups as defined by sensitive_features. To get to parity with the FairlearnDashboard we should also add model comparison visualizations. This would be pretty much identical to the following existing view:
image

Just like the functions in #561 this should live in the metrics module.

Describe the solution you'd like

plot_model_comparison(
    y_true, 
    y_preds={"model 1": y_pred_1, "model 2": y_pred_2, ... }, 
    sensitive_features, show_plot=True)
@romanlutz romanlutz added enhancement New feature or request help wanted labels Dec 31, 2020
@adrinjalali
Copy link
Member

The function should either accept a metric or the user should compute the metrics outside the function and pass the scores to the plotting method.

@romanlutz
Copy link
Member Author

You're right, of course! I missed the metric function. Thanks!

@romanlutz
Copy link
Member Author

@rishabhsamb expressed some interest in this, so I wanted to make sure we agree on all the basics. In my description above it looks like a standalone function. I think that's also our only option since MetricFrame doesn't take multiple predictions. That is, unless we're willing to extend MetricFrame towards multiple models/sets of predictions(?). I'm assuming we don't (at least not in the short term), so @rishabhsamb could implement this as

plot_model_comparison(
    metrics,
    y_true, 
    y_preds={"model 1": y_pred_1, "model 2": y_pred_2, ... }, 
    sensitive_features, show_plot=True)

But there are still a few different things to choose here such as:

  • Should metrics include some indication on which one is on the x-axis and which one is on the y-axis? Or do we just randomly choose that?
  • Related: Is metrics a set of arbitrarily many metrics? If so, do we generate all combinations as charts? Or do we restrict it to just two? In that case, would it be simpler to provide both explicitly as x_metric, y_metric. In the existing dashboard x is always the "performance" metric and y is always the "fairness" metric. We could also make that restriction.
  • How do we specify the metrics? Simply putting selection_rate isn't sufficient, since we may want to evaluate the group_max, group_min, ratio, or difference.

After thinking through these questions I'm thinking of the following

plot_model_comparison(
    x_metric=accuracy_score,  # no default value, this is just for illustrative purposes
    x_metric_aggregator=None,  # optional, default None (means "overall")
    y_metric=selection_rate,  # no default value, this is just for illustrative purposes
    y_metric_aggregator="group_max",  # optional, default None (means "overall")
    y_true, 
    y_preds={"model 1": y_pred_1, "model 2": y_pred_2, ... }, 
    sensitive_features, show_plot=True)

which would then proceed to plot overall accuracy on the x-axis (the value you get from MetricFrame.overall) and demographic parity difference or, in other words, MetricFrame.difference() based on selection_rate. If the aggregation isn't needed it can be left out (i.e., it's an optional arg). I would also expect that y_preds has to be a dictionary (or DataFrame?), not just a single prediction, because this is meant to be a scatter plot after all.

@fairlearn/fairlearn-maintainers wdyt? Also, @rishabhsamb please chime in with your own thoughts. I realize I may be missing some pieces or skipping over important details.

@adrinjalali
Copy link
Member

Just like the functions in #561 this should live in the metrics module.

I'm not sure, an inspection module makes more sense to me for this.

I guess ideally, the user would have like a few different models and maybe one or two or more grid search CV done on a few hyper parameter sets, and would like to visualize those results.

Have you looked in to the outputs GridSearchCV? We could utilize cv_results_ and multimetric_ for instance, just an idea to explore.

  • How do we specify the metrics? Simply putting selection_rate isn't sufficient, since we may want to evaluate the group_max, group_min, ratio, or difference.

I would discourage us from incorporating these concepts into the signature of such a function. I would rather have simple metric functions as X and Y, and whatever logic needs to be there, is in that function. We have a bunch of y_pred (or estimators for that matter), one y_true, and the sensitive feature, and that's enough.

@romanlutz
Copy link
Member Author

Regarding the module, it seemed just intuitive to have it with metrics since this is basically a metrics visualization. That's not a terribly important point for me, though, so whatever works for people is fine by me.

Note that this isn't just for GridSearch either. I can plug in my original model, a postprocessed model, and an ExponentiatedGradient model and compare the three of them without touching GridSearch at all. Ideally I'd like it to be generic enough to support such use cases.

I understand the concern with exposing so many knobs on the metric aggregation. However, I'm not sure how else you will achieve the same outcome. This isn't just the simple metric, but aggregations of that metric between the groups. If I need to plot the overall accuracy score vs. selection rate difference there isn't really another way to express that (as far as I can tell) other than

  • having two args per axis like I described above (metric + aggregator)
  • creating a new object that encapsulates metric + aggregator, and pass that object for each axis. I don't like creating new objects that people aren't familiar with, though.
  • defining different functions plot_model_comparison_* where the * would be all kinds of variations of aggregating a certain way on x and y axis. That's obviously a ton of combinations and not really a good option.

If I've missed another option please point it out.

@adrinjalali
Copy link
Member

So for me to understand, if I have 3 models and a sensitive attribute which has 2 groups in it, how many points do I want to see in that plot? I was assuming 3.

@romanlutz
Copy link
Member Author

That matches my expectation exactly. But for each of those 3 models I have many choices for what the axes should be. Let's say I want to plot accuracy vs. selection rate. I'll just go with my earlier API proposal which would result in

plot_model_comparison(
    x_metric=accuracy_score,
    y_metric=selection_rate,
    y_true, 
    y_preds={"model 1": y_pred_1, "model 2": y_pred_2, ... }, 
    sensitive_features, show_plot=True)

How does this function internally know whether I want

  • the overall accuracy and selection rate (i.e., ignoring sensitive features)
  • the min (or max, although in this case that wouldn't make sense) accuracy or selection rate
  • the difference between max and min accuracy and selection rates
  • the ratio of min and max accuracy and selection rates
  • OR most likely: a combination thereof, like what I specified in my example with overall accuracy of the model (aka "performance metric") and the selection rate difference (aka "fairness metric").

I see no way of expressing that without args, hence:

plot_model_comparison(
    x_metric=accuracy_score,  # no default value, this is just for illustrative purposes
    x_metric_aggregator=None,  # optional, default None (means "overall")
    y_metric=selection_rate,  # no default value, this is just for illustrative purposes
    y_metric_aggregator="group_max",  # optional, default None (means "overall")
    y_true, 
    y_preds={"model 1": y_pred_1, "model 2": y_pred_2, ... }, 
    sensitive_features, show_plot=True)

@adrinjalali
Copy link
Member

The plot_model_comparison would expect the metric function to have one of the two signatures:

  • metric(y_true, y_pred, sensitive_attribute)
  • metric(y_true, y_pred)

Internally we can call the first one, in case it raises something, we call the second.
The user would call as:

plot_model_comparison(
    x_metric=accuracy_score,
    y_metric=partial(equalized_odds_ratio, method="max"),
    y_true, 
    y_preds={"model 1": y_pred_1, "model 2": y_pred_2, ... }, 
    sensitive_features, show_plot=True)

@romanlutz
Copy link
Member Author

I had thought about a similar option, but I still think it would cause some extra trouble for users. equalized_odds_ratio is already defined, but what about just getting the group with minimum accuracy? Thinking about this a bit more, you can define variations of aggregations for each kind of metric (at least min, max, diff, ratio), but we didn't want that many custom ones like equalized_odds_ratio. Otherwise we'll end up with #metrics x #aggregations (and potentially multiplied by # options within those aggregations, as you've demonstrated with method="max") custom functions.

That leaves users with something like this:

def accuracy_score_min(y_true, y_pred, sensitive_features):
    MetricFrame(accuracy_score, y_true, y_pred, sensitive_features=sensitive_features).group_min()
plot_model_comparison(..., y_metric=accuracy_score_min,...)

It's more flexible than what I proposed in that my proposal can't accommodate the method argument. It does still feel like quite a lot to write given that the key idea is easy to express. One could put it all in the same line with a lambda but that certainly doesn't help in making this readable.

The middle path would be to still go with your suggestion, but create "shortcuts" like equalized_odds_ratio for the most common ones. I imagine there are somewhere between 10 and 20 of them so perhaps it's not quite as dramatic as I explained with my estimate above.

I certainly like having fewer arguments and for me that pretty much overrules any concerns I've mentioned. I'm just wondering whether there's another way to avoid creating so many custom functions.

@adrinjalali
Copy link
Member

I very much prefer to make the user write 3 lines of code rather than having an API which is really hard to understand and remember.

Ideally for more complicated cases, make_derived_metric would be the way users can give their custom metric function.

@MiroDudik
Copy link
Member

I like @adrinjalali's proposal above. Just a request to make all the arguments keyword only!

@MiroDudik
Copy link
Member

[@romanlutz : for the specific use case that you're mentioning, we already support accuracy_score_group_min, but of course, i don't want to add too many entries to that table.]

riedgar-ms pushed a commit that referenced this issue May 19, 2021
…766)

With a slight delay (originally targeted for April) I'm finally removing the `FairlearnDashboard` since a newer version already exists in `raiwidgets`. The documentation is updated to instead use the plots @MiroDudik created with a single line directly from the `MetricFrame`. In the future we want to add more kinds of plots as already mentioned in #758 #666 and #668 . Specifically, the model comparison plots do not yet have a replacement yet.

Note that the "example" added to the `examples` directory is not shown under "Example notebooks" on the webpage, which is intentional since it's technically not a notebook.

This also makes #561 mostly redundant, which I'll close shortly. 
#667 is also directly addressed with this PR as the examples illustrate.

Signed-off-by: Roman Lutz <rolutz@microsoft.com>
@romanlutz romanlutz added this to To do in Visualizations May 27, 2021
@romanlutz
Copy link
Member Author

This issue is still up for grabs, but whoever wants to do it can start from where @rishabhsamb and I left off which is captured in this branch: https://github.com/rishabhsamb/fairlearn/tree/refactored-input

There's quite a bit of input validation we needed to refactor first because the validation is almost identical to what we're doing for mitigation techniques, but doesn't require certain parts of it. I suppose that could be its own PR to start with, and maybe I should get that out of the way before someone does this...

@SeanMcCarren
Copy link
Contributor

Hi! I worked on this :) I'm still a bit new and I appreciate feedback/help. I have a few questions also. You can check out my progress on https://github.com/SeanMcCarren/fairlearn/tree/model_comparison_plot

  1. Currently, I'm having trouble with from fairlearn.metrics import plot_model_comparison. It won't seem to import. I tried adding an import in metrics/__init__.py, to no avail. I also added it at the bottom, in the __all__ (this is new for me so I assume something must be wrong here).

  2. Secondly, I'm not sure about protocol here, do I make a PR at this point? I don't want to waste the precious CI server every time I make a silly commit haha, but its nice to have tests checked.

  3. Lastly (for now hihi), I saw that there are tests using pytest.mark.mpl_image_compare for other plotting scenarios, but that those are disabled. Should I also write those tests, test it locally, and then disable them?

@riedgar-ms
Copy link
Member

For (1) have you done a pip install -e . from the root of your local copy? Or have you made accidentally installed Fairlearn from PyPI in your Python environment?

For (2), definitely go ahead and make the PR. You'll only trigger builds when you do a push, not on every commit. And CPU time isn't that precious.

For (3), I don't think we have a good solution for plotting tests. You can write an example notebook to show their use, but that would end up being a manual check.

@romanlutz
Copy link
Member Author

I agree with everything @riedgar-ms wrote, and I'm SUPER excited you're interested in this issue!!!

I would completely ignore the matplotlib tests since I could never get them to work in CI. This requires a separate issue perhaps at some point, but the main thing is that it renders right on the webpage which we need to check every time we make a change to the corresponding page. So, in summary, @riedgar-ms is right in saying that we should have documentation with an example. However, I would recommend adding a user guide page as well, ideally based on the example notebook (if you search for literalinclude you'll find examples of what I mean... if those don't make sense please reach out!).

SeanMcCarren added a commit to SeanMcCarren/fairlearn that referenced this issue Sep 10, 2021
…numpydoc

- also support extra kwargs which are passed to ax.scatter
- only if no ax is provided, automatically infer axis labels
hildeweerts added a commit that referenced this issue Oct 12, 2022
* refactored input

Signed-off-by: Rishabh Sambare <rishabhsamb@gmail.com>

* plot_model_comparison and example in a notebook

The (initial attempt at) the body of plot_model_comparison (probably still missing sanity w.r.t. the x_axis_metric and y_axis_metric... can I just pass sensitive_features always? nothing else required?)

To make it work, added the import in __init__.py

Also created an example in one of the existing notebooks as I believe it is a fitting context!

* Implemented plot_model_comparison as specified in #666 with numpydoc

- also support extra kwargs which are passed to ax.scatter
- only if no ax is provided, automatically infer axis labels

* Fixed small bugs

* Start documentation

Example notebook section
User guide section (with literalinclude)

* rewording

* Update docs/user_guide/assessment.rst

Co-authored-by: Roman Lutz <romanlutz13@gmail.com>

* Update fairlearn/metrics/_multi_curve_plot.py

Co-authored-by: Roman Lutz <romanlutz13@gmail.com>

* Update fairlearn/metrics/_multi_curve_plot.py

Co-authored-by: Roman Lutz <romanlutz13@gmail.com>

* Added :end-before: for safety with literalinclude

* Typechecking y_preds and not showing plot

* Testcase for comparison plot when there is no matplotlib installed

* Removed show_plot in User Guide example

* remove show_plot here also

* impl roman's suggest on 15 sept to replace "_" with " "

* add point color and label as a baseline

* flake8

* Update _multi_curve_plot.py

* Enumerate over dictionary keys to get indices

* group_by_name kw: plot models with same prefix same color

* Also change point label text

* changes as described in PR

* fixed point_labels

* rename groups to model_group

* model_kwargs instead of group_kwargs

* Combine calls to scatter if kwargs are equal

* lint

* Delete ._plotting.py.swp

* legend_kwargs

* Improve plot_model_comparison example

* Got rid of redundant preprocessing changes

* flake8

* fix test

* implement changes suggested by @MiroDudik

* variable name change

* fix param name

* unused import

* Add point_labels_position suggestion from @MiroDudik

* Suggestion to not use global variable

* Bug fix

* Basic testing

* flake8

* small rewrite

* test coverage

* fix tests with pytest.raises

* more coverage

* More consistent handling of sf and cf

* Additional test coverage

* Fixing faulty test cases

* flake8

* Apply suggestions from code review

Co-authored-by: MiroDudik <mdudik@gmail.com>

* simplified selection rate difference

* ask for list instead of tuple

* make backwards compatible

* Update fairlearn/metrics/_multi_curve_plot.py

Co-authored-by: MiroDudik <mdudik@gmail.com>

* Remove legend kwargs

* Update fairlearn/metrics/_multi_curve_plot.py

Co-authored-by: Richard Edgar <riedgar@microsoft.com>

* refactor renaming

* Testing errors

* Remove test for legend_kwargs as we removed legend_kwargs as arg

* Remove unnecessary comment & some final test coverage

Signed-off-by: Rishabh Sambare <rishabhsamb@gmail.com>
Co-authored-by: Rishabh Sambare <rishabhsamb@gmail.com>
Co-authored-by: Roman Lutz <romanlutz13@gmail.com>
Co-authored-by: MiroDudik <mdudik@gmail.com>
Co-authored-by: Hilde Weerts <24417440+hildeweerts@users.noreply.github.com>
Co-authored-by: Richard Edgar <riedgar@microsoft.com>
@romanlutz
Copy link
Member Author

#947 addressed this! Closing 🙂

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted
Projects
No open projects
Visualizations
  
To do
Development

No branches or pull requests

5 participants