Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi target classification #441

Merged

Conversation

YonyBresler
Copy link
Contributor

@YonyBresler YonyBresler commented May 1, 2024

Added the ability for multi-target classification, discussed in #430.

Highlighting some of the changes:

  1. Initiated by having targets a list, same as in multi-target regression
  2. In inferred-config, added 'output_cardinality' to track the number of classes per target
  3. label_encoder is now a list, [i] corresponds to the i'th target
  4. Updated unit tests for all models to include a multi-target variation of classification
  5. Updated documentation to remove limitation of no multi-target classification support
  6. [pre-commit.ci] auto fixes from pre-commit.com hooks

Things that I'm happy to improve:

  1. With custom metrics, we now need a set of parameters for each metric and for each target (since the number of classes can vary). I'm not as familiar with OmegaConf, I created a sub_params_list config object with I omegaConf.create() on the fly, but wondering if there's a cleaner way to do this.
  2. Bagging predict would require significant changes to allow multi-target classification. It is not clear to me if multi-target regression is currently supported in tabular_model._combine_predictions()? For now bagging only works correctly with single-target classification.
  3. While there are tests to see that multi-target is working, all current model tests only evaluate that it runs without errors, but not that it runs correctly.

📚 Documentation preview 📚: https://pytorch-tabular--441.org.readthedocs.build/en/441/

@manujosephv
Copy link
Owner

Thanks a lot for this PR. This is something that was requested by a lot of folks in the community! I'm running a bit from pillar to post at the moment, but I'll take a look at the code as soon as I get a chance :)

@manujosephv
Copy link
Owner

Thank you @YonyBresler for the excellent PR. I have almost no suggestions except one minor thing about an error message! And apologize for the tardiness in reviewing the PR. Things been crazy at work.

Also, the custom_metrics solution is something I can live with.

One more thing I would add is a tutorial notebook.. Maybe a How-To guide? That would be a good start to put down how one would use multi-label classification

@manujosephv
Copy link
Owner

@YonyBresler Did you manage to take a look at the comment? Or do you want me to make the changes and merge the PR? Let me know. This is a very useful and asked for feature

@YonyBresler
Copy link
Contributor Author

Hi @manujosephv, it's in progress, preparing the notebook I found and fixed a bug with how the metric is reported in some circumstances (already committed)

There's still an issue with certain metrics (that require probabilities) and auto-lr for model like Gandalf, it causes an error, I believe things are not getting initialized properly with the new changes.

Let me see if I can resolve it, or if not perhaps add a limit to not use auto-lr in multi-target classifier with models that can't handle it right now, but I hope I can fix it.

@YonyBresler
Copy link
Contributor Author

Thanks for your patience @manujosephv, I've resolved the issue (it wasn't a bug, but rather improper configuration) and added a tutorial notebook to walk through multi-target classification (albeit with a very rudimentary 2nd target).

As far as I can tell, should be good to go, please take a look when you get a chance and let me know if there's any issue or ready to merge.

Thanks!

@manujosephv
Copy link
Owner

There is some error in test cases. I think it's not your code, but some library compatibility. I'll try and figure it out as soon as I get some time.

@YonyBresler
Copy link
Contributor Author

It's weird, I tested on a clean python 3.10.14 install and all tests pass, no library issue, so not sure what's causing it on your test script.
Let me know if there's something I can do to help resolve this.

@YonyBresler
Copy link
Contributor Author

Thanks for resolving this issue @Borda!

@@ -2035,23 +2036,21 @@ def _combine_predictions(
elif callable(aggregate):
bagged_pred = aggregate(pred_prob_l)
if self.config.task == "classification":
classes = self.datamodule.label_encoder.classes_
# FIXME need to iterate .label_encoder[x]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should give an error message if somebody attempts bagging predict with multi label classification?

@manujosephv manujosephv merged commit 25691f5 into manujosephv:main Sep 17, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants