forked from allenai/allennlp
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
96ff585
commit 2f41cc8
Showing
1 changed file
with
73 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from typing import Any, Dict, Type | ||
import inspect | ||
|
||
|
||
def get_arg_params(argument: Any, annotation: Type) -> Any: | ||
existing_params = getattr(argument, '_params', None) | ||
if existing_params is not None: | ||
return existing_params | ||
elif annotation in {float, int}: | ||
return argument | ||
else: | ||
# There are several more cases to handle here if we want to actually do this, but not *that* | ||
# many more. We just have to cover all base python types. Unions might be tricky to figure | ||
# out, but should still be doable. This logic basically mirrors what we do when creating | ||
# objects in FromParams, we're just going the other way. | ||
raise ValueError() | ||
|
||
|
||
def get_params(init, *args, **kwargs) -> Dict[str, Any]: | ||
signature = inspect.signature(init) | ||
parameters = dict(signature.parameters) | ||
# need some fancy logic here to match *args and **kwargs with the parametrs. It's deterministic | ||
# and doable, but more than I want to write right now. The below is a quick and dirty first | ||
# pass. Oh, hmm, super classes and **kwargs _inside_ the __init__ method make this more tricky, | ||
# but still doable. | ||
arg_list = list(args) | ||
saved_params = {} | ||
for param_name, param in parameters.items(): | ||
if param_name == "self": | ||
continue | ||
if param_name in kwargs: | ||
argument = kwargs.pop(param_name) | ||
else: | ||
argument = arg_list.pop(0) | ||
argument_params = get_arg_params(argument, param.annotation) | ||
saved_params[param_name] = argument_params | ||
return saved_params | ||
|
||
|
||
class Meta(type): | ||
def __new__(mcs, name, bases, namespace, **kwargs): | ||
new_cls = super(Meta, mcs).__new__(mcs, name, bases, namespace, **kwargs) | ||
user_init = new_cls.__init__ | ||
def __init__(self, *args, **kwargs): | ||
self._params = get_params(user_init, *args, **kwargs) | ||
user_init(self, *args, **kwargs) | ||
setattr(new_cls, '__init__', __init__) | ||
return new_cls | ||
|
||
|
||
class FromParams(metaclass=Meta): | ||
pass | ||
|
||
|
||
class Gaussian(FromParams): | ||
def __init__(self, mean: float, variance: float): | ||
self.mean = mean | ||
self.variance = variance | ||
|
||
|
||
class NestedGaussian(FromParams): | ||
def __init__(self, gaussian: Gaussian, alpha: float): | ||
self.gaussian = gaussian | ||
self.alpha = alpha | ||
|
||
|
||
g = Gaussian(1.3, 2.3) | ||
print(g._params) | ||
g = Gaussian(mean=1.8, variance=4.3) | ||
print(g._params) | ||
|
||
n = NestedGaussian(Gaussian(mean=1.8, variance=4.3), alpha=.01) | ||
print(n._params) |