In [304]:
import torch
import inspect
import importlib

## GLOBAL

In [305]:
class ConfigDict(dict):
    def __init__(self, **kwargs):
        super().__init__()
        super().update(**kwargs)
    
    def __getattr__(self, key):
        if key in self:
            return self[key]
        else:
            raise AttributeError('---')
    
    def __setattr__(self, key, value):
        super().__setattr__(key, value)

In [306]:
GLOBAL_CONFIG = {}

## UTILS

In [307]:
def wrapModule(cls):
    '''decorate
    '''
    schema = {}
    schema['name'] = cls.__name__
    schema['module'] = importlib.import_module(cls.__module__)
    schema['cls'] = cls
    
    argspec = inspect.getfullargspec(cls.__init__)
    keys = [arg for arg in argspec.args if arg != 'self']
    values = [] if argspec.defaults is None else list(argspec.defaults)
    values = [None, ] * (len(keys) - len(values)) + values
    assert len(keys) == len(values), ''
    
    print(keys, values)
    schema['kwargs'] = dict(zip(keys, values))
    schema['args_requeired'] = len(keys) - len(values)
    
    
    return schema



def register(cls):
    '''
    '''
    if cls.__name__ in GLOBAL_CONFIG:
        raise ValueError(f'{cls.__name__} already exist.')
    
    GLOBAL_CONFIG[cls.__name__] = wrapModule(cls)
    
    return cls


def create(name):
    '''
    '''
    cls = getattr(GLOBAL_CONFIG[name]['module'], GLOBAL_CONFIG[name]['name'])
    kwargs = GLOBAL_CONFIG[name]['kwargs']
    
    return cls(**kwargs)


def inject(cfgs):
    '''to GLOVAL_CONFIG
    '''
    
    for name in cfgs:
        schema = GLOBAL_CONFIG[name]
        cfg = cfgs[name]
        for k in cfg:
            schema[k].update(cfg[k])
    

## COMPONET

In [308]:
@register
class MM(torch.nn.Module):
    def __init__(self, a=1, b=2):
        super().__init__()
        self.conv = torch.nn.Conv2d(a, b, 3, 2, 1)
    
    def forward(self, data):
        pass

['a', 'b'] [1, 2]


## EXAMPLE

In [309]:
config = {'MM': {'kwargs': {'a': 22, 'b': 20}}}


print(GLOBAL_CONFIG)
inject(config)
print('------------')
print(GLOBAL_CONFIG)

{'MM': {'name': 'MM', 'module': <module '__main__'>, 'cls': <class '__main__.MM'>, 'kwargs': {'a': 1, 'b': 2}, 'args_requeired': 0}}
------------
{'MM': {'name': 'MM', 'module': <module '__main__'>, 'cls': <class '__main__.MM'>, 'kwargs': {'a': 22, 'b': 20}, 'args_requeired': 0}}


In [310]:
mm = create('MM')
print(mm)

MM(
  (conv): Conv2d(22, 20, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)
