Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: config tweaks #360

Merged
merged 4 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions secator/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,8 @@ def config():
def config_get(full, key=None):
"""Get config value."""
if key is None:
CONFIG.print(partial=not full)
partial = not full and CONFIG != default_config
CONFIG.print(partial=partial)
return
CONFIG.get(key)

Expand All @@ -471,13 +472,16 @@ def config_get(full, key=None):
@click.argument('value')
def config_set(key, value):
"""Set config value."""
success = CONFIG.set(key, value)
if success:
CONFIG.set(key, value)
config = CONFIG.validate()
if config:
CONFIG.get(key)
saved = CONFIG.save()
if not saved:
return
console.print(f'[bold green]:tada: Saved config to [/]{CONFIG._path}')
else:
console.print('[bold red]:x: Invalid config, not saving it.')


@config.command('edit')
Expand All @@ -489,7 +493,7 @@ def config_edit(resume):
shutil.copyfile(config_path, tmp_config)
click.edit(filename=tmp_config)
config = Config.parse(path=tmp_config)
if config._valid:
if config:
config.save(config_path)
console.print(f'\n[bold green]:tada: Saved config to [/]{config_path}.')
tmp_config.unlink()
Expand Down
222 changes: 123 additions & 99 deletions secator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,57 +188,63 @@ def get(self, key=None, print=True):
Config.print_yaml(yaml_str)
return value

def set(self, key, value, set_partial=True):
def set(self, key, value):
"""Set a value in the configuration using a dotted path.

Args:
key (str | None): Dotted key path.
value (Any): Value.
partial (bool): Also set value in partial config (written to disk).

Returns:
bool: Success boolean.
"""
# Get existing value
existing_value = self.get(key, print=False)

# Convert dotted key path to the corresponding uppercase key used in _keymap
map_key = key.upper().replace('.', '_')
success = False
if map_key in self._keymap:
# Traverse to the second last key to handle the setting correctly
target = self
partial = self._partial
for part in self._keymap[map_key][:-1]:
target = target[part]
if set_partial:
partial = partial[part]

# Set the value on the final part of the path
final_key = self._keymap[map_key][-1]

# Convert the value to the correct type based on the current value type
try:
if isinstance(existing_value, bool):
if isinstance(value, str):
value = value.lower() in ("true", "1", "t")
elif isinstance(value, (int, float)):
value = True if value == 1 else False
elif isinstance(existing_value, int):
value = int(value)
elif isinstance(existing_value, float):
value = float(value)
if existing_value != value:
target[final_key] = value
if set_partial:
partial[final_key] = value
success = True
except ValueError:
success = False
# console.print(f'[bold red]{key}: cannot cast value "{value}" to {type(existing_value).__name__}')
else:

# Check if map key exists
if map_key not in self._keymap:
console.print(f'[bold red]Key "{key}" not found in config keymap[/].')
return success
return

# Traverse to the second last key to handle the setting correctly
target = self
partial = self._partial
for part in self._keymap[map_key][:-1]:
target = target[part]
partial = partial[part]

# Set the value on the final part of the path
final_key = self._keymap[map_key][-1]

# Try to convert value to expected type
try:
if isinstance(existing_value, list):
if isinstance(value, str):
if value.startswith('[') and value.endswith(']'):
value = value[1:-1]
if ',' in value:
value = [c.strip() for c in value.split(',')]
elif isinstance(existing_value, dict):
if isinstance(value, str):
if value.startswith('{') and value.endswith('}'):
import json
value = json.loads(value)
elif isinstance(existing_value, bool):
if isinstance(value, str):
value = value.lower() in ("true", "1", "t")
elif isinstance(value, (int, float)):
value = True if value == 1 else False
elif isinstance(existing_value, int):
value = int(value)
elif isinstance(existing_value, float):
value = float(value)
elif isinstance(existing_value, Path):
value = Path(value)
except ValueError:
pass
finally:
target[final_key] = value
partial[final_key] = value

def save(self, target_path: Path = None, partial=True):
"""Save config as YAML on disk.
Expand Down Expand Up @@ -266,67 +272,81 @@ def print(self, partial=True):
Config.print_yaml(yaml_str)

@staticmethod
def parse(data: dict = {}, path: Path = None, env_overrides: bool = False):
def parse(data: dict = {}, path: Path = None, print_errors: bool = True):
"""Parse config.

Args:
data (dict): Config data.
path (Path | None): Path to YAML config.
env_overrides (bool): Apply env overrides.
print_errors (bool): Print validation errors to console.

Returns:
Config: instance of Config object.
None: if the config was not loaded properly or there are validation errors.
"""
# Load YAML file
if path:
data = Config.read_yaml(path)

# Load data
try:
config = Config.load(SecatorConfig, data)
config._valid = True

# HACK: set default result_backend if unset
if not config.celery.result_backend:
config.celery.result_backend = f'file://{config.dirs.celery_results}'
config = Config.load(SecatorConfig, data, print_errors=print_errors)
valid = config is not None
if not valid:
return None

except ValidationError as e:
error_str = str(e).replace('\n', '\n ')
if path:
error_str.replace('SecatorConfig', f'SecatorConfig ({path})')
console.print(f'[bold red]:x: {error_str}')
# console.print('[bold green]Using default config.[/]')
config = Config.parse()
config._valid = False

# Set hidden attributes
keymap = Config.build_key_map(config)
partial = Config(data)
config._partial = partial
config._path = path
config._keymap = keymap
# Set extras
config.set_extras(data, path)

# Override config values with environment variables
if env_overrides:
config.apply_env_overrides()
data = {k: v for k, v in config.toDict().items() if not k.startswith('_')}
config = Config.parse(data, env_overrides=False) # re-validate config
config._partial = partial
config._path = path
config.apply_env_overrides(print_errors=print_errors)

# Validate config
config.validate(print_errors=print_errors)

return config

def validate(self, print_errors=True):
"""Validate config."""
return Config.load(
SecatorConfig,
data=self._partial.toDict(),
print_errors=print_errors)

def set_extras(self, original_data, original_path):
"""Set extra useful values in config.

Args:
original_data (data): Original dict data.
original_path (pathlib.Path): Original YAML path.
valid (bool): Boolean indicating if config is valid or not.
"""
self._path = original_path
self._partial = Config(original_data)
self._keymap = Config.build_key_map(self)

# HACK: set default result_backend if unset
if not self.celery.result_backend:
self.celery.result_backend = f'file://{self.dirs.celery_results}'

@staticmethod
def load(schema, data: dict = {}):
def load(schema, data: dict = {}, print_errors=True):
"""Validate a config using Pydantic.

Args:
data (dict): Config dict.
schema (pydantic.Schema): Pydantic schema.
data (dict): Input data.
print_errors (bool): Print validation errors.

Returns:
Config: instance of Config object.
Config|None: instance of Config object or None if invalid.
"""
return Config(schema(**data).model_dump())
try:
return Config(schema(**data).model_dump())
except ValidationError as e:
if print_errors:
error_str = str(e).replace('\n', '\n ')
console.print(f'[bold red]:x: {error_str}')
return None

@staticmethod
def read_yaml(yaml_path):
Expand All @@ -338,9 +358,16 @@ def read_yaml(yaml_path):
Returns:
dict: Loaded data.
"""
with yaml_path.open('r') as f:
data = yaml.load(f.read(), Loader=yaml.Loader)
return data or {}
if not yaml_path.exists():
console.print(f'[bold red]Config not found: {yaml_path}.[/]')
return {}
try:
with yaml_path.open('r') as f:
data = yaml.load(f.read(), Loader=yaml.Loader)
return data or {}
except yaml.YAMLError as e:
console.print(f'[bold red]:x: Error loading {yaml_path} {str(e)}')
return {}

@staticmethod
def print_yaml(string):
Expand Down Expand Up @@ -381,6 +408,7 @@ def posix_path_representer(dumper, data):
path = path.replace(home, '~')
return dumper.represent_scalar('tag:yaml.org,2002:str', path)

LineBreakDumper.add_representer(str, posix_path_representer)
LineBreakDumper.add_representer(Path, posix_path_representer)
LineBreakDumper.add_representer(PosixPath, posix_path_representer)
LineBreakDumper.add_representer(WindowsPath, posix_path_representer)
Expand Down Expand Up @@ -413,31 +441,20 @@ def build_key_map(config, base_path=[]):
key_map['_'.join(current_path).upper()] = current_path
return key_map

def apply_env_overrides(self):
def apply_env_overrides(self, print_errors=True):
"""Override config values from environment variables."""
# Build a map of keys from the config
key_map = Config.build_key_map(self)

# Prefix for environment variables to target
prefix = "SECATOR_"

# Loop through environment variables
for var in os.environ:
if var.startswith(prefix):
# Remove prefix and get the path from the key map
key = var[len(prefix):]
if key in key_map:
path = '.'.join(k.lower() for k in key_map[key])
key = var[len(prefix):] # remove prefix
if key in self._keymap:
path = '.'.join(k.lower() for k in self._keymap[key])
value = os.environ[var]

# Set the new value recursively
success = self.set(path, value, set_partial=True)
if success:
console.print(f'[bold green4]{var} (override success)[/]')
else:
console.print(f'[bold red]{var} (override failed: cannot update value)[/]')
else:
console.print(f'[bold red]{var} (override failed: key not found in config)[/]')
self.set(path, value)
if not self.validate(print_errors=False) and print_errors:
console.print(f'[bold red]{var} (override failed)[/]')
elif print_errors:
console.print(f'[bold red]{var} (override failed: key not found)[/]')


def download_files(data: dict, target_folder: Path, offline_mode: bool, type: str):
Expand Down Expand Up @@ -502,8 +519,10 @@ def download_files(data: dict, target_folder: Path, offline_mode: bool, type: st
data[name] = target_path.resolve()


# Load configs
default_config = Config.parse()
# Load default_config
default_config = Config.parse(print_errors=False)

# Load user config
data_root = default_config.dirs.data
config_path = data_root / 'config.yml'
if not config_path.exists():
Expand All @@ -515,7 +534,12 @@ def download_files(data: dict, target_folder: Path, offline_mode: bool, type: st
f'[bold turquoise4]Creating user conf [bold magenta]{config_path}[/]... [/]', end='')
config_path.touch()
console.print('[bold green]ok.[/]')
CONFIG = Config.parse(path=config_path, env_overrides=True)
CONFIG = Config.parse(path=config_path)

# Fallback to default if invalid user config
if not CONFIG:
console.print(f'[bold orange1]Invalid user config {config_path}. Falling back to default config.')
CONFIG = default_config

# Create directories if they don't exist already
for name, dir in CONFIG.dirs.items():
Expand Down
Loading
Loading