In [8]:
import numpy as np
from dataclasses import dataclass, is_dataclass, fields, MISSING

# --- Leaf-level dataclasses ---
@dataclass
class QuadParams:
    mass: float
    inertia: float
    arm_length: float
    
@dataclass
class kkParams:
    mass: float
    rope_length: float

@dataclass
class PayloadParams:
    mass: float
    rope_length: float
    kk: kkParams

@dataclass
class EnvironmentParams:
    gravity: float

@dataclass
class WinchParams:
    model: str
    omega: float

# --- Top-level dataclass that groups them ---
@dataclass
class Params:
    quad: QuadParams
    payload: PayloadParams
    environment: EnvironmentParams
    winch: WinchParams

In [None]:
from pendulum_ml.utils import load_cfg

cfg = load_cfg("../configs/quad_config.yaml")

params_from_cfg = cfg["dynamics"]["params"]

cps = __import__(f"pendulum_ml.dynamics.{cfg['system']}", fromlist=[''])

def validate_params(cps, params):
    """ Validate that all required parameters are present at any level of nesting.
    
    Args:
        cps (module): dynamics module (e.g. pendulum_ml.dynamics.pendulum)
        params (dict): parameters dictionary
    Raises:
        ValueError: if a required parameter is missing or has the wrong type
    Returns:
        bool: True if all required parameters are present
    """
    for key, value in cps.REQUIRED_PARAMS.items():
        if key not in params:
            raise ValueError(f"Missing required parameter: {key}")
        if isinstance(value, dict):
            if not isinstance(params[key], dict):
                raise ValueError(f"Parameter {key} should be a dictionary.")
            validate_params(params[key])
    return True


def validate_params_dataclass(dataclass_params, params):
    """ Validate that all required parameters are present at any level of nesting.
    
    Args:
        dataclass_params (dataclass): dataclass type from the dynamics module (e.g. pendulum_ml.dynamics.pendulum.Params)
        params (dict): parameters dictionary from config file.
    Raises:
        ValueError: if a required parameter is missing or has the wrong type
    Returns:
        Param dataclass instance if all required parameters are present
    """
    
    if not is_dataclass(dataclass_params):
        raise ValueError("The Params attribute in the dynamics module is not a dataclass.")
    
    # Create a dictionary to store the processed parameters
    processed_params = {}
    
    for field in fields(dataclass_params):
        field_name = field.name
        field_type = field.type
        field_default = field.default
        field_default_factory = field.default_factory if field.default_factory is not MISSING else None
        
        if field_name not in params:
            if field_default is MISSING and field_default_factory is None:
                raise ValueError(f"Missing required parameter: {field_name}")
            else:
                continue  # Optional field with a default value, will use the default
            
        field_value = params[field_name]
        if is_dataclass(field_type):
            if not isinstance(field_value, dict):
                raise ValueError(f"Parameter {field_name} should be a dictionary.")
            
            # Recursively create dataclass instance for nested dataclass
            processed_params[field_name] = validate_params_dataclass(field_type, field_value)
        else:
            if not isinstance(field_value, field_type):
                raise ValueError(f"Parameter {field_name} should be of type {field_type.__name__}. \
                    Got {field_value} of type {type(field_value).__name__} instead.")
            processed_params[field_name] = field_value
    
    return dataclass_params(**processed_params)

print("Validating params from config file...")
params = validate_params_dataclass(Params, params_from_cfg)

print("All required parameters are present.")

print(params.quad.mass)  # Accessing nested dataclass fields with dot notation
print(params.payload.kk.rope_length)  # Accessing deeply nested dataclass fields with dot notation

Validating params from config file...
All required parameters are present.
1.5
0.5
