In [1]:
import hydra
from typing import Any, Dict, Callable, Optional, Tuple, TypeVar, Union, List
from dataclasses import dataclass
from omegaconf import OmegaConf
from hydra.core.config_store import ConfigStore
from pprint import pprint
from hydra.core.global_hydra import GlobalHydra

import logging
import warnings
warnings.filterwarnings('ignore')


In [2]:
# Taken from this thread: https://github.com/facebookresearch/hydra/issues/1982
# This aims to add support for vscode autocompletion



SUBCONFIG_PATHS = {
    'optim' : 'optim@model.optim',
    'tokenizer': 'model/tokenizer@model.tokenizer'
}

_T = TypeVar('_T')

def __dataclass_transform__(
    *,
    eq_default: bool = True,
    order_default: bool = False,
    kw_only_default: bool = False,
    field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]:
    return lambda a: a


@__dataclass_transform__(eq_default=True, order_default=False, kw_only_default=True)
def structured_config(
    name: Optional[str] = None,
    group: Optional[str] = None,
    package: Optional[str] = None,
    provider: Optional[str] = None,
    init: bool = True,
    repr: bool = True,
    eq: bool = True,
    order: bool = False,
    unsafe_hash: bool = False,
    frozen: bool = False
):
    def decorator(cls=None):
        def wrapper(cls: Any):
            # Wrap class into a dataclass
            new_cls = dataclass(cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen)            
            # Store structure config into Config Store
            if name is not None:
                config_store = ConfigStore.instance()
                config_store.store(group=group, name=name, package=package, provider=provider, node=new_cls)
             
            #### WIP: Configs without a name append their fields instead of substituting whole subtree
            if name:
                new_cls.hydra_override_str = f'{SUBCONFIG_PATHS[group]}={name}'
            else:
                ov_str = SUBCONFIG_PATHS[group].split('@')[-1]
                new_cls.hydra_override_str = ov_str
            
            #####
            return new_cls
        

        # See if we're being called as @structured_config or @structured_config().
        if cls is None:
            return wrapper
        return wrapper(cls)
    return decorator


In [3]:
class Singleton(type):
    _instances = {}
    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super().__call__(*args, **kwargs)
        return cls._instances[cls]
        
# We are making it a singleton because I still don't know how to handle multiple config dirs and this is a safe option
class ConfigBuilder(metaclass=Singleton):
    def __init__(self, config_dir):
        if not GlobalHydra().is_initialized():
            hydra.initialize_config_dir(config_dir=config_dir)
        self._hydra = GlobalHydra().hydra
        self.overrides = []
    
    def compose(self, task):
        return hydra.compose(config_name=f'{task}_config', overrides=self.overrides)
    
    def str_override(self, s):
        # TODO: Add some check whether this kind of override is already present
        # The order of overrides is last-one-wins. So when we type model=gpt model=bert
        # then the model ends up being bert
        
        # TODO: do a grammar check
        self.overrides.append(s)
        
    def info(self, task, option='all'):
        #options = {
        #    "all": self._hydra._print_all_info,
        #    "defaults": self._hydra._print_defaults_list,
        #    "defaults-tree": self._hydra._print_defaults_tree,
        #    "config": self._hydra._print_config_info,
        #    "plugins": self._hydra._print_plugins_info,
        #    "searchpath": self._hydra._print_search_path,
        #}

        GlobalHydra().hydra.show_info(option, config_name=f'{task}_config', overrides=self.overrides)
        
        # Clean up the logger handles. WAR needed, because hydra leaves them open
        root = logging.getLogger()
        for h in root.handlers:
            h.close()
        root.handlers=[]
        
    def override(self, cfg):  
        # Dirty but works, needs rethinking!
        if '=' not in cfg.hydra_override_str:
            for k, v in vars(cfg()).items():
                self.str_override(f'{cfg.hydra_override_str}.{k}={v}')
        else:
            self.str_override(cfg.hydra_override_str)
            


In [4]:
@structured_config(group='optim', name='sgd')
class SGDConfig:
    lr: float = 1e-3
    momentum: float = 0.0
    dampening: float = 0.0
    weight_decay: float = 0.0
        
@structured_config(group='optim')
class MyOptim:
    lr: float = 2e-2
    weight_decay: float = 0.2

NameError: name 'SDGConfig' is not defined

In [5]:
builder = ConfigBuilder('/opt/NeMo/examples/nlp/language_modeling/alt_conf/')

In [6]:
builder.str_override('model=gpt')
builder.override(SGDConfig) # Substitution
builder.override(MyOptim) # Merging

builder.info(task='pretraining', option='defaults')


Defaults List
*************
| Config path                           | Package             | _self_ | Parent             | 
----------------------------------------------------------------------------------------------
| hydra/output/default                  | hydra               | False  | hydra/config       |
| hydra/launcher/basic                  | hydra.launcher      | False  | hydra/config       |
| hydra/sweeper/basic                   | hydra.sweeper       | False  | hydra/config       |
| hydra/help/default                    | hydra.help          | False  | hydra/config       |
| hydra/hydra_help/default              | hydra.hydra_help    | False  | hydra/config       |
| hydra/hydra_logging/default           | hydra.hydra_logging | False  | hydra/config       |
| hydra/job_logging/default             | hydra.job_logging   | False  | hydra/config       |
| hydra/env/default                     | hydra.env           | False  | hydra/config       |
| hydra/config               

In [7]:
cfg = builder.compose(task='pretraining')
pprint(OmegaConf.to_container(cfg))

{'exp_manager': {'checkpoint_callback_params': {'always_save_nemo': False,
                                                'filename': '${name}--{val_loss:.2f}-{step}-{consumed_samples}',
                                                'mode': 'min',
                                                'model_parallel_size': '${multiply:${model.tensor_model_parallel_size}, '
                                                                       '${model.pipeline_model_parallel_size}}',
                                                'monitor': 'val_loss',
                                                'save_nemo_on_train_end': False,
                                                'save_top_k': 10},
                 'create_checkpoint_callback': True,
                 'create_wandb_logger': False,
                 'exp_dir': None,
                 'explicit_log_dir': None,
                 'name': '${name}',
                 'resume_if_exists': True,
                 'resume_ignore_no_chec

In [8]:
builder.overrides

['model=gpt',
 'optim@model.optim=sgd',
 'model.optim.lr=0.02',
 'model.optim.weight_decay=0.2']