In [1]:
%pip install omegaconf

Collecting omegaconf
  Downloading omegaconf-2.3.0-py3-none-any.whl (79 kB)
[K     |████████████████████████████████| 79 kB 2.0 MB/s eta 0:00:011
[?25hCollecting antlr4-python3-runtime==4.9.*
  Downloading antlr4-python3-runtime-4.9.3.tar.gz (117 kB)
[K     |████████████████████████████████| 117 kB 4.2 MB/s eta 0:00:01
Building wheels for collected packages: antlr4-python3-runtime
  Building wheel for antlr4-python3-runtime (setup.py) ... [?25ldone
[?25h  Created wheel for antlr4-python3-runtime: filename=antlr4_python3_runtime-4.9.3-py3-none-any.whl size=144575 sha256=9b79ba137efae429a5470f9939d607326614f17dcd583b94b0ec1d1b6e00a147
  Stored in directory: /Users/beckhamc/Library/Caches/pip/wheels/b1/a3/c2/6df046c09459b73cc9bb6c4401b0be6c47048baf9a1617c485
Successfully built antlr4-python3-runtime
Installing collected packages: antlr4-python3-runtime, omegaconf
Successfully installed antlr4-python3-runtime-4.9.3 omegaconf-2.3.0
Note: you may need to restart the kernel to use update

In [3]:
%pip install dataclasses

Collecting dataclasses
  Downloading dataclasses-0.6-py3-none-any.whl (14 kB)
Installing collected packages: dataclasses
Successfully installed dataclasses-0.6
Note: you may need to restart the kernel to use updated packages.


----------

## Creating a data class

In [1]:
from dataclasses import dataclass, field
from typing import List

In [2]:
@dataclass
class Model:
    dim: int = 128
    dim_mults: List[int] = field(default_factory=lambda: [1,2,4,8])

In [3]:
@dataclass
class Dataset:
    name: str = "cifar10"

In [134]:
@dataclass
class Arguments:
    model: Model = Model()
    dataset: Dataset = Dataset()
    num_workers: int = 8
    epochs: int = 200
    eval_every: int = 10
    special_eval_arg: int = 1000 # we'll get to this later

In [5]:
from omegaconf import OmegaConf as OC

### Type checking

In [45]:
# OmegaConf verifies at runtime that your Lists contains only values of the correct type.

In [93]:
OC.structured(Model())

{'dim': 128, 'dim_mults': [1, 2, 4, 8]}

In [94]:
invalid_model = Model(dim_mults="1248")
try:
    OC.structured(invalid_model)
except Exception as e:
    print(e)

Invalid value assigned: str is not a ListConfig, list or tuple.
    full_key: dim_mults
    object_type=None


In [95]:
invalid_model = Model(dim_mults=None)
try:
    OC.structured(invalid_model)
except Exception as e:
    print(e)

Non optional ListConfig cannot be constructed from None
    full_key: dim_mults
    object_type=None


### Reading from dotlist

In [96]:
cfg_dotlist = [
 "num_workers=8",
 "epochs=200",
 "bad_arg=1",
]

In [97]:
cfg_dotlist

['num_workers=8', 'epochs=200', 'bad_arg=1']

First, read the dotlist into a dict, then initialise Arguments with this dict

In [98]:
# This should fail, because bad_arg is not a valid argument
Arguments(**OC.from_dotlist(cfg_dotlist))

TypeError: __init__() got an unexpected keyword argument 'bad_arg'

In [6]:
cfg_dotlist = [
 "num_workers=8",
 "epochs=helloworld",
]

In [7]:
# This should fail, because epochs needs to be an int, not str
# OC.structured() is what does the validation
try:
    OC.structured(
        Arguments(**OC.from_dotlist(cfg_dotlist))
    )
except Exception as e:
    print(e)

Value 'helloworld' of type 'str' could not be converted to Integer
    full_key: epochs
    object_type=Arguments


Let's try a more sophisticated dotlist

In [45]:
cfg_dotlist = [
 "num_workers=8",
 "epochs=200",  
    
 # dataset does not contain this arg
 "dataset.bad_arg=helloworld"
]

In [46]:
user_args = OC.from_dotlist(
 "num_workers=8",
 "epochs=200",  
    
 # dataset does not contain this arg
 "dataset.bad_arg=helloworld"
)
user_args

{'num_workers': 8, 'epochs': 200, 'dataset': {'bad_arg': 'helloworld'}}

In [47]:
main_args = OC.structured(Arguments())

In [19]:
try:
    main_args.epochs = "abc"
except Exception as e:
    print(e)

Value 'abc' of type 'str' could not be converted to Integer
    full_key: epochs
    object_type=Arguments


### Nested checking

Our dataclasses may be nested with other dataclasses (e.g `Arguments` contains `Model` as well as `Dataset`). We would like to be able to pass in a dotlist that may specify values for multiple of these classes. We would also like to validate all of these with a simple method call.

In [23]:
from omegaconf.dictconfig import DictConfig

In [150]:
def validate(main_args: DictConfig, user_args: DictConfig, padding=""):
    """Insert user args into a default main args. An exception will be thrown
    if any illegal arguments are given."""
    if type(main_args) is not DictConfig:
        raise ValueError("{}main_args needs to be a DictConfig".format(padding))
    if type(user_args) is not DictConfig:
        raise ValueError("{}user_args needs to be a DictConfig".format(padding))
    for k,v in user_args.items():
        #if k not in main_args:
        #    raise ValueError("user_args contains {} but this is not in main_args".\
        #                     format(k))
        if type(v) is not DictConfig:
            print("{}set {} -> {}".format(padding, k,v))
            main_args[k] = v
        else:
            #pass
            print("{}recurse into key={}".format(padding, k))
            validate(main_args[k], user_args[k], padding + "  ")

In [86]:
main_args = OC.structured(Arguments())    # construct default arg list
user_args = OC.from_dotlist([
 "num_workers=8",  # OK
 "epochs=200",     # OK
    
 # dataset does not contain this arg, so it
 # should fail.
 "dataset.bad_arg=helloworld"
])

In [87]:
try:
    validate(main_args, user_args)
except Exception as e:
    print(e)

set num_workers -> 8
set epochs -> 200
recurse into key=dataset
  set bad_arg -> helloworld
Key 'bad_arg' not in 'Dataset'
    full_key: dataset.bad_arg
    reference_type=Dataset
    object_type=Dataset


In [103]:
main_args = OC.structured(Arguments())    # construct default arg list
user_args = OC.from_dotlist([
 "num_workers=8",  # OK
 "epochs=200",     # OK
    
 # model does contain dim_mults but it needs to be a list
 # of ints, not a list of strings
 "model.dim_mults=['blah','blah']"
])

In [102]:
try:
    validate(main_args, user_args)
except Exception as e:
    print(e)

set num_workers -> 8
set epochs -> 200
recurse into key=model
  set dim_mults -> ['blah', 'blah']
Value 'blah' of type 'str' could not be converted to Integer
    full_key: model.dim_mults[0]
    reference_type=List[int]
    object_type=list


In [95]:
main_args = OC.structured(Arguments())    # construct default arg list
user_args = OC.from_dotlist([
 "num_workers=8",
 "epochs=200",  
    
 # we specify dataset but it's a blank dictionary
 # this should pass the test
 "dataset={}"
])

In [96]:
try:
    validate(main_args, user_args)
except Exception as e:
    print(e)

set num_workers -> 8
set epochs -> 200
recurse into key=dataset


In [111]:
main_args = OC.structured(Arguments())    # construct default arg list
user_args = OC.from_dotlist([
 "num_workers=8",
 "epochs=200",  
    
 "dataset.name=mnist",
 # dataset does not have a key called 'nested_dataset'
 # so this should fail.
 "dataset.nested_dataset={}"
])

In [112]:
try:
    validate(main_args, user_args)
except Exception as e:
    print(e)

set num_workers -> 8
set epochs -> 200
recurse into key=dataset
  set name -> mnist
  recurse into key=nested_dataset
Key 'nested_dataset' not in 'Dataset'
    full_key: dataset.nested_dataset
    reference_type=Dataset
    object_type=Dataset


In [131]:
main_args = OC.structured(Arguments())    # construct default arg list
user_args = OC.from_dotlist([
 "num_workers=8",
 "epochs=200",  
    
 # this should pass
 "dataset={}",
 "model={}",
])

try:
    validate(main_args, user_args)
except Exception as e:
    print(e)

set num_workers -> 8
set epochs -> 200
recurse into key=dataset
recurse into key=model


In [132]:
user_args

{'num_workers': 8, 'epochs': 200, 'dataset': {}, 'model': {}}

In [140]:
OC.register_new_resolver("eval", eval)

In [143]:
main_args = OC.structured(Arguments())    # construct default arg list
user_args = OC.from_dotlist([
 "num_workers=8",
 "epochs=200",  
    
 # see if the special eval argument works
 # let's set it to the result of 10*2
 "special_eval_arg=${eval:'10*2'}",
])

try:
    validate(main_args, user_args)
except Exception as e:
    print(e)

set num_workers -> 8
set epochs -> 200
set special_eval_arg -> 20


**TODO: add a custom resolver which resolves a class, kind of like what importlib does**

### Dummy class for Model

In [120]:
import dataclasses
dataclasses.is_dataclass(Model())

True

In [148]:
def create_model(n_in,
                 *,
                 model_args: Model):
    assert dataclasses.is_dataclass(model_args)
    # create model here
    # e.g. return UNet(n_in, **model_args)
    pass

In [149]:
create_model(10, model_args=Model())

In [146]:
Model()

SyntaxError: invalid syntax (335139012.py, line 1)

In [113]:
# https://stackoverflow.com/questions/14301967/bare-asterisk-in-function-parameters