-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modularising models #40
Conversation
…ing to call close_connection even if database has not been queried prior to being passed to DataLoader
…s of graphs) and make event_no separate from features and targets
Hej Andreas
Hvad medium/link mødes vi på?
BH. Troels
… On 26 Oct 2021, at 09.38, Andreas Søgaard ***@***.***> wrote:
Hi @RasmusOrsoe <https://github.com/RasmusOrsoe> and @mhaminh <https://github.com/mhaminh>,
I am creating this (massive) pull request as an initial proposal for how we can structure model building in a flexible, modular fashion. The main contribution of this PR is the creation of separate Detector, GNN, and Task classes that are chained together through the model Model class. You can see what this looks like in examples/test_model_training_sqlite.py. This way:
Different objectives (detector-specific data ingestion and preprocessing; GNN model building; and task-specific readout), respectively, are handled by dedicated, restricted classes that can be switched out in a plug-and-play fashion
Code duplication (e.g. graph building, read-out) is reduced between models
Comparing different model architectures 1:1 becomes as easy as switching GNN=DynEdge(**kwargs) to GNN=ConvNet(**kwargs)
There is less boilerplate/overhead required to develop, implement, and test novel GNN architectures in the future
Applying existing GNN architectures to novel tasks becomes close to trivial, as it just requires switching the task(s) parameter in Model.
Etc.
All I have done is tried to take the common bits from your respective models and make dedicated, common classes for these; and take the unique bits and turn them into separate classes inheriting from GNN.
@RasmusOrsoe <https://github.com/RasmusOrsoe>: The only non-trivial (intentional, at least) difference wrt. the previous modelling code is that I have commented out the re-calculation of graph connectivity at each message-passing layer in ConvNet because I have tried to relegate graph-building to the detector-specific read-in layer much like @mhaminh <https://github.com/mhaminh> has been doing. We can discuss whether we should prioritise supporting the other functionality as well.
Please take provide as many and as detailed comments as you want; and if we feel like this is a promising direction to pursue I will do my best to update the PR to everyone's liking.
NB: I haven't touched the Trainer/Predictor classes, which may be handled by pytorch-lightning, but we can explore/discuss that in a separate PR.
You can view, comment on, or merge this pull request online at:
#40 <#40>
Commit Summary
Refactor contents of components/utils to better reflect contents <9323abe>
Move save_results to relevant utils module <ac8d279>
Remove call to establish_connection in initialiser so as to avoid having to call close_connection even if database has not been queried prior to being passed to DataLoader <2185527>
Convert all graph attributes to tensors (to allow for indexing batches of graphs) and make event_no separate from features and targets <6a7133d>
Make features and targets package-level constants (DRY) <89bb57f>
Make loss functions class-type, inheriting from pytorch losses <db4bfb4>
Add GraphBuilder class and first example using k-nearest neighbours <d80f74f>
Refactor model construction and training <74c3dbb>
Update model training example to use refactored model structure <7a7ba53>
Move common functionality to base Detector class <3628c6f>
Change reduction from sum to mean to make results comparable across batch size <02dae10>
Add a few docstrings <a781498>
Add type hinting <459fc39>
Add sanity checks on inputs <29a33d2>
Resolve merge conflictt <https://github.com/icecube/gnn-reco/pull/40/commits/752ac3466d22a5b0404a4e3dd6b026659b297317>File Changes (24 files <https://github.com/icecube/gnn-reco/pull/40/files>)
M examples/test_model_training_sqlite.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-101a4808aa1ba4c30b19db22124024bc083a66f2caae9748885bb73f5843f096> (128)
M src/gnn_reco/components/loss_functions.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-28ad1fad75685196f83a2ca79cd2705f27d161f8ee73798a9afb22fd1f76dfb0> (138)
M src/gnn_reco/components/utils.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-431a01f1b0c9763a3dfe66a2a87af9ee1c5672f05d432705649c97470662dcc7> (324)
A src/gnn_reco/data/constants.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-ed780645b4ee307933f01d16d9ed5777eeb0e36558a2158a80c3d27c656960b9> (27)
M src/gnn_reco/data/sqlite_dataset.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-abb56473e2b8bd6f07a5127359fb83021cea64a2308541034378cc5b958df695> (29)
A src/gnn_reco/models/__init__.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-17605bc07a958f6300418bbfaa7aa9105eb0114948cc556fda43e128cc3b9515> (1)
D src/gnn_reco/models/convnet.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-5169f1fe62be0d7978c3510df1fe545b9a570822929f6fc64d45681aaf0db101> (134)
A src/gnn_reco/models/detector/__init__.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-42b31af34befcf971957f583cb397bc5e5437521b754d9b9393b4e96366fe6a4> (1)
A src/gnn_reco/models/detector/detector.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-dac99a48dbe75cad35d95f5367cdc1a784f79738d238d37b8d17cd4d1e5af5d8> (50)
A src/gnn_reco/models/detector/icecube86.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-3bedac14bc63cb05b33b468e796436ca79beefd382852b99abfe01e8f7c6a489> (49)
D src/gnn_reco/models/dynedge.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-13a5f01c463d840b56756f5bf46948f303a08450f7c79ccf6d0be8da858a3cf5> (110)
A src/gnn_reco/models/gnn/__init__.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-d11c2ffea1eb4b6044ca5203c1d28a8184aefce8e18ea38716f8d00d48b4139f> (3)
A src/gnn_reco/models/gnn/convnet.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-17e5f5aaec8124631d6d6c158f42dddfd2b26c67e51541b6a051d2e01d6541d7> (111)
A src/gnn_reco/models/gnn/dynedge.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-61323b71a3a29b8d643f8e8c62eec8a00c65f9f8702e32d9780f6b8c617b4b2c> (140)
A src/gnn_reco/models/gnn/gnn.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-da2988156e7ccb0ff980f1c9924cce256db71e17633332fdfebbcd2ab22e1b2a> (21)
A src/gnn_reco/models/graph_builders.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-28cea8e279c54b00730c7042b354dff01d5ac9616e6c658160e41d7ffae03b39> (31)
A src/gnn_reco/models/model.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-c6017f4c1fd7e6174f1b0e43921f15a16d10153668e071c0883cefe3a1741b96> (65)
A src/gnn_reco/models/task/__init__.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-527dcbf4eec02401a0d5d31365646fe58405065a0235c2c8b1bafed7b3630363> (1)
A src/gnn_reco/models/task/reconstruction.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-9af66145d1e613baebb979f12b146a070e0f950bd0ccf55020d253b81fd4dd69> (33)
A src/gnn_reco/models/task/task.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-93591d2ec2265c35e38a02788a31ee892717b219ac4ff6448c71f1ae86336bc7> (32)
A src/gnn_reco/models/training/__init__.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-1829a37c18638b576f9132b5a1ae5038fe3f60e373480d96d3d16f93fb867ce8> (0)
A src/gnn_reco/models/training/callbacks.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-0491bea380c86f5cfbf7672ae02556d4051df1ec4d52633bfcebf0f90043c598> (81)
A src/gnn_reco/models/training/trainers.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-dc9a7d1522fdf5ef075e9a232915abf83ed0bd5fb32b9a93f3e3a78de00015a2> (178)
A src/gnn_reco/models/training/utils.py <https://github.com/icecube/gnn-reco/pull/40/files#diff-6d6ac9f24a007bb1a55f7d25dc9b5ed35895d9f5aa10eeeb138469020ce68f3b> (35)
Patch Links:
https://github.com/icecube/gnn-reco/pull/40.patch <https://github.com/icecube/gnn-reco/pull/40.patch>
https://github.com/icecube/gnn-reco/pull/40.diff <https://github.com/icecube/gnn-reco/pull/40.diff>
—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub <#40>, or unsubscribe <https://github.com/notifications/unsubscribe-auth/ADLMFXKHDV7EKVSO5SFPAWDUIZSI7ANCNFSM5GXCUGMQ>.
Triage notifications on the go with GitHub Mobile for iOS <https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675> or Android <https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
|
… them and can update parameters during training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I quite like this.
However, it is absolutely neccesary for the edge_indicies to be recalculated on the fly in DynEdge. thats the "Dyn" in DynEdge. We have no idea how this change effects the performance and it's something I feel strongly we should not mess with.
Also, I cannot see how we get to play around with preprocessing - the scalers seems to have been replaced with some static code? Maybe we should talk about this - it kind of breaks the work I'm doing for the paper.
Also, two minor things:
Lets please keep the original logcosh loss function also (we can have both the original one and the approximated one)
How do we handle loss-function specific final activation functions? Currently it seems like we have lost the vonmisesfisher-required tanh(x) activation functions in the angular task.
Rasmus
…t general, exact von-Mises Fisher loss function; and improve structure of Tasks and LossFunctions
…having to manually call close_connection
…to modularising-models
Hi @RasmusOrsoe,
|
…to modularising-models
…to modularising-models
…down-stream task(s)
…ng script to illustrate use
Hi @RasmusOrsoe, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
|
||
# Calculate homophily (scalar variables) | ||
h_x, h_y, h_z, h_t = calculate_xyzt_homophily(x, edge_index, batch) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not my model! We need to recalculate the connectivity between each convolutional pass!
Its a central point of the approach (See https://arxiv.org/abs/1801.07829)
Hi @RasmusOrsoe and @mhaminh,
I am creating this (massive) pull request as an initial proposal for how we can structure model building in a flexible, modular fashion. The main contribution of this PR is the creation of separate
Detector
,GNN
, andTask
classes that are chained together through the modelModel
class. You can see what this looks like inexamples/test_model_training_sqlite.py
. This way:GNN=DynEdge(**kwargs)
toGNN=ConvNet(**kwargs)
task(s)
parameter inModel
.All I have done is tried to take the common bits from your respective models and make dedicated, common classes for these; and take the unique bits and turn them into separate classes inheriting from
GNN
.@RasmusOrsoe: The only non-trivial (intentional, at least) difference wrt. the previous modelling code is that I have commented out the re-calculation of graph connectivity at each message-passing layer in ConvNet because I have tried to relegate graph-building to the detector-specific read-in layer much like @mhaminh has been doing. We can discuss whether we should prioritise supporting the other functionality as well.
Please take provide as many and as detailed comments as you want; and if we feel like this is a promising direction to pursue I will do my best to update the PR to everyone's liking.
NB: I haven't touched the
Trainer
/Predictor
classes, which may be handled bypytorch-lightning
, but we can explore/discuss that in a separate PR.