Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-class classification implementation #329

Merged
merged 70 commits into from
Dec 6, 2022

Conversation

MortenHolmRep
Copy link
Collaborator

Implementation of multi-class classification predictions as per #111 , where the new operational structure would imply
Raw Data -> [noise, muon, neutrino]

In the current implementation, this could also easily be expanded to
Raw Data -> [noise, muon, nu_e, nu_mu, nu_tau]
But a two-step classification, first on [noise, muon, neutrino] and then a second classification, with a neutrino-specific task, training with the identified neutrinos on [nu_e, nu_mu, nu_tau] is probably preferred for neutrino classifications.
comments on an approach or alteration to this multi-class classification implementation are welcome.

Training with the task using examples/train_model.py works with the following alterations, where the target are pid:

task = MulticlassificationTask(
        hidden_size=gnn.nb_outputs,
        target_labels=config["target"],
        loss_function=MultiClassificationCrossEntropyLoss(),
    )

and for the results

results = get_predictions(
        trainer,
        model,
        validation_dataloader,
        [config["target"] + "_noise_pred", config["target"] + "_muon_pred", config["target"]+ "_neutrino_pred"],
        additional_attributes=[config["target"], "event_no"],
    )

Implementation by @Peterandresen12 and I

@MortenHolmRep MortenHolmRep added the feature New feature or request label Oct 27, 2022
@MortenHolmRep MortenHolmRep added this to the v1.0.0 / Features milestone Oct 27, 2022
Copy link
Contributor

@Peterandresen12 Peterandresen12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me :-)

@asogaard asogaard removed this from the v1.0.0 / Features milestone Oct 28, 2022
Copy link
Collaborator

@asogaard asogaard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @MortenHolmRep and @Peterandresen12,

Thanks for this PR! It would make a nice addition to the repo. 🙂

I have added a few suggestions and comments that center on the fact that, currently, the proposed task and loss function hard-code assumption about the number of classes, and the PIDs of the particles being classified. We should try to be more general if we want to implement generic multi-class classification.

src/graphnet/models/task/reconstruction.py Outdated Show resolved Hide resolved
src/graphnet/training/loss_functions.py Outdated Show resolved Hide resolved
src/graphnet/training/loss_functions.py Outdated Show resolved Hide resolved
src/graphnet/training/loss_functions.py Outdated Show resolved Hide resolved
src/graphnet/training/loss_functions.py Outdated Show resolved Hide resolved
src/graphnet/training/loss_functions.py Outdated Show resolved Hide resolved
src/graphnet/training/loss_functions.py Outdated Show resolved Hide resolved
@asogaard asogaard linked an issue Oct 28, 2022 that may be closed by this pull request
@MortenHolmRep
Copy link
Collaborator Author

Hi @MortenHolmRep and @Peterandresen12,

Thanks for this PR! It would make a nice addition to the repo. 🙂

I have added a few suggestions and comments that center on the fact that, currently, the proposed task and loss function hard-code assumption about the number of classes, and the PIDs of the particles being classified. We should try to be more general if we want to implement generic multi-class classification.

I will look into a rework tonight, based on the suggestions 👍

MortenHolmRep and others added 6 commits October 28, 2022 19:39
device inheritance

Co-authored-by: Andreas Søgaard <andreas.sogaard@gmail.com>
loss function renaming

Co-authored-by: Andreas Søgaard <andreas.sogaard@gmail.com>
description reformulation

Co-authored-by: Andreas Søgaard <andreas.sogaard@gmail.com>
trimming description

Co-authored-by: Andreas Søgaard <andreas.sogaard@gmail.com>
description reformulation

Co-authored-by: Andreas Søgaard <andreas.sogaard@gmail.com>
dynamical multiclass classification

Co-authored-by: Andreas Søgaard <andreas.sogaard@gmail.com>
@asogaard
Copy link
Collaborator

asogaard commented Nov 1, 2022

Hi @MortenHolmRep,

Thanks for the quick iteration! :) I have added comments to a few threads. Please let me know if there's anything else you'd like to discuss there. Otherwise, feel free to re-request a review once you think the code in the PR is ready for a second look. :)

@asogaard asogaard removed their assignment Nov 1, 2022
@MortenHolmRep
Copy link
Collaborator Author

Thanks for your review @asogaard! I have added some comments for now. The suggestion you made I'll have to test later today.

@asogaard asogaard assigned MortenHolmRep and unassigned asogaard Dec 1, 2022
MortenHolmRep and others added 3 commits December 1, 2022 14:14
Co-authored-by: Andreas Søgaard <andreas.sogaard@gmail.com>
Co-authored-by: Andreas Søgaard <andreas.sogaard@gmail.com>
Co-authored-by: Andreas Søgaard <andreas.sogaard@gmail.com>
@asogaard
Copy link
Collaborator

asogaard commented Dec 1, 2022

@asogaard Can you review the PR again? I have attempted to clean the code, config and unused packages. The generic classification now works and produces good results. The build fails for the LogCMK class in loss_functions, which is unrelated to my implementation, I do not know why I am getting this error.

Regarding the failing unit test: As far as I can tell, the fact that we change the assertion logic regarding transform_target/transform_inference will affect (some of) the unit test(s) that check for this. This is perfectly OK — we will just have to update the unit tests to reflect this new logic. :) I can have a look at the unit test results once you've had a chance to look at the comments above.

@MortenHolmRep
Copy link
Collaborator Author

I have added a script called "train_classification_model_without_configs.py" that follows the old structure and a new script called "train_classification_model.py" that should adhere to the new structure.
The new setup with configs is not done as I need input on how to set it up properly.

Copy link
Collaborator

@asogaard asogaard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @MortenHolmRep,

I think this looks great! If it runs as expected, I am all for merging this. There is still the problem of the failing unit test that needs to be resolved. I think it is just a matter of deleting this block as it no longer represents expected behaviour. :)

@asogaard asogaard removed their assignment Dec 6, 2022
@Peterandresen12
Copy link
Contributor

Well done!

@Peterandresen12 Peterandresen12 merged commit b77b25c into graphnet-team:main Dec 6, 2022
RasmusOrsoe pushed a commit to RasmusOrsoe/graphnet that referenced this pull request Oct 25, 2023
Multi-class classification implementation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement multi-class classification
3 participants