Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix load from checkpoint in python 3.12 #741

Merged
merged 2 commits into from
Apr 1, 2024

Conversation

KnathanM
Copy link
Contributor

@KnathanM KnathanM commented Mar 23, 2024

This PR fixes an interesting error with our use of save_hyperparameters() that prevents loading checkpoint files in python 3.12. We have had discussion of this several places (#714 and #738). I learned a lot while debugging this and felt it was too long for a comment. So I opened this PR both to isolate the change to demonstraight it works and also to have the space to explain what is going on. I'd be also happy to include this change in the other PR and close this one.

The short version

When we call save_hyperparameters() in the __init__ of our modules (e.g. _MessagePassingBase) lightning will "Recursively collects the arguments passed to the child constructors in the inheritance tree." It knows when to stop by looking for the absense of __class__ in the local variables. It can do this because super().__init__() is required in the __init__ of modules that inherit from LightningModule and this adds __class__ to the local variables. If our call to __init__ originates in a scope that also has __class__ (as is the case in load_from_checkpoint()) then lightning won't know where to stop. In python 3.11 we were okay because we instantiate the modules in load_from_checkpoint() within a comprehension loop which has its own scope. Python 3.12 speeded up comprehension by removing it creating its own scope. We can fix our use of save_hyperparameters() by instantiating the modules within a different scope that doesn't have __class__, done here by calling a helper function load_submodules().

The longer version

The error

A simple test reproduces the error:

from chemprop.models import MPNN
checkpoint_path = "tests/data/example_model_v2_regression_mol.ckpt"
model = MPNN.load_from_checkpoint(checkpoint_path)

The traceback shows the error originates in lightning/pytorch/utilities/parsing.py in _get_init_args():

Traceback (most recent call last):
  File "---/simple_example.py", line 3, in <module>
    model = MPNN.load_from_checkpoint(checkpoint_path)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "---/mambaforge/envs/chemprop_3-12/lib/python3.12/site-packages/chemprop/models/model.py", line 230, in load_from_checkpoint
    key: hparams[key].pop("cls")(**hparams[key])
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "---/mambaforge/envs/chemprop_3-12/lib/python3.12/site-packages/chemprop/nn/message_passing/base.py", line 62, in __init__
    self.save_hyperparameters()
  File "---/mambaforge/envs/chemprop_3-12/lib/python3.12/site-packages/lightning/pytorch/core/mixins/hparams_mixin.py", line 130, in save_hyperparameters
    save_hyperparameters(self, *args, ignore=ignore, frame=frame, given_hparams=given_hparams)
  File "---/mambaforge/envs/chemprop_3-12/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py", line 175, in save_hyperparameters
    for local_args in collect_init_args(frame, [], classes=(HyperparametersMixin,)):
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "---/mambaforge/envs/chemprop_3-12/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py", line 139, in collect_init_args
    return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "---/mambaforge/envs/chemprop_3-12/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py", line 135, in collect_init_args
    local_self, local_args = _get_init_args(frame)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "---/mambaforge/envs/chemprop_3-12/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py", line 101, in _get_init_args
    local_args = {k: local_vars[k] for k in init_parameters}
                     ~~~~~~~~~~^^^
KeyError: 'self'

The stack frame

The error occurs because lightning looks through the local variable (local_vars) to pick out what parameters are needed to re-initalize the module later (i.e. the hyperparameters), but the some are missing, specifically 'self' the module object that is currently being created. We see in _get_init_args() that local_vars comes from _, _, _, local_vars = inspect.getargvalues(frame). The frame here is part of the stack frame which is like an onion, whose layers keep track of the scopes (local variables) of the various function calls. (Potentially helpful reference) When the code hits an error, I think the stack trace comes from following the stack frame back out to where we started. Side note that init_parameters comes from inspecting the signature of the __init__ of the class init_parameters = inspect.signature(cls.__init__).parameters.

Recursion

Why does the frame not have the variables that __init__ needs? _get_init_args() is called by collect_init_args() which says it "Recursively collects the arguments passed to the child constructors in the inheritance tree." Part of this function checks if __class__ is in the local_vars and if it is, it will call collect_init_args() again but with frame.f_back as the new frame. (Side note: it also checks if class in the current frame inherits from HyperparametersMixin, but I don't think that is relavent to our error.) What frame.f_back does is gets the frame one up in the stack. So collect_init_args() will continue back up the stack until it hits a frame that doesn't have __class__.

Inspecting the frames

The initial frame that gets given to collect_init_args() is generated in the call to save_hyperparameters() in the module __init__. I'll use _MessagePassingBase as an example. When self.save_hyperparameters() is called python goes into lightning.pytorch.core.mixins.hparams_mixin which defines save_hyperparameters(). This internally calls a different save_hyperparameters() that is in lightning.pytorch.utilities.parsing. The call looks like save_hyperparameters(self, *args, ignore=ignore, frame=frame, given_hparams=given_hparams) where frame comes from the couple lines above current_frame = inspect.currentframe(), frame = current_frame.f_back. save_hyperparameters() gets the current frame, but really we want the frame of the function that called save_hyperparameters() which is __init__ so it uses f_back. This means it is sufficient to look at the stack frame in __init__ to figure out what is going on. Adding a block of code before self.save_hyperparameters() in _MessagePassingBase will show us the variable names that are defined in each frame's scope.

    super().__init__()
    import inspect
    frame = inspect.currentframe()
    while True:
        _, _, _, local_vars = inspect.getargvalues(frame)
        print(local_vars.keys())
        if '__class__' in local_vars.keys():
	    print('__class__:', local_vars['__class__']) 
	frame = frame.f_back
    return
    self.save_hyperparameters()

Python 3.12

Running the simple test example from earlier gives a stack trace of:

dict_keys(['self', 'd_v', 'd_e', 'd_h', 'bias', 'depth', 'dropout', 'activation', 'undirected', 'd_vd', 'inspect', 'frame', '__class__'])
__class__: <class 'chemprop.nn.message_passing.base._MessagePassingBase'>
dict_keys(['cls', 'checkpoint_path', 'map_location', 'hparams_file', 'strict', 'kwargs', 'hparams', 'key', '__class__'])
__class__: <class 'chemprop.models.model.MPNN'>
dict_keys(['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__annotations__', '__builtins__', '__file__', '__cached__', 'MPNN', 'checkpoint_path'])

The first dict_keys line is the variables that were available inside __init__. This makes sense because those variables are either arguments to the function (d_v, d_vd, etc.) or are created in init (import inspect, frame = ...). The second dict_keys line comes from load_from_checkpoint() in MPNN in model.py and the third comes from the test example's script file.
So when save_hyperparameters() runs it first tries to save the arguments for the first line (d_v, etc.) for <class 'chemprop.nn.message_passing.base._MessagePassingBase'>. Then it sees that the function call before also has __class__ so it tries to save what it needs to init a <class 'chemprop.models.model.MPNN'> by starting with self but that wasn't given in our function call because it was load_from_checkpoint() not __init__.

Python 3.11

If I run the same simple text example in python 3.11 the stack trace is instead:

dict_keys(['self', 'd_v', 'd_e', 'd_h', 'bias', 'depth', 'dropout', 'activation', 'undirected', 'd_vd', 'inspect', 'frame', '__class__'])
__class__: <class 'chemprop.nn.message_passing.base._MessagePassingBase'>
dict_keys(['.0', 'key', 'hparams', 'kwargs'])
dict_keys(['cls', 'checkpoint_path', 'map_location', 'hparams_file', 'strict', 'kwargs', 'hparams', '__class__'])
__class__: <class 'chemprop.models.model.MPNN'>
dict_keys(['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__annotations__', '__builtins__', '__file__', '__cached__', 'MPNN', 'checkpoint_path'])

This is the same as before but there is an extra line. This line come from the dictionary comprehension used to loop through the three submodules that need to be instantiated.

kwargs |= {
            key: hparams[key].pop("cls")(**hparams[key])
            for key in ("message_passing", "agg", "predictor")
            if key not in kwargs
        }

In this case after saving the hyperparameters for <class 'chemprop.nn.message_passing.base._MessagePassingBase'>, it sees the next function call doesn't have __class__ and correctly terminates.

What changed in 3.12

PEP 709 changed comprehensions to be inline. This means "there is no longer a separate frame for the comprehension in tracebacks".

Solution

A simple solution is to add another function call between load_from_checkpoint() and calling __init__ of the modules. The current implementation gives:

dict_keys(['self', 'd_v', 'd_e', 'd_h', 'bias', 'depth', 'dropout', 'activation', 'undirected', 'd_vd', 'inspect', 'frame', '__class__'])
__class__: <class 'chemprop.nn.message_passing.base._MessagePassingBase'>
dict_keys(['checkpoint_path', 'kwargs', 'hparams', 'key'])
dict_keys(['cls', 'checkpoint_path', 'map_location', 'hparams_file', 'strict', 'kwargs', '__class__'])
__class__: <class 'chemprop.models.model.MPNN'>
dict_keys(['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__annotations__', '__builtins__', '__file__', '__cached__', 'MPNN', 'checkpoint_path'])

Where does __class__ come from?

As I tried to debug this, I wondered why load_from_file() works just fine while load_from_checkpoint() didn't. The stack for load_from_file() in the current v2/dev with python 3.12 gives:

dict_keys(['self', 'd_v', 'd_e', 'd_h', 'bias', 'depth', 'dropout', 'activation', 'undirected', 'd_vd', 'inspect', 'frame', '__class__'])
__class__: <class 'chemprop.nn.message_passing.base._MessagePassingBase'>
dict_keys(['cls', 'model_path', 'map_location', 'strict', 'd', 'hparams', 'state_dict', 'key', 'hparam_kwargs', 'hparam_cls'])
dict_keys(['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__annotations__', '__builtins__', '__file__', '__cached__', 'MPNN', 'checkpoint_path'])

The second dict_keys line comes from load_from_file() but doesn't have __class__. I learned here that __class__ gets added to the local vars when super() is called. load_from_checkpoint() has a call to super() while load_from_file() doesn't.

(Side note: apparently super() doesn't even need to be called, just present. Consider this example where super() is not reachable.)

class MyClass():
    def f(self):
        import inspect
        frame = inspect.currentframe()
        _, _, _, local_vars = inspect.getargvalues(frame)
        print(local_vars.keys())
        return
	print("hi")
        super()
        

temp = MyClass()
temp.f()
>>> dict_keys(['self', 'inspect', 'frame', '__class__'])

Implementation and Questions

I decided to only have one load_submodules() shared between MPNN and MulticomponentMPNN. This was to keep things simple, but then requires checking for "blocks" in hparams["message_passing"] in the function in MPNN which is a multicomponent only idea. Alternatively we could be more explicit that load_submodules() is for both MPNN and MulticomponentMPNN by putting it in a separate file. I couldn't put it in models/utils.py though because this file imports from MPNN.

Both MPNN and MulticomponentMPNN have their own load_from_file() method. These could be condensed into a single function by checking for "blocks" in hparams["message_passing"] in the same way that I did with load_from_checkpoint(). Do we want to do this? If we do, there is also the question if we want load_from_file() to call on load_submodules() instead of using its current for key in loop.

@KnathanM KnathanM changed the title Title Fix load from checkpoint in python 3.12 Mar 23, 2024
@KnathanM KnathanM marked this pull request as ready for review March 23, 2024 16:12
@JacksonBurns
Copy link
Member

JacksonBurns commented Mar 23, 2024

@KnathanM thank you for this incredible and thorough investigation of the problem. I have learned a lot just from reading this!

I like the way this is implemented as is - I don't think we need to modify the other method for the sake of just cutting down the code a bit. I also like the way you have provided the implementation in model.py and I don't think we need to do anything else to clarify that it applies to the subclass.

My only open question would be if we can also resolve this by changing this from a comprehension to just a simple loop. I'll try that elsewhere and get back to this. Didn't work, I think this is good!

Copy link
Contributor

@davidegraff davidegraff left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if I'm understanding you correctly, we can't load any models in python 3.12 because of the in-lining of comprehensions?

@@ -255,3 +248,22 @@ def load_from_file(cls, model_path, map_location=None, strict=True) -> MPNN:
model.load_state_dict(state_dict, strict=strict)

return model


def load_submodules(checkpoint_path, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be an @classmethod? If possible, it seems better because this function is strongly coupled to the load_checkpoint() method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally I put it outside MPNN to suggest that it also loads submodules for MulticomponenetMPNN. But thinking about it again, I agree it should probably be a class method of MPNN because then it is more clearly inherited by MulticomponenetMPNN when it pulls from MPNN. I'll wait to see if Hao-Wei has any comments to address before making that change on Monday.

Comment on lines 256 to 262
if "blocks" in hparams["message_passing"]:
mp_hparams = hparams["message_passing"]
mp_hparams["blocks"] = [
block_hparams.pop("cls")(**block_hparams) for block_hparams in mp_hparams["blocks"]
]
message_passing = mp_hparams.pop("cls")(**mp_hparams)
kwargs["message_passing"] = message_passing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems somewhat error prone. What if the hyper_parameters key for an MPNN contains "blocks"? Reasonably, we should ignore this, but sharing this code between the two classes necessitates the assumption of its structure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this isn't ideal. If we left load_submodules() outside of MPNN we could bring back load_from_checkpoint() in MulticomponentMPNN and put a similar function in multi.py. Or if we make load_submodules() a classmethod of MPNN, then MulticomponentMPNN could overload it. The only catch is each needs to be self sufficient because we can't call super().
In MPNN

    @classmethod
    def load_submodules(cls, checkpoint_path, **kwargs):
        hparams = torch.load(checkpoint_path)["hyper_parameters"]

        kwargs |= {
            key: hparams[key].pop("cls")(**hparams[key])
            for key in ("message_passing", "agg", "predictor")
            if key not in kwargs
        }
        return kwargs

In MulticomponentMPNN

    @classmethod
    def load_submodules(cls, checkpoint_path, **kwargs):
        hparams = torch.load(checkpoint_path)["hyper_parameters"]

        hparams["message_passing"]["blocks"] = [
            block_hparams.pop("cls")(**block_hparams)
            for block_hparams in hparams["message_passing"]["blocks"]
        ]
        kwargs |= {
            key: hparams[key].pop("cls")(**hparams[key])
            for key in ("message_passing", "agg", "predictor")
            if key not in kwargs
        }
        return kwargs

This works for loading both MPNN and MulticomponentMPNN models when load_from_checkpoint() is class method of MPNN that calls cls.load_submodules(). What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the second option better. Obviously there's some duplicate code, but it's only across one class and it's not strictly worse style-wise either. In this case, a user doesn't have to bark up the class hierarchy to see what the function does, as it's all contained in one place

@KnathanM
Copy link
Contributor Author

So if I'm understanding you correctly, we can't load any models in python 3.12 because of the in-lining of comprehensions?

Models can still be loaded, but the message_passing, agg, and predictor submodules need to be loaded in a function that does not call super(). In 3.11 these submodules were loaded in a dictionary comprehension function call, but 3.12's in-lining comprehensions removed this extra function call.

I think we were just lucky that before 3.12 comprehensions had an extra function call. The wider problem is that lightning assumes that the call to __init__ will come from a function that doesn't have super(), which seems brittle to me. Perhaps this is an issue that should be opened on the Lightning repo, but I don't understand their code well enough to propose a solution.

But it is also confusing to me that they do this recursion based on __class__ in the first case. I've assumed this is so that if you have a submodule initialized inside of a parent module's __init__ function, then when you load the parent from a checkpoint, it will also have the hyperparameters for the child submodule, but then you need to pass all the child's hyperparameters to it through the parent, instead of creating the submodule yourself and passing that in as an object to the parent. #444 is somewhat related to this.

@davidegraff
Copy link
Contributor

davidegraff commented Mar 23, 2024

You're right that it is brittle, but that's just the way code goes... My guess is that for the large majority of cases, the code works even if it has an obvious failure mode, and the reason is that most client code is not designed with modularity in mind. That is, if you ever go to a repo for a particular ML model's implementation, the model is defined by way of its parameterization rather than its composition. Put more concretely, consider most FFN implementations in PyTorch, they're usually defined as something like this (including ours):

class FFN(nn.Module):
    def __init__(self, in_features, num_layers, hidden_size, out_features, dropout):
        layers = [nn.Linear(in_features, hidden_size)]
        layers += sum(
            [nn.Dropout(dropout), nn.Relu(), nn.Linear(hidden_size, hidden_size)] for _ in range(num_layers - 1),
            []
        )
        layers += [nn.Linear(hidden_size, out_features)]
        self.layers = nn.Sequential(*layers)

You'll notice that the logical concept of an "FFN" here is defined by some set of hyperparameters, but really it's only specific parameterization of an FFN. A resnet with a skip connection between the input and final layer is also an FFN but not according to this implementation. Really, an FFN can be configured in any number of ways, and a more general view of them is to that an FFN is a differentiable function composed of individual transformations that feed forward from the input to the output. You could define your FFN like that, taking in individual transformations (e.g., nn.Linears, nn.Dropouts, etc.) and stacking them together, but usually that degree of customization is unimportant for typical use-cases.

Consider also the v1 chemprop model; the model's architecture was hardcoded and accepted in a number of hyperparameters with which to configure it. But really the chemprop model defined in the original publication is just one type of MPNN. There are many ways to define an MPNN, which is abstractly a composition of functions on a graph: $h \circ \mathtt{agg} \circ \Phi (\mathcal G)$. You can swap out any of these components and still have a valid MPNN, and that's why the object model of v2 is designed as it is.

However, the pytorch lightning object model assumes that most users opt for the former object model in their own code. It's an unfortunate reality of the situation and why we've had to design our own workarounds.

Copy link
Member

@JacksonBurns JacksonBurns left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM would like this in soon

@KnathanM KnathanM added this to the v2.0.0 milestone Mar 28, 2024
@KnathanM KnathanM merged commit 4c18d2e into chemprop:v2/dev Apr 1, 2024
4 of 6 checks passed
@KnathanM KnathanM deleted the v2/fix/loading-checkpoint branch April 1, 2024 16:14
JacksonBurns added a commit to JacksonBurns/chemprop that referenced this pull request Apr 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG]: MulticomponentMPNN and MPNN can't be loaded from checkpoint file in Python 3.12
3 participants