Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions src/easyscience/fitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable
from typing import List
from typing import Optional
from typing import Union

import numpy as np

Expand All @@ -14,7 +15,7 @@
from .minimizers.factory import factory
from .minimizers.factory import from_string_to_enum

DEFAULT_MINIMIZER = 'lmfit-leastsq'
DEFAULT_MINIMIZER = AvailableMinimizers.LMFit_leastsq


class Fitter:
Expand All @@ -27,9 +28,9 @@ def __init__(self, fit_object, fit_function: Callable):
self._fit_function = fit_function
self._dependent_dims = None

self._name_current_minimizer = DEFAULT_MINIMIZER
self._enum_current_minimizer = DEFAULT_MINIMIZER
self._minimizer: MinimizerBase # _minimizer is set in the create method
self._update_minimizer(self._name_current_minimizer)
self._update_minimizer(self._enum_current_minimizer)

def fit_constraints(self) -> list:
return self._minimizer.fit_constraints()
Expand Down Expand Up @@ -62,26 +63,32 @@ def initialize(self, fit_object, fit_function: Callable) -> None:
self._update_minimizer(DEFAULT_MINIMIZER)

# TODO: remove this method when we are ready to adjust the dependent products
def create(self, minimizer_name: str = DEFAULT_MINIMIZER) -> None:
def create(self, minimizer_enum: Union[AvailableMinimizers, str] = DEFAULT_MINIMIZER) -> None:
"""
Create the required minimizer.
:param minimizer_name: The label of the minimization engine to create.
:param minimizer_enum: The enum of the minimization engine to create.
"""
self._update_minimizer(minimizer_name)
if isinstance(minimizer_enum, str):
print(f'minimizer should be set with enum {minimizer_enum}')
minimizer_enum = from_string_to_enum(minimizer_enum)
self._update_minimizer(minimizer_enum)

def switch_minimizer(self, minimizer_name: str) -> None:
def switch_minimizer(self, minimizer_enum: Union[AvailableMinimizers, str]) -> None:
"""
Switch minimizer and initialize.
:param minimizer_name: The label of the minimizer to create and instantiate.
:param minimizer_enum: The enum of the minimizer to create and instantiate.
"""
if isinstance(minimizer_enum, str):
print(f'minimizer should be set with enum {minimizer_enum}')
minimizer_enum = from_string_to_enum(minimizer_enum)

constraints = self._minimizer.fit_constraints()
self._update_minimizer(minimizer_name)
self._update_minimizer(minimizer_enum)
self._minimizer.set_fit_constraint(constraints)

def _update_minimizer(self, minimizer_name: str) -> None:
minimizer_enum = from_string_to_enum(minimizer_name)
def _update_minimizer(self, minimizer_enum: AvailableMinimizers) -> None:
self._minimizer = factory(minimizer_enum=minimizer_enum, fit_object=self._fit_object, fit_function=self.fit_function)
self._name_current_minimizer = minimizer_name
self._enum_current_minimizer = minimizer_enum

@property
def available_minimizers(self) -> List[str]:
Expand Down Expand Up @@ -119,7 +126,7 @@ def fit_function(self, fit_function: Callable) -> None:
:return: None
"""
self._fit_function = fit_function
self._update_minimizer(self._name_current_minimizer)
self._update_minimizer(self._enum_current_minimizer)

@property
def fit_object(self):
Expand All @@ -137,7 +144,7 @@ def fit_object(self, fit_object) -> None:
:return: None
"""
self._fit_object = fit_object
self._update_minimizer(self._name_current_minimizer)
self._update_minimizer(self._enum_current_minimizer)

def _fit_function_wrapper(self, real_x=None, flatten: bool = True) -> Callable:
"""
Expand Down
25 changes: 13 additions & 12 deletions src/easyscience/fitting/minimizers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
bumps_engine_imported = True
except ImportError:
# TODO make this a proper message (use logging?)
warnings.warn('Bummps minimization is not available. Probably bumps has not been installed.', ImportWarning, stacklevel=2)
warnings.warn('Bumps minimization is not available. Probably bumps has not been installed.', ImportWarning, stacklevel=2)

dfo_engine_imported = False
try:
Expand Down Expand Up @@ -50,31 +50,32 @@ class AvailableMinimizers(Enum):
DFO = auto()
DFO_leastsq = auto()


# Temporary solution to convert string to enum
def from_string_to_enum(minimizer_name: str) -> AvailableMinimizers:
if minimizer_name == 'lmfit':
if minimizer_name == 'LMFit':
minmizer_enum = AvailableMinimizers.LMFit
elif minimizer_name == 'lmfit-leastsq':
elif minimizer_name == 'LMFit_leastsq':
minmizer_enum = AvailableMinimizers.LMFit_leastsq
elif minimizer_name == 'lmfit-powell':
elif minimizer_name == 'LMFit_powell':
minmizer_enum = AvailableMinimizers.LMFit_powell
elif minimizer_name == 'lmfit-cobyla':
elif minimizer_name == 'LMFit_cobyla':
minmizer_enum = AvailableMinimizers.LMFit_cobyla

elif minimizer_name == 'bumps':
elif minimizer_name == 'Bumps':
minmizer_enum = AvailableMinimizers.Bumps
elif minimizer_name == 'bumps-simplex':
elif minimizer_name == 'Bumps_simplex':
minmizer_enum = AvailableMinimizers.Bumps_simplex
elif minimizer_name == 'bumps-newton':
elif minimizer_name == 'Bumps_newton':
minmizer_enum = AvailableMinimizers.Bumps_newton
elif minimizer_name == 'bumps-lm':
elif minimizer_name == 'Bumps_lm':
minmizer_enum = AvailableMinimizers.Bumps_lm

elif minimizer_name == 'dfo':
elif minimizer_name == 'DFO':
minmizer_enum = AvailableMinimizers.DFO
elif minimizer_name == 'dfo-leastsq':
elif minimizer_name == 'DFO_leastsq':
minmizer_enum = AvailableMinimizers.DFO_leastsq
else:
raise ValueError(f"Invalid minimizer name: {minimizer_name}. The following minimizers are available: {[minimize.name for minimize in AvailableMinimizers]}") # noqa: E501

return minmizer_enum

Expand Down
5 changes: 4 additions & 1 deletion src/easyscience/fitting/minimizers/minimizer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ class MinimizerBase(metaclass=ABCMeta):
wrapping: str = None

def __init__(
self, obj, fit_function: Callable, method: Optional[str] = None
self,
obj, #: BaseObj,
fit_function: Callable,
method: Optional[str] = None,
): # todo after constraint changes, add type hint: obj: BaseObj # noqa: E501
if method not in self.available_methods():
raise FitError(f'Method {method} not available in {self.__class__}')
Expand Down
Loading