Skip to content

Commit

Permalink
Merge pull request #258 from nlesc-nano/noise
Browse files Browse the repository at this point in the history
ENH: Round the net charge to the nearest integer
  • Loading branch information
BvB93 committed Oct 28, 2021
2 parents 6d2a241 + 1af1d28 commit bbe36c7
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 37 deletions.
5 changes: 5 additions & 0 deletions FOX/armc/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,11 @@ def _set_net_charge(self) -> None:
else:
self._net_charge = None

def _net_charge_to_integer(self) -> None:
if self._net_charge is not None:
self._net_charge = np.round(self._net_charge).astype(np.int64)
self._net_charge.setflags(write=False)

# The actual meat of the class

def add_param(self, idx: Tup3, value: float, **kwargs: Any) -> None:
Expand Down
1 change: 1 addition & 0 deletions FOX/armc/sanitization.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def dict_to_armc(input_dict: MainMapping) -> Tuple[MonteCarloABC, RunDict]:
param._set_net_charge()
validate_charge(param._net_charge, tolerance=validation_dict['charge_tolerance'])
validate_constraints(param, enforce_constraints=validation_dict['enforce_constraints'])
param._net_charge_to_integer()

mc.param.param.sort_index(inplace=True)
mc.param.param_old.sort_index(inplace=True)
Expand Down
106 changes: 69 additions & 37 deletions FOX/functions/charge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
"""

from __future__ import annotations

from itertools import chain
from typing import (
Hashable, Optional, Collection, Container, Tuple, SupportsFloat, Set, TypeVar, Type, Generic
)
from collections.abc import Hashable, Collection, Container
from typing import SupportsFloat, TypeVar, Generic

import numpy as np
import pandas as pd
Expand All @@ -32,31 +33,35 @@
class _StateDict(TypedDict):
"""A dictionary representing the keyword-only arguments of :exc:`ChargeError`."""

reference: Optional[float]
value: Optional[float]
tol: Optional[float]
reference: None | float
value: None | float
tol: None | float


class ChargeError(ValueError, Generic[T]):
"""A :exc:`ValueError` subclass for charge-related errors."""

__slots__ = ('__weakref__', 'reference', 'value', 'tol')

reference: Optional[float]
value: Optional[float]
tol: Optional[float]
args: Tuple[T, ...]

def __init__(self, *args: T, reference: Optional[SupportsFloat] = None,
value: Optional[SupportsFloat] = None,
tol: Optional[SupportsFloat] = 0.001) -> None:
reference: None | float
value: None | float
tol: None | float
args: tuple[T, ...]

def __init__(
self,
*args: T,
reference: None | SupportsFloat = None,
value: None | SupportsFloat = None,
tol: None | SupportsFloat = 0.001,
) -> None:
"""Initialize an instance."""
super().__init__(*args)
self.reference = float(reference) if reference is not None else None
self.value = float(value) if value is not None else None
self.tol = float(tol) if tol is not None else None

def __reduce__(self: ST) -> Tuple[Type[ST], Tuple[T, ...], _StateDict]:
def __reduce__(self: ST) -> tuple[type[ST], tuple[T, ...], _StateDict]:
"""Helper for :mod:`pickle`."""
cls = type(self)
kwargs = _StateDict(reference=self.reference, value=self.value, tol=self.tol)
Expand All @@ -68,8 +73,11 @@ def __setstate__(self, state: _StateDict) -> None:
setattr(self, k, v)


def get_net_charge(param: pd.Series, count: pd.Series,
index: Optional[Collection] = None) -> float:
def get_net_charge(
param: pd.Series,
count: pd.Series,
index: None | Collection[Hashable] = None,
) -> float:
"""Calculate the total charge in **df**.
Returns the (summed) product of the ``"param"`` and ``"count"`` columns in **df**.
Expand Down Expand Up @@ -98,12 +106,17 @@ def get_net_charge(param: pd.Series, count: pd.Series,
return ret.sum()


def update_charge(atom: KT, value: float, param: pd.Series, count: pd.Series,
atom_coefs: Optional[Collection[pd.Series]] = None,
prm_min: Optional[pd.Series] = None,
prm_max: Optional[pd.Series] = None,
exclude: Optional[Collection[KT]] = None,
net_charge: Optional[float] = None) -> Optional[ChargeError]:
def update_charge(
atom: KT,
value: float,
param: pd.Series,
count: pd.Series,
atom_coefs: None | Collection[pd.Series] = None,
prm_min: None | pd.Series = None,
prm_max: None | pd.Series = None,
exclude: None | Collection[KT] = None,
net_charge: None | float = None,
) -> None | ChargeError:
"""Set the atomic charge of **at** equal to **charge**.
The atomic charges in **df** are furthermore exposed to the following constraints:
Expand Down Expand Up @@ -148,10 +161,16 @@ def update_charge(atom: KT, value: float, param: pd.Series, count: pd.Series,
return None


def constrained_update(atom: KT, value: float, param: pd.Series, count: pd.Series,
atom_coefs: Collection[pd.Series],
param_min: pd.Series, param_max: pd.Series,
exclude: Optional[Set[KT]] = None) -> None:
def constrained_update(
atom: KT,
value: float,
param: pd.Series,
count: pd.Series,
atom_coefs: Collection[pd.Series],
param_min: pd.Series,
param_max: pd.Series,
exclude: None | set[KT] = None,
) -> None:
"""Perform a constrained update of atomic charges.
Performs an inplace update **param**.
Expand Down Expand Up @@ -208,9 +227,14 @@ def constrained_update(atom: KT, value: float, param: pd.Series, count: pd.Serie
exclude_set.update(idx.intersection(idx_ref))


def _update_1st_charge(atom: KT, value: float, param: pd.Series,
param_min: pd.Series, param_max: pd.Series,
exclude: Optional[Set[KT]] = None) -> Set[KT]:
def _update_1st_charge(
atom: KT,
value: float,
param: pd.Series,
param_min: pd.Series,
param_max: pd.Series,
exclude: None | set[KT] = None,
) -> set[KT]:
"""Helper function for :func:`constrained_update`."""
if exclude is not None:
exclude_set = exclude.copy()
Expand All @@ -230,10 +254,14 @@ def _update_1st_charge(atom: KT, value: float, param: pd.Series,
return exclude_set


def unconstrained_update(net_charge: float, param: pd.Series, count: pd.Series,
prm_min: Optional[pd.Series] = None,
prm_max: Optional[pd.Series] = None,
exclude: Optional[Container[Hashable]] = None) -> None:
def unconstrained_update(
net_charge: float,
param: pd.Series,
count: pd.Series,
prm_min: None | pd.Series = None,
prm_max: None | pd.Series = None,
exclude: None | Container[Hashable] = None,
) -> None:
"""Perform an unconstrained update of atomic charges."""
if exclude is None:
include = param.astype(bool, copy=True)
Expand Down Expand Up @@ -270,8 +298,12 @@ def unconstrained_update(net_charge: float, param: pd.Series, count: pd.Series,
s_clip = np.clip(s, s_min, s_max).loc[include]


def _check_net_charge(param: pd.Series, count: pd.Series, net_charge: float,
tolerance: float = 0.001) -> None:
def _check_net_charge(
param: pd.Series,
count: pd.Series,
net_charge: float,
tolerance: float = 0.001,
) -> None:
"""Check if the net charge is actually conserved."""
net_charge_new = get_net_charge(param, count)
condition = abs(net_charge - net_charge_new) > tolerance
Expand All @@ -280,6 +312,6 @@ def _check_net_charge(param: pd.Series, count: pd.Series, net_charge: float,
return

raise ChargeError(
f"Failed to conserve the net charge: ref = {net_charge:.4f}); {net_charge_new:.4f} != ref",
f"Failed to conserve the net charge: ref = {net_charge:.4f}; {net_charge_new:.4f} != ref",
reference=net_charge, value=net_charge_new, tol=tolerance
)
1 change: 1 addition & 0 deletions FOX/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def wrap_plams_logger(logfile: Union[None, str, os.PathLike] = None,

# Modify the plams logger
config.log.time = False
config.log.date = False
config.log.file = 0

# Replace the plams logger with a proper logging.Logger instance
Expand Down

0 comments on commit bbe36c7

Please sign in to comment.