In [1]:
from dataclasses import dataclass, fields

In [2]:
from pathlib import Path
import torch

class BaseConfig:
    def to_dict(self):
        """
        A json serializable dict representation of the dataclass.
        Some types are converted to str for json serialization.
        """
        result = {}
        for f in fields(self):
            value = getattr(self, f.name)
            if isinstance(value, Path):
                value = str(value)
            elif isinstance(value, torch.device):
                value = str(value)
            # elif isinstance(value, (ModelConfig, TrainConfig, RunConfig)):
            #     value = value.to_dict()
            result[f.name] = value
        return result


In [3]:
@dataclass
class ModelConfig(BaseConfig):
    seq_len: int = None
    d_model: int = None
    n_layers: int = None
    heads: int = None

In [4]:
mc = ModelConfig(seq_len=100, d_model=512, n_layers=6, heads=8)
print(mc)
print("---")
d = mc.to_dict()
print(d)

ModelConfig(seq_len=100, d_model=512, n_layers=6, heads=8)
---
{'seq_len': 100, 'd_model': 512, 'n_layers': 6, 'heads': 8}


In [8]:
mc2 = ModelConfig(**d)

In [9]:
print(mc2)
print("---")
print(mc2.to_dict())

ModelConfig(seq_len=100, d_model=512, n_layers=6, heads=8)
---
{'seq_len': 100, 'd_model': 512, 'n_layers': 6, 'heads': 8}


In [13]:
from utils.config import RunConfig
base_dir = "/Users/ron/dev/torch/medium/movies"
dataset_dir = base_dir + "/datasets"
rc = RunConfig(base_dir=base_dir, datasets_dir=dataset_dir, run_id="123", parallel_mode="ddp", wandb=False, compile=False,
               dist_master_addr="127.0.0.1", dist_master_port=1234, dist_backend="nccl",
               async_to_device=False, fused_adamw=False, case="movies")
print(rc)
print("---")
d = rc.to_dict()
print(d)
rc2 = RunConfig(**d)
print(rc2)
print("---")
print(rc2.to_dict())
print("---")
print(rc2 == rc)

RunConfig(base_dir=PosixPath('/Users/ron/dev/torch/medium/movies'), run_id='123', parallel_mode='ddp', dist_master_addr='127.0.0.1', dist_master_port=1234, dist_backend='nccl', wandb=False, compile=False, async_to_device=False, fused_adamw=False, datasets_dir='/Users/ron/dev/torch/medium/movies/datasets', run_dir=PosixPath('/Users/ron/dev/torch/medium/movies/runs/run123'), logs_dir=PosixPath('/Users/ron/dev/torch/medium/movies/runs/run123/logs'), checkpoints_dir=PosixPath('/Users/ron/dev/torch/medium/movies/runs/run123/checkpoints'), local_rank=None, device=None, is_primary=True, case='movies')
---
{'base_dir': '/Users/ron/dev/torch/medium/movies', 'run_id': '123', 'parallel_mode': 'ddp', 'dist_master_addr': '127.0.0.1', 'dist_master_port': 1234, 'dist_backend': 'nccl', 'wandb': False, 'compile': False, 'async_to_device': False, 'fused_adamw': False, 'datasets_dir': '/Users/ron/dev/torch/medium/movies/datasets', 'run_dir': '/Users/ron/dev/torch/medium/movies/runs/run123', 'logs_dir': '

In [15]:
@dataclass
class Config(BaseConfig):
    model: ModelConfig
    run: RunConfig

c = Config(model=mc, run=rc)
print(c)
print("---")
d = c.to_dict()
print(d)
c2 = Config(**d)
print(c2)
print("---")
print(c2.to_dict())
print("---")
print(c2 == c)


Config(model=ModelConfig(seq_len=100, d_model=512, n_layers=6, heads=8), run=RunConfig(base_dir=PosixPath('/Users/ron/dev/torch/medium/movies'), run_id='123', parallel_mode='ddp', dist_master_addr='127.0.0.1', dist_master_port=1234, dist_backend='nccl', wandb=False, compile=False, async_to_device=False, fused_adamw=False, datasets_dir='/Users/ron/dev/torch/medium/movies/datasets', run_dir=PosixPath('/Users/ron/dev/torch/medium/movies/runs/run123'), logs_dir=PosixPath('/Users/ron/dev/torch/medium/movies/runs/run123/logs'), checkpoints_dir=PosixPath('/Users/ron/dev/torch/medium/movies/runs/run123/checkpoints'), local_rank=None, device=None, is_primary=True, case='movies'))
---
{'model': ModelConfig(seq_len=100, d_model=512, n_layers=6, heads=8), 'run': RunConfig(base_dir=PosixPath('/Users/ron/dev/torch/medium/movies'), run_id='123', parallel_mode='ddp', dist_master_addr='127.0.0.1', dist_master_port=1234, dist_backend='nccl', wandb=False, compile=False, async_to_device=False, fused_adamw