-
Notifications
You must be signed in to change notification settings - Fork 548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix load from checkpoint in python 3.12 #741
Conversation
59d0511
to
f7c1f91
Compare
@KnathanM thank you for this incredible and thorough investigation of the problem. I have learned a lot just from reading this! I like the way this is implemented as is - I don't think we need to modify the other method for the sake of just cutting down the code a bit. I also like the way you have provided the implementation in
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if I'm understanding you correctly, we can't load any models in python 3.12 because of the in-lining of comprehensions?
chemprop/models/model.py
Outdated
@@ -255,3 +248,22 @@ def load_from_file(cls, model_path, map_location=None, strict=True) -> MPNN: | |||
model.load_state_dict(state_dict, strict=strict) | |||
|
|||
return model | |||
|
|||
|
|||
def load_submodules(checkpoint_path, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be an @classmethod
? If possible, it seems better because this function is strongly coupled to the load_checkpoint()
method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally I put it outside MPNN
to suggest that it also loads submodules for MulticomponenetMPNN
. But thinking about it again, I agree it should probably be a class method of MPNN
because then it is more clearly inherited by MulticomponenetMPNN
when it pulls from MPNN
. I'll wait to see if Hao-Wei has any comments to address before making that change on Monday.
chemprop/models/model.py
Outdated
if "blocks" in hparams["message_passing"]: | ||
mp_hparams = hparams["message_passing"] | ||
mp_hparams["blocks"] = [ | ||
block_hparams.pop("cls")(**block_hparams) for block_hparams in mp_hparams["blocks"] | ||
] | ||
message_passing = mp_hparams.pop("cls")(**mp_hparams) | ||
kwargs["message_passing"] = message_passing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems somewhat error prone. What if the hyper_parameters
key for an MPNN
contains "blocks"
? Reasonably, we should ignore this, but sharing this code between the two classes necessitates the assumption of its structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree this isn't ideal. If we left load_submodules()
outside of MPNN
we could bring back load_from_checkpoint()
in MulticomponentMPNN
and put a similar function in multi.py
. Or if we make load_submodules()
a classmethod of MPNN
, then MulticomponentMPNN
could overload it. The only catch is each needs to be self sufficient because we can't call super()
.
In MPNN
@classmethod
def load_submodules(cls, checkpoint_path, **kwargs):
hparams = torch.load(checkpoint_path)["hyper_parameters"]
kwargs |= {
key: hparams[key].pop("cls")(**hparams[key])
for key in ("message_passing", "agg", "predictor")
if key not in kwargs
}
return kwargs
In MulticomponentMPNN
@classmethod
def load_submodules(cls, checkpoint_path, **kwargs):
hparams = torch.load(checkpoint_path)["hyper_parameters"]
hparams["message_passing"]["blocks"] = [
block_hparams.pop("cls")(**block_hparams)
for block_hparams in hparams["message_passing"]["blocks"]
]
kwargs |= {
key: hparams[key].pop("cls")(**hparams[key])
for key in ("message_passing", "agg", "predictor")
if key not in kwargs
}
return kwargs
This works for loading both MPNN
and MulticomponentMPNN
models when load_from_checkpoint()
is class method of MPNN
that calls cls.load_submodules()
. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the second option better. Obviously there's some duplicate code, but it's only across one class and it's not strictly worse style-wise either. In this case, a user doesn't have to bark up the class hierarchy to see what the function does, as it's all contained in one place
Models can still be loaded, but the message_passing, agg, and predictor submodules need to be loaded in a function that does not call super(). In 3.11 these submodules were loaded in a dictionary comprehension function call, but 3.12's in-lining comprehensions removed this extra function call. I think we were just lucky that before 3.12 comprehensions had an extra function call. The wider problem is that lightning assumes that the call to But it is also confusing to me that they do this recursion based on |
You're right that it is brittle, but that's just the way code goes... My guess is that for the large majority of cases, the code works even if it has an obvious failure mode, and the reason is that most client code is not designed with modularity in mind. That is, if you ever go to a repo for a particular ML model's implementation, the model is defined by way of its parameterization rather than its composition. Put more concretely, consider most FFN implementations in PyTorch, they're usually defined as something like this (including ours): class FFN(nn.Module):
def __init__(self, in_features, num_layers, hidden_size, out_features, dropout):
layers = [nn.Linear(in_features, hidden_size)]
layers += sum(
[nn.Dropout(dropout), nn.Relu(), nn.Linear(hidden_size, hidden_size)] for _ in range(num_layers - 1),
[]
)
layers += [nn.Linear(hidden_size, out_features)]
self.layers = nn.Sequential(*layers) You'll notice that the logical concept of an "FFN" here is defined by some set of hyperparameters, but really it's only specific parameterization of an FFN. A resnet with a skip connection between the input and final layer is also an FFN but not according to this implementation. Really, an FFN can be configured in any number of ways, and a more general view of them is to that an FFN is a differentiable function composed of individual transformations that feed forward from the input to the output. You could define your FFN like that, taking in individual transformations (e.g., Consider also the v1 chemprop model; the model's architecture was hardcoded and accepted in a number of hyperparameters with which to configure it. But really the chemprop model defined in the original publication is just one type of MPNN. There are many ways to define an MPNN, which is abstractly a composition of functions on a graph: However, the pytorch lightning object model assumes that most users opt for the former object model in their own code. It's an unfortunate reality of the situation and why we've had to design our own workarounds. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM would like this in soon
linter was not run on chemprop#741
This PR fixes an interesting error with our use of
save_hyperparameters()
that prevents loading checkpoint files in python 3.12. We have had discussion of this several places (#714 and #738). I learned a lot while debugging this and felt it was too long for a comment. So I opened this PR both to isolate the change to demonstraight it works and also to have the space to explain what is going on. I'd be also happy to include this change in the other PR and close this one.The short version
When we call
save_hyperparameters()
in the__init__
of our modules (e.g._MessagePassingBase
) lightning will "Recursively collects the arguments passed to the child constructors in the inheritance tree." It knows when to stop by looking for the absense of__class__
in the local variables. It can do this becausesuper().__init__()
is required in the__init__
of modules that inherit fromLightningModule
and this adds__class__
to the local variables. If our call to__init__
originates in a scope that also has__class__
(as is the case inload_from_checkpoint()
) then lightning won't know where to stop. In python 3.11 we were okay because we instantiate the modules inload_from_checkpoint()
within a comprehension loop which has its own scope. Python 3.12 speeded up comprehension by removing it creating its own scope. We can fix our use ofsave_hyperparameters()
by instantiating the modules within a different scope that doesn't have__class__
, done here by calling a helper functionload_submodules()
.The longer version
The error
A simple test reproduces the error:
The traceback shows the error originates in
lightning/pytorch/utilities/parsing.py
in_get_init_args()
:The stack frame
The error occurs because lightning looks through the local variable (
local_vars
) to pick out what parameters are needed to re-initalize the module later (i.e. the hyperparameters), but the some are missing, specifically 'self' the module object that is currently being created. We see in_get_init_args()
thatlocal_vars
comes from_, _, _, local_vars = inspect.getargvalues(frame)
. The frame here is part of the stack frame which is like an onion, whose layers keep track of the scopes (local variables) of the various function calls. (Potentially helpful reference) When the code hits an error, I think the stack trace comes from following the stack frame back out to where we started. Side note thatinit_parameters
comes from inspecting the signature of the__init__
of the classinit_parameters = inspect.signature(cls.__init__).parameters
.Recursion
Why does the frame not have the variables that
__init__
needs?_get_init_args()
is called bycollect_init_args()
which says it"Recursively collects the arguments passed to the child constructors in the inheritance tree."
Part of this function checks if__class__
is in the local_vars and if it is, it will callcollect_init_args()
again but withframe.f_back
as the new frame. (Side note: it also checks if class in the current frame inherits fromHyperparametersMixin
, but I don't think that is relavent to our error.) Whatframe.f_back
does is gets the frame one up in the stack. Socollect_init_args()
will continue back up the stack until it hits a frame that doesn't have__class__
.Inspecting the frames
The initial frame that gets given to
collect_init_args()
is generated in the call tosave_hyperparameters()
in the module__init__
. I'll use_MessagePassingBase
as an example. Whenself.save_hyperparameters()
is called python goes intolightning.pytorch.core.mixins.hparams_mixin
which definessave_hyperparameters()
. This internally calls a differentsave_hyperparameters()
that is inlightning.pytorch.utilities.parsing
. The call looks likesave_hyperparameters(self, *args, ignore=ignore, frame=frame, given_hparams=given_hparams)
whereframe
comes from the couple lines abovecurrent_frame = inspect.currentframe()
,frame = current_frame.f_back
.save_hyperparameters()
gets the current frame, but really we want the frame of the function that calledsave_hyperparameters()
which is__init__
so it usesf_back
. This means it is sufficient to look at the stack frame in__init__
to figure out what is going on. Adding a block of code beforeself.save_hyperparameters()
in_MessagePassingBase
will show us the variable names that are defined in each frame's scope.Python 3.12
Running the simple test example from earlier gives a stack trace of:
The first dict_keys line is the variables that were available inside
__init__
. This makes sense because those variables are either arguments to the function (d_v
,d_vd
, etc.) or are created in init (import inspect
,frame = ...
). The second dict_keys line comes fromload_from_checkpoint()
inMPNN
inmodel.py
and the third comes from the test example's script file.So when
save_hyperparameters()
runs it first tries to save the arguments for the first line (d_v
, etc.) for<class 'chemprop.nn.message_passing.base._MessagePassingBase'>
. Then it sees that the function call before also has__class__
so it tries to save what it needs to init a<class 'chemprop.models.model.MPNN'>
by starting withself
but that wasn't given in our function call because it wasload_from_checkpoint()
not__init__
.Python 3.11
If I run the same simple text example in python 3.11 the stack trace is instead:
This is the same as before but there is an extra line. This line come from the dictionary comprehension used to loop through the three submodules that need to be instantiated.
In this case after saving the hyperparameters for
<class 'chemprop.nn.message_passing.base._MessagePassingBase'>
, it sees the next function call doesn't have__class__
and correctly terminates.What changed in 3.12
PEP 709 changed comprehensions to be inline. This means "there is no longer a separate frame for the comprehension in tracebacks".
Solution
A simple solution is to add another function call between
load_from_checkpoint()
and calling__init__
of the modules. The current implementation gives:Where does
__class__
come from?As I tried to debug this, I wondered why
load_from_file()
works just fine whileload_from_checkpoint()
didn't. The stack forload_from_file()
in the current v2/dev with python 3.12 gives:The second dict_keys line comes from
load_from_file()
but doesn't have__class__
. I learned here that__class__
gets added to the local vars whensuper()
is called.load_from_checkpoint()
has a call tosuper()
whileload_from_file()
doesn't.(Side note: apparently
super()
doesn't even need to be called, just present. Consider this example wheresuper()
is not reachable.)Implementation and Questions
I decided to only have one
load_submodules()
shared betweenMPNN
andMulticomponentMPNN
. This was to keep things simple, but then requires checking for"blocks" in hparams["message_passing"]
in the function inMPNN
which is a multicomponent only idea. Alternatively we could be more explicit thatload_submodules()
is for bothMPNN
andMulticomponentMPNN
by putting it in a separate file. I couldn't put it inmodels/utils.py
though because this file imports fromMPNN
.Both MPNN and MulticomponentMPNN have their own
load_from_file()
method. These could be condensed into a single function by checking for"blocks" in hparams["message_passing"]
in the same way that I did withload_from_checkpoint()
. Do we want to do this? If we do, there is also the question if we wantload_from_file()
to call onload_submodules()
instead of using its currentfor key in
loop.