Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Partial instantiate/call #1283

Closed
n2cholas opened this issue Jan 8, 2021 · 10 comments · Fixed by #1905
Closed

[Feature Request] Partial instantiate/call #1283

n2cholas opened this issue Jan 8, 2021 · 10 comments · Fixed by #1905
Assignees
Labels
enhancement Enhanvement request
Milestone

Comments

@n2cholas
Copy link

n2cholas commented Jan 8, 2021

First off, thanks for the wonderful library!!

🚀 Feature Request

I would like to be able to partially instantiate/call a class/function. For example:

# config.yaml
loss_fn:
  _target_: main.cross_entropy_loss
  _partial_: True
  label_smooth_amount: 0.1
# main.py
def cross_entropy_loss(logits, labels, label_smooth_amount):
    ...

loss_fn = hydra.utils.instantiate(cfg.loss_fn)   
# loss_fn = functools.partial(cross_entropy_loss, label_smooth_amount=0.1)

Motivation

Currently, to make the above use-case work, I would do the following:

# config.yaml
loss_fn:
  _target_: main.get_cross_entropy_loss
  label_smooth_amount: 0.1
# main.py
def get_cross_entropy_loss(label_smooth_amount):
    def cross_entropy_loss(logits, labels):
        ...
    return cross_entropy_loss

loss_fn = hydra.utils.instantiate(cfg.loss_fn)

This is acceptable for code I write, but does not work well when I am trying to configure functions that libraries provide. Many libraries follow a more functional style (compared to PyTorch and TensorFlow), so losses/activations/metrics are provided as simple functions as opposed to callable objects. For example, Flax for JAX (and several other neural network libraries for JAX) defines all its activation functions and pooling layers as straightforward functions instead of classes, making partial instantiation crucial for configuration.

Also, code will often be more clear when there are more simple functions and fewer higher order functions/classes. Partial instantiation will prevent code from having too many of the latter.

Pitch

Describe the solution you'd like

Having an optional _partial_ entry in the config (similar to _recursive_ and _convert_) in my view is the most straightforward way to achieve this. By default this would be False, and when True, partially instantiates/calls the class/function instead of actually instantiating/calling.

Describe alternatives you've considered

Another option is to introduce two new methods: hydra.utils.partial_instantiate and hydra.utils.partial_call. This removes the need for another config entry, and makes it more clear at the call-site what's going on. There is one major disadvantage: it's not clear how this would work with _recursive_=True. Would all the recursive instantiations be partial? You probably don't want that. Will only the top level instantiation be partial? This would limit some use cases as well.

For this reason, I think the _partial_ entry makes the most sense.

Are you willing to open a pull request? (See CONTRIBUTING)

Yes! I was planning to make an _pop_is_partial function (like pop_convert_mode), then add the appropriate functools.partial calls here.

@n2cholas n2cholas added the enhancement Enhanvement request label Jan 8, 2021
@omry
Copy link
Collaborator

omry commented Jan 8, 2021

Can you explain what problem you are trying to solve in more details?

@n2cholas
Copy link
Author

n2cholas commented Jan 8, 2021

Sure, below is an example similar to what I am facing in my current project. Consider this simple CNN class:

# main.py
from torch import nn
from typing import Sequence


class CNN(nn.Module):
    def __init__(self,
                 input_channels: int,
                 hidden_sizes: Sequence[int],
                 conv_layer = nn.Conv2d,
                 norm_layer = nn.BatchNorm2d,
                 activation_fn = nn.ReLU):
        super().__init__()
        
        layers = []
        for h_size in hidden_sizes:
            layers.append(conv_layer(input_channels, h_size, kernel_size=(3, 3)))
            layers.append(norm_layer(h_size))
            layers.append(activation_fn())
            input_channels = h_size

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x).mean((-1, -2))

I would like to be able to change what type of conv_layer, norm_layer, or activation_fn the model uses, as well as their hyperparameters. For example, I may want to use SyncBatchNorm instead of BachNorm2d. Or, I may want to adjust the eps or momentum of the batchnorm used in this model. Currently, I don't see any way to do this, because instantiate will only create one specific instance of the norm_layer.

With partial instantiation, I can specify my config as follows:

# config.yaml
model:
  _target_: main.CNN
  input_channels: 3
  hidden_sizes: [8, 16, 32, 10]
  norm_layer:
    _target_: torch.nn.SyncBatchNorm
    _partial_: True
    momentum: 0.2
    eps: 0.01

Then in my main,

# main.py
model = instantiate(cfg.model)

Which would be the equivalent of

# main.py
from functools import partial

model = CNN(
    input_channels=3,
    hidden_sizes=[8, 16, 32, 10],
    norm_layer=partial(torch.nn.SyncBatchNorm, momentum=0.2, eps=0.1)
)

Does this make sense? Is there any way to accomplish this currently?

@omry
Copy link
Collaborator

omry commented Jan 8, 2021

Yes, it takes some acrobatics but you can already achieve it in a relatively clean way:

from typing import Any

from functools import partial

import hydra
from hydra.utils import instantiate
from omegaconf import OmegaConf

# Registering a resolver that can return a callabale method
# This also works on classes, but you can also register get_class which is almost identical.
OmegaConf.register_resolver("get_method", hydra.utils.get_method)

cfg = OmegaConf.create({"method": "${get_method:math.sin}"})
print("sin(1)", cfg.method(1))


class Foo:
    def __init__(self) -> None:
        self.a = 1
        self.b = 1

    def run(self) -> int:
        return self.a + self.b


class UsefulFoo(Foo):
    def __init__(self, a: int, b: int) -> None:
        self.a = a
        self.b = b


prt = partial(UsefulFoo, 10)
print("20 + 10 = ", prt(20).run())


class Bar:
    def __init__(self, foo: Any = Foo, y: int = 1) -> None:
        self.result = foo().run()


def partial2(func: Any, *args, **kwargs) -> Any:
    """
    normal partial requires func to be passed as a positional argument.
    This is not currently supported by instantiate, this function bridges that gap
    """
    return partial(func, *args, **kwargs)


print("direct instantiate", Bar(y=2).result)

print(
    "instantiate",
    instantiate({"_target_": "__main__.Bar"}).result,
)

print(
    "instantiate partial",
    instantiate(
        {
            "_target_": "__main__.Bar",
            "foo": {
                "_target_": "__main__.partial2",
                "func": "${get_method:__main__.UsefulFoo}",
                "a": 10,
                "b": 20,
            },
        }
    ).result,
)

Output:

sin(1) 0.8414709848078965
20 + 10 =  30
direct instantiate 2
instantiate 2
instantiate partial 30

@n2cholas
Copy link
Author

n2cholas commented Jan 8, 2021

Thanks, this solves my problem! I think it would be clean to have this integrated into the library via _partial_, but if you think that it's an unnecessary addition, please close this issue. This fits well with my use case :). Thanks again!!

@omry
Copy link
Collaborator

omry commented Jan 8, 2021

Great!
While it could be a nice to add explicit support for partial, it seems like a pretty remote corner of Python and I would rather not commit to explicit support at this time (also, imagine how difficult it would be to explain to a reader of the instantiate doc what _partial_ means).

If this issue generates a lot of interest I can consider adding explicit support later.

I hope my answer can help others in a similar situation.

Closing for now.

@omry
Copy link
Collaborator

omry commented Feb 27, 2021

Will re-evaluate for 1.2.

@omry omry reopened this Feb 27, 2021
@omry
Copy link
Collaborator

omry commented Apr 29, 2021

FYI: Hydra 1.1 instantiate now supports positional arguments.
This means that the acrobatics in the above example are no longer needed and this can be achieved in a cleaner way (The example is also significantly cleaner):

from functools import partial

import hydra
from hydra.utils import instantiate
from omegaconf import OmegaConf

# python version
basetwo = partial(int, base=2)
assert basetwo("10010") == 18

# instantiate version

# Registering a resolver that can return a callabale method
# This also works on classes, but you can also register get_class which is almost identical.
OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)

basetwo2 = instantiate({
        "_target_": "functools.partial",
        "_args_": ["${get_method:builtins.int}"],
        "base": 2,
    }
)

assert basetwo2("10010") == 18

@Queuecumber
Copy link
Contributor

You can make this very clean with a wrapper around partial:

mm/run/partial.py

import functools

from hydra.utils import get_method


def partial(_partial_, *args, **kwargs):
    return functools.partial(get_method(_partial_), *args, **kwargs)

mm/partial.yaml

_target_: mm.run.partial
_partial_: ???

optimizers/adam.yaml

defaults:
  - /mm/partial@_here_

_partial_: torch.optim.Adam
lr: ???

I tried a bunch of different ways of doing this with interpolations and more tricks with the base config to try to make it even cleaner but this was the only thing I could get working.

@xvr-hlt
Copy link

xvr-hlt commented Sep 1, 2021

FWIW we'd love if _partial_ became a first class citizen too – it comes up a whole bunch w/ anything pytorch_lightning related, e.g. passing through a partial of the optimizer to then call with self.parameters, passing through the function that generates instances to the DataModule rather than the instances themselves so they're created on each ddp process without IO overhead etc.

Using @Queuecumber 's solution at the moment, and understand if you don't end up including explicit partials (it's a little complicated) – but just wanted to mention that we'd love to see them too and happy to answer any questions about use case.

@addisonklinke
Copy link

Thanks for the excellent example @omry. Here's a PyTorch Lightning specific demo I created to illustrate @xvr-hlt's use-case for passing partial optimizers through for later use with model parameters

from functools import partial

import hydra
from hydra.utils import instantiate
from omegaconf import OmegaConf
import pytorch_lightning as pl
import torch


class BoringModel(pl.LightningModule):

    def __init__(self, optim_partial, in_feats=4, out_feats=2):
        super().__init__()

        # Hydra config components
        self.optim_partial = optim_partial
        self.in_feats = in_feats
        self.out_feats = out_feats

        # Control weight randomness
        pl.seed_everything(1234)
        self.layer = torch.nn.Linear(in_feats, out_feats)

    def forward(self, x):
        return self.layer(x)

    def configure_optimizers(self):
        return self.optim_partial(self.parameters())


# Plain python approach
model = BoringModel(optim_partial=partial(torch.optim.Adam, lr=1e-5, weight_decay=0.2))
optimizer = model.configure_optimizers()

# Partial instantiate approach
OmegaConf.register_new_resolver("get_method", hydra.utils.get_method)
cfg = {
    '_target_': '__main__.BoringModel',
    'in_feats': 4,
    'out_feats': 2,
    'optim_partial': {
        '_target_': 'functools.partial',
        '_args_': ['${get_method: torch.optim.Adam}'],
        'lr': 1e-5,
        'weight_decay': 0.2}}
model2 = instantiate(cfg)
optimizer2 = model2.configure_optimizers()

# Equality comparison of all optimization hyperparameters + model parameters
for g, group in enumerate(optimizer2.param_groups):
    for k, v in group.items():
        if k == 'params':
            for p, param in enumerate(v):
                assert torch.all(param == optimizer.param_groups[g][k][p])
        else:
            assert v == optimizer.param_groups[g][k]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Enhanvement request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants