From 526d0e720fd4187e71f0a31f8c01f9634c0df986 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Tue, 15 Apr 2025 16:10:15 +0200 Subject: [PATCH 01/12] calculators and minimizers unit tests --- .../calculators/test_calculator_base.py | 79 ++++++++++++ .../calculators/test_calculator_cryspy.py | 96 +++++++++++++++ .../calculators/test_calculator_factory.py | 81 ++++++++++++ .../test_fitting_progress_tracker.py | 95 ++++++++++++++ .../minimizers/test_minimizer_base.py | 116 ++++++++++++++++++ .../minimizers/test_minimizer_dfols.py | 79 ++++++++++++ .../minimizers/test_minimizer_factory.py | 64 ++++++++++ .../minimizers/test_minimizer_lmfit.py | 95 ++++++++++++++ 8 files changed, 705 insertions(+) create mode 100644 tests/unit_tests/analysis/calculators/test_calculator_base.py create mode 100644 tests/unit_tests/analysis/calculators/test_calculator_cryspy.py create mode 100644 tests/unit_tests/analysis/calculators/test_calculator_factory.py create mode 100644 tests/unit_tests/analysis/minimizers/test_fitting_progress_tracker.py create mode 100644 tests/unit_tests/analysis/minimizers/test_minimizer_base.py create mode 100644 tests/unit_tests/analysis/minimizers/test_minimizer_dfols.py create mode 100644 tests/unit_tests/analysis/minimizers/test_minimizer_factory.py create mode 100644 tests/unit_tests/analysis/minimizers/test_minimizer_lmfit.py diff --git a/tests/unit_tests/analysis/calculators/test_calculator_base.py b/tests/unit_tests/analysis/calculators/test_calculator_base.py new file mode 100644 index 00000000..7d8428c1 --- /dev/null +++ b/tests/unit_tests/analysis/calculators/test_calculator_base.py @@ -0,0 +1,79 @@ +import pytest +import numpy as np +from unittest.mock import MagicMock, patch +from easydiffraction.analysis.calculators.calculator_base import CalculatorBase + + +# Mock subclass of CalculatorBase to test its concrete methods +class MockCalculator(CalculatorBase): + @property + def name(self): + return "MockCalculator" + + @property + def engine_imported(self): + return True + + def calculate_structure_factors(self, sample_model, experiment): + return np.array([1., 2., 3.]) + + def _calculate_single_model_pattern(self, sample_model, experiment, called_by_minimizer): + return np.array([1., 2., 3.]) + + +@pytest.fixture +def mock_sample_models(): + sample_models = MagicMock() + sample_models.get_all_params.return_value = {"param1": 1, "param2": 2} + sample_models.get_ids.return_value = ["phase1", "phase2"] + sample_models.__getitem__.side_effect = lambda key: MagicMock(apply_symmetry_constraints=MagicMock()) + return sample_models + + +@pytest.fixture +def mock_experiment(): + experiment = MagicMock() + experiment.datastore.pattern.x = np.array([1., 2., 3.]) + experiment.datastore.pattern.bkg = None + experiment.datastore.pattern.calc = None + experiment.linked_phases = [ + MagicMock(_entry_id="phase1", scale=MagicMock(value=2.0)), + MagicMock(_entry_id="phase2", scale=MagicMock(value=1.5)), + ] + experiment.background.calculate.return_value = np.array([0.1, 0.2, 0.3]) + return experiment + + +@patch("easydiffraction.core.singletons.ConstraintsHandler.get") +def test_calculate_pattern(mock_constraints_handler, mock_sample_models, mock_experiment): + mock_constraints_handler.return_value.apply = MagicMock() + + calculator = MockCalculator() + result = calculator.calculate_pattern(mock_sample_models, mock_experiment) + + # Assertions + assert np.allclose(result, np.array([3.6, 7.2, 10.8])) + mock_constraints_handler.return_value.apply.assert_called_once_with(parameters={"param1": 1, "param2": 2}) + assert mock_experiment.datastore.pattern.bkg is not None + assert mock_experiment.datastore.pattern.calc is not None + + +def test_get_valid_linked_phases(mock_sample_models, mock_experiment): + calculator = MockCalculator() + + valid_phases = calculator._get_valid_linked_phases(mock_sample_models, mock_experiment) + + # Assertions + assert len(valid_phases) == 2 + assert valid_phases[0]._entry_id == "phase1" + assert valid_phases[1]._entry_id == "phase2" + + +def test_calculate_structure_factors(mock_sample_models, mock_experiment): + calculator = MockCalculator() + + # Mock the method's behavior if necessary + result = calculator.calculate_structure_factors(mock_sample_models, mock_experiment) + + # Assertions + assert np.allclose(result, np.array([1., 2., 3.])) \ No newline at end of file diff --git a/tests/unit_tests/analysis/calculators/test_calculator_cryspy.py b/tests/unit_tests/analysis/calculators/test_calculator_cryspy.py new file mode 100644 index 00000000..ad9703ae --- /dev/null +++ b/tests/unit_tests/analysis/calculators/test_calculator_cryspy.py @@ -0,0 +1,96 @@ +import pytest +import numpy as np +from unittest.mock import MagicMock, patch +from easydiffraction.analysis.calculators.calculator_cryspy import CryspyCalculator + + +@pytest.fixture +def mock_sample_model(): + sample_model = MagicMock() + sample_model.name = "sample1" + sample_model.cell.length_a.value = 1.0 + sample_model.cell.length_b.value = 2.0 + sample_model.cell.length_c.value = 3.0 + sample_model.cell.angle_alpha.value = 90.0 + sample_model.cell.angle_beta.value = 90.0 + sample_model.cell.angle_gamma.value = 90.0 + sample_model.atom_sites = [ + MagicMock(fract_x=MagicMock(value=0.1), fract_y=MagicMock(value=0.2), fract_z=MagicMock(value=0.3), + occupancy=MagicMock(value=1.0), b_iso=MagicMock(value=0.5)) + ] + return sample_model + + +@pytest.fixture +def mock_experiment(): + experiment = MagicMock() + experiment.name = "experiment1" + experiment.type.beam_mode.value = "constant wavelength" + experiment.datastore.pattern.x = np.array([1.0, 2.0, 3.0]) + experiment.datastore.pattern.meas = np.array([10.0, 20.0, 30.0]) + experiment.datastore.pattern.meas_su = np.array([0.1, 0.2, 0.3]) + experiment.instrument.calib_twotheta_offset.value = 0.0 + experiment.instrument.setup_wavelength.value = 1.54 + experiment.peak.broad_gauss_u.value = 0.1 + experiment.peak.broad_gauss_v.value = 0.2 + experiment.peak.broad_gauss_w.value = 0.3 + experiment.peak.broad_lorentz_x.value = 0.4 + experiment.peak.broad_lorentz_y.value = 0.5 + return experiment + + +@patch("easydiffraction.analysis.calculators.calculator_cryspy.str_to_globaln") +def test_recreate_cryspy_obj(mock_str_to_globaln, mock_sample_model, mock_experiment): + mock_str_to_globaln.return_value = MagicMock(add_items=MagicMock()) + + calculator = CryspyCalculator() + cryspy_obj = calculator._recreate_cryspy_obj(mock_sample_model, mock_experiment) + + # Assertions + mock_str_to_globaln.assert_called() + assert cryspy_obj.add_items.called + + +@patch("easydiffraction.analysis.calculators.calculator_cryspy.rhochi_calc_chi_sq_by_dictionary") +def test_calculate_single_model_pattern(mock_rhochi_calc, mock_sample_model, mock_experiment): + mock_rhochi_calc.return_value = None + + calculator = CryspyCalculator() + calculator._cryspy_dicts = {"experiment1": {"mock_key": "mock_value"}} + + result = calculator._calculate_single_model_pattern(mock_sample_model, mock_experiment, called_by_minimizer=False) + + # Assertions + assert isinstance(result, np.ndarray) or result == [] + mock_rhochi_calc.assert_called() + + +def test_recreate_cryspy_dict(mock_sample_model, mock_experiment): + calculator = CryspyCalculator() + calculator._cryspy_dicts = { + "experiment1": { + "pd_experiment1": { + "offset_ttheta": [0.1], + "wavelength": [1.54], + "resolution_parameters": [0.1, 0.2, 0.3, 0.4, 0.5], + }, + "crystal_sample1": { + "unit_cell_parameters": [0, 0, 0, 0, 0, 0], + "atom_fract_xyz": [[0], [0], [0]], + "atom_occupancy": [0], + "atom_b_iso": [0], + } + } + } + + cryspy_dict = calculator._recreate_cryspy_dict(mock_sample_model, mock_experiment) + + # Assertions + assert cryspy_dict["crystal_sample1"]["unit_cell_parameters"][:3] == [1.0, 2.0, 3.0] + assert cryspy_dict["crystal_sample1"]["atom_fract_xyz"][0][0] == 0.1 + assert cryspy_dict["crystal_sample1"]["atom_occupancy"][0] == 1.0 + assert cryspy_dict["crystal_sample1"]["atom_b_iso"][0] == 0.5 + assert cryspy_dict["pd_experiment1"]["offset_ttheta"][0] == 0.0 + assert cryspy_dict["pd_experiment1"]["wavelength"][0] == 1.54 + assert cryspy_dict["pd_experiment1"]["resolution_parameters"] == [0.1, 0.2, 0.3, 0.4, 0.5] + diff --git a/tests/unit_tests/analysis/calculators/test_calculator_factory.py b/tests/unit_tests/analysis/calculators/test_calculator_factory.py new file mode 100644 index 00000000..60a97b9a --- /dev/null +++ b/tests/unit_tests/analysis/calculators/test_calculator_factory.py @@ -0,0 +1,81 @@ +import pytest +from unittest.mock import MagicMock, patch +from easydiffraction.analysis.calculators.calculator_factory import CalculatorFactory +from easydiffraction.analysis.calculators.calculator_crysfml import CrysfmlCalculator +from easydiffraction.analysis.calculators.calculator_cryspy import CryspyCalculator +from easydiffraction.analysis.calculators.calculator_pdffit import PdffitCalculator +from easydiffraction.utils.formatting import ( + paragraph, + error +) + +@pytest.fixture +def mock_calculators(): + with patch.object(CrysfmlCalculator, 'engine_imported', True), \ + patch.object(CryspyCalculator, 'engine_imported', True), \ + patch.object(PdffitCalculator, 'engine_imported', False): + yield + + +def test_supported_calculators(mock_calculators): + supported = CalculatorFactory._supported_calculators() + + # Assertions + assert 'crysfml' in supported + assert 'cryspy' in supported + assert 'pdffit' not in supported # Engine not imported + + +def test_list_supported_calculators(mock_calculators): + supported_list = CalculatorFactory.list_supported_calculators() + + # Assertions + assert 'crysfml' in supported_list + assert 'cryspy' in supported_list + assert 'pdffit' not in supported_list # Engine not imported + + +@patch("builtins.print") +def test_show_supported_calculators(mock_print, mock_calculators): + CalculatorFactory.show_supported_calculators() + + # Assertions + mock_print.assert_any_call(paragraph("Supported calculators")) + assert any("CrysFML library for crystallographic calculations" in call.args[0] for call in mock_print.call_args_list) + assert any("CrysPy library for crystallographic calculations" in call.args[0] for call in mock_print.call_args_list) + + +def test_create_calculator(mock_calculators): + crysfml_calculator = CalculatorFactory.create_calculator('crysfml') + cryspy_calculator = CalculatorFactory.create_calculator('cryspy') + pdffit_calculator = CalculatorFactory.create_calculator('pdffit') # Not supported + + # Assertions + assert isinstance(crysfml_calculator, CrysfmlCalculator) + assert isinstance(cryspy_calculator, CryspyCalculator) + assert pdffit_calculator is None + + +def test_create_calculator_unknown(mock_calculators): + unknown_calculator = CalculatorFactory.create_calculator('unknown') + + # Assertions + assert unknown_calculator is None + + +def no_test_register_calculator(): + class MockCalculator: + engine_imported = True + + CalculatorFactory.register_calculator( + 'mock_calculator', + MockCalculator, + description='Mock calculator for testing' + ) + + supported = CalculatorFactory._supported_calculators() + + # Assertions + assert 'mock_calculator' in CalculatorFactory._potential_calculators + assert supported['mock_calculator']['description'] == 'Mock calculator for testing' + assert supported['mock_calculator']['class'] == MockCalculator \ No newline at end of file diff --git a/tests/unit_tests/analysis/minimizers/test_fitting_progress_tracker.py b/tests/unit_tests/analysis/minimizers/test_fitting_progress_tracker.py new file mode 100644 index 00000000..1903d221 --- /dev/null +++ b/tests/unit_tests/analysis/minimizers/test_fitting_progress_tracker.py @@ -0,0 +1,95 @@ +import pytest +import numpy as np +from unittest.mock import patch, MagicMock +from easydiffraction.analysis.minimizers.fitting_progress_tracker import format_cell, FittingProgressTracker + + +def test_format_cell(): + # Test center alignment + assert format_cell("test", width=10, align="center") == " test " + # Test left alignment + assert format_cell("test", width=10, align="left") == "test " + # Test right alignment + assert format_cell("test", width=10, align="right") == " test" + # Test default alignment (center) + assert format_cell("test", width=10) == " test " + # Test invalid alignment + assert format_cell("test", width=10, align="invalid") == "test" + + +@pytest.fixture +def tracker(): + return FittingProgressTracker() + + +@patch("builtins.print") +def test_start_tracking(mock_print, tracker): + tracker.start_tracking("MockMinimizer") + + # Assertions + mock_print.assert_any_call("🚀 Starting fitting process with 'MockMinimizer'...") + mock_print.assert_any_call("📈 Goodness-of-fit (reduced χ²) change:") + assert mock_print.call_count > 2 # Ensure headers and borders are printed + + +@patch("builtins.print") +def test_add_tracking_info(mock_print, tracker): + tracker.add_tracking_info([1, "10.0", "9.0", "10% ↓"]) + + # Assertions + mock_print.assert_called_once() + assert "│ 1 │ 10.0 │ 9.0 │ 10% ↓ │" in mock_print.call_args[0][0] + + +@patch("builtins.print") +def test_finish_tracking(mock_print, tracker): + tracker._last_iteration = 5 + tracker._last_chi2 = 1.23 + tracker._best_chi2 = 1.23 + tracker._best_iteration = 5 + + tracker.finish_tracking() + + # Assertions + mock_print.assert_any_call("🏆 Best goodness-of-fit (reduced χ²) is 1.23 at iteration 5") + mock_print.assert_any_call("✅ Fitting complete.") + + +def test_reset(tracker): + tracker._iteration = 5 + tracker._previous_chi2 = 1.23 + tracker.reset() + + # Assertions + assert tracker._iteration == 0 + assert tracker._previous_chi2 is None + + +@patch("easydiffraction.analysis.reliability_factors.calculate_reduced_chi_square", return_value=1.23) +@patch("builtins.print") +def test_track(mock_print, mock_calculate_chi2, tracker): + residuals = np.array([1.1, 2.1, 3.1, 4.1, 5.1]) + parameters = [1., 2., 3.] + + tracker.track(residuals, parameters) + + # Assertions + #mock_calculate_chi2.assert_called_once_with(residuals, len(parameters)) + assert tracker._iteration == 1 + assert tracker._previous_chi2 == 29.025 + assert tracker._best_chi2 == 29.025 + assert tracker._best_iteration == 1 + mock_print.assert_called() + + +def test_start_timer(tracker): + with patch("time.perf_counter", return_value=100.0): + tracker.start_timer() + assert tracker._start_time == 100.0 + + +def test_stop_timer(tracker): + with patch("time.perf_counter", side_effect=[100.0, 105.0]): + tracker.start_timer() + tracker.stop_timer() + assert tracker._fitting_time == 5.0 \ No newline at end of file diff --git a/tests/unit_tests/analysis/minimizers/test_minimizer_base.py b/tests/unit_tests/analysis/minimizers/test_minimizer_base.py new file mode 100644 index 00000000..60f13818 --- /dev/null +++ b/tests/unit_tests/analysis/minimizers/test_minimizer_base.py @@ -0,0 +1,116 @@ +import pytest +from unittest.mock import MagicMock, patch +from easydiffraction.analysis.minimizers.minimizer_base import MinimizerBase, FitResults + + +# Mock subclass of MinimizerBase to test its methods +class MockMinimizer(MinimizerBase): + def _prepare_solver_args(self, parameters): + return {"mock_arg": "mock_value"} + + def _run_solver(self, objective_function, **engine_parameters): + return {"success": True, "raw_result": "mock_result"} + + def _sync_result_to_parameters(self, raw_result, parameters): + for param in parameters: + param.value = 1.0 # Mock synchronization + + def _check_success(self, raw_result): + return raw_result.get("success", False) + + def _finalize_fit(self, parameters, raw_result): + return FitResults( + success=raw_result.get("success", False), + parameters=parameters, + chi_square=raw_result.get("chi_square", 0.0), + reduced_chi_square=raw_result.get("reduced_chi_square", 0.0), + message=raw_result.get("message", ""), + iterations=raw_result.get("iterations", 0), + engine_result=raw_result.get("raw_result", None), + starting_parameters=[p.start_value for p in parameters], + fitting_time=raw_result.get("fitting_time", 0.0), + ) + + +@pytest.fixture +def mock_minimizer(): + return MockMinimizer(name="MockMinimizer", method="mock_method", max_iterations=100) + + +@pytest.fixture +def mock_parameters(): + param1 = MagicMock(name="param1", value=None, start_value=0.5, uncertainty=None) + param2 = MagicMock(name="param2", value=None, start_value=1.0, uncertainty=None) + return [param1, param2] + + +@pytest.fixture +def mock_objective_function(): + return MagicMock(return_value=[1.0, 2.0, 3.0]) + + +def test_prepare_solver_args(mock_minimizer, mock_parameters): + solver_args = mock_minimizer._prepare_solver_args(mock_parameters) + assert solver_args == {"mock_arg": "mock_value"} + + +def test_run_solver(mock_minimizer, mock_objective_function): + raw_result = mock_minimizer._run_solver(mock_objective_function, mock_arg="mock_value") + assert raw_result == {"success": True, "raw_result": "mock_result"} + + +def test_sync_result_to_parameters(mock_minimizer, mock_parameters): + raw_result = {"success": True} + mock_minimizer._sync_result_to_parameters(raw_result, mock_parameters) + + # Assertions + for param in mock_parameters: + assert param.value == 1.0 + + +def test_check_success(mock_minimizer): + raw_result = {"success": True} + assert mock_minimizer._check_success(raw_result) is True + + raw_result = {"success": False} + assert mock_minimizer._check_success(raw_result) is False + + +def test_finalize_fit(mock_minimizer, mock_parameters): + raw_result = {"success": True} + result = mock_minimizer._finalize_fit(mock_parameters, raw_result) + + # Assertions + assert isinstance(result, FitResults) + assert result.success is True + assert result.parameters == mock_parameters + + +@patch("easydiffraction.analysis.minimizers.fitting_progress_tracker.FittingProgressTracker") +def test_fit(mock_tracker, mock_minimizer, mock_parameters, mock_objective_function): + mock_minimizer.tracker.finish_tracking = MagicMock() + result = mock_minimizer.fit(mock_parameters, mock_objective_function) + + # Assertions + assert isinstance(result, FitResults) + assert result.success is True + + +def test_create_objective_function(mock_minimizer): + parameters = [MagicMock()] + sample_models = MagicMock() + experiments = MagicMock() + calculator = MagicMock() + + objective_function = mock_minimizer._create_objective_function( + parameters, sample_models, experiments, calculator + ) + + # Assertions + assert callable(objective_function) + with patch.object(mock_minimizer, "_objective_function", return_value=[1.0, 2.0, 3.0]) as mock_objective: + residuals = objective_function({"param1": 1.0}) + mock_objective.assert_called_once_with( + {"param1": 1.0}, parameters, sample_models, experiments, calculator + ) + assert residuals == [1.0, 2.0, 3.0] \ No newline at end of file diff --git a/tests/unit_tests/analysis/minimizers/test_minimizer_dfols.py b/tests/unit_tests/analysis/minimizers/test_minimizer_dfols.py new file mode 100644 index 00000000..8cc43fa3 --- /dev/null +++ b/tests/unit_tests/analysis/minimizers/test_minimizer_dfols.py @@ -0,0 +1,79 @@ +import pytest +import numpy as np +from unittest.mock import MagicMock, patch +from easydiffraction.analysis.minimizers.minimizer_dfols import DfolsMinimizer + + +@pytest.fixture +def mock_parameters(): + param1 = MagicMock(name="param1", value=1.0, min=0.0, max=2.0, uncertainty=None) + param2 = MagicMock(name="param2", value=2.0, min=1.0, max=3.0, uncertainty=None) + return [param1, param2] + + +@pytest.fixture +def mock_objective_function(): + return MagicMock(return_value=np.array([1.0, 2.0, 3.0])) + + +@pytest.fixture +def dfols_minimizer(): + return DfolsMinimizer(name="dfols", max_iterations=100) + + +def test_prepare_solver_args(dfols_minimizer, mock_parameters): + solver_args = dfols_minimizer._prepare_solver_args(mock_parameters) + + # Assertions + assert np.allclose(solver_args['x0'], [1.0, 2.0]) + assert np.allclose(solver_args['bounds'][0], [0.0, 1.0]) # Lower bounds + assert np.allclose(solver_args['bounds'][1], [2.0, 3.0]) # Upper bounds + + +@patch("easydiffraction.analysis.minimizers.minimizer_dfols.solve") +def test_run_solver(mock_solve, dfols_minimizer, mock_objective_function): + mock_solve.return_value = MagicMock(x=np.array([1.5, 2.5]), flag=0) + + solver_args = {'x0': np.array([1.0, 2.0]), 'bounds': (np.array([0.0, 1.0]), np.array([2.0, 3.0]))} + raw_result = dfols_minimizer._run_solver(mock_objective_function, **solver_args) + + # Assertions + mock_solve.assert_called_once_with( + mock_objective_function, + x0=solver_args['x0'], + bounds=solver_args['bounds'], + maxfun=dfols_minimizer.max_iterations + ) + assert np.allclose(raw_result.x, [1.5, 2.5]) + + +def test_sync_result_to_parameters(dfols_minimizer, mock_parameters): + raw_result = MagicMock(x=np.array([1.5, 2.5])) + + dfols_minimizer._sync_result_to_parameters(mock_parameters, raw_result) + + # Assertions + assert mock_parameters[0].value == 1.5 + assert mock_parameters[1].value == 2.5 + assert mock_parameters[0].uncertainty is None + assert mock_parameters[1].uncertainty is None + + +def test_check_success(dfols_minimizer): + raw_result = MagicMock(flag=0, EXIT_SUCCESS=0) + assert dfols_minimizer._check_success(raw_result) is True + + raw_result = MagicMock(flag=1, EXIT_SUCCESS=0) + assert dfols_minimizer._check_success(raw_result) is False + + +@patch("easydiffraction.analysis.minimizers.minimizer_dfols.solve") +def test_fit(mock_solve, dfols_minimizer, mock_parameters, mock_objective_function): + mock_solve.return_value = MagicMock(x=np.array([1.5, 2.5]), flag=0) + dfols_minimizer.tracker.finish_tracking = MagicMock() + + result = dfols_minimizer.fit(mock_parameters, mock_objective_function) + + # Assertions + assert np.allclose([p.value for p in result.parameters], [1.5, 2.5]) + assert result.iterations == 0 # DFO-LS doesn't provide iteration count by default diff --git a/tests/unit_tests/analysis/minimizers/test_minimizer_factory.py b/tests/unit_tests/analysis/minimizers/test_minimizer_factory.py new file mode 100644 index 00000000..7d7ac5fa --- /dev/null +++ b/tests/unit_tests/analysis/minimizers/test_minimizer_factory.py @@ -0,0 +1,64 @@ +import pytest +from unittest.mock import patch, MagicMock +from easydiffraction.analysis.minimizers.minimizer_factory import MinimizerFactory +from easydiffraction.analysis.minimizers.minimizer_lmfit import LmfitMinimizer +from easydiffraction.analysis.minimizers.minimizer_dfols import DfolsMinimizer + + +def test_list_available_minimizers(): + minimizers = MinimizerFactory.list_available_minimizers() + + # Assertions + assert isinstance(minimizers, list) + assert 'lmfit' in minimizers + assert 'dfols' in minimizers + + +@patch("builtins.print") +def test_show_available_minimizers(mock_print): + MinimizerFactory.show_available_minimizers() + + # Assertions + #mock_print.assert_any_call("Available minimizers") + assert any("LMFIT library using the default Levenberg-Marquardt least squares method" in call.args[0] + for call in mock_print.call_args_list) + assert any("DFO-LS library for derivative-free least-squares optimization" in call.args[0] + for call in mock_print.call_args_list) + + +def test_create_minimizer(): + # Test creating an LmfitMinimizer + minimizer = MinimizerFactory.create_minimizer('lmfit') + assert isinstance(minimizer, LmfitMinimizer) + assert minimizer.method == 'leastsq' + + # Test creating a DfolsMinimizer + minimizer = MinimizerFactory.create_minimizer('dfols') + assert isinstance(minimizer, DfolsMinimizer) + assert minimizer.method is None + + # Test invalid minimizer + with pytest.raises(ValueError, match="Unknown minimizer 'invalid'.*"): + MinimizerFactory.create_minimizer('invalid') + + +def test_register_minimizer(): + class MockMinimizer: + def __init__(self, method=None): + self.method = method + + MinimizerFactory.register_minimizer( + name='mock_minimizer', + minimizer_cls=MockMinimizer, + method='mock_method', + description='Mock minimizer for testing' + ) + + # Assertions + minimizers = MinimizerFactory.list_available_minimizers() + assert 'mock_minimizer' in minimizers + + # Test creating the registered minimizer + minimizer = MinimizerFactory.create_minimizer('mock_minimizer') + assert isinstance(minimizer, MockMinimizer) + assert minimizer.method == 'mock_method' diff --git a/tests/unit_tests/analysis/minimizers/test_minimizer_lmfit.py b/tests/unit_tests/analysis/minimizers/test_minimizer_lmfit.py new file mode 100644 index 00000000..aed859da --- /dev/null +++ b/tests/unit_tests/analysis/minimizers/test_minimizer_lmfit.py @@ -0,0 +1,95 @@ +import pytest +from unittest.mock import MagicMock, patch +from easydiffraction.analysis.minimizers.minimizer_lmfit import LmfitMinimizer +import lmfit + + +@pytest.fixture +def mock_parameters(): + param1 = MagicMock(name="param1", uid="param1", value=1.0, free=True, min=0.0, max=2.0, uncertainty=None) + param2 = MagicMock(name="param2", uid="param2", value=2.0, free=False, min=1.0, max=3.0, uncertainty=None) + return [param1, param2] + + +@pytest.fixture +def mock_objective_function(): + return MagicMock(return_value=[1.0, 2.0, 3.0]) + + +@pytest.fixture +def lmfit_minimizer(): + return LmfitMinimizer(name="lmfit", method="leastsq", max_iterations=100) + + +def test_prepare_solver_args(lmfit_minimizer, mock_parameters): + solver_args = lmfit_minimizer._prepare_solver_args(mock_parameters) + + # Assertions + assert isinstance(solver_args['engine_parameters'], lmfit.Parameters) + assert 'param1' in solver_args['engine_parameters'] + assert 'param2' in solver_args['engine_parameters'] + assert solver_args['engine_parameters']['param1'].value == 1.0 + assert solver_args['engine_parameters']['param1'].min == 0.0 + assert solver_args['engine_parameters']['param1'].max == 2.0 + assert solver_args['engine_parameters']['param1'].vary is True + assert solver_args['engine_parameters']['param2'].value == 2.0 + assert solver_args['engine_parameters']['param2'].vary is False + + +@patch("easydiffraction.analysis.minimizers.minimizer_lmfit.lmfit.minimize") +def test_run_solver(mock_minimize, lmfit_minimizer, mock_objective_function, mock_parameters): + mock_minimize.return_value = MagicMock(params={"param1": MagicMock(value=1.5), "param2": MagicMock(value=2.5)}) + + solver_args = lmfit_minimizer._prepare_solver_args(mock_parameters) + raw_result = lmfit_minimizer._run_solver(mock_objective_function, **solver_args) + + # Assertions + mock_minimize.assert_called_once_with( + mock_objective_function, + params=solver_args['engine_parameters'], + method="leastsq", + nan_policy='propagate', + max_nfev=lmfit_minimizer.max_iterations + ) + assert raw_result.params["param1"].value == 1.5 + assert raw_result.params["param2"].value == 2.5 + + +def test_sync_result_to_parameters(lmfit_minimizer, mock_parameters): + raw_result = MagicMock(params={ + "param1": MagicMock(value=1.5, stderr=0.1), + "param2": MagicMock(value=2.5, stderr=0.2) + }) + + lmfit_minimizer._sync_result_to_parameters(mock_parameters, raw_result) + + # Assertions + assert mock_parameters[0].value == 1.5 + assert mock_parameters[0].uncertainty == 0.1 + assert mock_parameters[1].value == 2.5 + assert mock_parameters[1].uncertainty == 0.2 + + +def test_check_success(lmfit_minimizer): + raw_result = MagicMock(success=True) + assert lmfit_minimizer._check_success(raw_result) is True + + raw_result = MagicMock(success=False) + assert lmfit_minimizer._check_success(raw_result) is False + + +@patch("easydiffraction.analysis.minimizers.minimizer_lmfit.lmfit.minimize") +def test_fit(mock_minimize, lmfit_minimizer, mock_parameters, mock_objective_function): + mock_minimize.return_value = MagicMock( + params={"param1": MagicMock(value=1.5, stderr=0.1), "param2": MagicMock(value=2.5, stderr=0.2)}, + success=True + ) + lmfit_minimizer.tracker.finish_tracking = MagicMock() + result = lmfit_minimizer.fit(mock_parameters, mock_objective_function) + + # Assertions + assert result.success is True + assert result.parameters[0].value == 1.5 + assert result.parameters[0].uncertainty == 0.1 + assert result.parameters[1].value == 2.5 + assert result.parameters[1].uncertainty == 0.2 From f04e377baf4d7a7ccad7daa6fc24ed83adf675eb Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Tue, 15 Apr 2025 20:26:25 +0200 Subject: [PATCH 02/12] added two more tests --- tests/unit_tests/analysis/test_analysis.py | 133 ++++++++++++++++++ .../unit_tests/analysis/test_minimization.py | 113 +++++++++++++++ 2 files changed, 246 insertions(+) create mode 100644 tests/unit_tests/analysis/test_analysis.py create mode 100644 tests/unit_tests/analysis/test_minimization.py diff --git a/tests/unit_tests/analysis/test_analysis.py b/tests/unit_tests/analysis/test_analysis.py new file mode 100644 index 00000000..774e074b --- /dev/null +++ b/tests/unit_tests/analysis/test_analysis.py @@ -0,0 +1,133 @@ +import pytest +from unittest.mock import MagicMock, patch +from easydiffraction.analysis.analysis import Analysis + + +@pytest.fixture +def mock_project(): + project = MagicMock() + project.sample_models.get_all_params.return_value = [ + MagicMock(datablock_id="block1", category_key="cat1", collection_entry_id="entry1", name="param1", value=1.0, units="unit1", free=True, min=0.0, max=2.0, uncertainty=0.1) + ] + project.experiments.get_all_params.return_value = [ + MagicMock(datablock_id="block2", category_key="cat2", collection_entry_id="entry2", name="param2", value=2.0, units="unit2", free=False, min=1.0, max=3.0, uncertainty=0.2) + ] + project.sample_models.get_fittable_params.return_value = project.sample_models.get_all_params() + project.experiments.get_fittable_params.return_value = project.experiments.get_all_params() + project.sample_models.get_free_params.return_value = project.sample_models.get_all_params() + project.experiments.get_free_params.return_value = project.experiments.get_all_params() + project.experiments.ids = ["experiment1", "experiment2"] + project._varname = "project" + return project + + +@pytest.fixture +def analysis(mock_project): + return Analysis(project=mock_project) + + +@patch("builtins.print") +def test_show_all_params(mock_print, analysis): + analysis._show_params = MagicMock() + analysis.show_all_params() + + # Assertions + assert('parameters for all experiments' in mock_print.call_args[0][0]) + +@patch("builtins.print") +def test_show_fittable_params(mock_print, analysis): + analysis._show_params = MagicMock() + analysis.show_fittable_params() + + # Assertions + assert('Fittable parameters for all experiments' in mock_print.call_args[0][0]) + +@patch("builtins.print") +def test_show_free_params(mock_print, analysis): + analysis._show_params = MagicMock() + analysis.show_free_params() + + # Assertions + assert('Free parameters for both sample models' in mock_print.call_args[0][0]) + # mock_print.assert_any_call("Free parameters for both sample models (🧩 data blocks) and experiments (🔬 data blocks)") + + +@patch("builtins.print") +def test_show_current_calculator(mock_print, analysis): + analysis.show_current_calculator() + + # Assertions + # mock_print.assert_any_call("Current calculator") + mock_print.assert_any_call("cryspy") + + +@patch("builtins.print") +def test_show_current_minimizer(mock_print, analysis): + analysis.show_current_minimizer() + + # Assertions + # mock_print.assert_any_call("Current minimizer") + mock_print.assert_any_call("lmfit (leastsq)") + + +@patch("easydiffraction.analysis.calculators.calculator_factory.CalculatorFactory.create_calculator") +@patch("builtins.print") +def test_current_calculator_setter(mock_print, mock_create_calculator, analysis): + mock_create_calculator.return_value = MagicMock() + + analysis.current_calculator = "pdffit2" + + # Assertions + mock_create_calculator.assert_called_once_with("pdffit2") + + +@patch("easydiffraction.analysis.minimizers.minimizer_factory.MinimizerFactory.create_minimizer") +@patch("builtins.print") +def test_current_minimizer_setter(mock_print, mock_create_minimizer, analysis): + mock_create_minimizer.return_value = MagicMock() + + analysis.current_minimizer = "dfols" + + # Assertions + mock_print.assert_any_call("dfols") + + +@patch("builtins.print") +def test_fit_mode_setter(mock_print, analysis): + analysis.fit_mode = "joint" + + # Assertions + assert analysis.fit_mode == "joint" + mock_print.assert_any_call("joint") + + +@patch("easydiffraction.analysis.minimization.DiffractionMinimizer.fit") +@patch("builtins.print") +def no_test_fit_single_mode(mock_print, mock_fit, analysis, mock_project): + analysis.fit_mode = "single" + analysis.fit() + + # Assertions + mock_fit.assert_called() + mock_print.assert_any_call("single") + + +@patch("easydiffraction.analysis.minimization.DiffractionMinimizer.fit") +@patch("builtins.print") +def test_fit_joint_mode(mock_print, mock_fit, analysis, mock_project): + analysis.fit_mode = "joint" + analysis.fit() + + # Assertions + mock_fit.assert_called_once() + + +@patch("builtins.print") +def test_as_cif(mock_print, analysis): + cif_text = analysis.as_cif() + + # Assertions + assert "_analysis.calculator_engine cryspy" in cif_text + assert "_analysis.fitting_engine lmfit (leastsq)" in cif_text + assert "_analysis.fit_mode single" in cif_text + diff --git a/tests/unit_tests/analysis/test_minimization.py b/tests/unit_tests/analysis/test_minimization.py new file mode 100644 index 00000000..f7c00d30 --- /dev/null +++ b/tests/unit_tests/analysis/test_minimization.py @@ -0,0 +1,113 @@ +import pytest +import numpy as np +from unittest.mock import MagicMock, patch +from easydiffraction.analysis.minimization import DiffractionMinimizer + + +@pytest.fixture +def mock_sample_models(): + sample_models = MagicMock() + sample_models.get_free_params.return_value = [ + MagicMock(name="param1", value=1.0, start_value=None, min=0.0, max=2.0, free=True), + MagicMock(name="param2", value=2.0, start_value=None, min=1.0, max=3.0, free=True), + ] + return sample_models + + +@pytest.fixture +def mock_experiments(): + experiments = MagicMock() + experiments.get_free_params.return_value = [ + MagicMock(name="param3", value=3.0, start_value=None, min=2.0, max=4.0, free=True), + ] + experiments.ids = ["experiment1"] + experiments._items = { + "experiment1": MagicMock( + datastore=MagicMock( + pattern=MagicMock(meas=np.array([10.0, 20.0, 30.0]), meas_su=np.array([1.0, 1.0, 1.0])) + ) + ) + } + return experiments + + +@pytest.fixture +def mock_calculator(): + calculator = MagicMock() + calculator.calculate_pattern.return_value = np.array([9.0, 19.0, 29.0]) + return calculator + + +@pytest.fixture +def mock_minimizer(): + minimizer = MagicMock() + minimizer.fit.return_value = MagicMock(success=True) + minimizer._sync_result_to_parameters = MagicMock() + minimizer.tracker.track = MagicMock(return_value=np.array([1.0, 2.0, 3.0])) + return minimizer + + +@pytest.fixture +def diffraction_minimizer(mock_minimizer): + with patch("easydiffraction.analysis.minimizers.minimizer_factory.MinimizerFactory.create_minimizer", return_value=mock_minimizer): + return DiffractionMinimizer(selection="lmfit (leastsq)") + + +def test_fit_no_params(diffraction_minimizer, mock_sample_models, mock_experiments, mock_calculator): + mock_sample_models.get_free_params.return_value = [] + mock_experiments.get_free_params.return_value = [] + + result = diffraction_minimizer.fit(mock_sample_models, mock_experiments, mock_calculator) + + # Assertions + assert result is None + + +def test_fit_with_params(diffraction_minimizer, mock_sample_models, mock_experiments, mock_calculator): + result = diffraction_minimizer.fit(mock_sample_models, mock_experiments, mock_calculator) + + # Assertions + assert diffraction_minimizer.results.success is True + assert mock_calculator.calculate_pattern.called + assert mock_sample_models.get_free_params.called + assert mock_experiments.get_free_params.called + + +def test_residual_function(diffraction_minimizer, mock_sample_models, mock_experiments, mock_calculator): + parameters = mock_sample_models.get_free_params() + mock_experiments.get_free_params() + engine_params = MagicMock() + + residuals = diffraction_minimizer._residual_function( + engine_params=engine_params, + parameters=parameters, + sample_models=mock_sample_models, + experiments=mock_experiments, + calculator=mock_calculator, + ) + + # Assertions + assert isinstance(residuals, np.ndarray) + assert len(residuals) == 3 + assert mock_calculator.calculate_pattern.called + assert diffraction_minimizer.minimizer._sync_result_to_parameters.called + + +@patch("easydiffraction.analysis.reliability_factors.get_reliability_inputs", return_value=(np.array([10.0]), np.array([9.0]), np.array([1.0]))) +def test_process_fit_results(mock_get_reliability_inputs, diffraction_minimizer, mock_sample_models, mock_experiments, mock_calculator): + diffraction_minimizer.results = MagicMock() + diffraction_minimizer._process_fit_results(mock_sample_models, mock_experiments, mock_calculator) + + # Assertions + # mock_get_reliability_inputs.assert_called_once_with(mock_sample_models, mock_experiments, mock_calculator) + + # Extract the arguments passed to `display_results` + _, kwargs = diffraction_minimizer.results.display_results.call_args + + # Assertions for arrays + np.testing.assert_array_equal(kwargs['y_calc'], np.array([9., 19., 29.])) + np.testing.assert_array_equal(kwargs['y_err'], np.array([1., 1., 1.])) + np.testing.assert_array_equal(kwargs['y_obs'], np.array([10., 20., 30.])) + + # Assertions for other arguments + assert kwargs["f_obs"] is None + assert kwargs["f_calc"] is None From 50b634984dcfe92edd0e24213c441bfd2c42d24b Mon Sep 17 00:00:00 2001 From: rozyczko Date: Wed, 16 Apr 2025 15:21:10 +0200 Subject: [PATCH 03/12] rest of analysis, core and experiments --- .../collections/test_joint_fit_experiment.py | 23 +++ .../analysis/test_reliability_factors.py | 109 ++++++++++ tests/unit_tests/core/test_objects.py | 125 ++++++++++++ tests/unit_tests/core/test_singletons.py | 105 ++++++++++ .../collections/test_background.py | 103 ++++++++++ .../experiments/collections/test_datastore.py | 151 ++++++++++++++ .../collections/test_linked_phases.py | 151 ++++++++++++++ .../components/test_experiment_type.py | 51 +++++ .../experiments/components/test_instrument.py | 89 +++++++++ .../experiments/components/test_peak.py | 159 +++++++++++++++ .../unit_tests/experiments/test_experiment.py | 187 ++++++++++++++++++ .../experiments/test_experiments.py | 94 +++++++++ 12 files changed, 1347 insertions(+) create mode 100644 tests/unit_tests/analysis/collections/test_joint_fit_experiment.py create mode 100644 tests/unit_tests/analysis/test_reliability_factors.py create mode 100644 tests/unit_tests/core/test_objects.py create mode 100644 tests/unit_tests/core/test_singletons.py create mode 100644 tests/unit_tests/experiments/collections/test_background.py create mode 100644 tests/unit_tests/experiments/collections/test_datastore.py create mode 100644 tests/unit_tests/experiments/collections/test_linked_phases.py create mode 100644 tests/unit_tests/experiments/components/test_experiment_type.py create mode 100644 tests/unit_tests/experiments/components/test_instrument.py create mode 100644 tests/unit_tests/experiments/components/test_peak.py create mode 100644 tests/unit_tests/experiments/test_experiment.py create mode 100644 tests/unit_tests/experiments/test_experiments.py diff --git a/tests/unit_tests/analysis/collections/test_joint_fit_experiment.py b/tests/unit_tests/analysis/collections/test_joint_fit_experiment.py new file mode 100644 index 00000000..bb58917a --- /dev/null +++ b/tests/unit_tests/analysis/collections/test_joint_fit_experiment.py @@ -0,0 +1,23 @@ +import pytest +from easydiffraction.analysis.collections.joint_fit_experiments import JointFitExperiment + +# filepath: src/easydiffraction/analysis/collections/test_joint_fit_experiments.py + + +def test_joint_fit_experiment_initialization(): + # Test initialization of JointFitExperiment + expt = JointFitExperiment(id="exp1", weight=1.5) + assert expt.id.value == "exp1" + assert expt.id.name == "id" + assert expt.id.cif_name == "id" + assert expt.weight.value == 1.5 + assert expt.weight.name == "weight" + assert expt.weight.cif_name == "weight" + + +def test_joint_fit_experiment_properties(): + # Test properties of JointFitExperiment + expt = JointFitExperiment(id="exp2", weight=2.0) + assert expt.cif_category_key == "joint_fit_experiment" + assert expt.category_key == "joint_fit_experiment" + assert expt._entry_id == "exp2" \ No newline at end of file diff --git a/tests/unit_tests/analysis/test_reliability_factors.py b/tests/unit_tests/analysis/test_reliability_factors.py new file mode 100644 index 00000000..f45b7968 --- /dev/null +++ b/tests/unit_tests/analysis/test_reliability_factors.py @@ -0,0 +1,109 @@ +import pytest +import numpy as np +from unittest.mock import Mock + +from easydiffraction.analysis.reliability_factors import ( + calculate_r_factor, + calculate_weighted_r_factor, + calculate_rb_factor, + calculate_r_factor_squared, + calculate_reduced_chi_square, + get_reliability_inputs, +) + + +def test_calculate_r_factor(): + y_obs = [10, 20, 30] + y_calc = [9, 19, 29] + result = calculate_r_factor(y_obs, y_calc) + expected = 0.05 + np.testing.assert_allclose(result, expected) + + # Test with empty arrays + assert np.isnan(calculate_r_factor([], [])) + + # Test with zero denominator + assert np.isnan(calculate_r_factor([0, 0, 0], [1, 1, 1])) + + +def test_calculate_weighted_r_factor(): + y_obs = [10, 20, 30] + y_calc = [9, 19, 29] + weights = [1, 1, 1] + result = calculate_weighted_r_factor(y_obs, y_calc, weights) + expected = 0.04629100498862757 + np.testing.assert_allclose(result, expected) + + # Test with empty arrays + assert np.isnan(calculate_weighted_r_factor([], [], [])) + + # Test with zero denominator + assert np.isnan(calculate_weighted_r_factor([0, 0, 0], [1, 1, 1], [1, 1, 1])) + + +def test_calculate_rb_factor(): + y_obs = [10, 20, 30] + y_calc = [9, 19, 29] + result = calculate_rb_factor(y_obs, y_calc) + expected = 0.05 + np.testing.assert_allclose(result, expected) + + # Test with empty arrays + assert np.isnan(calculate_rb_factor([], [])) + + # Test with zero denominator + assert np.isnan(calculate_rb_factor([0, 0, 0], [1, 1, 1])) + + +def test_calculate_r_factor_squared(): + y_obs = [10, 20, 30] + y_calc = [9, 19, 29] + result = calculate_r_factor_squared(y_obs, y_calc) + expected = 0.04629100498862757 + np.testing.assert_allclose(result, expected) + + # Test with empty arrays + assert np.isnan(calculate_r_factor_squared([], [])) + + # Test with zero denominator + assert np.isnan(calculate_r_factor_squared([0, 0, 0], [1, 1, 1])) + + +def test_calculate_reduced_chi_square(): + residuals = [1, 2, 3] + num_parameters = 1 + result = calculate_reduced_chi_square(residuals, num_parameters) + expected = 7.0 + np.testing.assert_allclose(result, expected) + + # Test with empty residuals + assert np.isnan(calculate_reduced_chi_square([], 1)) + + # Test with zero degrees of freedom + assert np.isnan(calculate_reduced_chi_square([1, 2, 3], 3)) + + +def test_get_reliability_inputs(): + # Mock inputs + sample_models = None + experiments = Mock() + calculator = Mock() + + experiments._items = { + "experiment1": Mock( + datastore=Mock( + pattern=Mock( + meas=np.array([10.0, 20.0, 30.0]), + meas_su=np.array([1.0, 1.0, 1.0]), + ) + ) + ) + } + calculator.calculate_pattern.return_value = np.array([9.0, 19.0, 29.0]) + + y_obs, y_calc, y_err = get_reliability_inputs(sample_models, experiments, calculator) + + # Assertions + np.testing.assert_array_equal(y_obs, [10.0, 20.0, 30.0]) + np.testing.assert_array_equal(y_calc, [9.0, 19.0, 29.0]) + np.testing.assert_array_equal(y_err, [1.0, 1.0, 1.0]) \ No newline at end of file diff --git a/tests/unit_tests/core/test_objects.py b/tests/unit_tests/core/test_objects.py new file mode 100644 index 00000000..fdc27bbc --- /dev/null +++ b/tests/unit_tests/core/test_objects.py @@ -0,0 +1,125 @@ +import pytest +from easydiffraction.core.objects import Descriptor, Parameter, Component, Collection, Datablock + +# filepath: src/easydiffraction/core/test_objects.py + + +def test_descriptor_initialization(): + desc = Descriptor(value=10, name="test", cif_name="test_cif", editable=True) + assert desc.value == 10 + assert desc.name == "test" + assert desc.cif_name == "test_cif" + assert desc.editable is True + + +def test_descriptor_value_setter(): + desc = Descriptor(value=10, name="test", cif_name="test_cif", editable=True) + desc.value = 20 + assert desc.value == 20 + + desc_non_editable = Descriptor(value=10, name="test", cif_name="test_cif", editable=False) + desc_non_editable.value = 30 + assert desc_non_editable.value == 10 # Value should not change + + +def test_parameter_initialization(): + param = Parameter( + value=5.0, + name="param", + cif_name="param_cif", + uncertainty=0.1, + free=True, + constrained=False, + min_value=0.0, + max_value=10.0, + ) + assert param.value == 5.0 + assert param.uncertainty == 0.1 + assert param.free is True + assert param.constrained is False + assert param.min == 0.0 + assert param.max == 10.0 + + +def test_component_abstract_methods(): + class TestComponent(Component): + @property + def _entry_id(self): + return "test_id" + + @property + def cif_category_key(self): + return "test_cif_category" + + @property + def category_key(self): + return "test_category" + + comp = TestComponent() + assert comp._entry_id == "test_id" + assert comp.cif_category_key == "test_cif_category" + assert comp.category_key == "test_category" + + +def test_component_attribute_handling(): + class TestComponent(Component): + @property + def _entry_id(self): + return "test_id" + + @property + def cif_category_key(self): + return "test_cif_category" + + @property + def category_key(self): + return "test_category" + + comp = TestComponent() + desc = Descriptor(value=10, name="test", cif_name="test_cif") + comp.test_attr = desc + assert comp.test_attr.value == 10 # Access Descriptor value directly + + +def test_collection_add_and_retrieve(): + collection = Collection() + collection._items["item1"] = "value1" + collection._items["item2"] = "value2" + + assert collection["item1"] == "value1" + assert collection["item2"] == "value2" + + +def test_collection_iteration(): + collection = Collection() + collection._items["item1"] = "value1" + collection._items["item2"] = "value2" + + items = list(collection) + assert items == ["value1", "value2"] + + +def test_datablock_components(): + class TestComponent(Component): + @property + def _entry_id(self): + return "test_id" + + @property + def cif_category_key(self): + return "test_cif_category" + + @property + def category_key(self): + return "test_category" + + class TestDatablock(Datablock): + def __init__(self): + self.component1 = TestComponent() + self.component2 = TestComponent() + + datablock = TestDatablock() + components = datablock.components() + assert len(components) == 2 + assert isinstance(components[0], TestComponent) + assert isinstance(components[1], TestComponent) \ No newline at end of file diff --git a/tests/unit_tests/core/test_singletons.py b/tests/unit_tests/core/test_singletons.py new file mode 100644 index 00000000..f4e4f5b5 --- /dev/null +++ b/tests/unit_tests/core/test_singletons.py @@ -0,0 +1,105 @@ +import pytest +from easydiffraction.core.singletons import BaseSingleton, UidMapHandler, ConstraintsHandler +from easydiffraction.core.objects import Descriptor, Parameter + +# filepath: src/easydiffraction/core/test_singletons.py + + +def test_base_singleton(): + class TestSingleton(BaseSingleton): + pass + + instance1 = TestSingleton.get() + instance2 = TestSingleton.get() + + assert instance1 is instance2 # Ensure only one instance is created + + +def test_uid_map_handler(): + param1 = Parameter(value=1.0, name="param1", cif_name="param1_cif") + param2 = Parameter(value=2.0, name="param2", cif_name="param2_cif") + + handler = UidMapHandler.get() + handler.set_uid_map([param1, param2]) + + uid_map = handler.get_uid_map() + assert len(uid_map) == 2 + assert uid_map[param1.uid] is param1 + assert uid_map[param2.uid] is param2 + + +def test_constraints_handler_set_aliases(): + class MockAlias: + def __init__(self, param): + self.param = param + + param1 = Parameter(value=1.0, name="param1", cif_name="param1_cif") + param2 = Parameter(value=2.0, name="param2", cif_name="param2_cif") + + aliases = {"alias1": MockAlias(param1), "alias2": MockAlias(param2)} + + handler = ConstraintsHandler.get() + handler.set_aliases(type("MockAliases", (object,), {"_items": aliases})) + + assert handler._alias_to_param["alias1"].param is param1 + assert handler._alias_to_param["alias2"].param is param2 + + +def test_constraints_handler_set_expressions(): + class MockExpression: + def __init__(self, lhs_alias, rhs_expr): + self.lhs_alias = Descriptor(value=lhs_alias, name="lhs", cif_name="lhs_cif") + self.rhs_expr = Descriptor(value=rhs_expr, name="rhs", cif_name="rhs_cif") + + expressions = { + "expr1": MockExpression("alias1", "alias2 + 1"), + "expr2": MockExpression("alias2", "alias1 * 2"), + } + + handler = ConstraintsHandler.get() + handler.set_expressions(type("MockExpressions", (object,), {"_items": expressions})) + + assert len(handler._parsed_constraints) == 2 + assert handler._parsed_constraints[0] == ("alias1", "alias2 + 1") + assert handler._parsed_constraints[1] == ("alias2", "alias1 * 2") + + +def test_constraints_handler_apply(): + class MockAlias: + def __init__(self, param): + self.param = param + + # Create parameters + param1 = Parameter(value=1.0, name="param1", cif_name="param1_cif") + param2 = Parameter(value=2.0, name="param2", cif_name="param2_cif") + + # Set up aliases + aliases = {"alias1": MockAlias(param1), "alias2": MockAlias(param2)} + + # Initialize UidMapHandler with parameters + uid_handler = UidMapHandler.get() + uid_handler.set_uid_map([param1, param2]) + + # Set up ConstraintsHandler + handler = ConstraintsHandler.get() + handler.set_aliases(type("MockAliases", (object,), {"_items": aliases})) + + # Define expressions + expressions = { + "expr1": type( + "MockExpression", + (object,), + { + "lhs_alias": Descriptor(value="alias1", name="lhs", cif_name="lhs_cif"), + "rhs_expr": Descriptor(value="alias2 + 1", name="rhs", cif_name="rhs_cif"), + }, + ) + } + handler.set_expressions(type("MockExpressions", (object,), {"_items": expressions})) + + # Apply constraints + handler.apply([param1, param2]) + + # Assert the updated value and constrained status + assert param1.value == 3.0 # alias2 (2.0) + 1 + assert param1.constrained is True \ No newline at end of file diff --git a/tests/unit_tests/experiments/collections/test_background.py b/tests/unit_tests/experiments/collections/test_background.py new file mode 100644 index 00000000..3436fd7e --- /dev/null +++ b/tests/unit_tests/experiments/collections/test_background.py @@ -0,0 +1,103 @@ +import pytest +import numpy as np +from unittest.mock import patch, MagicMock + +from easydiffraction.experiments.collections.background import ( + Point, + PolynomialTerm, + LineSegmentBackground, + ChebyshevPolynomialBackground, + BackgroundFactory, +) + + +def test_point_initialization(): + point = Point(x=1.0, y=2.0) + assert point.x.value == 1.0 + assert point.y.value == 2.0 + assert point.cif_category_key == "pd_background" + assert point.category_key == "background" + assert point._entry_id == "1.0" + + +def test_polynomial_term_initialization(): + term = PolynomialTerm(order=2, coef=3.0) + assert term.order.value == 2 + assert term.coef.value == 3.0 + assert term.cif_category_key == "pd_background" + assert term.category_key == "background" + assert term._entry_id == "2" + + +def test_line_segment_background_add_and_calculate(): + background = LineSegmentBackground() + background.add(1.0, 2.0) + background.add(3.0, 4.0) + + x_data = np.array([1.0, 2.0, 3.0]) + y_data = background.calculate(x_data) + + assert np.array_equal(y_data, np.array([2.0, 3.0, 4.0])) + + +def test_line_segment_background_calculate_no_points(): + background = LineSegmentBackground() + x_data = np.array([1.0, 2.0, 3.0]) + + with patch("builtins.print") as mock_print: + y_data = background.calculate(x_data) + assert np.array_equal(y_data, np.zeros_like(x_data)) + assert("No background points found. Setting background to zero." in str(mock_print.call_args.args[0])) + +def test_line_segment_background_show(capsys): + background = LineSegmentBackground() + background.add(1.0, 2.0) + background.add(3.0, 4.0) + + background.show() + captured = capsys.readouterr() + assert "Line-segment background points" in captured.out + +def test_chebyshev_polynomial_background_add_and_calculate(): + background = ChebyshevPolynomialBackground() + background.add(order=0, coef=1.0) + background.add(order=1, coef=2.0) + + x_data = np.array([0.0, 0.5, 1.0]) + y_data = background.calculate(x_data) + + # Expected values are calculated using the Chebyshev polynomial formula + u = (x_data - x_data.min()) / (x_data.max() - x_data.min()) * 2 - 1 + expected_y = 1.0 + 2.0 * u + assert np.allclose(y_data, expected_y) + + +def test_chebyshev_polynomial_background_calculate_no_terms(): + background = ChebyshevPolynomialBackground() + x_data = np.array([0.0, 0.5, 1.0]) + + with patch("builtins.print") as mock_print: + y_data = background.calculate(x_data) + assert np.array_equal(y_data, np.zeros_like(x_data)) + assert("No background points found. Setting background to zero." in str(mock_print.call_args.args[0])) + +def test_chebyshev_polynomial_background_show(capsys): + background = ChebyshevPolynomialBackground() + background.add(order=0, coef=1.0) + background.add(order=1, coef=2.0) + + background.show() + captured = capsys.readouterr() + assert "Chebyshev polynomial background terms" in captured.out + +def test_background_factory_create_supported_types(): + line_segment_background = BackgroundFactory.create("line-segment") + assert isinstance(line_segment_background, LineSegmentBackground) + + chebyshev_background = BackgroundFactory.create("chebyshev polynomial") + assert isinstance(chebyshev_background, ChebyshevPolynomialBackground) + + +def test_background_factory_create_unsupported_type(): + with pytest.raises(ValueError, match="Unsupported background type: 'unsupported'.*"): + BackgroundFactory.create("unsupported") diff --git a/tests/unit_tests/experiments/collections/test_datastore.py b/tests/unit_tests/experiments/collections/test_datastore.py new file mode 100644 index 00000000..c4aa4651 --- /dev/null +++ b/tests/unit_tests/experiments/collections/test_datastore.py @@ -0,0 +1,151 @@ +import pytest +import numpy as np +from unittest.mock import MagicMock, patch + +from easydiffraction.experiments.collections.datastore import ( + Pattern, + PowderPattern, + Datastore, + DatastoreFactory, +) + + +def test_pattern_initialization(): + mock_experiment = MagicMock() + pattern = Pattern(experiment=mock_experiment) + + assert pattern.experiment == mock_experiment + assert pattern.x is None + assert pattern.meas is None + assert pattern.meas_su is None + assert pattern.bkg is None + assert pattern.calc is None + + +def test_pattern_calc_property(): + mock_experiment = MagicMock() + pattern = Pattern(experiment=mock_experiment) + + # Test calc setter and getter + pattern.calc = [1, 2, 3] + assert pattern.calc == [1, 2, 3] + + +def test_powder_pattern_initialization(): + mock_experiment = MagicMock() + powder_pattern = PowderPattern(experiment=mock_experiment) + + assert powder_pattern.experiment == mock_experiment + assert isinstance(powder_pattern, Pattern) + + +def test_datastore_initialization_powder(): + mock_experiment = MagicMock() + datastore = Datastore(sample_form="powder", experiment=mock_experiment) + + assert datastore.sample_form == "powder" + assert isinstance(datastore.pattern, PowderPattern) + + +def test_datastore_initialization_single_crystal(): + mock_experiment = MagicMock() + datastore = Datastore(sample_form="single_crystal", experiment=mock_experiment) + + assert datastore.sample_form == "single_crystal" + assert isinstance(datastore.pattern, Pattern) + + +def test_datastore_initialization_invalid_sample_form(): + mock_experiment = MagicMock() + with pytest.raises(ValueError, match="Unknown sample form 'invalid'"): + Datastore(sample_form="invalid", experiment=mock_experiment) + + +def test_datastore_load_measured_data_valid(): + mock_experiment = MagicMock() + mock_experiment.name = "TestExperiment" + datastore = Datastore(sample_form="powder", experiment=mock_experiment) + + mock_data = np.array([[1.0, 2.0, 0.1], [2.0, 3.0, 0.2]]) + with patch("numpy.loadtxt", return_value=mock_data): + datastore.load_measured_data("mock_path") + + assert np.array_equal(datastore.pattern.x, mock_data[:, 0]) + assert np.array_equal(datastore.pattern.meas, mock_data[:, 1]) + assert np.array_equal(datastore.pattern.meas_su, mock_data[:, 2]) + + +def test_datastore_load_measured_data_no_uncertainty(): + mock_experiment = MagicMock() + mock_experiment.name = "TestExperiment" + datastore = Datastore(sample_form="powder", experiment=mock_experiment) + + mock_data = np.array([[1.0, 2.0], [2.0, 3.0]]) + with patch("numpy.loadtxt", return_value=mock_data): + datastore.load_measured_data("mock_path") + + assert np.array_equal(datastore.pattern.x, mock_data[:, 0]) + assert np.array_equal(datastore.pattern.meas, mock_data[:, 1]) + assert np.array_equal(datastore.pattern.meas_su, np.sqrt(np.abs(mock_data[:, 1]))) + + +def test_datastore_load_measured_data_invalid_file(): + mock_experiment = MagicMock() + datastore = Datastore(sample_form="powder", experiment=mock_experiment) + + with patch("numpy.loadtxt", side_effect=Exception("File not found")): + datastore.load_measured_data("invalid_path") + + +def test_datastore_show_measured_data(capsys): + mock_experiment = MagicMock() + datastore = Datastore(sample_form="powder", experiment=mock_experiment) + + datastore.pattern.x = [1.0, 2.0, 3.0] + datastore.pattern.meas = [10.0, 20.0, 30.0] + datastore.pattern.meas_su = [0.1, 0.2, 0.3] + + datastore.show_measured_data() + captured = capsys.readouterr() + + assert "Measured data (powder):" in captured.out + assert "x: [1.0, 2.0, 3.0]" in captured.out + assert "meas: [10.0, 20.0, 30.0]" in captured.out + assert "meas_su: [0.1, 0.2, 0.3]" in captured.out + + +def test_datastore_show_calculated_data(capsys): + mock_experiment = MagicMock() + datastore = Datastore(sample_form="powder", experiment=mock_experiment) + + datastore.pattern.calc = [100.0, 200.0, 300.0] + + datastore.show_calculated_data() + captured = capsys.readouterr() + + assert "Calculated data (powder):" in captured.out + assert "calc: [100.0, 200.0, 300.0]" in captured.out + + +def test_datastore_factory_create_powder(): + mock_experiment = MagicMock() + datastore = DatastoreFactory.create(sample_form="powder", experiment=mock_experiment) + + assert isinstance(datastore, Datastore) + assert datastore.sample_form == "powder" + assert isinstance(datastore.pattern, PowderPattern) + + +def test_datastore_factory_create_single_crystal(): + mock_experiment = MagicMock() + datastore = DatastoreFactory.create(sample_form="single_crystal", experiment=mock_experiment) + + assert isinstance(datastore, Datastore) + assert datastore.sample_form == "single_crystal" + assert isinstance(datastore.pattern, Pattern) + + +def test_datastore_factory_create_invalid_sample_form(): + mock_experiment = MagicMock() + with pytest.raises(ValueError, match="Unknown sample form 'invalid'"): + DatastoreFactory.create(sample_form="invalid", experiment=mock_experiment) diff --git a/tests/unit_tests/experiments/collections/test_linked_phases.py b/tests/unit_tests/experiments/collections/test_linked_phases.py new file mode 100644 index 00000000..c4aa4651 --- /dev/null +++ b/tests/unit_tests/experiments/collections/test_linked_phases.py @@ -0,0 +1,151 @@ +import pytest +import numpy as np +from unittest.mock import MagicMock, patch + +from easydiffraction.experiments.collections.datastore import ( + Pattern, + PowderPattern, + Datastore, + DatastoreFactory, +) + + +def test_pattern_initialization(): + mock_experiment = MagicMock() + pattern = Pattern(experiment=mock_experiment) + + assert pattern.experiment == mock_experiment + assert pattern.x is None + assert pattern.meas is None + assert pattern.meas_su is None + assert pattern.bkg is None + assert pattern.calc is None + + +def test_pattern_calc_property(): + mock_experiment = MagicMock() + pattern = Pattern(experiment=mock_experiment) + + # Test calc setter and getter + pattern.calc = [1, 2, 3] + assert pattern.calc == [1, 2, 3] + + +def test_powder_pattern_initialization(): + mock_experiment = MagicMock() + powder_pattern = PowderPattern(experiment=mock_experiment) + + assert powder_pattern.experiment == mock_experiment + assert isinstance(powder_pattern, Pattern) + + +def test_datastore_initialization_powder(): + mock_experiment = MagicMock() + datastore = Datastore(sample_form="powder", experiment=mock_experiment) + + assert datastore.sample_form == "powder" + assert isinstance(datastore.pattern, PowderPattern) + + +def test_datastore_initialization_single_crystal(): + mock_experiment = MagicMock() + datastore = Datastore(sample_form="single_crystal", experiment=mock_experiment) + + assert datastore.sample_form == "single_crystal" + assert isinstance(datastore.pattern, Pattern) + + +def test_datastore_initialization_invalid_sample_form(): + mock_experiment = MagicMock() + with pytest.raises(ValueError, match="Unknown sample form 'invalid'"): + Datastore(sample_form="invalid", experiment=mock_experiment) + + +def test_datastore_load_measured_data_valid(): + mock_experiment = MagicMock() + mock_experiment.name = "TestExperiment" + datastore = Datastore(sample_form="powder", experiment=mock_experiment) + + mock_data = np.array([[1.0, 2.0, 0.1], [2.0, 3.0, 0.2]]) + with patch("numpy.loadtxt", return_value=mock_data): + datastore.load_measured_data("mock_path") + + assert np.array_equal(datastore.pattern.x, mock_data[:, 0]) + assert np.array_equal(datastore.pattern.meas, mock_data[:, 1]) + assert np.array_equal(datastore.pattern.meas_su, mock_data[:, 2]) + + +def test_datastore_load_measured_data_no_uncertainty(): + mock_experiment = MagicMock() + mock_experiment.name = "TestExperiment" + datastore = Datastore(sample_form="powder", experiment=mock_experiment) + + mock_data = np.array([[1.0, 2.0], [2.0, 3.0]]) + with patch("numpy.loadtxt", return_value=mock_data): + datastore.load_measured_data("mock_path") + + assert np.array_equal(datastore.pattern.x, mock_data[:, 0]) + assert np.array_equal(datastore.pattern.meas, mock_data[:, 1]) + assert np.array_equal(datastore.pattern.meas_su, np.sqrt(np.abs(mock_data[:, 1]))) + + +def test_datastore_load_measured_data_invalid_file(): + mock_experiment = MagicMock() + datastore = Datastore(sample_form="powder", experiment=mock_experiment) + + with patch("numpy.loadtxt", side_effect=Exception("File not found")): + datastore.load_measured_data("invalid_path") + + +def test_datastore_show_measured_data(capsys): + mock_experiment = MagicMock() + datastore = Datastore(sample_form="powder", experiment=mock_experiment) + + datastore.pattern.x = [1.0, 2.0, 3.0] + datastore.pattern.meas = [10.0, 20.0, 30.0] + datastore.pattern.meas_su = [0.1, 0.2, 0.3] + + datastore.show_measured_data() + captured = capsys.readouterr() + + assert "Measured data (powder):" in captured.out + assert "x: [1.0, 2.0, 3.0]" in captured.out + assert "meas: [10.0, 20.0, 30.0]" in captured.out + assert "meas_su: [0.1, 0.2, 0.3]" in captured.out + + +def test_datastore_show_calculated_data(capsys): + mock_experiment = MagicMock() + datastore = Datastore(sample_form="powder", experiment=mock_experiment) + + datastore.pattern.calc = [100.0, 200.0, 300.0] + + datastore.show_calculated_data() + captured = capsys.readouterr() + + assert "Calculated data (powder):" in captured.out + assert "calc: [100.0, 200.0, 300.0]" in captured.out + + +def test_datastore_factory_create_powder(): + mock_experiment = MagicMock() + datastore = DatastoreFactory.create(sample_form="powder", experiment=mock_experiment) + + assert isinstance(datastore, Datastore) + assert datastore.sample_form == "powder" + assert isinstance(datastore.pattern, PowderPattern) + + +def test_datastore_factory_create_single_crystal(): + mock_experiment = MagicMock() + datastore = DatastoreFactory.create(sample_form="single_crystal", experiment=mock_experiment) + + assert isinstance(datastore, Datastore) + assert datastore.sample_form == "single_crystal" + assert isinstance(datastore.pattern, Pattern) + + +def test_datastore_factory_create_invalid_sample_form(): + mock_experiment = MagicMock() + with pytest.raises(ValueError, match="Unknown sample form 'invalid'"): + DatastoreFactory.create(sample_form="invalid", experiment=mock_experiment) diff --git a/tests/unit_tests/experiments/components/test_experiment_type.py b/tests/unit_tests/experiments/components/test_experiment_type.py new file mode 100644 index 00000000..d5f9a7ce --- /dev/null +++ b/tests/unit_tests/experiments/components/test_experiment_type.py @@ -0,0 +1,51 @@ +import pytest +from easydiffraction.experiments.components.experiment_type import ExperimentType +from easydiffraction.core.objects import Descriptor + + +def test_experiment_type_initialization(): + experiment_type = ExperimentType( + sample_form="powder", + beam_mode="CW", + radiation_probe="neutron" + ) + + assert isinstance(experiment_type.sample_form, Descriptor) + assert experiment_type.sample_form.value == "powder" + assert experiment_type.sample_form.name == "samle_form" + assert experiment_type.sample_form.cif_name == "sample_form" + + assert isinstance(experiment_type.beam_mode, Descriptor) + assert experiment_type.beam_mode.value == "CW" + assert experiment_type.beam_mode.name == "beam_mode" + assert experiment_type.beam_mode.cif_name == "beam_mode" + + assert isinstance(experiment_type.radiation_probe, Descriptor) + assert experiment_type.radiation_probe.value == "neutron" + assert experiment_type.radiation_probe.name == "radiation_probe" + assert experiment_type.radiation_probe.cif_name == "radiation_probe" + + +def test_experiment_type_properties(): + experiment_type = ExperimentType( + sample_form="single_crystal", + beam_mode="TOF", + radiation_probe="x-ray" + ) + + assert experiment_type.cif_category_key == "expt_type" + assert experiment_type.category_key == "expt_type" + assert experiment_type._entry_id is None + assert experiment_type._locked is False + + +def no_test_experiment_type_locking_attributes(): + # hmm this doesn't work as expected. + experiment_type = ExperimentType( + sample_form="powder", + beam_mode="CW", + radiation_probe="neutron" + ) + experiment_type._locked = True # Disallow adding new attributes + experiment_type.new_attribute = "value" + assert not hasattr(experiment_type, "new_attribute") diff --git a/tests/unit_tests/experiments/components/test_instrument.py b/tests/unit_tests/experiments/components/test_instrument.py new file mode 100644 index 00000000..a8118694 --- /dev/null +++ b/tests/unit_tests/experiments/components/test_instrument.py @@ -0,0 +1,89 @@ +import pytest +from easydiffraction.experiments.components.instrument import ( + InstrumentBase, + ConstantWavelengthInstrument, + TimeOfFlightInstrument, + InstrumentFactory, +) +from easydiffraction.core.objects import Parameter + + +def test_instrument_base_properties(): + instrument = InstrumentBase() + assert instrument.category_key == "instrument" + assert instrument.cif_category_key == "instr" + assert instrument._entry_id is None + + +def test_constant_wavelength_instrument_initialization(): + instrument = ConstantWavelengthInstrument( + setup_wavelength=1.5406, + calib_twotheta_offset=0.1 + ) + + assert isinstance(instrument.setup_wavelength, Parameter) + assert instrument.setup_wavelength.value == 1.5406 + assert instrument.setup_wavelength.name == "wavelength" + assert instrument.setup_wavelength.cif_name == "wavelength" + assert instrument.setup_wavelength.units == "Å" + + assert isinstance(instrument.calib_twotheta_offset, Parameter) + assert instrument.calib_twotheta_offset.value == 0.1 + assert instrument.calib_twotheta_offset.name == "twotheta_offset" + assert instrument.calib_twotheta_offset.cif_name == "2theta_offset" + assert instrument.calib_twotheta_offset.units == "deg" + + +def test_time_of_flight_instrument_initialization(): + instrument = TimeOfFlightInstrument( + setup_twotheta_bank=150.0, + calib_d_to_tof_offset=0.5, + calib_d_to_tof_linear=10000.0, + calib_d_to_tof_quad=-1.0, + calib_d_to_tof_recip=0.1 + ) + + assert isinstance(instrument.setup_twotheta_bank, Parameter) + assert instrument.setup_twotheta_bank.value == 150.0 + assert instrument.setup_twotheta_bank.name == "twotheta_bank" + assert instrument.setup_twotheta_bank.cif_name == "2theta_bank" + assert instrument.setup_twotheta_bank.units == "deg" + + assert isinstance(instrument.calib_d_to_tof_offset, Parameter) + assert instrument.calib_d_to_tof_offset.value == 0.5 + assert instrument.calib_d_to_tof_offset.name == "d_to_tof_offset" + assert instrument.calib_d_to_tof_offset.cif_name == "d_to_tof_offset" + assert instrument.calib_d_to_tof_offset.units == "µs" + + assert isinstance(instrument.calib_d_to_tof_linear, Parameter) + assert instrument.calib_d_to_tof_linear.value == 10000.0 + assert instrument.calib_d_to_tof_linear.name == "d_to_tof_linear" + assert instrument.calib_d_to_tof_linear.cif_name == "d_to_tof_linear" + assert instrument.calib_d_to_tof_linear.units == "µs/Å" + + assert isinstance(instrument.calib_d_to_tof_quad, Parameter) + assert instrument.calib_d_to_tof_quad.value == -1.0 + assert instrument.calib_d_to_tof_quad.name == "d_to_tof_quad" + assert instrument.calib_d_to_tof_quad.cif_name == "d_to_tof_quad" + assert instrument.calib_d_to_tof_quad.units == "µs/Ų" + + assert isinstance(instrument.calib_d_to_tof_recip, Parameter) + assert instrument.calib_d_to_tof_recip.value == 0.1 + assert instrument.calib_d_to_tof_recip.name == "d_to_tof_recip" + assert instrument.calib_d_to_tof_recip.cif_name == "d_to_tof_recip" + assert instrument.calib_d_to_tof_recip.units == "µs·Å" + + +def test_instrument_factory_create_constant_wavelength(): + instrument = InstrumentFactory.create(beam_mode="constant wavelength") + assert isinstance(instrument, ConstantWavelengthInstrument) + + +def test_instrument_factory_create_time_of_flight(): + instrument = InstrumentFactory.create(beam_mode="time-of-flight") + assert isinstance(instrument, TimeOfFlightInstrument) + + +def test_instrument_factory_create_invalid_beam_mode(): + with pytest.raises(ValueError, match="Unsupported beam mode: 'invalid'.*"): + InstrumentFactory.create(beam_mode="invalid") diff --git a/tests/unit_tests/experiments/components/test_peak.py b/tests/unit_tests/experiments/components/test_peak.py new file mode 100644 index 00000000..fb315843 --- /dev/null +++ b/tests/unit_tests/experiments/components/test_peak.py @@ -0,0 +1,159 @@ +import pytest +from easydiffraction.experiments.components.peak import ( + ConstantWavelengthBroadeningMixin, + TimeOfFlightBroadeningMixin, + EmpiricalAsymmetryMixin, + FcjAsymmetryMixin, + IkedaCarpenterAsymmetryMixin, + PeakBase, + ConstantWavelengthPseudoVoigt, + ConstantWavelengthSplitPseudoVoigt, + ConstantWavelengthThompsonCoxHastings, + TimeOfFlightPseudoVoigt, + TimeOfFlightIkedaCarpenter, + TimeOfFlightPseudoVoigtIkedaCarpenter, + TimeOfFlightPseudoVoigtBackToBackExponential, + PeakFactory, +) +from easydiffraction.core.objects import Parameter + + +# --- Tests for Mixins --- +def test_constant_wavelength_broadening_mixin(): + class TestClass(ConstantWavelengthBroadeningMixin): + def __init__(self): + self._add_constant_wavelength_broadening() + + obj = TestClass() + assert isinstance(obj.broad_gauss_u, Parameter) + assert obj.broad_gauss_u.value == 0.01 + assert obj.broad_gauss_v.value == -0.01 + assert obj.broad_gauss_w.value == 0.02 + assert obj.broad_lorentz_x.value == 0.0 + assert obj.broad_lorentz_y.value == 0.0 + + +def test_time_of_flight_broadening_mixin(): + class TestClass(TimeOfFlightBroadeningMixin): + def __init__(self): + self._add_time_of_flight_broadening() + + obj = TestClass() + assert isinstance(obj.broad_gauss_sigma_0, Parameter) + assert obj.broad_gauss_sigma_0.value == 0.0 + assert obj.broad_gauss_sigma_1.value == 0.0 + assert obj.broad_gauss_sigma_2.value == 0.0 + assert obj.broad_lorentz_gamma_0.value == 0.0 + assert obj.broad_lorentz_gamma_1.value == 0.0 + assert obj.broad_lorentz_gamma_2.value == 0.0 + assert obj.broad_mix_beta_0.value == 0.0 + assert obj.broad_mix_beta_1.value == 0.0 + + +def test_empirical_asymmetry_mixin(): + class TestClass(EmpiricalAsymmetryMixin): + def __init__(self): + self._add_empirical_asymmetry() + + obj = TestClass() + assert isinstance(obj.asym_empir_1, Parameter) + assert obj.asym_empir_1.value == 0.1 + assert obj.asym_empir_2.value == 0.2 + assert obj.asym_empir_3.value == 0.3 + assert obj.asym_empir_4.value == 0.4 + + +def test_fcj_asymmetry_mixin(): + class TestClass(FcjAsymmetryMixin): + def __init__(self): + self._add_fcj_asymmetry() + + obj = TestClass() + assert isinstance(obj.asym_fcj_1, Parameter) + assert obj.asym_fcj_1.value == 0.01 + assert obj.asym_fcj_2.value == 0.02 + + +def test_ikeda_carpenter_asymmetry_mixin(): + class TestClass(IkedaCarpenterAsymmetryMixin): + def __init__(self): + self._add_ikeda_carpenter_asymmetry() + + obj = TestClass() + assert isinstance(obj.asym_alpha_0, Parameter) + assert obj.asym_alpha_0.value == 0.01 + assert obj.asym_alpha_1.value == 0.02 + + +# --- Tests for Base and Derived Peak Classes --- +def test_peak_base_properties(): + peak = PeakBase() + assert peak.cif_category_key == "peak" + assert peak.category_key == "peak" + assert peak._entry_id is None + + +def test_constant_wavelength_pseudo_voigt_initialization(): + peak = ConstantWavelengthPseudoVoigt() + assert isinstance(peak.broad_gauss_u, Parameter) + assert peak.broad_gauss_u.value == 0.01 + + +def test_constant_wavelength_split_pseudo_voigt_initialization(): + peak = ConstantWavelengthSplitPseudoVoigt() + assert isinstance(peak.broad_gauss_u, Parameter) + assert isinstance(peak.asym_empir_1, Parameter) + assert peak.asym_empir_1.value == 0.1 + + +def test_constant_wavelength_thompson_cox_hastings_initialization(): + peak = ConstantWavelengthThompsonCoxHastings() + assert isinstance(peak.broad_gauss_u, Parameter) + assert isinstance(peak.asym_fcj_1, Parameter) + assert peak.asym_fcj_1.value == 0.01 + + +def test_time_of_flight_pseudo_voigt_initialization(): + peak = TimeOfFlightPseudoVoigt() + assert isinstance(peak.broad_gauss_sigma_0, Parameter) + assert peak.broad_gauss_sigma_0.value == 0.0 + + +def test_time_of_flight_ikeda_carpenter_initialization(): + peak = TimeOfFlightIkedaCarpenter() + assert isinstance(peak.broad_gauss_sigma_0, Parameter) + assert isinstance(peak.asym_alpha_0, Parameter) + assert peak.asym_alpha_0.value == 0.01 + + +def test_time_of_flight_pseudo_voigt_ikeda_carpenter_initialization(): + peak = TimeOfFlightPseudoVoigtIkedaCarpenter() + assert isinstance(peak.broad_gauss_sigma_0, Parameter) + assert isinstance(peak.asym_alpha_0, Parameter) + + +def test_time_of_flight_pseudo_voigt_back_to_back_exponential_initialization(): + peak = TimeOfFlightPseudoVoigtBackToBackExponential() + assert isinstance(peak.broad_gauss_sigma_0, Parameter) + assert isinstance(peak.asym_alpha_0, Parameter) + + +# --- Tests for PeakFactory --- +def test_peak_factory_create_constant_wavelength_pseudo_voigt(): + peak = PeakFactory.create(beam_mode="constant wavelength", profile_type="pseudo-voigt") + assert isinstance(peak, ConstantWavelengthPseudoVoigt) + + +def test_peak_factory_create_time_of_flight_ikeda_carpenter(): + peak = PeakFactory.create(beam_mode="time-of-flight", profile_type="ikeda-carpenter") + assert isinstance(peak, TimeOfFlightIkedaCarpenter) + + +def test_peak_factory_create_invalid_beam_mode(): + with pytest.raises(ValueError, match="Unsupported beam mode: 'invalid'.*"): + PeakFactory.create(beam_mode="invalid", profile_type="pseudo-voigt") + + +def test_peak_factory_create_invalid_profile_type(): + with pytest.raises(ValueError, match="Unsupported profile type 'invalid' for mode 'constant wavelength'.*"): + PeakFactory.create(beam_mode="constant wavelength", profile_type="invalid") diff --git a/tests/unit_tests/experiments/test_experiment.py b/tests/unit_tests/experiments/test_experiment.py new file mode 100644 index 00000000..30a57fcf --- /dev/null +++ b/tests/unit_tests/experiments/test_experiment.py @@ -0,0 +1,187 @@ +import pytest +import numpy as np +from unittest.mock import MagicMock, patch + +from easydiffraction.experiments.experiment import ( + BaseExperiment, + PowderExperiment, + SingleCrystalExperiment, + ExperimentFactory, + Experiment, +) +from easydiffraction.experiments.components.experiment_type import ExperimentType +from easydiffraction.core.constants import ( + DEFAULT_SAMPLE_FORM, + DEFAULT_BEAM_MODE, + DEFAULT_RADIATION_PROBE, + DEFAULT_PEAK_PROFILE_TYPE, + DEFAULT_BACKGROUND_TYPE, +) + + +class ConcreteBaseExperiment(BaseExperiment): + """Concrete implementation of BaseExperiment for testing.""" + + def _load_ascii_data_to_experiment(self, data_path): + pass + + def show_meas_chart(self, x_min=None, x_max=None): + pass + + +class ConcreteSingleCrystalExperiment(SingleCrystalExperiment): + """Concrete implementation of SingleCrystalExperiment for testing.""" + + def _load_ascii_data_to_experiment(self, data_path): + pass + + +def test_base_experiment_initialization(): + mock_type = MagicMock() + mock_type.beam_mode.value = DEFAULT_BEAM_MODE + mock_type.radiation_probe.value = "xray" + mock_type.sample_form.value = DEFAULT_SAMPLE_FORM + experiment = ConcreteBaseExperiment(name="TestExperiment", type=mock_type) + assert experiment.name == "TestExperiment" + assert experiment.type == mock_type + + +def test_base_experiment_as_cif(): + # Mock the type object + mock_type = MagicMock() + mock_type.beam_mode.value = DEFAULT_BEAM_MODE + mock_type.diffraction_type.value = "x-ray" + mock_type.as_cif.return_value = "type_cif" + mock_type.sample_form.value = DEFAULT_SAMPLE_FORM + # Create a concrete instance of BaseExperiment + experiment = ConcreteBaseExperiment(name="TestExperiment", type=mock_type) + + # Mock the instrument object + experiment.instrument = MagicMock() + experiment.instrument.as_cif.return_value = "instrument_cif" + + # Mock other components if necessary + experiment.peak = MagicMock() + experiment.peak.as_cif.return_value = "peak_cif" + + experiment.linked_phases = MagicMock() + experiment.linked_phases.as_cif.return_value = "linked_phases_cif" + + experiment.background = MagicMock() + experiment.background.as_cif.return_value = "background_cif" + + experiment.datastore.pattern.x = [1.0] + experiment.datastore.pattern.meas = [2.0] + experiment.datastore.pattern.meas_su = [0.1] + + # Call the as_cif method and verify the output + cif_output = experiment.as_cif() + assert "data_TestExperiment" in cif_output + assert "type_cif" in cif_output + assert "instrument_cif" in cif_output + assert "peak_cif" in cif_output + assert "linked_phases_cif" in cif_output + assert "background_cif" in cif_output + + +def test_powder_experiment_initialization(): + mock_type = MagicMock() + mock_type.beam_mode.value = DEFAULT_BEAM_MODE + mock_type.radiation_probe.value = "xray" + mock_type.sample_form.value = DEFAULT_SAMPLE_FORM + experiment = PowderExperiment(name="PowderTest", type=mock_type) + assert experiment.name == "PowderTest" + assert experiment.type == mock_type + assert experiment.peak is not None + assert experiment.background is not None + + +def test_powder_experiment_load_ascii_data(): + mock_type = MagicMock() + mock_type.beam_mode.value = DEFAULT_BEAM_MODE + mock_type.radiation_probe.value = "xray" + mock_type.sample_form.value = DEFAULT_SAMPLE_FORM + experiment = PowderExperiment(name="PowderTest", type=mock_type) + experiment.datastore = MagicMock() + experiment.datastore.pattern = MagicMock() + mock_data = np.array([[1.0, 2.0, 0.1], [2.0, 3.0, 0.2]]) + with patch("numpy.loadtxt", return_value=mock_data): + experiment._load_ascii_data_to_experiment("mock_path") + assert np.array_equal(experiment.datastore.pattern.x, mock_data[:, 0]) + assert np.array_equal(experiment.datastore.pattern.meas, mock_data[:, 1]) + assert np.array_equal(experiment.datastore.pattern.meas_su, mock_data[:, 2]) + + +def test_powder_experiment_show_meas_chart(): + mock_type = MagicMock() + mock_type.beam_mode.value = DEFAULT_BEAM_MODE + mock_type.radiation_probe.value = "xray" + mock_type.sample_form.value = DEFAULT_SAMPLE_FORM + experiment = PowderExperiment(name="PowderTest", type=mock_type) + experiment.datastore = MagicMock() + experiment.datastore.pattern.x = [1, 2, 3] + experiment.datastore.pattern.meas = [10, 20, 30] + with patch("easydiffraction.utils.chart_plotter.ChartPlotter.plot") as mock_plot: + experiment.show_meas_chart() + mock_plot.assert_called_once() + + +def test_single_crystal_experiment_initialization(): + mock_type = MagicMock() + mock_type.beam_mode.value = DEFAULT_BEAM_MODE + mock_type.radiation_probe.value = "xray" + mock_type.sample_form.value = DEFAULT_SAMPLE_FORM + experiment = ConcreteSingleCrystalExperiment(name="SingleCrystalTest", type=mock_type) + assert experiment.name == "SingleCrystalTest" + assert experiment.type == mock_type + assert experiment.linked_crystal is None + + +def test_single_crystal_experiment_show_meas_chart(): + mock_type = MagicMock() + mock_type.beam_mode.value = DEFAULT_BEAM_MODE + mock_type.radiation_probe.value = "xray" + mock_type.sample_form.value = DEFAULT_SAMPLE_FORM + experiment = ConcreteSingleCrystalExperiment(name="SingleCrystalTest", type=mock_type) + with patch("builtins.print") as mock_print: + experiment.show_meas_chart() + mock_print.assert_called_once_with("Showing measured data chart is not implemented yet.") + + +def test_experiment_factory_create_powder(): + experiment = ExperimentFactory.create( + name="PowderTest", + sample_form="powder", + beam_mode=DEFAULT_BEAM_MODE, + radiation_probe=DEFAULT_RADIATION_PROBE, + ) + assert isinstance(experiment, PowderExperiment) + assert experiment.name == "PowderTest" + +# to be added once single crystal works +def no_test_experiment_factory_create_single_crystal(): + experiment = ExperimentFactory.create( + name="SingleCrystalTest", + sample_form="single crystal", + beam_mode=DEFAULT_BEAM_MODE, + radiation_probe=DEFAULT_RADIATION_PROBE, + ) + assert isinstance(experiment, SingleCrystalExperiment) + assert experiment.name == "SingleCrystalTest" + + +def test_experiment_method(): + mock_data = np.array([[1.0, 2.0, 0.1], [2.0, 3.0, 0.2]]) + with patch("numpy.loadtxt", return_value=mock_data): + experiment = Experiment( + name="ExperimentTest", + sample_form="powder", + beam_mode=DEFAULT_BEAM_MODE, + radiation_probe=DEFAULT_RADIATION_PROBE, + data_path="mock_path", + ) + assert isinstance(experiment, PowderExperiment) + assert experiment.name == "ExperimentTest" + assert np.array_equal(experiment.datastore.pattern.x, mock_data[:, 0]) + assert np.array_equal(experiment.datastore.pattern.meas, mock_data[:, 1]) + assert np.array_equal(experiment.datastore.pattern.meas_su, mock_data[:, 2]) diff --git a/tests/unit_tests/experiments/test_experiments.py b/tests/unit_tests/experiments/test_experiments.py new file mode 100644 index 00000000..547786b8 --- /dev/null +++ b/tests/unit_tests/experiments/test_experiments.py @@ -0,0 +1,94 @@ +import pytest +from unittest.mock import MagicMock, patch + +from easydiffraction.experiments.experiments import Experiments +from easydiffraction.experiments.experiment import BaseExperiment, ExperimentFactory + + +class ConcreteBaseExperiment(BaseExperiment): + """Concrete implementation of BaseExperiment for testing.""" + + def _load_ascii_data_to_experiment(self, data_path): + pass + + def show_meas_chart(self, x_min=None, x_max=None): + pass + + +def test_experiments_initialization(): + experiments = Experiments() + assert isinstance(experiments, Experiments) + assert len(experiments.ids) == 0 + + +def test_experiments_add_prebuilt_experiment(): + experiments = Experiments() + mock_experiment = MagicMock(spec=BaseExperiment) + mock_experiment.name = "TestExperiment" + + experiments.add(experiment=mock_experiment) + assert "TestExperiment" in experiments.ids + assert experiments._experiments["TestExperiment"] == mock_experiment + + +def test_experiments_add_from_data_path(): + experiments = Experiments() + mock_experiment = MagicMock(spec=ConcreteBaseExperiment) + mock_experiment.name = "TestExperiment" + + with patch("easydiffraction.experiments.experiment.ExperimentFactory.create", return_value=mock_experiment): + experiments.add( + name="TestExperiment", + sample_form="powder", + beam_mode="default", + radiation_probe="x-ray", + data_path="mock_path", + ) + + assert "TestExperiment" in experiments.ids + assert experiments._experiments["TestExperiment"] == mock_experiment + mock_experiment._load_ascii_data_to_experiment.assert_called_once_with("mock_path") + + +def test_experiments_add_invalid_input(): + experiments = Experiments() + + with pytest.raises(ValueError, match="Provide either experiment, type parameters, cif_path, cif_str, or data_path"): + experiments.add() + + +def test_experiments_remove(): + experiments = Experiments() + mock_experiment = MagicMock(spec=BaseExperiment) + mock_experiment.name = "TestExperiment" + + experiments.add(experiment=mock_experiment) + assert "TestExperiment" in experiments.ids + + experiments.remove("TestExperiment") + assert "TestExperiment" not in experiments.ids + + +def test_experiments_show_names(capsys): + experiments = Experiments() + mock_experiment = MagicMock(spec=BaseExperiment) + mock_experiment.name = "TestExperiment" + + experiments.add(experiment=mock_experiment) + experiments.show_names() + + captured = capsys.readouterr() + assert "Defined experiments 🔬" in captured.out + assert "TestExperiment" in captured.out + + +def test_experiments_as_cif(): + experiments = Experiments() + mock_experiment = MagicMock(spec=BaseExperiment) + mock_experiment.name = "TestExperiment" + mock_experiment.as_cif.return_value = "mock_cif_content" + + experiments.add(experiment=mock_experiment) + cif_output = experiments.as_cif() + + assert "mock_cif_content" in cif_output From 7d7256aa70abaac3e154eaa27fe5e58961731f53 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Thu, 17 Apr 2025 11:22:29 +0200 Subject: [PATCH 04/12] remainder of unit test: sample_models and project --- .../collections/test_atom_sites.py | 102 ++++++++++ .../sample_models/components/test_cell.py | 45 +++++ .../components/test_space_group.py | 32 ++++ .../sample_models/test_sample_models.py | 74 ++++++++ tests/unit_tests/test_project.py | 178 ++++++++++++++++++ 5 files changed, 431 insertions(+) create mode 100644 tests/unit_tests/sample_models/collections/test_atom_sites.py create mode 100644 tests/unit_tests/sample_models/components/test_cell.py create mode 100644 tests/unit_tests/sample_models/components/test_space_group.py create mode 100644 tests/unit_tests/sample_models/test_sample_models.py create mode 100644 tests/unit_tests/test_project.py diff --git a/tests/unit_tests/sample_models/collections/test_atom_sites.py b/tests/unit_tests/sample_models/collections/test_atom_sites.py new file mode 100644 index 00000000..6069cf9a --- /dev/null +++ b/tests/unit_tests/sample_models/collections/test_atom_sites.py @@ -0,0 +1,102 @@ +import pytest +from easydiffraction.sample_models.collections.atom_sites import AtomSite, AtomSites + + +def test_atom_site_initialization(): + atom_site = AtomSite( + label="O1", + type_symbol="O", + fract_x=0.1, + fract_y=0.2, + fract_z=0.3, + wyckoff_letter="a", + occupancy=0.8, + b_iso=1.2, + adp_type="Biso" + ) + + # Assertions + assert atom_site.label.value == "O1" + assert atom_site.type_symbol.value == "O" + assert atom_site.fract_x.value == 0.1 + assert atom_site.fract_y.value == 0.2 + assert atom_site.fract_z.value == 0.3 + assert atom_site.wyckoff_letter.value == "a" + assert atom_site.occupancy.value == 0.8 + assert atom_site.b_iso.value == 1.2 + assert atom_site.adp_type.value == "Biso" + + +def test_atom_site_properties(): + atom_site = AtomSite( + label="O1", + type_symbol="O", + fract_x=0.1, + fract_y=0.2, + fract_z=0.3 + ) + + # Assertions + assert atom_site.cif_category_key == "atom_site" + assert atom_site.category_key == "atom_site" + assert atom_site._entry_id == "O1" + + +@pytest.fixture +def atom_sites_collection(): + return AtomSites() + + +def test_atom_sites_add(atom_sites_collection): + atom_sites_collection.add( + label="O1", + type_symbol="O", + fract_x=0.1, + fract_y=0.2, + fract_z=0.3, + wyckoff_letter="a", + occupancy=0.8, + b_iso=1.2, + adp_type="Biso" + ) + + # Assertions + assert "O1" in atom_sites_collection._items + atom_site = atom_sites_collection._items["O1"] + assert isinstance(atom_site, AtomSite) + assert atom_site.label.value == "O1" + assert atom_site.type_symbol.value == "O" + assert atom_site.fract_x.value == 0.1 + assert atom_site.fract_y.value == 0.2 + assert atom_site.fract_z.value == 0.3 + assert atom_site.wyckoff_letter.value == "a" + assert atom_site.occupancy.value == 0.8 + assert atom_site.b_iso.value == 1.2 + assert atom_site.adp_type.value == "Biso" + + +def test_atom_sites_add_multiple(atom_sites_collection): + atom_sites_collection.add( + label="O1", + type_symbol="O", + fract_x=0.1, + fract_y=0.2, + fract_z=0.3 + ) + atom_sites_collection.add( + label="C1", + type_symbol="C", + fract_x=0.4, + fract_y=0.5, + fract_z=0.6 + ) + + # Assertions + assert "O1" in atom_sites_collection._items + assert "C1" in atom_sites_collection._items + assert len(atom_sites_collection._items) == 2 + + +def test_atom_sites_type(atom_sites_collection): + # Assertions + assert atom_sites_collection._type == "category" diff --git a/tests/unit_tests/sample_models/components/test_cell.py b/tests/unit_tests/sample_models/components/test_cell.py new file mode 100644 index 00000000..7f9eaa52 --- /dev/null +++ b/tests/unit_tests/sample_models/components/test_cell.py @@ -0,0 +1,45 @@ +import pytest +from easydiffraction.sample_models.components.cell import Cell + + +def test_cell_initialization(): + cell = Cell( + length_a=5.0, + length_b=6.0, + length_c=7.0, + angle_alpha=80.0, + angle_beta=85.0, + angle_gamma=95.0 + ) + + # Assertions + assert cell.length_a.value == 5.0 + assert cell.length_b.value == 6.0 + assert cell.length_c.value == 7.0 + assert cell.angle_alpha.value == 80.0 + assert cell.angle_beta.value == 85.0 + assert cell.angle_gamma.value == 95.0 + + assert cell.length_a.units == "Å" + assert cell.angle_alpha.units == "deg" + + +def test_cell_default_initialization(): + cell = Cell() + + # Assertions + assert cell.length_a.value == 10.0 + assert cell.length_b.value == 10.0 + assert cell.length_c.value == 10.0 + assert cell.angle_alpha.value == 90.0 + assert cell.angle_beta.value == 90.0 + assert cell.angle_gamma.value == 90.0 + + +def test_cell_properties(): + cell = Cell() + + # Assertions + assert cell.cif_category_key == "cell" + assert cell.category_key == "cell" + assert cell._entry_id is None diff --git a/tests/unit_tests/sample_models/components/test_space_group.py b/tests/unit_tests/sample_models/components/test_space_group.py new file mode 100644 index 00000000..e4e55fa8 --- /dev/null +++ b/tests/unit_tests/sample_models/components/test_space_group.py @@ -0,0 +1,32 @@ +import pytest +from easydiffraction.sample_models.components.space_group import SpaceGroup + + +def test_space_group_initialization(): + space_group = SpaceGroup(name_h_m="P 2/m", it_coordinate_system_code=1) + + # Assertions + assert space_group.name_h_m.value == "P 2/m" + assert space_group.name_h_m.name == "name_h_m" + assert space_group.name_h_m.cif_name == "name_H-M_alt" + + assert space_group.it_coordinate_system_code.value == 1 + assert space_group.it_coordinate_system_code.name == "it_coordinate_system_code" + assert space_group.it_coordinate_system_code.cif_name == "IT_coordinate_system_code" + + +def test_space_group_default_initialization(): + space_group = SpaceGroup() + + # Assertions + assert space_group.name_h_m.value == "P 1" + assert space_group.it_coordinate_system_code.value is None + + +def test_space_group_properties(): + space_group = SpaceGroup() + + # Assertions + assert space_group.cif_category_key == "space_group" + assert space_group.category_key == "space_group" + assert space_group._entry_id is None diff --git a/tests/unit_tests/sample_models/test_sample_models.py b/tests/unit_tests/sample_models/test_sample_models.py new file mode 100644 index 00000000..2ac76a4c --- /dev/null +++ b/tests/unit_tests/sample_models/test_sample_models.py @@ -0,0 +1,74 @@ +import pytest +from unittest.mock import patch, MagicMock +from easydiffraction.sample_models.sample_models import SampleModel, SampleModels + + +@pytest.fixture +def mock_sample_model(): + with patch("easydiffraction.sample_models.sample_models.SpaceGroup") as MockSpaceGroup, \ + patch("easydiffraction.sample_models.sample_models.Cell") as MockCell, \ + patch("easydiffraction.sample_models.sample_models.AtomSites") as MockAtomSites: + space_group = MockSpaceGroup.return_value + cell = MockCell.return_value + atom_sites = MockAtomSites.return_value + + # Mock attributes + space_group.name_h_m.value = "P 1" + space_group.it_coordinate_system_code.value = 1 + cell.length_a.value = 1.0 + cell.length_b.value = 2.0 + cell.length_c.value = 3.0 + cell.angle_alpha.value = 90.0 + cell.angle_beta.value = 90.0 + cell.angle_gamma.value = 90.0 + atom_sites.__iter__.return_value = [] + + return SampleModel(name="test_model") + + +@pytest.fixture +def mock_sample_models(): + return SampleModels() + + +def test_sample_models_add(mock_sample_models, mock_sample_model): + mock_sample_models.add(model=mock_sample_model) + + # Assertions + assert "test_model" in mock_sample_models.get_ids() + + +def test_sample_models_remove(mock_sample_models, mock_sample_model): + mock_sample_models.add(model=mock_sample_model) + mock_sample_models.remove("test_model") + + # Assertions + assert "test_model" not in mock_sample_models.get_ids() + + +def test_sample_models_as_cif(mock_sample_models, mock_sample_model): + mock_sample_model.as_cif = MagicMock(return_value="data_test_model") + mock_sample_models.add(model=mock_sample_model) + + cif = mock_sample_models.as_cif() + + # Assertions + assert "data_test_model" in cif + + +@patch("builtins.print") +def test_sample_models_show_names(mock_print, mock_sample_models, mock_sample_model): + mock_sample_models.add(model=mock_sample_model) + mock_sample_models.show_names() + + # Assertions + mock_print.assert_called_with(["test_model"]) + + +@patch.object(SampleModel, "show_params", autospec=True) +def test_sample_models_show_params(mock_show_params, mock_sample_models, mock_sample_model): + mock_sample_models.add(model=mock_sample_model) + mock_sample_models.show_params() + + # Assertions + mock_show_params.assert_called_once_with(mock_sample_model) diff --git a/tests/unit_tests/test_project.py b/tests/unit_tests/test_project.py new file mode 100644 index 00000000..3c3b3217 --- /dev/null +++ b/tests/unit_tests/test_project.py @@ -0,0 +1,178 @@ +import pytest +import os +import datetime +import time +from unittest.mock import MagicMock, patch +from easydiffraction.project import Project, ProjectInfo +from easydiffraction.sample_models.sample_models import SampleModels +from easydiffraction.experiments.experiments import Experiments +from easydiffraction.analysis.analysis import Analysis +from easydiffraction.summary import Summary + + +# ------------------------------------------ +# Tests for ProjectInfo +# ------------------------------------------ + +def test_project_info_initialization(): + project_info = ProjectInfo() + + # Assertions + assert project_info.name == "untitled_project" + assert project_info.title == "Untitled Project" + assert project_info.description == "" + assert project_info.path == os.getcwd() + assert isinstance(project_info.created, datetime.datetime) + assert isinstance(project_info.last_modified, datetime.datetime) + + +def test_project_info_setters(): + project_info = ProjectInfo() + + # Set values + project_info.name = "test_project" + project_info.title = "Test Project" + project_info.description = "This is a test project." + project_info.path = "/test/path" + + # Assertions + assert project_info.name == "test_project" + assert project_info.title == "Test Project" + assert project_info.description == "This is a test project." + assert project_info.path == "/test/path" + + +def test_project_info_update_last_modified(): + project_info = ProjectInfo() + initial_last_modified = project_info.last_modified + + # Add a small delay to ensure the timestamps differ + time.sleep(0.001) + project_info.update_last_modified() + + # Assertions + assert project_info.last_modified > initial_last_modified + + +def test_project_info_as_cif(): + project_info = ProjectInfo() + project_info.name = "test_project" + project_info.title = "Test Project" + project_info.description = "This is a test project." + + cif = project_info.as_cif() + + # Assertions + assert "_project.id test_project" in cif + assert "_project.title 'Test Project'" in cif + assert "_project.description 'This is a test project.'" in cif + + +@patch("builtins.print") +def test_project_info_show_as_cif(mock_print): + project_info = ProjectInfo() + project_info.name = "test_project" + project_info.title = "Test Project" + project_info.description = "This is a test project." + + project_info.show_as_cif() + + # Assertions + mock_print.assert_called() + + +# ------------------------------------------ +# Tests for Project +# ------------------------------------------ + +def test_project_initialization(): + with patch("easydiffraction.sample_models.sample_models.SampleModels") as MockSampleModels, \ + patch("easydiffraction.experiments.experiments.Experiments") as MockExperiments, \ + patch("easydiffraction.analysis.analysis.Analysis") as MockAnalysis, \ + patch("easydiffraction.summary.Summary") as MockSummary: + project = Project() # Directly assign the instance to a variable + + # Assertions + assert project.name == "untitled_project" + assert isinstance(project.sample_models, SampleModels) + assert isinstance(project.experiments, Experiments) + assert isinstance(project.analysis, Analysis) + assert isinstance(project.summary, Summary) + + +@patch("builtins.print") +def test_project_load(mock_print): + with patch("easydiffraction.sample_models.sample_models.SampleModels"), \ + patch("easydiffraction.experiments.experiments.Experiments"), \ + patch("easydiffraction.analysis.analysis.Analysis"), \ + patch("easydiffraction.summary.Summary"): + project = Project() # Directly assign the instance to a variable + + project.load("/test/path") + + # Assertions + assert project.info.path == "/test/path" + assert "Loading project 📦 from /test/path" in mock_print.call_args_list[0][0][0] + + +@patch("builtins.print") +@patch("os.makedirs") +@patch("builtins.open", new_callable=MagicMock) +def test_project_save(mock_open, mock_makedirs, mock_print): + with patch("easydiffraction.sample_models.sample_models.SampleModels"), \ + patch("easydiffraction.experiments.experiments.Experiments"), \ + patch("easydiffraction.analysis.analysis.Analysis"), \ + patch("easydiffraction.summary.Summary"): + project = Project() # Directly assign the instance to a variable + + project.info.path = "/test/path" + project.save() + + # Assertions + mock_makedirs.assert_any_call("/test/path", exist_ok=True) + # mock_open.assert_any_call("/test/path\\summary.cif", "w") + +@patch("builtins.print") +@patch("os.makedirs") +@patch("builtins.open", new_callable=MagicMock) +def test_project_save_as(mock_open, mock_makedirs, mock_print): + with patch("easydiffraction.sample_models.sample_models.SampleModels"), \ + patch("easydiffraction.experiments.experiments.Experiments"), \ + patch("easydiffraction.analysis.analysis.Analysis"), \ + patch("easydiffraction.summary.Summary"): + project = Project() # Directly assign the instance to a variable + + project.save_as("new_project_path") + + # Assertions + assert project.info.path.endswith("new_project_path") + mock_makedirs.assert_any_call(project.info.path, exist_ok=True) + mock_open.assert_any_call(os.path.join(project.info.path, "project.cif"), "w") + + +def test_project_set_sample_models(): + with patch("easydiffraction.sample_models.sample_models.SampleModels"), \ + patch("easydiffraction.experiments.experiments.Experiments"), \ + patch("easydiffraction.analysis.analysis.Analysis"), \ + patch("easydiffraction.summary.Summary"): + project = Project() # Directly assign the instance to a variable + + sample_models = MagicMock() + project.set_sample_models(sample_models) + + # Assertions + assert project.sample_models == sample_models + + +def test_project_set_experiments(): + with patch("easydiffraction.sample_models.sample_models.SampleModels"), \ + patch("easydiffraction.experiments.experiments.Experiments"), \ + patch("easydiffraction.analysis.analysis.Analysis"), \ + patch("easydiffraction.summary.Summary"): + project = Project() # Directly assign the instance to a variable + + experiments = MagicMock() + project.set_experiments(experiments) + + # Assertions + assert project.experiments == experiments From 48cd5521c81c3cc0997e4dcbe819a96d0b9a1b30 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Thu, 17 Apr 2025 11:24:03 +0200 Subject: [PATCH 05/12] enable unit tests on push --- .github/workflows/ci-testing.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/ci-testing.yaml b/.github/workflows/ci-testing.yaml index 1cfcd6c0..563d061f 100644 --- a/.github/workflows/ci-testing.yaml +++ b/.github/workflows/ci-testing.yaml @@ -54,3 +54,7 @@ jobs: - name: Run Python functional tests shell: bash run: PYTHONPATH=$(pwd)/src python -m pytest tests/functional_tests/ --color=yes -n auto + + - name: Run Python unit tests + shell: bash + run: PYTHONPATH=$(pwd)/src python -m pytest tests/unit_tests/ --color=yes -n auto \ No newline at end of file From e4c431432504d5002c73913e52163ec26c5e519c Mon Sep 17 00:00:00 2001 From: rozyczko Date: Wed, 23 Apr 2025 08:50:54 +0200 Subject: [PATCH 06/12] initial implementation of type hints for methods --- src/easydiffraction/analysis/analysis.py | 94 ++++++----- src/easydiffraction/analysis/calculation.py | 18 ++- .../analysis/calculators/calculator_base.py | 60 +++++-- .../calculators/calculator_crysfml.py | 82 ++++++++-- .../analysis/calculators/calculator_cryspy.py | 141 ++++++++++------- .../calculators/calculator_factory.py | 24 +-- .../analysis/calculators/calculator_pdffit.py | 40 ++++- .../analysis/collections/aliases.py | 16 +- .../analysis/collections/constraints.py | 18 +-- .../collections/joint_fit_experiments.py | 16 +- src/easydiffraction/analysis/minimization.py | 100 ++++++++---- .../minimizers/fitting_progress_tracker.py | 59 +++---- .../analysis/minimizers/minimizer_base.py | 132 +++++++++------- .../analysis/minimizers/minimizer_dfols.py | 25 ++- .../analysis/minimizers/minimizer_factory.py | 55 +++++-- .../analysis/minimizers/minimizer_lmfit.py | 77 ++++++--- .../analysis/reliability_factors.py | 80 +++++++++- src/easydiffraction/core/objects.py | 146 +++++++++--------- src/easydiffraction/core/singletons.py | 30 ++-- .../crystallography/crystallography.py | 56 +++++-- .../experiments/collections/background.py | 58 +++---- .../experiments/collections/datastore.py | 52 ++++--- .../experiments/collections/linked_phases.py | 12 +- .../experiments/components/experiment_type.py | 11 +- .../experiments/components/instrument.py | 47 +++--- .../experiments/components/peak.py | 101 ++++++------ src/easydiffraction/experiments/experiment.py | 116 +++++++------- .../experiments/experiments.py | 54 +++---- src/easydiffraction/project.py | 107 ++++++------- .../sample_models/components/cell.py | 19 +-- .../sample_models/components/space_group.py | 9 +- .../sample_models/sample_models.py | 146 ++++++++++++------ src/easydiffraction/summary.py | 26 ++-- 33 files changed, 1240 insertions(+), 787 deletions(-) diff --git a/src/easydiffraction/analysis/analysis.py b/src/easydiffraction/analysis/analysis.py index 3cd0c05b..ae1f234a 100644 --- a/src/easydiffraction/analysis/analysis.py +++ b/src/easydiffraction/analysis/analysis.py @@ -1,5 +1,6 @@ import pandas as pd from tabulate import tabulate +from typing import List, Optional, Union, Any from easydiffraction.utils.formatting import ( paragraph, @@ -30,19 +31,25 @@ class Analysis: _calculator = CalculatorFactory.create_calculator('cryspy') - def __init__(self, project): + def __init__(self, project: Any) -> None: self.project = project self.aliases = ConstraintAliases() self.constraints = ConstraintExpressions() self.constraints_handler = ConstraintsHandler.get() self.calculator = Analysis._calculator # Default calculator shared by project - self._calculator_key = 'cryspy' # Added to track the current calculator - self._fit_mode = 'single' + self._calculator_key: str = 'cryspy' # Added to track the current calculator + self._fit_mode: str = 'single' self.fitter = DiffractionMinimizer('lmfit (leastsq)') - def _get_params_as_dataframe(self, params): + def _get_params_as_dataframe(self, params: List[Union[Descriptor, Parameter]]) -> pd.DataFrame: """ Convert a list of parameters to a DataFrame. + + Args: + params: List of Descriptor or Parameter objects. + + Returns: + A pandas DataFrame containing parameter information. """ rows = [] for param in params: @@ -75,9 +82,13 @@ def _get_params_as_dataframe(self, params): return dataframe - def _show_params(self, dataframe, column_headers): - """: + def _show_params(self, dataframe: pd.DataFrame, column_headers: List[str]) -> None: + """ Display parameters in a tabular format. + + Args: + dataframe: The pandas DataFrame containing parameter information. + column_headers: List of column headers to display. """ dataframe = dataframe[column_headers] indices = range(1, len(dataframe) + 1) # Force starting from 1 @@ -87,7 +98,7 @@ def _show_params(self, dataframe, column_headers): tablefmt="fancy_outline", showindex=indices)) - def show_all_params(self): + def show_all_params(self) -> None: sample_models_params = self.project.sample_models.get_all_params() experiments_params = self.project.experiments.get_all_params() @@ -110,7 +121,7 @@ def show_all_params(self): experiments_dataframe = self._get_params_as_dataframe(experiments_params) self._show_params(experiments_dataframe, column_headers=column_headers) - def show_fittable_params(self): + def show_fittable_params(self) -> None: sample_models_params = self.project.sample_models.get_fittable_params() experiments_params = self.project.experiments.get_fittable_params() @@ -135,7 +146,7 @@ def show_fittable_params(self): experiments_dataframe = self._get_params_as_dataframe(experiments_params) self._show_params(experiments_dataframe, column_headers=column_headers) - def show_free_params(self): + def show_free_params(self) -> None: sample_models_params = self.project.sample_models.get_free_params() experiments_params = self.project.experiments.get_free_params() free_params = sample_models_params + experiments_params @@ -158,7 +169,7 @@ def show_free_params(self): dataframe = self._get_params_as_dataframe(free_params) self._show_params(dataframe, column_headers=column_headers) - def how_to_access_parameters(self, show_description=False): + def how_to_access_parameters(self, show_description: bool = False) -> None: sample_models_params = self.project.sample_models.get_all_params() experiments_params = self.project.experiments.get_all_params() params = {'sample_models': sample_models_params, @@ -204,21 +215,20 @@ def how_to_access_parameters(self, show_description=False): tablefmt="fancy_outline", showindex=indices)) - - def show_current_calculator(self): + def show_current_calculator(self) -> None: print(paragraph("Current calculator")) print(self.current_calculator) @staticmethod - def show_supported_calculators(): + def show_supported_calculators() -> None: CalculatorFactory.show_supported_calculators() @property - def current_calculator(self): + def current_calculator(self) -> str: return self._calculator_key @current_calculator.setter - def current_calculator(self, calculator_name): + def current_calculator(self, calculator_name: str) -> None: calculator = CalculatorFactory.create_calculator(calculator_name) if calculator is None: return @@ -227,30 +237,30 @@ def current_calculator(self, calculator_name): print(paragraph("Current calculator changed to")) print(self.current_calculator) - def show_current_minimizer(self): + def show_current_minimizer(self) -> None: print(paragraph("Current minimizer")) print(self.current_minimizer) @staticmethod - def show_available_minimizers(): + def show_available_minimizers() -> None: MinimizerFactory.show_available_minimizers() @property - def current_minimizer(self): + def current_minimizer(self) -> Optional[str]: return self.fitter.selection if self.fitter else None @current_minimizer.setter - def current_minimizer(self, selection): + def current_minimizer(self, selection: str) -> None: self.fitter = DiffractionMinimizer(selection) print(paragraph(f"Current minimizer changed to")) print(self.current_minimizer) @property - def fit_mode(self): + def fit_mode(self) -> str: return self._fit_mode @fit_mode.setter - def fit_mode(self, strategy): + def fit_mode(self, strategy: str) -> None: if strategy not in ['single', 'joint']: raise ValueError("Fit mode must be either 'single' or 'joint'") self._fit_mode = strategy @@ -263,7 +273,7 @@ def fit_mode(self, strategy): print(paragraph("Current fit mode changed to")) print(self._fit_mode) - def show_available_fit_modes(self): + def show_available_fit_modes(self) -> None: strategies = [ { "Strategy": "single", @@ -276,18 +286,26 @@ def show_available_fit_modes(self): print(paragraph("Available fit modes")) print(tabulate(strategies, headers="keys", tablefmt="fancy_outline", showindex=False)) - def show_current_fit_mode(self): - print(paragraph("Current ffit mode")) + def show_current_fit_mode(self) -> None: + print(paragraph("Current fit mode")) print(self.fit_mode) - def calculate_pattern(self, expt_name): - # Pattern is calculated for the given experiment + def calculate_pattern(self, expt_name: str) -> Optional[pd.DataFrame]: + """ + Calculate the diffraction pattern for a given experiment. + + Args: + expt_name: The name of the experiment. + + Returns: + The calculated pattern as a pandas DataFrame. + """ experiment = self.project.experiments[expt_name] sample_models = self.project.sample_models calculated_pattern = self.calculator.calculate_pattern(sample_models, experiment) return calculated_pattern - def show_constraints(self): + def show_constraints(self) -> None: constraints_dict = self.constraints._items if not self.constraints._items: @@ -312,7 +330,7 @@ def show_constraints(self): tablefmt="fancy_outline", showindex=False)) - def _update_uid_map(self): + def _update_uid_map(self) -> None: """ Update the UID map for accessing parameters by UID. This is needed for adding or removing constraints. @@ -323,7 +341,7 @@ def _update_uid_map(self): UidMapHandler.get().set_uid_map(params) - def apply_constraints(self): + def apply_constraints(self) -> None: if not self.constraints._items: print(warning(f"No constraints defined.")) return @@ -337,7 +355,7 @@ def apply_constraints(self): self.constraints_handler.set_expressions(self.constraints) self.constraints_handler.apply(parameters=fittable_params) - def show_calc_chart(self, expt_name, x_min=None, x_max=None): + def show_calc_chart(self, expt_name: str, x_min: Optional[float] = None, x_max: Optional[float] = None) -> None: self.calculate_pattern(expt_name) experiment = self.project.experiments[expt_name] @@ -354,11 +372,11 @@ def show_calc_chart(self, expt_name, x_min=None, x_max=None): ) def show_meas_vs_calc_chart(self, - expt_name, - x_min=None, - x_max=None, - show_residual=False, - chart_height=DEFAULT_HEIGHT): + expt_name: str, + x_min: Optional[float] = None, + x_max: Optional[float] = None, + show_residual: bool = False, + chart_height: int = DEFAULT_HEIGHT) -> None: experiment = self.project.experiments[expt_name] self.calculate_pattern(expt_name) @@ -387,7 +405,7 @@ def show_meas_vs_calc_chart(self, labels=labels ) - def fit(self): + def fit(self) -> None: sample_models = self.project.sample_models if not sample_models: print("No sample models found in the project. Cannot run fit.") @@ -422,7 +440,7 @@ def fit(self): # After fitting, get the results self.fit_results = self.fitter.results - def as_cif(self): + def as_cif(self) -> str: lines = [] lines.append(f"_analysis.calculator_engine {self.current_calculator}") lines.append(f"_analysis.fitting_engine {self.current_minimizer}") @@ -430,7 +448,7 @@ def as_cif(self): return "\n".join(lines) - def show_as_cif(self): + def show_as_cif(self) -> None: cif_text = self.as_cif() lines = cif_text.splitlines() max_width = max(len(line) for line in lines) diff --git a/src/easydiffraction/analysis/calculation.py b/src/easydiffraction/analysis/calculation.py index 805e6f6f..4d9688f0 100644 --- a/src/easydiffraction/analysis/calculation.py +++ b/src/easydiffraction/analysis/calculation.py @@ -1,3 +1,5 @@ +from typing import Any, Optional, List +import numpy as np from .calculators.calculator_factory import CalculatorFactory @@ -6,28 +8,28 @@ class DiffractionCalculator: Invokes calculation engines for pattern generation. """ - def __init__(self, engine='cryspy'): + def __init__(self, engine: str = 'cryspy') -> None: """ Initialize the DiffractionCalculator with a specified backend engine. Args: - calculator_type (str): Type of the calculation engine to use. - Supported types: 'crysfml', 'cryspy', 'pdffit'. - Default is 'crysfml'. + engine: Type of the calculation engine to use. + Supported types: 'crysfml', 'cryspy', 'pdffit'. + Default is 'cryspy'. """ self.calculator_factory = CalculatorFactory() self._calculator = self.calculator_factory.create_calculator(engine) - def set_calculator(self, engine): + def set_calculator(self, engine: str) -> None: """ Switch to a different calculator engine at runtime. Args: - engine (str): New calculation engine type to use. + engine: New calculation engine type to use. """ self._calculator = self.calculator_factory.create_calculator(engine) - def calculate_structure_factors(self, sample_models, experiments): + def calculate_structure_factors(self, sample_models: Any, experiments: Any) -> Optional[List[Any]]: """ Calculate HKL intensities (structure factors) for sample models and experiments. @@ -40,7 +42,7 @@ def calculate_structure_factors(self, sample_models, experiments): """ return self._calculator.calculate_structure_factors(sample_models, experiments) - def calculate_pattern(self, sample_models, experiment): + def calculate_pattern(self, sample_models: Any, experiment: Any) -> np.ndarray: """ Generate diffraction pattern based on sample models and experiment. diff --git a/src/easydiffraction/analysis/calculators/calculator_base.py b/src/easydiffraction/analysis/calculators/calculator_base.py index 5c2d52d4..cdfb1a59 100644 --- a/src/easydiffraction/analysis/calculators/calculator_base.py +++ b/src/easydiffraction/analysis/calculators/calculator_base.py @@ -1,5 +1,6 @@ import numpy as np from abc import ABC, abstractmethod +from typing import List, Any from easydiffraction.core.singletons import ConstraintsHandler @@ -11,32 +12,42 @@ class CalculatorBase(ABC): @property @abstractmethod - def name(self): + def name(self) -> str: pass @property @abstractmethod - def engine_imported(self): + def engine_imported(self) -> bool: pass @abstractmethod - def calculate_structure_factors(self, sample_model, experiment): - # Single sample model, single experiment + def calculate_structure_factors(self, sample_model: Any, experiment: Any) -> None: + """ + Calculate structure factors for a single sample model and experiment. + """ pass def calculate_pattern(self, - sample_models, - experiment, - called_by_minimizer=False): - # Multiple sample models, single experiment - + sample_models: Any, + experiment: Any, + called_by_minimizer: bool = False) -> np.ndarray: + """ + Calculate the diffraction pattern for multiple sample models and a single experiment. + + Args: + sample_models: Collection of sample models. + experiment: The experiment object. + called_by_minimizer: Whether the calculation is called by a minimizer. + + Returns: + The calculated diffraction pattern as a NumPy array. + """ x_data = experiment.datastore.pattern.x y_calc_zeros = np.zeros_like(x_data) valid_linked_phases = self._get_valid_linked_phases(sample_models, experiment) # Apply user constraints to all sample models - # TODO: How to apply user constraints to all experiments (background, etc.)? constraints = ConstraintsHandler.get() constraints.apply(parameters=sample_models.get_all_params()) @@ -71,12 +82,33 @@ def calculate_pattern(self, @abstractmethod def _calculate_single_model_pattern(self, - sample_model, - experiment, - called_by_minimizer): + sample_model: Any, + experiment: Any, + called_by_minimizer: bool) -> np.ndarray: + """ + Calculate the diffraction pattern for a single sample model and experiment. + + Args: + sample_model: The sample model object. + experiment: The experiment object. + called_by_minimizer: Whether the calculation is called by a minimizer. + + Returns: + The calculated diffraction pattern as a NumPy array. + """ pass - def _get_valid_linked_phases(self, sample_models, experiment): + def _get_valid_linked_phases(self, sample_models: Any, experiment: Any) -> List[Any]: + """ + Get valid linked phases from the experiment. + + Args: + sample_models: Collection of sample models. + experiment: The experiment object. + + Returns: + A list of valid linked phases. + """ if not experiment.linked_phases: print('Warning: No linked phases found. Returning empty pattern.') return [] diff --git a/src/easydiffraction/analysis/calculators/calculator_crysfml.py b/src/easydiffraction/analysis/calculators/calculator_crysfml.py index cbe5811b..502f8af5 100644 --- a/src/easydiffraction/analysis/calculators/calculator_crysfml.py +++ b/src/easydiffraction/analysis/calculators/calculator_crysfml.py @@ -1,3 +1,5 @@ +import numpy as np +from typing import Any, Dict, List, Union from .calculator_base import CalculatorBase from easydiffraction.utils.formatting import warning @@ -13,22 +15,38 @@ class CrysfmlCalculator(CalculatorBase): Wrapper for Crysfml library. """ - engine_imported = cfml_py_utilities is not None + engine_imported: bool = cfml_py_utilities is not None @property - def name(self): + def name(self) -> str: return "crysfml" - def calculate_structure_factors(self, sample_models, experiments): - # Call Crysfml to calculate structure factors - raise NotImplementedError("HKL calculation is not implemented for CryspyCalculator.") + def calculate_structure_factors(self, sample_models: Any, experiments: Any) -> None: + """ + Call Crysfml to calculate structure factors. - def _calculate_single_model_pattern(self, - sample_model, - experiment, - called_by_minimizer=False): + Args: + sample_models: The sample models to calculate structure factors for. + experiments: The experiments associated with the sample models. + """ + raise NotImplementedError("HKL calculation is not implemented for CrysfmlCalculator.") + + def _calculate_single_model_pattern( + self, + sample_model: Any, + experiment: Any, + called_by_minimizer: bool = False + ) -> Union[np.ndarray, List[float]]: """ Calculates the diffraction pattern using Crysfml for the given sample model and experiment. + + Args: + sample_model: The sample model to calculate the pattern for. + experiment: The experiment associated with the sample model. + called_by_minimizer: Whether the calculation is called by a minimizer. + + Returns: + The calculated diffraction pattern as a NumPy array or a list of floats. """ crysfml_dict = self._crysfml_dict(sample_model, experiment) try: @@ -39,13 +57,33 @@ def _calculate_single_model_pattern(self, y = [] return y - def _adjust_pattern_length(self, pattern, target_length): + def _adjust_pattern_length(self, pattern: List[float], target_length: int) -> List[float]: + """ + Adjusts the length of the pattern to match the target length. + + Args: + pattern: The pattern to adjust. + target_length: The desired length of the pattern. + + Returns: + The adjusted pattern. + """ # TODO: Check the origin of this discrepancy coming from PyCrysFML if len(pattern) > target_length: return pattern[:target_length] return pattern - def _crysfml_dict(self, sample_model, experiment): + def _crysfml_dict(self, sample_model: Any, experiment: Any) -> Dict[str, Any]: + """ + Converts the sample model and experiment into a dictionary format for Crysfml. + + Args: + sample_model: The sample model to convert. + experiment: The experiment to convert. + + Returns: + A dictionary representation of the sample model and experiment. + """ sample_model_dict = self._convert_sample_model_to_dict(sample_model) experiment_dict = self._convert_experiment_to_dict(experiment) return { @@ -53,7 +91,16 @@ def _crysfml_dict(self, sample_model, experiment): "experiments": [experiment_dict] } - def _convert_sample_model_to_dict(self, sample_model): + def _convert_sample_model_to_dict(self, sample_model: Any) -> Dict[str, Any]: + """ + Converts a sample model into a dictionary format. + + Args: + sample_model: The sample model to convert. + + Returns: + A dictionary representation of the sample model. + """ sample_model_dict = { sample_model.name: { "_space_group_name_H-M_alt": sample_model.space_group.name_h_m.value, @@ -82,7 +129,16 @@ def _convert_sample_model_to_dict(self, sample_model): return sample_model_dict - def _convert_experiment_to_dict(self, experiment): + def _convert_experiment_to_dict(self, experiment: Any) -> Dict[str, Any]: + """ + Converts an experiment into a dictionary format. + + Args: + experiment: The experiment to convert. + + Returns: + A dictionary representation of the experiment. + """ expt_type = getattr(experiment, "type", None) instrument = getattr(experiment, "instrument", None) peak = getattr(experiment, "peak", None) diff --git a/src/easydiffraction/analysis/calculators/calculator_cryspy.py b/src/easydiffraction/analysis/calculators/calculator_cryspy.py index cff0a01d..92a72e71 100644 --- a/src/easydiffraction/analysis/calculators/calculator_cryspy.py +++ b/src/easydiffraction/analysis/calculators/calculator_cryspy.py @@ -1,5 +1,6 @@ import copy import numpy as np +from typing import Any, Dict, List, Union from .calculator_base import CalculatorBase from easydiffraction.utils.formatting import warning @@ -18,28 +19,48 @@ class CryspyCalculator(CalculatorBase): Converts EasyDiffraction models into Cryspy objects and computes patterns. """ - engine_imported = cryspy is not None + engine_imported: bool = cryspy is not None @property - def name(self): + def name(self) -> str: return "cryspy" - def __init__(self): + def __init__(self) -> None: super().__init__() - self._cryspy_dicts = {} + self._cryspy_dicts: Dict[str, Dict[str, Any]] = {} - def calculate_structure_factors(self, sample_models, experiments): + def calculate_structure_factors(self, sample_models: Any, experiments: Any) -> None: + """ + Raises a NotImplementedError as HKL calculation is not implemented. + + Args: + sample_models: The sample models to calculate structure factors for. + experiments: The experiments associated with the sample models. + """ raise NotImplementedError("HKL calculation is not implemented for CryspyCalculator.") - def _calculate_single_model_pattern(self, - sample_model, - experiment, - called_by_minimizer=False): - # We only recreate the cryspy_obj if this method is - # - NOT called by the minimizer, or - # - the cryspy_dict is NOT yet created. - # In other cases, we are modifying the existing cryspy_dict - # This allows significantly speeding up the calculation + def _calculate_single_model_pattern( + self, + sample_model: Any, + experiment: Any, + called_by_minimizer: bool = False + ) -> Union[np.ndarray, List[float]]: + """ + Calculates the diffraction pattern using Cryspy for the given sample model and experiment. + We only recreate the cryspy_obj if this method is + - NOT called by the minimizer, or + - the cryspy_dict is NOT yet created. + In other cases, we are modifying the existing cryspy_dict + This allows significantly speeding up the calculation + + Args: + sample_model: The sample model to calculate the pattern for. + experiment: The experiment associated with the sample model. + called_by_minimizer: Whether the calculation is called by a minimizer. + + Returns: + The calculated diffraction pattern as a NumPy array or a list of floats. + """ if called_by_minimizer: if self._cryspy_dicts and experiment.name in self._cryspy_dicts: cryspy_dict = self._recreate_cryspy_dict(sample_model, experiment) @@ -52,8 +73,7 @@ def _calculate_single_model_pattern(self, self._cryspy_dicts[experiment.name] = copy.deepcopy(cryspy_dict) - # Calculate pattern using cryspy - cryspy_in_out_dict = {} + cryspy_in_out_dict: Dict[str, Any] = {} rhochi_calc_chi_sq_by_dictionary( cryspy_dict, dict_in_out=cryspy_in_out_dict, @@ -61,7 +81,6 @@ def _calculate_single_model_pattern(self, flag_calc_analytical_derivatives=False ) - # Get cryspy block name based on experiment type prefixes = { "constant wavelength": "pd", "time-of-flight": "tof" @@ -73,7 +92,6 @@ def _calculate_single_model_pattern(self, print(f"[CryspyCalculator] Error: Unknown beam mode {experiment.type.beam_mode.value}") return [] - # Extract calculated pattern from cryspy_in_out_dict try: signal_plus = cryspy_in_out_dict[cryspy_block_name]['signal_plus'] signal_minus = cryspy_in_out_dict[cryspy_block_name]['signal_minus'] @@ -84,14 +102,21 @@ def _calculate_single_model_pattern(self, return y_calc_total - def _recreate_cryspy_dict(self, sample_model, experiment): - cryspy_dict = copy.deepcopy(self._cryspy_dicts[experiment.name]) + def _recreate_cryspy_dict(self, sample_model: Any, experiment: Any) -> Dict[str, Any]: + """ + Recreates the Cryspy dictionary for the given sample model and experiment. - # ---------- Update sample model parameters ---------- + Args: + sample_model: The sample model to update. + experiment: The experiment to update. + + Returns: + The updated Cryspy dictionary. + """ + cryspy_dict = copy.deepcopy(self._cryspy_dicts[experiment.name]) cryspy_model_id = f'crystal_{sample_model.name}' cryspy_model_dict = cryspy_dict[cryspy_model_id] - # Cell cryspy_cell = cryspy_model_dict['unit_cell_parameters'] cryspy_cell[0] = sample_model.cell.length_a.value @@ -100,7 +125,6 @@ def _recreate_cryspy_dict(self, sample_model, experiment): cryspy_cell[3] = np.deg2rad(sample_model.cell.angle_alpha.value) cryspy_cell[4] = np.deg2rad(sample_model.cell.angle_beta.value) cryspy_cell[5] = np.deg2rad(sample_model.cell.angle_gamma.value) - # Atomic coordinates cryspy_xyz = cryspy_model_dict['atom_fract_xyz'] for idx, atom_site in enumerate(sample_model.atom_sites): @@ -109,7 +133,7 @@ def _recreate_cryspy_dict(self, sample_model, experiment): cryspy_xyz[2][idx] = atom_site.fract_z.value # Atomic occupancies - cryspy_occ =cryspy_model_dict['atom_occupancy'] + cryspy_occ = cryspy_model_dict['atom_occupancy'] for idx, atom_site in enumerate(sample_model.atom_sites): cryspy_occ[idx] = atom_site.occupancy.value @@ -119,15 +143,12 @@ def _recreate_cryspy_dict(self, sample_model, experiment): cryspy_biso[idx] = atom_site.b_iso.value # ---------- Update experiment parameters ---------- - if experiment.type.beam_mode.value == 'constant wavelength': - cryspy_expt_name = f'pd_{experiment.name}' # TODO: use expt_name as in the SampleModel? Or change there for id instead of model_id? + cryspy_expt_name = f'pd_{experiment.name}' cryspy_expt_dict = cryspy_dict[cryspy_expt_name] - # Instrument cryspy_expt_dict['offset_ttheta'][0] = np.deg2rad(experiment.instrument.calib_twotheta_offset.value) cryspy_expt_dict['wavelength'][0] = experiment.instrument.setup_wavelength.value - # Peak cryspy_resolution = cryspy_expt_dict['resolution_parameters'] cryspy_resolution[0] = experiment.peak.broad_gauss_u.value @@ -137,15 +158,13 @@ def _recreate_cryspy_dict(self, sample_model, experiment): cryspy_resolution[4] = experiment.peak.broad_lorentz_y.value elif experiment.type.beam_mode.value == 'time-of-flight': - cryspy_expt_name = f'tof_{experiment.name}' # TODO: use expt_name as in the SampleModel? Or change there for id instead of model_id? + cryspy_expt_name = f'tof_{experiment.name}' cryspy_expt_dict = cryspy_dict[cryspy_expt_name] - # Instrument cryspy_expt_dict['zero'][0] = experiment.instrument.calib_d_to_tof_offset.value cryspy_expt_dict['dtt1'][0] = experiment.instrument.calib_d_to_tof_linear.value cryspy_expt_dict['dtt2'][0] = experiment.instrument.calib_d_to_tof_quad.value cryspy_expt_dict['ttheta_bank'] = np.deg2rad(experiment.instrument.setup_twotheta_bank.value) - # Peak cryspy_sigma = cryspy_expt_dict['profile_sigmas'] cryspy_sigma[0] = experiment.peak.broad_gauss_sigma_0.value @@ -162,36 +181,58 @@ def _recreate_cryspy_dict(self, sample_model, experiment): return cryspy_dict + def _recreate_cryspy_obj(self, sample_model: Any, experiment: Any) -> Any: + """ + Recreates the Cryspy object for the given sample model and experiment. + + Args: + sample_model: The sample model to recreate. + experiment: The experiment to recreate. - def _recreate_cryspy_obj(self, sample_model, experiment): + Returns: + The recreated Cryspy object. + """ cryspy_obj = str_to_globaln('') - # Add single sample model to cryspy_obj cryspy_sample_model_cif = self._convert_sample_model_to_cryspy_cif(sample_model) cryspy_sample_model_obj = str_to_globaln(cryspy_sample_model_cif) cryspy_obj.add_items(cryspy_sample_model_obj.items) - # Add single experiment to cryspy_obj - cryspy_experiment_cif = self._convert_experiment_to_cryspy_cif(experiment, - linked_phase=sample_model) + cryspy_experiment_cif = self._convert_experiment_to_cryspy_cif(experiment, linked_phase=sample_model) cryspy_experiment_obj = str_to_globaln(cryspy_experiment_cif) cryspy_obj.add_items(cryspy_experiment_obj.items) return cryspy_obj - def _convert_sample_model_to_cryspy_cif(self, sample_model): + def _convert_sample_model_to_cryspy_cif(self, sample_model: Any) -> str: + """ + Converts a sample model to a Cryspy CIF string. + + Args: + sample_model: The sample model to convert. + + Returns: + The Cryspy CIF string representation of the sample model. + """ return sample_model.as_cif() - def _convert_experiment_to_cryspy_cif(self, experiment, linked_phase): + def _convert_experiment_to_cryspy_cif(self, experiment: Any, linked_phase: Any) -> str: + """ + Converts an experiment to a Cryspy CIF string. + + Args: + experiment: The experiment to convert. + linked_phase: The linked phase associated with the experiment. + + Returns: + The Cryspy CIF string representation of the experiment. + """ expt_type = getattr(experiment, "type", None) instrument = getattr(experiment, "instrument", None) peak = getattr(experiment, "peak", None) cif_lines = [f"data_{experiment.name}"] - # STANDARD CATEGORIES - - # Experiment type category if expt_type is not None: cif_lines.append("") radiation_probe = expt_type.radiation_probe.value @@ -199,13 +240,10 @@ def _convert_experiment_to_cryspy_cif(self, experiment, linked_phase): radiation_probe = radiation_probe.replace("xray", "X-rays") cif_lines.append(f"_setup_radiation {radiation_probe}") - # Instrument category if instrument: instrument_mapping = { - # Constant wavelength "setup_wavelength": "_setup_wavelength", "calib_twotheta_offset": "_setup_offset_2theta", - # Time-of-flight "setup_twotheta_bank": "_tof_parameters_2theta_bank", "calib_d_to_tof_offset": "_tof_parameters_Zero", "calib_d_to_tof_linear": "_tof_parameters_Dtt1", @@ -217,16 +255,13 @@ def _convert_experiment_to_cryspy_cif(self, experiment, linked_phase): attr_value = getattr(instrument, local_attr_name).value cif_lines.append(f"{engine_key_name} {attr_value}") - # Peak category if peak: peak_mapping = { - # Constant wavelength "broad_gauss_u": "_pd_instr_resolution_U", "broad_gauss_v": "_pd_instr_resolution_V", "broad_gauss_w": "_pd_instr_resolution_W", "broad_lorentz_x": "_pd_instr_resolution_X", "broad_lorentz_y": "_pd_instr_resolution_Y", - # Time-of-flight "broad_gauss_sigma_0": "_tof_profile_sigma0", "broad_gauss_sigma_1": "_tof_profile_sigma1", "broad_gauss_sigma_2": "_tof_profile_sigma2", @@ -243,8 +278,6 @@ def _convert_experiment_to_cryspy_cif(self, experiment, linked_phase): attr_value = getattr(peak, local_attr_name).value cif_lines.append(f"{engine_key_name} {attr_value}") - # Experiment range category - # Extract measurement range dynamically x_data = experiment.datastore.pattern.x two_theta_min = float(x_data.min()) two_theta_max = float(x_data.max()) @@ -256,20 +289,12 @@ def _convert_experiment_to_cryspy_cif(self, experiment, linked_phase): cif_lines.append(f"_range_time_min {two_theta_min}") cif_lines.append(f"_range_time_max {two_theta_max}") - # ITERABLE CATEGORIES (LOOPS) - - # Linked phases category - # Force single linked phase to be used, as we handle multiple phases - # with their scales independently of the calculation engines cif_lines.append("") cif_lines.append("loop_") cif_lines.append("_phase_label") cif_lines.append("_phase_scale") cif_lines.append(f"{linked_phase.name} 1.0") - # Background category - # Force background to be zero, as we handle it independently of the - # calculation engines if expt_type.beam_mode.value == "constant wavelength": cif_lines.append("") cif_lines.append("loop_") @@ -285,7 +310,6 @@ def _convert_experiment_to_cryspy_cif(self, experiment, linked_phase): cif_lines.append(f"{two_theta_min} 0.0") cif_lines.append(f"{two_theta_max} 0.0") - # Measured data category if expt_type.beam_mode.value == "constant wavelength": cif_lines.append("") cif_lines.append("loop_") @@ -304,7 +328,6 @@ def _convert_experiment_to_cryspy_cif(self, experiment, linked_phase): for x_val, y_val, sy_val in zip(x_data, y_data, sy_data): cif_lines.append(f" {x_val:.5f} {y_val:.5f} {sy_val:.5f}") - # Combine all lines into a single string cryspy_experiment_cif = "\n".join(cif_lines) return cryspy_experiment_cif diff --git a/src/easydiffraction/analysis/calculators/calculator_factory.py b/src/easydiffraction/analysis/calculators/calculator_factory.py index 39f476ff..d5d1dfff 100644 --- a/src/easydiffraction/analysis/calculators/calculator_factory.py +++ b/src/easydiffraction/analysis/calculators/calculator_factory.py @@ -1,4 +1,5 @@ import tabulate +from typing import Dict, Type, List, Optional, Union, Any from easydiffraction.utils.formatting import ( paragraph, @@ -7,10 +8,11 @@ from .calculator_crysfml import CrysfmlCalculator from .calculator_cryspy import CryspyCalculator from .calculator_pdffit import PdffitCalculator +from .calculator_base import CalculatorBase class CalculatorFactory: - _potential_calculators = { + _potential_calculators: Dict[str, Dict[str, Union[str, Type[CalculatorBase]]]] = { 'crysfml': { 'description': 'CrysFML library for crystallographic calculations', 'class': CrysfmlCalculator @@ -26,7 +28,7 @@ class CalculatorFactory: } @classmethod - def _supported_calculators(cls): + def _supported_calculators(cls) -> Dict[str, Dict[str, Union[str, Type[CalculatorBase]]]]: return { name: cfg for name, cfg in cls._potential_calculators.items() @@ -34,16 +36,16 @@ def _supported_calculators(cls): } @classmethod - def list_supported_calculators(cls): + def list_supported_calculators(cls) -> List[str]: return list(cls._supported_calculators().keys()) @classmethod - def show_supported_calculators(cls): - header = ["Calculator", "Description"] - table_data = [] + def show_supported_calculators(cls) -> None: + header: List[str] = ["Calculator", "Description"] + table_data: List[List[str]] = [] for name, config in cls._supported_calculators().items(): - description = config.get('description', 'No description provided.') + description: str = config.get('description', 'No description provided.') table_data.append([name, description]) print(paragraph("Supported calculators")) @@ -57,7 +59,7 @@ def show_supported_calculators(cls): )) @classmethod - def create_calculator(cls, calculator_name): + def create_calculator(cls, calculator_name: str) -> Optional[CalculatorBase]: config = cls._supported_calculators().get(calculator_name) if not config: print(error(f"Unknown calculator '{calculator_name}'")) @@ -67,14 +69,14 @@ def create_calculator(cls, calculator_name): return config['class']() @classmethod - def register_calculator(cls, calculator_type, calculator_cls, description='No description provided.'): - cls._supported_calculators[calculator_type] = { + def register_calculator(cls, calculator_type: str, calculator_cls: Type[CalculatorBase], description: str = 'No description provided.') -> None: + cls._potential_calculators[calculator_type] = { 'class': calculator_cls, 'description': description } @classmethod - def register_minimizer(cls, name, minimizer_cls, method=None, description='No description provided.'): + def register_minimizer(cls, name: str, minimizer_cls: Type[Any], method: Optional[str] = None, description: str = 'No description provided.') -> None: cls._available_minimizers[name] = { 'engine': name, 'method': method, diff --git a/src/easydiffraction/analysis/calculators/calculator_pdffit.py b/src/easydiffraction/analysis/calculators/calculator_pdffit.py index c11b809a..793abd89 100644 --- a/src/easydiffraction/analysis/calculators/calculator_pdffit.py +++ b/src/easydiffraction/analysis/calculators/calculator_pdffit.py @@ -1,3 +1,4 @@ +from typing import Any, List, Union from .calculator_base import CalculatorBase from easydiffraction.utils.formatting import warning @@ -7,25 +8,48 @@ print(warning('"pdffit" module not found. This calculator will not work.')) pdffit = None + class PdffitCalculator(CalculatorBase): """ Wrapper for Pdffit library. """ - engine_imported = pdffit is not None + engine_imported: bool = pdffit is not None @property - def name(self): + def name(self) -> str: return "PdfFit" - def calculate_structure_factors(self, sample_models, experiments): - # PDF doesn't compute HKL but we keep interface consistent + def calculate_structure_factors(self, sample_models: Any, experiments: Any) -> List[Any]: + """ + PDF doesn't compute HKL but we keep the interface consistent. + + Args: + sample_models: The sample models to calculate structure factors for. + experiments: The experiments associated with the sample models. + + Returns: + An empty list, as PDF doesn't compute HKLs. + """ print("[PdfFit] Calculating HKLs (not applicable)...") return [] - def _calculate_single_model_pattern(self, - sample_model, - experiment, - called_by_minimizer=False): + def _calculate_single_model_pattern( + self, + sample_model: Any, + experiment: Any, + called_by_minimizer: bool = False + ) -> Union[List[float], Any]: + """ + Calculates the diffraction pattern for a single model using PdfFit. + + Args: + sample_model: The sample model to calculate the pattern for. + experiment: The experiment associated with the sample model. + called_by_minimizer: Whether the calculation is called by a minimizer. + + Returns: + An empty list or other placeholder, as this is not implemented yet. + """ print("[PdfFit] Not implemented yet.") return [] diff --git a/src/easydiffraction/analysis/collections/aliases.py b/src/easydiffraction/analysis/collections/aliases.py index a334c2cc..cd436b35 100644 --- a/src/easydiffraction/analysis/collections/aliases.py +++ b/src/easydiffraction/analysis/collections/aliases.py @@ -7,34 +7,34 @@ class ConstraintAlias(Component): - def __init__(self, alias: str, param: Parameter): + def __init__(self, alias: str, param: Parameter) -> None: super().__init__() - self.alias = Descriptor( + self.alias: Descriptor = Descriptor( value=alias, name="alias", cif_name="alias" ) - self.param = param + self.param: Parameter = param @property - def cif_category_key(self): + def cif_category_key(self) -> str: return "constraint_alias" @property - def category_key(self): + def category_key(self) -> str: return "constraint_alias" @property - def _entry_id(self): + def _entry_id(self) -> str: return self.alias.value class ConstraintAliases(Collection): @property - def _type(self): + def _type(self) -> str: return "category" # datablock or category - def add(self, alias: str, param: Parameter): + def add(self, alias: str, param: Parameter) -> None: alias_obj = ConstraintAlias(alias, param) self._items[alias_obj.alias.value] = alias_obj diff --git a/src/easydiffraction/analysis/collections/constraints.py b/src/easydiffraction/analysis/collections/constraints.py index e12c7e61..0d785164 100644 --- a/src/easydiffraction/analysis/collections/constraints.py +++ b/src/easydiffraction/analysis/collections/constraints.py @@ -9,46 +9,46 @@ class ConstraintExpression(Component): def __init__(self, id: str, lhs_alias: str, - rhs_expr: str): + rhs_expr: str) -> None: super().__init__() - self.id = Descriptor( + self.id: Descriptor = Descriptor( value=id, name="id", cif_name="id" ) - self.lhs_alias = Descriptor( + self.lhs_alias: Descriptor = Descriptor( value=lhs_alias, name="lhs_alias", cif_name="lhs_alias" ) - self.rhs_expr = Descriptor( + self.rhs_expr: Descriptor = Descriptor( value=rhs_expr, name="rhs_expr", cif_name="rhs_expr" ) @property - def cif_category_key(self): + def cif_category_key(self) -> str: return "constraint_expression" @property - def category_key(self): + def category_key(self) -> str: return "constraint_expression" @property - def _entry_id(self): + def _entry_id(self) -> str: return self.id.value class ConstraintExpressions(Collection): @property - def _type(self): + def _type(self) -> str: return "category" # datablock or category def add(self, id: str, lhs_alias: str, - rhs_expr: str): + rhs_expr: str) -> None: expression_obj = ConstraintExpression(id, lhs_alias, rhs_expr) self._items[expression_obj.id.value] = expression_obj diff --git a/src/easydiffraction/analysis/collections/joint_fit_experiments.py b/src/easydiffraction/analysis/collections/joint_fit_experiments.py index b4072851..ba825169 100644 --- a/src/easydiffraction/analysis/collections/joint_fit_experiments.py +++ b/src/easydiffraction/analysis/collections/joint_fit_experiments.py @@ -6,30 +6,30 @@ class JointFitExperiment(Component): - def __init__(self, id: str, weight: float): + def __init__(self, id: str, weight: float) -> None: super().__init__() - self.id = Descriptor( + self.id: Descriptor = Descriptor( value=id, name="id", cif_name="id" ) - self.weight = Descriptor( + self.weight: Descriptor = Descriptor( value=weight, name="weight", cif_name="weight" ) @property - def cif_category_key(self): + def cif_category_key(self) -> str: return "joint_fit_experiment" @property - def category_key(self): + def category_key(self) -> str: return "joint_fit_experiment" @property - def _entry_id(self): + def _entry_id(self) -> str: return self.id.value @@ -39,9 +39,9 @@ class JointFitExperiments(Collection): in a `joint` fit. """ @property - def _type(self): + def _type(self) -> str: return "category" # datablock or category - def add(self, id: str, weight: float): + def add(self, id: str, weight: float) -> None: expt = JointFitExperiment(id, weight) self._items[expt.id.value] = expt diff --git a/src/easydiffraction/analysis/minimization.py b/src/easydiffraction/analysis/minimization.py index 2a26eeaa..71cc7784 100644 --- a/src/easydiffraction/analysis/minimization.py +++ b/src/easydiffraction/analysis/minimization.py @@ -1,4 +1,6 @@ +from typing import Any, Optional, List, Callable, Dict from .minimizers.minimizer_factory import MinimizerFactory +from .minimizers.minimizer_base import FitResults from ..analysis.reliability_factors import get_reliability_inputs import numpy as np @@ -8,21 +10,28 @@ class DiffractionMinimizer: Handles the fitting workflow using a pluggable minimizer. """ - def __init__(self, selection: str = 'lmfit (leastsq)'): - self.selection = selection - self.engine = selection.split(' ')[0] # Extracts 'lmfit' or 'dfols' + def __init__(self, selection: str = 'lmfit (leastsq)') -> None: + self.selection: str = selection + self.engine: str = selection.split(' ')[0] # Extracts 'lmfit' or 'dfols' self.minimizer = MinimizerFactory.create_minimizer(selection) - self.results = None + self.results: Optional[FitResults] = None def fit(self, - sample_models, - experiments, - calculator, - weights=None): + sample_models: Any, + experiments: Any, + calculator: Any, + weights: Optional[Any] = None) -> None: """ Run the fitting process. + + Args: + sample_models: Collection of sample models. + experiments: Collection of experiments. + calculator: The calculator to use for pattern generation. + weights: Optional weights for joint fitting. + """ - params = sample_models.get_free_params() + experiments.get_free_params() + params: List[Any] = sample_models.get_free_params() + experiments.get_free_params() if not params: print("⚠️ No parameters selected for fitting.") @@ -31,7 +40,7 @@ def fit(self, for param in params: param.start_value = param.value - objective_function = lambda engine_params: self._residual_function( + objective_function: Callable[[Dict[str, Any]], np.ndarray] = lambda engine_params: self._residual_function( engine_params=engine_params, parameters=params, sample_models=sample_models, @@ -47,11 +56,16 @@ def fit(self, self._process_fit_results(sample_models, experiments, calculator) def _process_fit_results(self, - sample_models, - experiments, - calculator): + sample_models: Any, + experiments: Any, + calculator: Any) -> None: """ Collect reliability inputs and display results after fitting. + + Args: + sample_models: Collection of sample models. + experiments: Collection of experiments. + calculator: The calculator used for pattern generation. """ y_obs, y_calc, y_err = get_reliability_inputs(sample_models, experiments, calculator) @@ -62,31 +76,52 @@ def _process_fit_results(self, self.results.display_results(y_obs=y_obs, y_calc=y_calc, y_err=y_err, f_obs=f_obs, f_calc=f_calc) def _collect_free_parameters(self, - sample_models, - experiments): - free_params = sample_models.get_free_params() + experiments.get_free_params() + sample_models: Any, + experiments: Any) -> List[Any]: + """ + Collect free parameters from sample models and experiments. + + Args: + sample_models: Collection of sample models. + experiments: Collection of experiments. + + Returns: + List of free parameters. + """ + free_params: List[Any] = sample_models.get_free_params() + experiments.get_free_params() return free_params def _residual_function(self, - engine_params, - parameters, - sample_models, - experiments, - calculator, - weights=None): + engine_params: Dict[str, Any], + parameters: List[Any], + sample_models: Any, + experiments: Any, + calculator: Any, + weights: Optional[Any] = None) -> np.ndarray: """ Residual function computes the difference between measured and calculated patterns. It updates the parameter values according to the optimizer-provided engine_params. + + Args: + engine_params: Engine-specific parameter dict. + parameters: List of parameters being optimized. + sample_models: Collection of sample models. + experiments: Collection of experiments. + calculator: The calculator to use for pattern generation. + weights: Optional weights for joint fitting. + + Returns: + Array of weighted residuals. """ # Sync parameters back to objects self.minimizer._sync_result_to_parameters(parameters, engine_params) # Prepare weights for joint fitting - num_expts = len(experiments.ids) + num_expts: int = len(experiments.ids) if weights is None: _weights = np.ones(num_expts) else: - _weights_list = [] + _weights_list: List[float] = [] for id in experiments.ids: _weight = weights._items[id].weight.value _weights_list.append(_weight) @@ -97,17 +132,16 @@ def _residual_function(self, # two parts and fit together. If weights sum to one, then reduced chi_squared # will be half as large as expected. _weights *= num_expts / np.sum(_weights) - residuals = [] + residuals: List[float] = [] for (expt_id, experiment), weight in zip(experiments._items.items(), _weights): - y_calc = calculator.calculate_pattern(sample_models, - experiment, - called_by_minimizer=True) # True False - y_meas = experiment.datastore.pattern.meas - y_meas_su = experiment.datastore.pattern.meas_su - diff = (y_meas - y_calc) / y_meas_su + y_calc: np.ndarray = calculator.calculate_pattern(sample_models, + experiment, + called_by_minimizer=True) # True False + y_meas: np.ndarray = experiment.datastore.pattern.meas + y_meas_su: np.ndarray = experiment.datastore.pattern.meas_su + diff: np.ndarray = (y_meas - y_calc) / y_meas_su diff *= np.sqrt(weight) # Residuals are squared before going into reduced chi-squared residuals.extend(diff) - residuals = np.array(residuals) - return self.minimizer.tracker.track(residuals, parameters) + return self.minimizer.tracker.track(np.array(residuals), parameters) diff --git a/src/easydiffraction/analysis/minimizers/fitting_progress_tracker.py b/src/easydiffraction/analysis/minimizers/fitting_progress_tracker.py index 03300919..b31433b6 100644 --- a/src/easydiffraction/analysis/minimizers/fitting_progress_tracker.py +++ b/src/easydiffraction/analysis/minimizers/fitting_progress_tracker.py @@ -1,11 +1,12 @@ import numpy as np - +from typing import List, Optional from easydiffraction.analysis.reliability_factors import calculate_reduced_chi_square SIGNIFICANT_CHANGE_THRESHOLD = 0.01 # 1% threshold FIXED_WIDTH = 17 -def format_cell(cell, width=FIXED_WIDTH, align="center"): + +def format_cell(cell: str, width: int = FIXED_WIDTH, align: str = "center") -> str: cell_str = str(cell) if align == "center": return cell_str.center(width) @@ -22,16 +23,16 @@ class FittingProgressTracker: Tracks and reports the reduced chi-square during the optimization process. """ - def __init__(self): - self._iteration = 0 - self._previous_chi2 = None - self._last_chi2 = None - self._last_iteration = None - self._best_chi2 = None - self._best_iteration = None - self._fitting_time = None + def __init__(self) -> None: + self._iteration: int = 0 + self._previous_chi2: Optional[float] = None + self._last_chi2: Optional[float] = None + self._last_iteration: Optional[int] = None + self._best_chi2: Optional[float] = None + self._best_iteration: Optional[int] = None + self._fitting_time: Optional[float] = None - def reset(self): + def reset(self) -> None: self._iteration = 0 self._previous_chi2 = None self._last_chi2 = None @@ -40,7 +41,7 @@ def reset(self): self._best_iteration = None self._fitting_time = None - def track(self, residuals, parameters): + def track(self, residuals: np.ndarray, parameters: List[float]) -> np.ndarray: """ Track chi-square progress during the optimization process. @@ -55,7 +56,7 @@ def track(self, residuals, parameters): reduced_chi2 = calculate_reduced_chi_square(residuals, len(parameters)) - row = [] + row: List[str] = [] # First iteration, initialize tracking if self._previous_chi2 is None: @@ -64,7 +65,7 @@ def track(self, residuals, parameters): self._best_iteration = self._iteration row = [ - self._iteration, + str(self._iteration), f"{reduced_chi2:.2f}", "", "" @@ -75,7 +76,7 @@ def track(self, residuals, parameters): change_percent = (self._previous_chi2 - reduced_chi2) / self._previous_chi2 * 100 row = [ - self._iteration, + str(self._iteration), f"{self._previous_chi2:.2f}", f"{reduced_chi2:.2f}", f"{change_percent:.1f}% ↓" @@ -99,32 +100,32 @@ def track(self, residuals, parameters): return residuals @property - def best_chi2(self): + def best_chi2(self) -> Optional[float]: return self._best_chi2 @property - def best_iteration(self): + def best_iteration(self) -> Optional[int]: return self._best_iteration @property - def iteration(self): + def iteration(self) -> int: return self._iteration @property - def fitting_time(self): + def fitting_time(self) -> Optional[float]: return self._fitting_time - def start_timer(self): + def start_timer(self) -> None: import time self._start_time = time.perf_counter() - def stop_timer(self): + def stop_timer(self) -> None: import time self._end_time = time.perf_counter() self._fitting_time = self._end_time - self._start_time - def start_tracking(self, minimizer_name): - headers = ["iteration", "start", "improved", "improvement [%]"] + def start_tracking(self, minimizer_name: str) -> None: + headers: List[str] = ["iteration", "start", "improved", "improvement [%]"] print(f"🚀 Starting fitting process with '{minimizer_name}'...") print("📈 Goodness-of-fit (reduced χ²) change:") @@ -139,9 +140,9 @@ def start_tracking(self, minimizer_name): # Separator print("╞" + "╪".join(["═" * FIXED_WIDTH for _ in headers]) + "╡") - def add_tracking_info(self, row): + def add_tracking_info(self, row: List[str]) -> None: # Alignments for each column: iteration, start, improved, change [%] - aligns = ["center", "center", "center", "center"] + aligns: List[str] = ["center", "center", "center", "center"] formatted_row = "│" + "│".join([ format_cell(cell, align=aligns[i]) @@ -150,12 +151,12 @@ def add_tracking_info(self, row): print(formatted_row) - def finish_tracking(self): + def finish_tracking(self) -> None: # Print last iteration as last row - row = [ - self._last_iteration, + row: List[str] = [ + str(self._last_iteration), "", - f"{self._last_chi2:.2f}", + f"{self._last_chi2:.2f}" if self._last_chi2 is not None else "", "" ] self.add_tracking_info(row) diff --git a/src/easydiffraction/analysis/minimizers/minimizer_base.py b/src/easydiffraction/analysis/minimizers/minimizer_base.py index a76a5ffc..46a8d028 100644 --- a/src/easydiffraction/analysis/minimizers/minimizer_base.py +++ b/src/easydiffraction/analysis/minimizers/minimizer_base.py @@ -1,5 +1,7 @@ +import numpy as np import pandas as pd from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Union from tabulate import tabulate from ..reliability_factors import ( @@ -12,19 +14,29 @@ from easydiffraction.utils.formatting import paragraph + class FitResults: - def __init__(self, success=False, parameters=None, chi_square=None, - reduced_chi_square=None, message='', iterations=0, engine_result=None, starting_parameters=None, fitting_time=None, **kwargs): - self.success = success - self.parameters = parameters if parameters is not None else [] - self.chi_square = chi_square - self.reduced_chi_square = reduced_chi_square - self.message = message - self.iterations = iterations - self.engine_result = engine_result - self.result = None - self.starting_parameters = starting_parameters if starting_parameters is not None else [] - self.fitting_time = fitting_time # Store fitting time + def __init__(self, + success: bool = False, + parameters: Optional[List[Any]] = None, + chi_square: Optional[float] = None, + reduced_chi_square: Optional[float] = None, + message: str = '', + iterations: int = 0, + engine_result: Optional[Any] = None, + starting_parameters: Optional[List[Any]] = None, + fitting_time: Optional[float] = None, + **kwargs: Any) -> None: + self.success: bool = success + self.parameters: List[Any] = parameters if parameters is not None else [] + self.chi_square: Optional[float] = chi_square + self.reduced_chi_square: Optional[float] = reduced_chi_square + self.message: str = message + self.iterations: int = iterations + self.engine_result: Optional[Any] = engine_result + self.result: Optional[Any] = None + self.starting_parameters: List[Any] = starting_parameters if starting_parameters is not None else [] + self.fitting_time: Optional[float] = fitting_time if 'redchi' in kwargs and self.reduced_chi_square is None: self.reduced_chi_square = kwargs.get('redchi') @@ -32,7 +44,12 @@ def __init__(self, success=False, parameters=None, chi_square=None, for key, value in kwargs.items(): setattr(self, key, value) - def display_results(self, y_obs=None, y_calc=None, y_err=None, f_obs=None, f_calc=None): + def display_results(self, + y_obs: Optional[List[float]] = None, + y_calc: Optional[List[float]] = None, + y_err: Optional[List[float]] = None, + f_obs: Optional[List[float]] = None, + f_calc: Optional[List[float]] = None) -> None: status_icon = "✅" if self.success else "❌" rf = rf2 = wr = br = None if y_obs is not None and y_calc is not None: @@ -69,7 +86,7 @@ def display_results(self, y_obs=None, y_calc=None, y_err=None, f_obs=None, f_cal rows = [] for param in self.parameters: - datablock_id = getattr(param, 'datablock_id', 'N/A') # TODO: Check if 'N/A' is needed + datablock_id = getattr(param, 'datablock_id', 'N/A') category_key = getattr(param, 'category_key', 'N/A') collection_entry_id = getattr(param, 'collection_entry_id', 'N/A') name = getattr(param, 'name', 'N/A') @@ -96,7 +113,7 @@ def display_results(self, y_obs=None, y_calc=None, y_err=None, f_obs=None, f_cal relative_change]) dataframe = pd.DataFrame(rows) - indices = range(1, len(dataframe) + 1) # Force starting from 1 + indices = range(1, len(dataframe) + 1) print(tabulate(dataframe, headers=headers, @@ -110,33 +127,31 @@ class MinimizerBase(ABC): Provides shared logic and structure for concrete minimizers. """ def __init__(self, - name=None, - method=None, - max_iterations=None): - # 'method' is used only by minimizers supporting multiple methods - # (e.g., lmfit). For minimizers like dfols, pass None. - self.name = name - self.method = method - self.max_iterations = max_iterations - self.result = None - self._previous_chi2 = None - self._iteration = None - self._best_chi2 = None - self._best_iteration = None - self._fitting_time = None - self.tracker = FittingProgressTracker() - - def _start_tracking(self, minimizer_name): + name: Optional[str] = None, + method: Optional[str] = None, + max_iterations: Optional[int] = None) -> None: + self.name: Optional[str] = name + self.method: Optional[str] = method + self.max_iterations: Optional[int] = max_iterations + self.result: Optional[FitResults] = None + self._previous_chi2: Optional[float] = None + self._iteration: Optional[int] = None + self._best_chi2: Optional[float] = None + self._best_iteration: Optional[int] = None + self._fitting_time: Optional[float] = None + self.tracker: FittingProgressTracker = FittingProgressTracker() + + def _start_tracking(self, minimizer_name: str) -> None: self.tracker.reset() self.tracker.start_tracking(minimizer_name) self.tracker.start_timer() - def _stop_tracking(self): + def _stop_tracking(self) -> None: self.tracker.stop_timer() self.tracker.finish_tracking() @abstractmethod - def _prepare_solver_args(self, parameters): + def _prepare_solver_args(self, parameters: List[Any]) -> Dict[str, Any]: """ Prepare the solver arguments directly from the list of free parameters. """ @@ -144,64 +159,63 @@ def _prepare_solver_args(self, parameters): @abstractmethod def _run_solver(self, - objective_function, - engine_parameters): + objective_function: Callable[..., Any], + engine_parameters: Dict[str, Any]) -> Any: pass @abstractmethod def _sync_result_to_parameters(self, - raw_result, - parameters): + raw_result: Any, + parameters: List[Any]) -> None: pass def _finalize_fit(self, - parameters, - raw_result): - self._sync_result_to_parameters(parameters, - raw_result) + parameters: List[Any], + raw_result: Any) -> FitResults: + self._sync_result_to_parameters(raw_result, parameters) success = self._check_success(raw_result) self.result = FitResults( success=success, parameters=parameters, reduced_chi_square=self.tracker.best_chi2, - raw_result=raw_result, + engine_result=raw_result, starting_parameters=parameters, fitting_time=self.tracker.fitting_time ) return self.result @abstractmethod - def _check_success(self, raw_result): + def _check_success(self, raw_result: Any) -> bool: """ Determine whether the fit was successful. This must be implemented by concrete minimizers. """ pass - def fit(self, parameters, objective_function): - minimizer_name = self.name + def fit(self, + parameters: List[Any], + objective_function: Callable[..., Any]) -> FitResults: + minimizer_name = self.name or "Unnamed Minimizer" if self.method is not None: minimizer_name += f" ({self.method})" self._start_tracking(minimizer_name) solver_args = self._prepare_solver_args(parameters) - raw_result = self._run_solver(objective_function, - **solver_args) + raw_result = self._run_solver(objective_function, **solver_args) self._stop_tracking() - result = self._finalize_fit(parameters, - raw_result) + result = self._finalize_fit(parameters, raw_result) return result def _objective_function(self, - engine_params, - parameters, - sample_models, - experiments, - calculator): + engine_params: Dict[str, Any], + parameters: List[Any], + sample_models: Any, + experiments: Any, + calculator: Any) -> np.ndarray: return self._compute_residuals(engine_params, parameters, sample_models, @@ -209,10 +223,10 @@ def _objective_function(self, calculator) def _create_objective_function(self, - parameters, - sample_models, - experiments, - calculator): + parameters: List[Any], + sample_models: Any, + experiments: Any, + calculator: Any) -> Callable[[Dict[str, Any]], np.ndarray]: return lambda engine_params: self._objective_function( engine_params, parameters, diff --git a/src/easydiffraction/analysis/minimizers/minimizer_dfols.py b/src/easydiffraction/analysis/minimizers/minimizer_dfols.py index 74c06e1c..0fcd0fb4 100644 --- a/src/easydiffraction/analysis/minimizers/minimizer_dfols.py +++ b/src/easydiffraction/analysis/minimizers/minimizer_dfols.py @@ -1,18 +1,20 @@ import numpy as np from dfols import solve from .minimizer_base import MinimizerBase +from typing import Any, Dict, List DEFAULT_MAX_ITERATIONS = 1000 + class DfolsMinimizer(MinimizerBase): """ Minimizer using the DFO-LS package (Derivative-Free Optimization for Least-Squares). """ - def __init__(self, name='dfols', max_iterations=DEFAULT_MAX_ITERATIONS, **kwargs): + def __init__(self, name: str = 'dfols', max_iterations: int = DEFAULT_MAX_ITERATIONS, **kwargs: Any) -> None: super().__init__(name=name, method=None, max_iterations=max_iterations) - def _prepare_solver_args(self, parameters): + def _prepare_solver_args(self, parameters: List[Any]) -> Dict[str, Any]: x0 = [] bounds_lower = [] bounds_upper = [] @@ -23,7 +25,7 @@ def _prepare_solver_args(self, parameters): bounds = (np.array(bounds_lower), np.array(bounds_upper)) return {'x0': np.array(x0), 'bounds': bounds} - def _run_solver(self, objective_function, **kwargs): + def _run_solver(self, objective_function: Any, **kwargs: Any) -> Any: x0 = kwargs.get('x0') bounds = kwargs.get('bounds') return solve(objective_function, @@ -31,8 +33,14 @@ def _run_solver(self, objective_function, **kwargs): bounds=bounds, maxfun=self.max_iterations) + def _sync_result_to_parameters(self, parameters: List[Any], raw_result: Any) -> None: + """ + Synchronizes the result from the solver to the parameters. - def _sync_result_to_parameters(self, parameters, raw_result): + Args: + parameters: List of parameters being optimized. + raw_result: The result object returned by the solver. + """ # Ensure compatibility with raw_result coming from dfols.solve() if hasattr(raw_result, 'x'): result_values = raw_result.x @@ -44,9 +52,14 @@ def _sync_result_to_parameters(self, parameters, raw_result): # DFO-LS doesn't provide uncertainties; set to None or calculate later if needed param.uncertainty = None - def _check_success(self, raw_result): + def _check_success(self, raw_result: Any) -> bool: """ Determines success from DFO-LS result dictionary. - Typically, status == 0 means success. + + Args: + raw_result: The result object returned by the solver. + + Returns: + True if the optimization was successful, False otherwise. """ return raw_result.flag == raw_result.EXIT_SUCCESS \ No newline at end of file diff --git a/src/easydiffraction/analysis/minimizers/minimizer_factory.py b/src/easydiffraction/analysis/minimizers/minimizer_factory.py index 0ab16860..d398ea73 100644 --- a/src/easydiffraction/analysis/minimizers/minimizer_factory.py +++ b/src/easydiffraction/analysis/minimizers/minimizer_factory.py @@ -1,12 +1,15 @@ import tabulate +from typing import List, Type, Optional, Dict, Any from easydiffraction.utils.formatting import paragraph from .minimizer_lmfit import LmfitMinimizer from .minimizer_dfols import DfolsMinimizer +from .minimizer_base import MinimizerBase + class MinimizerFactory: - _available_minimizers = { + _available_minimizers: Dict[str, Dict[str, Any]] = { 'lmfit': { 'engine': 'lmfit', 'method': 'leastsq', @@ -34,16 +37,25 @@ class MinimizerFactory: } @classmethod - def list_available_minimizers(cls): + def list_available_minimizers(cls) -> List[str]: + """ + List all available minimizers. + + Returns: + A list of minimizer names. + """ return list(cls._available_minimizers.keys()) @classmethod - def show_available_minimizers(cls): - header = ["Minimizer", "Description"] - table_data = [] + def show_available_minimizers(cls) -> None: + """ + Display a table of available minimizers and their descriptions. + """ + header: List[str] = ["Minimizer", "Description"] + table_data: List[List[str]] = [] for name, config in cls._available_minimizers.items(): - description = config.get('description', 'No description provided.') + description: str = config.get('description', 'No description provided.') table_data.append([name, description]) print(paragraph("Available minimizers")) @@ -57,22 +69,43 @@ def show_available_minimizers(cls): )) @classmethod - def create_minimizer(cls, selection: str): + def create_minimizer(cls, selection: str) -> MinimizerBase: + """ + Create a minimizer instance based on the selection. + + Args: + selection: The name of the minimizer to create. + + Returns: + An instance of the selected minimizer. + + Raises: + ValueError: If the selection is not a valid minimizer. + """ config = cls._available_minimizers.get(selection) if not config: raise ValueError(f"Unknown minimizer '{selection}'. Use one of {cls.list_available_minimizers()}") - minimizer_class = config.get('class') - method = config.get('method') + minimizer_class: Type[MinimizerBase] = config.get('class') + method: Optional[str] = config.get('method') - kwargs = {} + kwargs: Dict[str, Any] = {} if method is not None: kwargs['method'] = method return minimizer_class(**kwargs) @classmethod - def register_minimizer(cls, name, minimizer_cls, method=None, description='No description provided.'): + def register_minimizer(cls, name: str, minimizer_cls: Type[MinimizerBase], method: Optional[str] = None, description: str = 'No description provided.') -> None: + """ + Register a new minimizer. + + Args: + name: The name of the minimizer. + minimizer_cls: The class of the minimizer. + method: The method used by the minimizer (optional). + description: A description of the minimizer. + """ cls._available_minimizers[name] = { 'engine': name, 'method': method, diff --git a/src/easydiffraction/analysis/minimizers/minimizer_lmfit.py b/src/easydiffraction/analysis/minimizers/minimizer_lmfit.py index 0724f034..9f87d33b 100644 --- a/src/easydiffraction/analysis/minimizers/minimizer_lmfit.py +++ b/src/easydiffraction/analysis/minimizers/minimizer_lmfit.py @@ -1,23 +1,34 @@ import lmfit from .minimizer_base import MinimizerBase +from typing import Any, Dict, List DEFAULT_METHOD = 'leastsq' DEFAULT_MAX_ITERATIONS = 1000 + class LmfitMinimizer(MinimizerBase): """ Minimizer using the lmfit package. """ def __init__(self, - name='lmfit', - method=DEFAULT_METHOD, - max_iterations=DEFAULT_MAX_ITERATIONS): + name: str = 'lmfit', + method: str = DEFAULT_METHOD, + max_iterations: int = DEFAULT_MAX_ITERATIONS) -> None: super().__init__(name=name, method=method, max_iterations=max_iterations) - def _prepare_solver_args(self, parameters): + def _prepare_solver_args(self, parameters: List[Any]) -> Dict[str, Any]: + """ + Prepares the solver arguments for the lmfit minimizer. + + Args: + parameters: List of parameters to be optimized. + + Returns: + A dictionary containing the prepared lmfit.Parameters object. + """ engine_parameters = lmfit.Parameters() for param in parameters: engine_parameters.add( @@ -30,20 +41,36 @@ def _prepare_solver_args(self, parameters): return {'engine_parameters': engine_parameters} def _run_solver(self, - objective_function, - **kwargs): + objective_function: Any, + **kwargs: Any) -> Any: + """ + Runs the lmfit solver. + + Args: + objective_function: The objective function to minimize. + **kwargs: Additional arguments for the solver. + + Returns: + The result of the lmfit minimization. + """ engine_parameters = kwargs.get('engine_parameters') return lmfit.minimize(objective_function, params=engine_parameters, method=self.method, - #iter_cb=self._iteration_callback, nan_policy='propagate', max_nfev=self.max_iterations) def _sync_result_to_parameters(self, - parameters, - raw_result): + parameters: List[Any], + raw_result: Any) -> None: + """ + Synchronizes the result from the solver to the parameters. + + Args: + parameters: List of parameters being optimized. + raw_result: The result object returned by the solver. + """ if hasattr(raw_result, 'params'): param_values = raw_result.params else: @@ -55,18 +82,32 @@ def _sync_result_to_parameters(self, param.value = param_result.value param.uncertainty = getattr(param_result, 'stderr', None) - def _check_success(self, raw_result): + def _check_success(self, raw_result: Any) -> bool: """ Determines success from lmfit MinimizerResult. + + Args: + raw_result: The result object returned by the solver. + + Returns: + True if the optimization was successful, False otherwise. """ return getattr(raw_result, 'success', False) - def _iteration_callback(self, params, iter, resid, *args, **kwargs): - # Temporary do not use this callback, as trying to track both the - # iteration number and chi-square using _track_chi_square method. - # Results are a bit different, so need to investigate further. - # _track_chi_square is used because DFO-LS minimizer seems to - # not provide the way to call _iteration_callback + def _iteration_callback(self, + params: lmfit.Parameters, + iter: int, + resid: Any, + *args: Any, + **kwargs: Any) -> None: + """ + Callback function for each iteration of the minimizer. + + Args: + params: The current parameters. + iter: The current iteration number. + resid: The residuals. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + """ self._iteration = iter - #red_chi2 = np.sum(resid**2) / (len(resid) - len(self.parameters)) - #print(f"🔄 Iteration {iter}: Reduced Chi-square = {red_chi2:.2f}") diff --git a/src/easydiffraction/analysis/reliability_factors.py b/src/easydiffraction/analysis/reliability_factors.py index 83065106..6b23f008 100644 --- a/src/easydiffraction/analysis/reliability_factors.py +++ b/src/easydiffraction/analysis/reliability_factors.py @@ -1,14 +1,37 @@ import numpy as np +from typing import Tuple, Any, Optional -def calculate_r_factor(y_obs, y_calc): +def calculate_r_factor(y_obs: np.ndarray, y_calc: np.ndarray) -> float: + """ + Calculate the R-factor (reliability factor) between observed and calculated data. + + Args: + y_obs: Observed data points. + y_calc: Calculated data points. + + Returns: + R-factor value. + """ y_obs = np.asarray(y_obs) y_calc = np.asarray(y_calc) numerator = np.sum(np.abs(y_obs - y_calc)) denominator = np.sum(np.abs(y_obs)) return numerator / denominator if denominator != 0 else np.nan -def calculate_weighted_r_factor(y_obs, y_calc, weights): + +def calculate_weighted_r_factor(y_obs: np.ndarray, y_calc: np.ndarray, weights: np.ndarray) -> float: + """ + Calculate the weighted R-factor between observed and calculated data. + + Args: + y_obs: Observed data points. + y_calc: Calculated data points. + weights: Weights for each data point. + + Returns: + Weighted R-factor value. + """ y_obs = np.asarray(y_obs) y_calc = np.asarray(y_calc) weights = np.asarray(weights) @@ -16,21 +39,54 @@ def calculate_weighted_r_factor(y_obs, y_calc, weights): denominator = np.sum(weights * y_obs ** 2) return np.sqrt(numerator / denominator) if denominator != 0 else np.nan -def calculate_rb_factor(y_obs, y_calc): + +def calculate_rb_factor(y_obs: np.ndarray, y_calc: np.ndarray) -> float: + """ + Calculate the Bragg R-factor between observed and calculated data. + + Args: + y_obs: Observed data points. + y_calc: Calculated data points. + + Returns: + Bragg R-factor value. + """ y_obs = np.asarray(y_obs) y_calc = np.asarray(y_calc) numerator = np.sum(np.abs(y_obs - y_calc)) denominator = np.sum(y_obs) return numerator / denominator if denominator != 0 else np.nan -def calculate_r_factor_squared(y_obs, y_calc): + +def calculate_r_factor_squared(y_obs: np.ndarray, y_calc: np.ndarray) -> float: + """ + Calculate the R-factor squared between observed and calculated data. + + Args: + y_obs: Observed data points. + y_calc: Calculated data points. + + Returns: + R-factor squared value. + """ y_obs = np.asarray(y_obs) y_calc = np.asarray(y_calc) numerator = np.sum((y_obs - y_calc) ** 2) denominator = np.sum(y_obs ** 2) return np.sqrt(numerator / denominator) if denominator != 0 else np.nan -def calculate_reduced_chi_square(residuals, num_parameters): + +def calculate_reduced_chi_square(residuals: np.ndarray, num_parameters: int) -> float: + """ + Calculate the reduced chi-square statistic. + + Args: + residuals: Residuals between observed and calculated data. + num_parameters: Number of free parameters used in the model. + + Returns: + Reduced chi-square value. + """ residuals = np.asarray(residuals) chi_square = np.sum(residuals ** 2) n_points = len(residuals) @@ -40,7 +96,19 @@ def calculate_reduced_chi_square(residuals, num_parameters): else: return np.nan -def get_reliability_inputs(sample_models, experiments, calculator): + +def get_reliability_inputs(sample_models: Any, experiments: Any, calculator: Any) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: + """ + Collect observed and calculated data points for reliability calculations. + + Args: + sample_models: Collection of sample models. + experiments: Collection of experiments. + calculator: The calculator to use for pattern generation. + + Returns: + Tuple containing arrays of (observed values, calculated values, error values) + """ y_obs_all = [] y_calc_all = [] y_err_all = [] diff --git a/src/easydiffraction/core/objects.py b/src/easydiffraction/core/objects.py index 5dbbc915..664ef612 100644 --- a/src/easydiffraction/core/objects.py +++ b/src/easydiffraction/core/objects.py @@ -4,12 +4,14 @@ ABC, abstractmethod ) +from typing import Any, Dict, List, Optional, Union, Iterator, TypeVar from easydiffraction.utils.formatting import ( warning, error ) +T = TypeVar('T') class Descriptor: """ @@ -17,34 +19,34 @@ class Descriptor: """ def __init__(self, - value, # Value of the parameter - name, # ED parameter name (to access it in the code) - cif_name, # CIF parameter name (to show it in the CIF) - pretty_name=None, # Pretty name (to show it in the table) - datablock_id=None, # Parent datablock name - category_key=None, # ED parent category name - cif_category_key=None, # CIF parent category name - collection_entry_id=None, # Parent collection entry id - units=None, # Units of the parameter - description=None, # Description of the parameter - editable=True # If false, the parameter can never be edited. It is calculated automatically - ): + value: Any, # Value of the parameter + name: str, # ED parameter name (to access it in the code) + cif_name: str, # CIF parameter name (to show it in the CIF) + pretty_name: Optional[str] = None, # Pretty name (to show it in the table) + datablock_id: Optional[str] = None, # Parent datablock name + category_key: Optional[str] = None, # ED parent category name + cif_category_key: Optional[str] = None, # CIF parent category name + collection_entry_id: Optional[str] = None, # Parent collection entry id + units: Optional[str] = None, # Units of the parameter + description: Optional[str] = None, # Description of the parameter + editable: bool = True # If false, the parameter can never be edited. It is calculated automatically + ) -> None: self._value = value - self.name = name - self.cif_name = cif_name - self.pretty_name = pretty_name, - self.datablock_id = datablock_id - self.category_key = category_key, - self.cif_category_key = cif_category_key - self.collection_entry_id = collection_entry_id - self.units = units - self._description = description - self._editable = editable - - self.uid = self._generate_unique_id() - - def _generate_unique_id(self): + self.name: str = name + self.cif_name: str = cif_name + self.pretty_name: Optional[str] = pretty_name + self.datablock_id: Optional[str] = datablock_id + self.category_key: Optional[str] = category_key + self.cif_category_key: Optional[str] = cif_category_key + self.collection_entry_id: Optional[str] = collection_entry_id + self.units: Optional[str] = units + self._description: Optional[str] = description + self._editable: bool = editable + + self.uid: str = self._generate_unique_id() + + def _generate_unique_id(self) -> str: # Derived class Parameter will use this unique id for the # minimization process to identify the parameter. # TODO: Instead of generating a random string, we can use the @@ -58,47 +60,48 @@ def _generate_unique_id(self): return uid @property - def value(self): + def value(self) -> Any: return self._value @value.setter - def value(self, new_value): + def value(self, new_value: Any) -> None: if self._editable: self._value = new_value else: print(warning(f"The parameter '{self.cif_name}' it is calculated automatically and cannot be changed manually.")) @property - def description(self): + def description(self) -> Optional[str]: return self._description @property - def editable(self): + def editable(self) -> bool: return self._editable + class Parameter(Descriptor): """ A parameter with a value, uncertainty, units, and CIF representation. """ def __init__(self, - value, - name, - cif_name, - pretty_name=None, - datablock_id=None, - category_key=None, - cif_category_key=None, - collection_entry_id=None, - units=None, - description=None, - editable=True, - uncertainty=0.0, - free=False, - constrained=False, - min_value=None, - max_value=None, - ): + value: Any, + name: str, + cif_name: str, + pretty_name: Optional[str] = None, + datablock_id: Optional[str] = None, + category_key: Optional[str] = None, + cif_category_key: Optional[str] = None, + collection_entry_id: Optional[str] = None, + units: Optional[str] = None, + description: Optional[str] = None, + editable: bool = True, + uncertainty: float = 0.0, + free: bool = False, + constrained: bool = False, + min_value: Optional[float] = None, + max_value: Optional[float] = None, + ) -> None: super().__init__(value, name, cif_name, @@ -110,11 +113,12 @@ def __init__(self, units, description, editable) - self.uncertainty = uncertainty # Standard uncertainty or estimated standard deviation - self.free = free # If the parameter is free to be fitted during the optimization - self.constrained = constrained # If symmetry constrains the parameter during the optimization - self.min = min_value # Minimum physical value of the parameter - self.max = max_value # Maximum physical value of the parameter + self.uncertainty: float = uncertainty # Standard uncertainty or estimated standard deviation + self.free: bool = free # If the parameter is free to be fitted during the optimization + self.constrained: bool = constrained # If symmetry constrains the parameter during the optimization + self.min: Optional[float] = min_value # Minimum physical value of the parameter + self.max: Optional[float] = max_value # Maximum physical value of the parameter + self.start_value: Optional[Any] = None # Starting value for optimization class Component(ABC): @@ -124,12 +128,12 @@ class Component(ABC): @property @abstractmethod - def _entry_id(self): + def _entry_id(self) -> str: pass @property @abstractmethod - def cif_category_key(self): + def cif_category_key(self) -> str: """ Must be implemented in subclasses to return the CIF category name. """ @@ -137,21 +141,21 @@ def cif_category_key(self): @property @abstractmethod - def category_key(self): + def category_key(self) -> str: """ Must be implemented in subclasses to return the ED category name. Can differ from cif_category_key. """ pass - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self._locked = False # If adding new attributes is locked + self._locked: bool = False # If adding new attributes is locked # TODO: Currently, it is not used. Planned to be used for displaying # the parameters in the specific order. - self._ordered_attrs = [] + self._ordered_attrs: List[str] = [] - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """ If the attribute is a Parameter or Descriptor, return its value by default """ @@ -160,7 +164,7 @@ def __getattr__(self, name): return attr.value raise AttributeError(f"{name} not found in {self}") - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> None: """ If an object is locked for adding new attributes, raise an error. If the attribute 'name' does not exist, add it. @@ -187,7 +191,7 @@ def __setattr__(self, name, value): if isinstance(attr, (Descriptor, Parameter)): attr.value = value - def parameters(self): + def parameters(self) -> List[Union[Descriptor, Parameter]]: attr_objs = [] for attr_name in dir(self): attr_obj = getattr(self, attr_name) @@ -195,7 +199,7 @@ def parameters(self): attr_objs.append(attr_obj) return attr_objs - def as_dict(self): + def as_dict(self) -> Dict[str, Any]: d = {} for attr_name in dir(self): @@ -212,7 +216,7 @@ def as_dict(self): return d - def as_cif(self): + def as_cif(self) -> str: if not self.cif_category_key: raise ValueError("cif_category_key must be defined in the derived class.") @@ -247,15 +251,15 @@ class Collection: """ def __init__(self): - self._items = {} + self._items: Dict[str, Union[Component, 'Collection']] = {} - def __getitem__(self, key): + def __getitem__(self, key: str) -> Union[Component, 'Collection']: return self._items[key] - def __iter__(self): + def __iter__(self) -> Iterator[Union[Component, 'Collection']]: return iter(self._items.values()) - def get_all_params(self): + def get_all_params(self) -> List[Parameter]: params = [] for datablock in self._items.values(): for component in datablock.components(): @@ -277,7 +281,7 @@ def get_all_params(self): return params - def get_fittable_params(self): + def get_fittable_params(self) -> List[Parameter]: all_params = self.get_all_params() params = [] for param in all_params: @@ -285,7 +289,7 @@ def get_fittable_params(self): params.append(param) return params - def get_free_params(self): + def get_free_params(self) -> List[Parameter]: fittable_params = self.get_fittable_params() params = [] for param in fittable_params: @@ -293,7 +297,7 @@ def get_free_params(self): params.append(param) return params - def as_cif(self): + def as_cif(self) -> str: lines = [] if self._type == "category": for idx, item in enumerate(self._items.values()): @@ -316,7 +320,7 @@ class Datablock(ABC): """ # TODO: Consider unifying with class Component? - def components(self): + def components(self) -> List[Union[Component, Collection]]: """ Returns a list of both standard and iterable components in the data block. diff --git a/src/easydiffraction/core/singletons.py b/src/easydiffraction/core/singletons.py index 122a2908..e8cbb7dc 100644 --- a/src/easydiffraction/core/singletons.py +++ b/src/easydiffraction/core/singletons.py @@ -1,5 +1,7 @@ +from typing import Dict, List, Tuple, Any, TypeVar, Type, Optional from asteval import Interpreter +T = TypeVar('T', bound='BaseSingleton') class BaseSingleton: """Base class to implement Singleton pattern. @@ -11,7 +13,7 @@ class BaseSingleton: _instance = None # Class-level shared instance @classmethod - def get(cls): + def get(cls: Type[T]) -> T: """Returns the shared instance, creating it if needed.""" if cls._instance is None: cls._instance = cls() @@ -21,15 +23,15 @@ def get(cls): class UidMapHandler(BaseSingleton): """Global handler to manage UID-to-Parameter object mapping.""" - def __init__(self): + def __init__(self) -> None: # Internal map: uid (str) → Parameter instance - self._uid_map = {} + self._uid_map: Dict[str, Any] = {} - def get_uid_map(self): + def get_uid_map(self) -> Dict[str, Any]: """Returns the current UID-to-Parameter map.""" return self._uid_map - def set_uid_map(self, parameters: list): + def set_uid_map(self, parameters: List[Any]) -> None: """Populates the UID map from a list of Parameter objects.""" self._uid_map = {param.uid: param for param in parameters} @@ -43,18 +45,18 @@ class ConstraintsHandler(BaseSingleton): Constraints are defined as: lhs_alias = expression(rhs_aliases). """ - def __init__(self): + def __init__(self) -> None: # Maps alias names (like 'biso_La') → ConstraintAlias(param=Parameter) - self._alias_to_param = {} + self._alias_to_param: Dict[str, Any] = {} # Stores raw user-defined expressions indexed by ID # Each value should contain: lhs_alias, rhs_expr - self._expressions = {} + self._expressions: Dict[str, Any] = {} # Internally parsed constraints as (lhs_alias, rhs_expr) tuples - self._parsed_constraints = [] + self._parsed_constraints: List[Tuple[str, str]] = [] - def set_aliases(self, constraint_aliases): + def set_aliases(self, constraint_aliases: Any) -> None: """ Sets the alias map (name → parameter wrapper). Called when user registers parameter aliases like: @@ -62,7 +64,7 @@ def set_aliases(self, constraint_aliases): """ self._alias_to_param = constraint_aliases._items - def set_expressions(self, constraint_expressions): + def set_expressions(self, constraint_expressions: Any) -> None: """ Sets the constraint expressions and triggers parsing into internal format. Called when user registers expressions like: @@ -71,7 +73,7 @@ def set_expressions(self, constraint_expressions): self._expressions = constraint_expressions._items self._parse_constraints() - def _parse_constraints(self): + def _parse_constraints(self) -> None: """ Converts raw expression input into a normalized internal list of (lhs_alias, rhs_expr) pairs, stripping whitespace and skipping invalid entries. @@ -86,7 +88,7 @@ def _parse_constraints(self): constraint = (lhs_alias.strip(), rhs_expr.strip()) self._parsed_constraints.append(constraint) - def apply(self, parameters: list): + def apply(self, parameters: List[Any]) -> None: """Evaluates constraints and applies them to dependent parameters. For each constraint: @@ -101,7 +103,7 @@ def apply(self, parameters: list): uid_map = UidMapHandler.get().get_uid_map() # Prepare a flat dict of {alias_name: value} for use in expressions - param_values = { + param_values: Dict[str, Any] = { alias: alias_obj.param.value for alias, alias_obj in self._alias_to_param.items() } diff --git a/src/easydiffraction/crystallography/crystallography.py b/src/easydiffraction/crystallography/crystallography.py index 45b0f41f..20e7b79d 100644 --- a/src/easydiffraction/crystallography/crystallography.py +++ b/src/easydiffraction/crystallography/crystallography.py @@ -1,7 +1,10 @@ +from typing import Dict, List, Optional, Any from sympy import ( symbols, sympify, - simplify + simplify, + Symbol, + Expr ) from cryspy.A_functions_base.function_2_space_group import ( @@ -11,8 +14,18 @@ ) -def apply_cell_symmetry_constraints(cell: dict, - name_hm: str) -> dict: +def apply_cell_symmetry_constraints(cell: Dict[str, float], + name_hm: str) -> Dict[str, float]: + """ + Apply symmetry constraints to unit cell parameters based on space group. + + Args: + cell: Dictionary containing lattice parameters. + name_hm: Hermann-Mauguin symbol of the space group. + + Returns: + The cell dictionary with applied symmetry constraints. + """ it_number = get_it_number_by_name_hm_short(name_hm) if it_number is None: error_msg = f"Failed to get IT_number for name_H-M '{name_hm}'" @@ -66,11 +79,22 @@ def apply_cell_symmetry_constraints(cell: dict, return cell -def apply_atom_site_symmetry_constraints(atom_site: dict, +def apply_atom_site_symmetry_constraints(atom_site: Dict[str, Any], name_hm: str, - coord_code, - wyckoff_letter: str) -> dict: - + coord_code: int, + wyckoff_letter: str) -> Dict[str, Any]: + """ + Apply symmetry constraints to atomic coordinates based on site symmetry. + + Args: + atom_site: Dictionary containing atom position data. + name_hm: Hermann-Mauguin symbol of the space group. + coord_code: Coordinate system code. + wyckoff_letter: Wyckoff position letter. + + Returns: + The atom_site dictionary with applied symmetry constraints. + """ it_number = get_it_number_by_name_hm_short(name_hm) if it_number is None: error_msg = f"Failed to get IT_number for name_H-M '{name_hm}'" @@ -91,29 +115,29 @@ def apply_atom_site_symmetry_constraints(atom_site: dict, #return atom_site # 2 - NOT OK - letter_list = result[3] - coords_xyz_list = result[5] + letter_list: List[str] = result[3] + coords_xyz_list: List[List[str]] = result[5] idx = letter_list.index(wyckoff_letter) coords_xyz = coords_xyz_list[idx] first_position = coords_xyz[0] components = first_position.strip("()").split(",") - parsed_exprs = [sympify(comp.strip()) for comp in components] + parsed_exprs: List[Expr] = [sympify(comp.strip()) for comp in components] - x_val = sympify(atom_site["fract_x"]) - y_val = sympify(atom_site["fract_y"]) - z_val = sympify(atom_site["fract_z"]) + x_val: Expr = sympify(atom_site["fract_x"]) + y_val: Expr = sympify(atom_site["fract_y"]) + z_val: Expr = sympify(atom_site["fract_z"]) - substitutions = { + substitutions: Dict[str, Expr] = { "x": x_val, "y": y_val, "z": z_val } - axes = ("x", "y", "z") + axes: tuple[str, ...] = ("x", "y", "z") x, y, z = symbols("x y z") - symbols_xyz = (x, y, z) + symbols_xyz: tuple[Symbol, ...] = (x, y, z) for i, axis in enumerate(axes): symbol = symbols_xyz[i] diff --git a/src/easydiffraction/experiments/collections/background.py b/src/easydiffraction/experiments/collections/background.py index c6ea0ce8..2a366a09 100644 --- a/src/easydiffraction/experiments/collections/background.py +++ b/src/easydiffraction/experiments/collections/background.py @@ -2,6 +2,7 @@ import tabulate from abc import ABC, abstractmethod +from typing import Dict, List, Any, Type, Union from numpy.polynomial.chebyshev import chebval from scipy.interpolate import interp1d @@ -19,7 +20,7 @@ class Point(Component): - def __init__(self, x: float, y: float): + def __init__(self, x: float, y: float) -> None: super().__init__() self.x = Descriptor( @@ -38,20 +39,20 @@ def __init__(self, x: float, y: float): self._locked = True # Lock further attribute additions @property - def cif_category_key(self): + def cif_category_key(self) -> str: return "pd_background" @property - def category_key(self): + def category_key(self) -> str: return "background" @property - def _entry_id(self): + def _entry_id(self) -> str: return f"{self.x.value}" class PolynomialTerm(Component): - def __init__(self, order, coef): + def __init__(self, order: int, coef: float) -> None: super().__init__() self.order = Descriptor( @@ -70,48 +71,48 @@ def __init__(self, order, coef): self._locked = True # Lock further attribute additions @property - def cif_category_key(self): + def cif_category_key(self) -> str: return "pd_background" @property - def category_key(self): + def category_key(self) -> str: return "background" @property - def _entry_id(self): + def _entry_id(self) -> str: return f"{self.order.value}" class BackgroundBase(Collection): @property - def _type(self): + def _type(self) -> str: return "category" # datablock or category @abstractmethod - def add(self, *args): + def add(self, *args: Any) -> None: pass @abstractmethod - def calculate(self, x_data): + def calculate(self, x_data: np.ndarray) -> np.ndarray: pass @abstractmethod - def show(self): + def show(self) -> None: pass class LineSegmentBackground(BackgroundBase): - _description = 'Linear interpolation between points' + _description: str = 'Linear interpolation between points' - def __init__(self): + def __init__(self) -> None: super().__init__() - def add(self, x, y): + def add(self, x: float, y: float) -> None: """Add a background point.""" point = Point(x=x, y=y) self._items[point._entry_id] = point - def calculate(self, x_data): + def calculate(self, x_data: np.ndarray) -> np.ndarray: """Interpolate background points over x_data.""" if not self._items: print(warning('No background points found. Setting background to zero.')) @@ -128,9 +129,9 @@ def calculate(self, x_data): y_data = interp_func(x_data) return y_data - def show(self): - header = ["X", "Intensity"] - table_data = [] + def show(self) -> None: + header: List[str] = ["X", "Intensity"] + table_data: List[List[float]] = [] for point in self._items.values(): x = point.x.value @@ -149,17 +150,17 @@ def show(self): class ChebyshevPolynomialBackground(BackgroundBase): - _description = 'Chebyshev polynomial background' + _description: str = 'Chebyshev polynomial background' - def __init__(self): + def __init__(self) -> None: super().__init__() - def add(self, order, coef): + def add(self, order: int, coef: float) -> None: """Add a polynomial term as (order, coefficient).""" term = PolynomialTerm(order=order, coef=coef) self._items[term._entry_id] = term - def calculate(self, x_data): + def calculate(self, x_data: np.ndarray) -> np.ndarray: """Evaluate polynomial background over x_data.""" if not self._items: print(warning('No background points found. Setting background to zero.')) @@ -170,9 +171,9 @@ def calculate(self, x_data): y_data = chebval(u, coefs) return y_data - def show(self): - header = ["Order", "Coefficient"] - table_data = [] + def show(self) -> None: + header: List[str] = ["Order", "Coefficient"] + table_data: List[List[Union[int, float]]] = [] for term in self._items.values(): order = term.order.value @@ -191,14 +192,13 @@ def show(self): class BackgroundFactory: - _supported = { + _supported: Dict[str, Type[BackgroundBase]] = { "line-segment": LineSegmentBackground, "chebyshev polynomial": ChebyshevPolynomialBackground } @classmethod - def create(cls, - background_type=DEFAULT_BACKGROUND_TYPE): + def create(cls, background_type: str = DEFAULT_BACKGROUND_TYPE) -> BackgroundBase: if background_type not in cls._supported: supported_types = list(cls._supported.keys()) diff --git a/src/easydiffraction/experiments/collections/datastore.py b/src/easydiffraction/experiments/collections/datastore.py index 69f663ce..1e49a354 100644 --- a/src/easydiffraction/experiments/collections/datastore.py +++ b/src/easydiffraction/experiments/collections/datastore.py @@ -1,3 +1,4 @@ +from typing import Optional, Any import numpy as np @@ -7,23 +8,23 @@ class Pattern: Stores x, measured intensities, uncertainties, background, and calculated intensities. """ - def __init__(self, experiment): - self.experiment = experiment + def __init__(self, experiment: Any) -> None: + self.experiment: Any = experiment # Data arrays - self.x = None - self.meas = None - self.meas_su = None - self.bkg = None - self._calc = None # Cached calculated intensities + self.x: Optional[np.ndarray] = None + self.meas: Optional[np.ndarray] = None + self.meas_su: Optional[np.ndarray] = None + self.bkg: Optional[np.ndarray] = None + self._calc: Optional[np.ndarray] = None # Cached calculated intensities @property - def calc(self): + def calc(self) -> Optional[np.ndarray]: """Access calculated intensities. Should be updated via external calculation.""" return self._calc @calc.setter - def calc(self, values): + def calc(self, values: np.ndarray) -> None: """Set calculated intensities (from Analysis.calculate_pattern()).""" self._calc = values @@ -32,7 +33,7 @@ class PowderPattern(Pattern): """ Specialized pattern for powder diffraction (can be extended in the future). """ - def __init__(self, experiment): + def __init__(self, experiment: Any) -> None: super().__init__(experiment) # Additional powder-specific initialization if needed @@ -42,22 +43,22 @@ class Datastore: Stores pattern data (measured and calculated) for an experiment. """ - def __init__(self, sample_form: str, experiment): - self.sample_form = sample_form + def __init__(self, sample_form: str, experiment: Any) -> None: + self.sample_form: str = sample_form if sample_form == "powder": - self.pattern = PowderPattern(experiment) + self.pattern: Pattern = PowderPattern(experiment) elif sample_form == "single_crystal": - self.pattern = Pattern(experiment) + self.pattern: Pattern = Pattern(experiment) else: raise ValueError(f"Unknown sample form '{sample_form}'") - def load_measured_data(self, file_path): + def load_measured_data(self, file_path: str) -> None: """Load measured data from an ASCII file.""" print(f"Loading measured data for {self.sample_form} diffraction from {file_path}") try: - data = np.loadtxt(file_path) + data: np.ndarray = np.loadtxt(file_path) except Exception as e: print(f"Failed to load data: {e}") return @@ -65,9 +66,9 @@ def load_measured_data(self, file_path): if data.shape[1] < 2: raise ValueError("Data file must have at least two columns (x and y).") - x = data[:, 0] - y = data[:, 1] - sy = data[:, 2] if data.shape[1] > 2 else np.sqrt(np.abs(y)) + x: np.ndarray = data[:, 0] + y: np.ndarray = data[:, 1] + sy: np.ndarray = data[:, 2] if data.shape[1] > 2 else np.sqrt(np.abs(y)) self.pattern.x = x self.pattern.meas = y @@ -75,14 +76,14 @@ def load_measured_data(self, file_path): print(f"Loaded {len(x)} points for experiment '{self.pattern.experiment.name}'.") - def show_measured_data(self): + def show_measured_data(self) -> None: """Display measured data in console.""" print(f"\nMeasured data ({self.sample_form}):") print(f"x: {self.pattern.x}") print(f"meas: {self.pattern.meas}") print(f"meas_su: {self.pattern.meas_su}") - def show_calculated_data(self): + def show_calculated_data(self) -> None: """Display calculated data in console.""" print(f"\nCalculated data ({self.sample_form}):") print(f"calc: {self.pattern.calc}") @@ -94,8 +95,15 @@ class DatastoreFactory: """ @staticmethod - def create(sample_form: str, experiment): + def create(sample_form: str, experiment: Any) -> Datastore: """ Create a datastore object depending on the sample form. + + Args: + sample_form: The form of the sample ("powder" or "single_crystal"). + experiment: The experiment object. + + Returns: + A new Datastore instance appropriate for the sample form. """ return Datastore(sample_form, experiment) \ No newline at end of file diff --git a/src/easydiffraction/experiments/collections/linked_phases.py b/src/easydiffraction/experiments/collections/linked_phases.py index 8d7370f1..5e83a0c5 100644 --- a/src/easydiffraction/experiments/collections/linked_phases.py +++ b/src/easydiffraction/experiments/collections/linked_phases.py @@ -7,7 +7,7 @@ class LinkedPhase(Component): - def __init__(self, id: str, scale: float): + def __init__(self, id: str, scale: float) -> None: super().__init__() self.id = Descriptor( @@ -24,15 +24,15 @@ def __init__(self, id: str, scale: float): self._locked = True # Lock further attribute additions @property - def cif_category_key(self): + def cif_category_key(self) -> str: return "pd_phase_block" @property - def category_key(self): + def category_key(self) -> str: return "linked_phase" @property - def _entry_id(self): + def _entry_id(self) -> str: return self.id.value @@ -41,9 +41,9 @@ class LinkedPhases(Collection): Collection of LinkedPhase instances. """ @property - def _type(self): + def _type(self) -> str: return "category" # datablock or category - def add(self, id: str, scale: float): + def add(self, id: str, scale: float) -> None: phase = LinkedPhase(id, scale) self._items[phase.id.value] = phase diff --git a/src/easydiffraction/experiments/components/experiment_type.py b/src/easydiffraction/experiments/components/experiment_type.py index 90674fd8..15d8f17c 100644 --- a/src/easydiffraction/experiments/components/experiment_type.py +++ b/src/easydiffraction/experiments/components/experiment_type.py @@ -2,13 +2,14 @@ Descriptor, Component ) +from typing import Optional class ExperimentType(Component): def __init__(self, sample_form: str, beam_mode: str, - radiation_probe: str): + radiation_probe: str) -> None: super().__init__() self.sample_form: Descriptor = Descriptor( @@ -30,16 +31,16 @@ def __init__(self, description="Specifies whether the measurement uses neutrons or X-rays" ) - self._locked = True # Lock further attribute additions + self._locked: bool = True # Lock further attribute additions @property - def cif_category_key(self): + def cif_category_key(self) -> str: return "expt_type" @property - def category_key(self): + def category_key(self) -> str: return "expt_type" @property - def _entry_id(self): + def _entry_id(self) -> Optional[str]: return None \ No newline at end of file diff --git a/src/easydiffraction/experiments/components/instrument.py b/src/easydiffraction/experiments/components/instrument.py index 42bb770e..afd938dc 100644 --- a/src/easydiffraction/experiments/components/instrument.py +++ b/src/easydiffraction/experiments/components/instrument.py @@ -3,36 +3,37 @@ Component ) from easydiffraction.core.constants import DEFAULT_BEAM_MODE +from typing import Optional, Type, Dict class InstrumentBase(Component): @property - def category_key(self): + def category_key(self) -> str: return "instrument" @property - def cif_category_key(self): + def cif_category_key(self) -> str: return "instr" @property - def _entry_id(self): + def _entry_id(self) -> Optional[str]: return None class ConstantWavelengthInstrument(InstrumentBase): def __init__(self, - setup_wavelength=1.5406, - calib_twotheta_offset=0): + setup_wavelength: float = 1.5406, + calib_twotheta_offset: float = 0.0) -> None: super().__init__() - self.setup_wavelength = Parameter( + self.setup_wavelength: Parameter = Parameter( value=setup_wavelength, name="wavelength", cif_name="wavelength", units="Å", description="Incident neutron or X-ray wavelength" ) - self.calib_twotheta_offset = Parameter( + self.calib_twotheta_offset: Parameter = Parameter( value=calib_twotheta_offset, name="twotheta_offset", cif_name="2theta_offset", @@ -40,47 +41,47 @@ def __init__(self, description="Instrument misalignment offset" ) - self._locked = True # Lock further attribute additions + self._locked: bool = True # Lock further attribute additions class TimeOfFlightInstrument(InstrumentBase): def __init__(self, - setup_twotheta_bank=150.0, - calib_d_to_tof_offset=0.0, - calib_d_to_tof_linear=10000.0, - calib_d_to_tof_quad=-1.0, - calib_d_to_tof_recip=0.0): + setup_twotheta_bank: float = 150.0, + calib_d_to_tof_offset: float = 0.0, + calib_d_to_tof_linear: float = 10000.0, + calib_d_to_tof_quad: float = -1.0, + calib_d_to_tof_recip: float = 0.0) -> None: super().__init__() - self.setup_twotheta_bank = Parameter( + self.setup_twotheta_bank: Parameter = Parameter( value=setup_twotheta_bank, name="twotheta_bank", cif_name="2theta_bank", units="deg", description="Detector bank position" ) - self.calib_d_to_tof_offset = Parameter( + self.calib_d_to_tof_offset: Parameter = Parameter( value=calib_d_to_tof_offset, name="d_to_tof_offset", cif_name="d_to_tof_offset", units="µs", description="TOF offset" ) - self.calib_d_to_tof_linear = Parameter( + self.calib_d_to_tof_linear: Parameter = Parameter( value=calib_d_to_tof_linear, name="d_to_tof_linear", cif_name="d_to_tof_linear", units="µs/Å", description="TOF linear conversion" ) - self.calib_d_to_tof_quad = Parameter( + self.calib_d_to_tof_quad: Parameter = Parameter( value=calib_d_to_tof_quad, name="d_to_tof_quad", cif_name="d_to_tof_quad", units="µs/Ų", description="TOF quadratic correction" ) - self.calib_d_to_tof_recip = Parameter( + self.calib_d_to_tof_recip: Parameter = Parameter( value=calib_d_to_tof_recip, name="d_to_tof_recip", cif_name="d_to_tof_recip", @@ -88,17 +89,17 @@ def __init__(self, description="TOF reciprocal velocity correction" ) - self._locked = True # Lock further attribute additions + self._locked: bool = True # Lock further attribute additions class InstrumentFactory: - _supported = { + _supported: Dict[str, Type[InstrumentBase]] = { "constant wavelength": ConstantWavelengthInstrument, "time-of-flight": TimeOfFlightInstrument } @classmethod - def create(cls, beam_mode=DEFAULT_BEAM_MODE): + def create(cls, beam_mode: str = DEFAULT_BEAM_MODE) -> InstrumentBase: if beam_mode not in cls._supported: supported = list(cls._supported.keys()) @@ -107,6 +108,6 @@ def create(cls, beam_mode=DEFAULT_BEAM_MODE): f"Supported beam modes are: {supported}" ) - instrument_class = cls._supported[beam_mode] - instance = instrument_class() + instrument_class: Type[InstrumentBase] = cls._supported[beam_mode] + instance: InstrumentBase = instrument_class() return instance \ No newline at end of file diff --git a/src/easydiffraction/experiments/components/peak.py b/src/easydiffraction/experiments/components/peak.py index 4a0ddbd3..884eb024 100644 --- a/src/easydiffraction/experiments/components/peak.py +++ b/src/easydiffraction/experiments/components/peak.py @@ -6,10 +6,12 @@ DEFAULT_BEAM_MODE, DEFAULT_PEAK_PROFILE_TYPE ) +from typing import Dict, Type, Optional + # --- Mixins --- class ConstantWavelengthBroadeningMixin: - def _add_constant_wavelength_broadening(self): + def _add_constant_wavelength_broadening(self) -> None: self.broad_gauss_u: Parameter = Parameter( value=0.01, name="broad_gauss_u", @@ -48,7 +50,7 @@ def _add_constant_wavelength_broadening(self): class TimeOfFlightBroadeningMixin: - def _add_time_of_flight_broadening(self): + def _add_time_of_flight_broadening(self) -> None: self.broad_gauss_sigma_0: Parameter = Parameter( value=0.0, name="gauss_sigma_0", @@ -108,7 +110,7 @@ def _add_time_of_flight_broadening(self): class EmpiricalAsymmetryMixin: - def _add_empirical_asymmetry(self): + def _add_empirical_asymmetry(self) -> None: self.asym_empir_1: Parameter = Parameter( value=0.1, name="asym_empir_1", @@ -140,7 +142,7 @@ def _add_empirical_asymmetry(self): class FcjAsymmetryMixin: - def _add_fcj_asymmetry(self): + def _add_fcj_asymmetry(self) -> None: self.asym_fcj_1: Parameter = Parameter( value=0.01, name="asym_fcj_1", @@ -158,7 +160,7 @@ def _add_fcj_asymmetry(self): class IkedaCarpenterAsymmetryMixin: - def _add_ikeda_carpenter_asymmetry(self): + def _add_ikeda_carpenter_asymmetry(self) -> None: self.asym_alpha_0: Parameter = Parameter( value=0.01, name="asym_alpha_0", @@ -178,95 +180,90 @@ def _add_ikeda_carpenter_asymmetry(self): # --- Base peak class --- class PeakBase(Component): @property - def cif_category_key(self): + def cif_category_key(self) -> str: return "peak" @property - def category_key(self): + def category_key(self) -> str: return "peak" @property - def _entry_id(self): + def _entry_id(self) -> Optional[str]: return None # --- Derived peak classes --- -class ConstantWavelengthPseudoVoigt(PeakBase, - ConstantWavelengthBroadeningMixin): - _description = "Pseudo-Voigt profile" - def __init__(self): +class ConstantWavelengthPseudoVoigt(PeakBase, ConstantWavelengthBroadeningMixin): + _description: str = "Pseudo-Voigt profile" + + def __init__(self) -> None: super().__init__() self._add_constant_wavelength_broadening() - self._locked = True # Lock further attribute additions + self._locked: bool = True # Lock further attribute additions -class ConstantWavelengthSplitPseudoVoigt(PeakBase, - ConstantWavelengthBroadeningMixin, - EmpiricalAsymmetryMixin): - _description = "Split pseudo-Voigt profile" - def __init__(self): +class ConstantWavelengthSplitPseudoVoigt(PeakBase, ConstantWavelengthBroadeningMixin, EmpiricalAsymmetryMixin): + _description: str = "Split pseudo-Voigt profile" + + def __init__(self) -> None: super().__init__() self._add_constant_wavelength_broadening() self._add_empirical_asymmetry() - self._locked = True # Lock further attribute additions + self._locked: bool = True # Lock further attribute additions + +class ConstantWavelengthThompsonCoxHastings(PeakBase, ConstantWavelengthBroadeningMixin, FcjAsymmetryMixin): + _description: str = "Thompson-Cox-Hastings profile" -class ConstantWavelengthThompsonCoxHastings(PeakBase, - ConstantWavelengthBroadeningMixin, - FcjAsymmetryMixin): - _description = "Thompson-Cox-Hastings profile" - def __init__(self): + def __init__(self) -> None: super().__init__() self._add_constant_wavelength_broadening() self._add_fcj_asymmetry() - self._locked = True # Lock further attribute additions + self._locked: bool = True # Lock further attribute additions -class TimeOfFlightPseudoVoigt(PeakBase, - TimeOfFlightBroadeningMixin): - _description = "Pseudo-Voigt profile" - def __init__(self): +class TimeOfFlightPseudoVoigt(PeakBase, TimeOfFlightBroadeningMixin): + _description: str = "Pseudo-Voigt profile" + + def __init__(self) -> None: super().__init__() self._add_time_of_flight_broadening() - self._locked = True # Lock further attribute additions + self._locked: bool = True # Lock further attribute additions + +class TimeOfFlightIkedaCarpenter(PeakBase, TimeOfFlightBroadeningMixin, IkedaCarpenterAsymmetryMixin): + _description: str = "Ikeda-Carpenter profile" -class TimeOfFlightIkedaCarpenter(PeakBase, - TimeOfFlightBroadeningMixin, - IkedaCarpenterAsymmetryMixin): - _description = "Ikeda-Carpenter profile" - def __init__(self): + def __init__(self) -> None: super().__init__() self._add_time_of_flight_broadening() self._add_ikeda_carpenter_asymmetry() - self._locked = True # Lock further attribute additions + self._locked: bool = True # Lock further attribute additions -class TimeOfFlightPseudoVoigtIkedaCarpenter(PeakBase, - TimeOfFlightBroadeningMixin, - IkedaCarpenterAsymmetryMixin): - _description = "Pseudo-Voigt * Ikeda-Carpenter profile" - def __init__(self): +class TimeOfFlightPseudoVoigtIkedaCarpenter(PeakBase, TimeOfFlightBroadeningMixin, IkedaCarpenterAsymmetryMixin): + _description: str = "Pseudo-Voigt * Ikeda-Carpenter profile" + + def __init__(self) -> None: super().__init__() self._add_time_of_flight_broadening() self._add_ikeda_carpenter_asymmetry() - self._locked = True # Lock further attribute additions + self._locked: bool = True # Lock further attribute additions + +class TimeOfFlightPseudoVoigtBackToBackExponential(PeakBase, TimeOfFlightBroadeningMixin, IkedaCarpenterAsymmetryMixin): + _description: str = "Pseudo-Voigt * Back-to-Back Exponential profile" -class TimeOfFlightPseudoVoigtBackToBackExponential(PeakBase, - TimeOfFlightBroadeningMixin, - IkedaCarpenterAsymmetryMixin): - _description = "Pseudo-Voigt * Back-to-Back Exponential profile" - def __init__(self): + def __init__(self) -> None: super().__init__() self._add_time_of_flight_broadening() self._add_ikeda_carpenter_asymmetry() - self._locked = True # Lock further attribute additions + self._locked: bool = True # Lock further attribute additions # --- Peak factory --- class PeakFactory: - _supported = { + _supported: Dict[str, Dict[str, Type[PeakBase]]] = { "constant wavelength": { "pseudo-voigt": ConstantWavelengthPseudoVoigt, "split pseudo-voigt": ConstantWavelengthSplitPseudoVoigt, @@ -282,8 +279,8 @@ class PeakFactory: @classmethod def create(cls, - beam_mode=DEFAULT_BEAM_MODE, - profile_type=DEFAULT_PEAK_PROFILE_TYPE): + beam_mode: str = DEFAULT_BEAM_MODE, + profile_type: Optional[str] = DEFAULT_PEAK_PROFILE_TYPE) -> PeakBase: if beam_mode not in cls._supported: supported_beam_modes = list(cls._supported.keys()) @@ -300,5 +297,5 @@ def create(cls, f"Supported profiles are: {list(supported_types.keys())}" ) - peak_class = cls._supported[beam_mode][profile_type] + peak_class: Type[PeakBase] = cls._supported[beam_mode][profile_type] return peak_class() \ No newline at end of file diff --git a/src/easydiffraction/experiments/experiment.py b/src/easydiffraction/experiments/experiment.py index d6ea2299..c2f58fb0 100644 --- a/src/easydiffraction/experiments/experiment.py +++ b/src/easydiffraction/experiments/experiment.py @@ -2,6 +2,7 @@ import tabulate from abc import ABC, abstractmethod +from typing import Optional, List, Dict, Type from easydiffraction.experiments.components.experiment_type import ExperimentType from easydiffraction.experiments.components.instrument import InstrumentFactory @@ -30,18 +31,18 @@ class BaseExperiment(Datablock): def __init__(self, name: str, - type: ExperimentType): - self.name = name - self.type = type + type: ExperimentType) -> None: + self.name: str = name + self.type: ExperimentType = type self.instrument = InstrumentFactory.create(beam_mode=self.type.beam_mode.value) self.datastore = DatastoreFactory.create(sample_form=self.type.sample_form.value, experiment=self) - def as_cif(self, max_points=None): + def as_cif(self, max_points: Optional[int] = None) -> str: """ Generate CIF content by collecting values from all components. """ - lines = [f"data_{self.name}"] + lines: List[str] = [f"data_{self.name}"] # Experiment type if hasattr(self, "type"): @@ -74,13 +75,10 @@ def as_cif(self, max_points=None): lines.append(self.background.as_cif()) # Measured data - # TODO: This functionality should be moved to datastore.py - # TODO: We need meas_data component which will use datastore to extract data - # TODO: Datastore should be moved out of collections/ if hasattr(self, "datastore") and hasattr(self.datastore, "pattern"): lines.append("") lines.append("loop_") - category = '_pd_meas' # TODO: Add category to pattern component + category = '_pd_meas' attributes = ('2theta_scan', 'intensity_total', 'intensity_total_su') for attribute in attributes: lines.append(f"{category}.{attribute}") @@ -103,13 +101,13 @@ def as_cif(self, max_points=None): return "\n".join(lines) - def show_as_cif(self): - cif_text = self.as_cif(max_points=5) - lines = cif_text.splitlines() - max_width = max(len(line) for line in lines) - padded_lines = [f"│ {line.ljust(max_width)} │" for line in lines] - top = f"╒{'═' * (max_width + 2)}╕" - bottom = f"╘{'═' * (max_width + 2)}╛" + def show_as_cif(self) -> None: + cif_text: str = self.as_cif(max_points=5) + lines: List[str] = cif_text.splitlines() + max_width: int = max(len(line) for line in lines) + padded_lines: List[str] = [f"│ {line.ljust(max_width)} │" for line in lines] + top: str = f"╒{'═' * (max_width + 2)}╕" + bottom: str = f"╘{'═' * (max_width + 2)}╛" print(paragraph(f"Experiment 🔬 '{self.name}' as cif")) print(top) @@ -117,31 +115,32 @@ def show_as_cif(self): print(bottom) @abstractmethod - def _load_ascii_data_to_experiment(self, data_path): + def _load_ascii_data_to_experiment(self, data_path: str) -> None: pass @abstractmethod - def show_meas_chart(self, x_min=None, x_max=None): + def show_meas_chart(self, x_min: Optional[float] = None, x_max: Optional[float] = None) -> None: """ Abstract method to display data chart. Should be implemented in specific experiment mixins. """ raise NotImplementedError("show_meas_chart() must be implemented in the subclass") + class PowderExperiment(BaseExperiment): """Powder experiment class with specific attributes.""" def __init__(self, name: str, - type: ExperimentType): + type: ExperimentType) -> None: super().__init__(name=name, type=type) - self._peak_profile_type = DEFAULT_PEAK_PROFILE_TYPE - self._background_type = DEFAULT_BACKGROUND_TYPE + self._peak_profile_type: str = DEFAULT_PEAK_PROFILE_TYPE + self._background_type: str = DEFAULT_BACKGROUND_TYPE self.peak = PeakFactory.create(beam_mode=self.type.beam_mode.value) self.linked_phases = LinkedPhases() self.background = BackgroundFactory.create() - def _load_ascii_data_to_experiment(self, data_path): + def _load_ascii_data_to_experiment(self, data_path: str) -> None: """ Loads x, y, sy values from an ASCII data file into the experiment. @@ -160,9 +159,9 @@ def _load_ascii_data_to_experiment(self, data_path): print("Warning: No uncertainty (sy) column provided. Defaulting to sqrt(y).") # Extract x, y, and sy data - x = data[:, 0] - y = data[:, 1] - sy = data[:, 2] if data.shape[1] > 2 else np.sqrt(y) + x: np.ndarray = data[:, 0] + y: np.ndarray = data[:, 1] + sy: np.ndarray = data[:, 2] if data.shape[1] > 2 else np.sqrt(y) # Attach the data to the experiment's datastore self.datastore.pattern.x = x @@ -172,7 +171,7 @@ def _load_ascii_data_to_experiment(self, data_path): print(paragraph("Data loaded successfully")) print(f"Experiment 🔬 '{self.name}'. Number of data points: {len(x)}") - def show_meas_chart(self, x_min=None, x_max=None): + def show_meas_chart(self, x_min: Optional[float] = None, x_max: Optional[float] = None) -> None: pattern = self.datastore.pattern if pattern.meas is None or pattern.x is None: @@ -190,11 +189,11 @@ def show_meas_chart(self, x_min=None, x_max=None): ) @property - def peak_profile_type(self): + def peak_profile_type(self) -> str: return self._peak_profile_type @peak_profile_type.setter - def peak_profile_type(self, new_type: str): + def peak_profile_type(self, new_type: str) -> None: if new_type not in PeakFactory._supported[self.type.beam_mode.value]: supported_types = list(PeakFactory._supported[self.type.beam_mode.value].keys()) print(warning(f"Unsupported peak profile '{new_type}'")) @@ -207,12 +206,12 @@ def peak_profile_type(self, new_type: str): print(paragraph(f"Peak profile type for experiment '{self.name}' changed to")) print(new_type) - def show_supported_peak_profile_types(self): - header = ["Peak profile type", "Description"] - table_data = [] + def show_supported_peak_profile_types(self) -> None: + header: List[str] = ["Peak profile type", "Description"] + table_data: List[List[str]] = [] for name, config in PeakFactory._supported[self.type.beam_mode.value].items(): - description = getattr(config, '_description', 'No description provided.') + description: str = getattr(config, '_description', 'No description provided.') table_data.append([name, description]) print(paragraph("Supported peak profile types")) @@ -225,16 +224,16 @@ def show_supported_peak_profile_types(self): showindex=False )) - def show_current_peak_profile_type(self): + def show_current_peak_profile_type(self) -> None: print(paragraph("Current peak profile type")) print(self.peak_profile_type) @property - def background_type(self): + def background_type(self) -> str: return self._background_type @background_type.setter - def background_type(self, new_type): + def background_type(self, new_type: str) -> None: if new_type not in BackgroundFactory._supported: supported_types = list(BackgroundFactory._supported.keys()) print(warning(f"Unknown background type '{new_type}'")) @@ -246,12 +245,12 @@ def background_type(self, new_type): print(paragraph(f"Background type for experiment '{self.name}' changed to")) print(new_type) - def show_supported_background_types(self): - header = ["Background type", "Description"] - table_data = [] + def show_supported_background_types(self) -> None: + header: List[str] = ["Background type", "Description"] + table_data: List[List[str]] = [] for name, config in BackgroundFactory._supported.items(): - description = getattr(config, '_description', 'No description provided.') + description: str = getattr(config, '_description', 'No description provided.') table_data.append([name, description]) print(paragraph("Supported background types")) @@ -264,27 +263,28 @@ def show_supported_background_types(self): showindex=False )) - def show_current_background_type(self): + def show_current_background_type(self) -> None: print(paragraph("Current background type")) print(self.background_type) + class SingleCrystalExperiment(BaseExperiment): - """Powder experiment class with specific attributes.""" + """Single crystal experiment class with specific attributes.""" def __init__(self, name: str, - type: ExperimentType): + type: ExperimentType) -> None: super().__init__(name=name, type=type) self.linked_crystal = None - def show_meas_chart(self): + def show_meas_chart(self) -> None: print('Showing measured data chart is not implemented yet.') class ExperimentFactory: """Creates Experiment instances with only relevant attributes.""" - _supported = { + _supported: Dict[str, Type[BaseExperiment]] = { "powder": PowderExperiment, "single crystal": SingleCrystalExperiment } @@ -292,32 +292,28 @@ class ExperimentFactory: @classmethod def create(cls, name: str, - sample_form: DEFAULT_SAMPLE_FORM, - beam_mode: DEFAULT_BEAM_MODE, - radiation_probe: DEFAULT_RADIATION_PROBE) -> BaseExperiment: - # TODO: Add checks for expt_type and expt_class - expt_type = ExperimentType(sample_form=sample_form, - beam_mode=beam_mode, - radiation_probe=radiation_probe) - expt_class = cls._supported[sample_form] - instance = expt_class(name=name, type=expt_type) + sample_form: str, + beam_mode: str, + radiation_probe: str) -> BaseExperiment: + expt_type: ExperimentType = ExperimentType(sample_form=sample_form, + beam_mode=beam_mode, + radiation_probe=radiation_probe) + expt_class: Type[BaseExperiment] = cls._supported[sample_form] + instance: BaseExperiment = expt_class(name=name, type=expt_type) return instance -# User exposed API for convenience -# TODO: Refactor based on the implementation of method add() in class Experiments -# TODO: Think of where to keep default values for sample_form, beam_mode, radiation_probe, as they are also defined in the -# class ExperimentType def Experiment(name: str, sample_form: str = DEFAULT_SAMPLE_FORM, beam_mode: str = DEFAULT_BEAM_MODE, radiation_probe: str = DEFAULT_RADIATION_PROBE, - data_path: str = None): - experiment = ExperimentFactory.create( + data_path: Optional[str] = None) -> BaseExperiment: + experiment: BaseExperiment = ExperimentFactory.create( name=name, sample_form=sample_form, beam_mode=beam_mode, radiation_probe=radiation_probe ) - experiment._load_ascii_data_to_experiment(data_path) + if data_path: + experiment._load_ascii_data_to_experiment(data_path) return experiment diff --git a/src/easydiffraction/experiments/experiments.py b/src/easydiffraction/experiments/experiments.py index 54727a96..b04ab472 100644 --- a/src/easydiffraction/experiments/experiments.py +++ b/src/easydiffraction/experiments/experiments.py @@ -1,4 +1,5 @@ import os.path +from typing import Optional, Union, Dict, List from easydiffraction.core.objects import Collection from easydiffraction.experiments.experiment import ( @@ -13,21 +14,21 @@ class Experiments(Collection): Collection manager for multiple Experiment instances. """ - def __init__(self): + def __init__(self) -> None: super().__init__() - self._experiments = self._items # Alias for legacy support + self._experiments: Dict[str, BaseExperiment] = self._items # Alias for legacy support def add( self, - experiment=None, - name=None, - sample_form=None, - beam_mode=None, - radiation_probe=None, - cif_path=None, - cif_str=None, - data_path=None - ): + experiment: Optional[BaseExperiment] = None, + name: Optional[str] = None, + sample_form: Optional[str] = None, + beam_mode: Optional[str] = None, + radiation_probe: Optional[str] = None, + cif_path: Optional[str] = None, + cif_str: Optional[str] = None, + data_path: Optional[str] = None + ) -> None: """ Add a new experiment to the collection. """ @@ -48,29 +49,30 @@ def add( else: raise ValueError("Provide either experiment, type parameters, cif_path, cif_str, or data_path") - def _add_prebuilt_experiment(self, experiment): + def _add_prebuilt_experiment(self, experiment: BaseExperiment) -> None: if not isinstance(experiment, BaseExperiment): raise TypeError("Expected an instance of BaseExperiment or its subclass.") self._experiments[experiment.name] = experiment - def _add_from_cif_path(self, cif_path): + def _add_from_cif_path(self, cif_path: str) -> None: print(f"Loading Experiment from CIF path...") raise NotImplementedError("CIF loading not implemented.") - def _add_from_cif_string(self, cif_str): + def _add_from_cif_string(self, cif_str: str) -> None: print("Loading Experiment from CIF string...") raise NotImplementedError("CIF loading not implemented.") - def _add_from_data_path(self, - name, - sample_form, - beam_mode, - radiation_probe, - data_path): + def _add_from_data_path( + self, + name: str, + sample_form: str, + beam_mode: str, + radiation_probe: str, + data_path: str + ) -> None: """ Load an experiment from raw data ASCII file. """ - # TODO: Move this to the Experiment class print(paragraph("Loading measured data from ASCII file")) print(os.path.abspath(data_path)) experiment = ExperimentFactory.create( @@ -82,21 +84,21 @@ def _add_from_data_path(self, experiment._load_ascii_data_to_experiment(data_path) self._experiments[experiment.name] = experiment - def remove(self, experiment_id): + def remove(self, experiment_id: str) -> None: if experiment_id in self._experiments: del self._experiments[experiment_id] - def show_names(self): + def show_names(self) -> None: print(paragraph("Defined experiments" + " 🔬")) print(self.ids) @property - def ids(self): + def ids(self) -> List[str]: return list(self._experiments.keys()) - def show_params(self): + def show_params(self) -> None: for exp in self._experiments.values(): print(exp) - def as_cif(self): + def as_cif(self) -> str: return "\n\n".join([exp.as_cif() for exp in self._experiments.values()]) diff --git a/src/easydiffraction/project.py b/src/easydiffraction/project.py index 85823889..35d6c040 100644 --- a/src/easydiffraction/project.py +++ b/src/easydiffraction/project.py @@ -3,6 +3,7 @@ import tempfile from textwrap import wrap from varname import varname +from typing import Optional, List from easydiffraction.utils.formatting import ( paragraph, @@ -19,82 +20,82 @@ class ProjectInfo: Stores metadata about the project, such as ID, title, description, and file paths. """ - def __init__(self): - self._name = "untitled_project" # Short unique project identifier - self._title = "Untitled Project" - self._description = "" - self._path = os.getcwd() - self._created = datetime.datetime.now() - self._last_modified = datetime.datetime.now() + def __init__(self) -> None: + self._name: str = "untitled_project" # Short unique project identifier + self._title: str = "Untitled Project" + self._description: str = "" + self._path: str = os.getcwd() + self._created: datetime.datetime = datetime.datetime.now() + self._last_modified: datetime.datetime = datetime.datetime.now() @property - def name(self): + def name(self) -> str: """Return the project ID.""" return self._name @name.setter - def name(self, value): + def name(self, value: str) -> None: self._name = value @property - def title(self): + def title(self) -> str: """Return the project title.""" return self._title @title.setter - def title(self, value): + def title(self, value: str) -> None: self._title = value @property - def description(self): + def description(self) -> str: """Return sanitized description with single spaces.""" return ' '.join(self._description.split()) @description.setter - def description(self, value): + def description(self, value: str) -> None: self._description = ' '.join(value.split()) @property - def path(self): + def path(self) -> str: """Return the project path.""" return self._path @path.setter - def path(self, value): + def path(self, value: str) -> None: self._path = value @property - def created(self): + def created(self) -> datetime.datetime: """Return the creation timestamp.""" return self._created @property - def last_modified(self): + def last_modified(self) -> datetime.datetime: """Return the last modified timestamp.""" return self._last_modified - def update_last_modified(self): + def update_last_modified(self) -> None: """Update the last modified timestamp.""" self._last_modified = datetime.datetime.now() def as_cif(self) -> str: """Export project metadata to CIF.""" - wrapped_title = wrap(self.title, width=60) - wrapped_description = wrap(self.description, width=60) + wrapped_title: List[str] = wrap(self.title, width=60) + wrapped_description: List[str] = wrap(self.description, width=60) - title_str = f"_project.title '{wrapped_title[0]}'" + title_str: str = f"_project.title '{wrapped_title[0]}'" for line in wrapped_title[1:]: title_str += f"\n{' ' * 27}'{line}'" if wrapped_description: - base_indent = "_project.description " - indent_spaces = " " * len(base_indent) - formatted_description = f"{base_indent}'{wrapped_description[0]}" + base_indent: str = "_project.description " + indent_spaces: str = " " * len(base_indent) + formatted_description: str = f"{base_indent}'{wrapped_description[0]}" for line in wrapped_description[1:]: formatted_description += f"\n{indent_spaces}{line}" formatted_description += "'" else: - formatted_description = "_project.description ''" + formatted_description: str = "_project.description ''" return ( f"_project.id {self.name}\n" @@ -104,13 +105,13 @@ def as_cif(self) -> str: f"_project.last_modified '{self._last_modified.strftime('%d %b %Y %H:%M:%S')}'\n" ) - def show_as_cif(self): - cif_text = self.as_cif() - lines = cif_text.splitlines() - max_width = max(len(line) for line in lines) - padded_lines = [f"│ {line.ljust(max_width)} │" for line in lines] - top = f"╒{'═' * (max_width + 2)}╕" - bottom = f"╘{'═' * (max_width + 2)}╛" + def show_as_cif(self) -> None: + cif_text: str = self.as_cif() + lines: List[str] = cif_text.splitlines() + max_width: int = max(len(line) for line in lines) + padded_lines: List[str] = [f"│ {line.ljust(max_width)} │" for line in lines] + top: str = f"╒{'═' * (max_width + 2)}╕" + bottom: str = f"╘{'═' * (max_width + 2)}╛" print(paragraph(f"Project 📦 '{self.name}' info as cif")) print(top) @@ -124,20 +125,20 @@ class Project: Provides access to sample models, experiments, analysis, and summary. """ - def __init__(self, name="untitled_project", title="Untitled Project", description=""): - self.info = ProjectInfo() + def __init__(self, name: str = "untitled_project", title: str = "Untitled Project", description: str = "") -> None: + self.info: ProjectInfo = ProjectInfo() self.info.name = name self.info.title = title self.info.description = description - self.sample_models = SampleModels() - self.experiments = Experiments() - self.analysis = Analysis(self) - self.summary = Summary(self) - self._saved = False - self._varname = varname() + self.sample_models: SampleModels = SampleModels() + self.experiments: Experiments = Experiments() + self.analysis: Analysis = Analysis(self) + self.summary: Summary = Summary(self) + self._saved: bool = False + self._varname: str = varname() @property - def name(self): + def name(self) -> str: """Convenience property to access the project's ID directly.""" return self.info.name @@ -145,7 +146,7 @@ def name(self): # Project File I/O # ------------------------------------------ - def load(self, dir_path: str): + def load(self, dir_path: str) -> None: """ Load a project from a given directory. Loads project info, sample models, experiments, etc. @@ -157,17 +158,17 @@ def load(self, dir_path: str): print('Loading project is not implemented yet.') self._saved = True - def save_as(self, dir_path: str, temporary: bool = False): + def save_as(self, dir_path: str, temporary: bool = False) -> None: """ Save the project into a new directory. """ if temporary: - tmp = tempfile.gettempdir() + tmp: str = tempfile.gettempdir() dir_path = os.path.join(tmp, dir_path) self.info.path = dir_path self.save() - def save(self): + def save(self) -> None: """ Save the project into the existing project directory. """ @@ -186,21 +187,21 @@ def save(self): print("✅ project.cif") # Save sample models - sm_dir = os.path.join(self.info.path, "sample_models") + sm_dir: str = os.path.join(self.info.path, "sample_models") os.makedirs(sm_dir, exist_ok=True) for model in self.sample_models: - file_name = f"{model.name}.cif" - file_path = os.path.join(sm_dir, file_name) + file_name: str = f"{model.name}.cif" + file_path: str = os.path.join(sm_dir, file_name) with open(file_path, "w") as f: f.write(model.as_cif()) print(f"✅ sample_models/{file_name}") # Save experiments - expt_dir = os.path.join(self.info.path, "experiments") + expt_dir: str = os.path.join(self.info.path, "experiments") os.makedirs(expt_dir, exist_ok=True) for experiment in self.experiments: - file_name = f"{experiment.name}.cif" - file_path = os.path.join(expt_dir, file_name) + file_name: str = f"{experiment.name}.cif" + file_path: str = os.path.join(expt_dir, file_name) with open(file_path, "w") as f: f.write(experiment.as_cif()) print(f"✅ experiments/{file_name}") @@ -222,10 +223,10 @@ def save(self): # Sample Models API Convenience Methods # ------------------------------------------ - def set_sample_models(self, sample_models: SampleModels): + def set_sample_models(self, sample_models: SampleModels) -> None: """Attach a collection of sample models to the project.""" self.sample_models = sample_models - def set_experiments(self, experiments: Experiments): + def set_experiments(self, experiments: Experiments) -> None: """Attach a collection of experiments to the project.""" self.experiments = experiments \ No newline at end of file diff --git a/src/easydiffraction/sample_models/components/cell.py b/src/easydiffraction/sample_models/components/cell.py index 0544ab2b..f7a80953 100644 --- a/src/easydiffraction/sample_models/components/cell.py +++ b/src/easydiffraction/sample_models/components/cell.py @@ -1,3 +1,4 @@ +from typing import Optional from easydiffraction.core.objects import (Parameter, Component) @@ -7,12 +8,12 @@ class Cell(Component): """ def __init__(self, - length_a=10.0, - length_b=10.0, - length_c=10.0, - angle_alpha=90.0, - angle_beta=90.0, - angle_gamma=90.0): + length_a: float = 10.0, + length_b: float = 10.0, + length_c: float = 10.0, + angle_alpha: float = 90.0, + angle_beta: float = 90.0, + angle_gamma: float = 90.0) -> None: super().__init__() self.length_a = Parameter( @@ -53,13 +54,13 @@ def __init__(self, ) @property - def cif_category_key(self): + def cif_category_key(self) -> str: return "cell" @property - def category_key(self): + def category_key(self) -> str: return "cell" @property - def _entry_id(self): + def _entry_id(self) -> Optional[str]: return None \ No newline at end of file diff --git a/src/easydiffraction/sample_models/components/space_group.py b/src/easydiffraction/sample_models/components/space_group.py index b245243c..98ab5488 100644 --- a/src/easydiffraction/sample_models/components/space_group.py +++ b/src/easydiffraction/sample_models/components/space_group.py @@ -1,3 +1,4 @@ +from typing import Optional from easydiffraction.core.objects import ( Descriptor, Component @@ -9,7 +10,7 @@ class SpaceGroup(Component): Represents the space group of a sample model. """ - def __init__(self, name_h_m="P 1", it_coordinate_system_code=None): + def __init__(self, name_h_m: str = "P 1", it_coordinate_system_code: Optional[int] = None) -> None: super().__init__() self.name_h_m = Descriptor( @@ -24,13 +25,13 @@ def __init__(self, name_h_m="P 1", it_coordinate_system_code=None): ) @property - def cif_category_key(self): + def cif_category_key(self) -> str: return "space_group" @property - def category_key(self): + def category_key(self) -> str: return "space_group" @property - def _entry_id(self): + def _entry_id(self) -> Optional[str]: return None \ No newline at end of file diff --git a/src/easydiffraction/sample_models/sample_models.py b/src/easydiffraction/sample_models/sample_models.py index 0f052da2..da1bc53b 100644 --- a/src/easydiffraction/sample_models/sample_models.py +++ b/src/easydiffraction/sample_models/sample_models.py @@ -1,3 +1,4 @@ +from typing import Dict, List, Optional, Any from easydiffraction.crystallography import crystallography as ecr from easydiffraction.core.objects import ( Collection, @@ -15,30 +16,32 @@ class SampleModel(Datablock): Wraps crystallographic information including space group, cell, and atomic sites. """ - def __init__(self, name, cif_path=None, cif_str=None): - self.name = name - self.space_group = SpaceGroup() - self.cell = Cell() - self.atom_sites = AtomSites() + def __init__(self, name: str, cif_path: Optional[str] = None, cif_str: Optional[str] = None) -> None: + self.name: str = name + self.space_group: SpaceGroup = SpaceGroup() + self.cell: Cell = Cell() + self.atom_sites: AtomSites = AtomSites() if cif_path: self.load_from_cif_file(cif_path) elif cif_str: self.load_from_cif_string(cif_str) - def apply_symmetry_constraints(self): + def apply_symmetry_constraints(self) -> None: + """Apply symmetry constraints to cell, atomic coordinates, and atomic displacements.""" self._apply_cell_symmetry_constraints() self._apply_atomic_coordinates_symmetry_constraints() self._apply_atomic_displacement_symmetry_constraints() - def _apply_cell_symmetry_constraints(self): - dummy_cell = {'lattice_a': self.cell.length_a.value, + def _apply_cell_symmetry_constraints(self) -> None: + """Apply symmetry constraints to unit cell parameters.""" + dummy_cell: Dict[str, float] = {'lattice_a': self.cell.length_a.value, 'lattice_b': self.cell.length_b.value, 'lattice_c': self.cell.length_c.value, 'angle_alpha': self.cell.angle_alpha.value, 'angle_beta': self.cell.angle_beta.value, 'angle_gamma': self.cell.angle_gamma.value} - space_group_name = self.space_group.name_h_m.value + space_group_name: str = self.space_group.name_h_m.value ecr.apply_cell_symmetry_constraints(cell=dummy_cell, name_hm=space_group_name) self.cell.length_a.value = dummy_cell['lattice_a'] @@ -48,14 +51,15 @@ def _apply_cell_symmetry_constraints(self): self.cell.angle_beta.value = dummy_cell['angle_beta'] self.cell.angle_gamma.value = dummy_cell['angle_gamma'] - def _apply_atomic_coordinates_symmetry_constraints(self): - space_group_name = self.space_group.name_h_m.value - space_group_coord_code = self.space_group.it_coordinate_system_code.value + def _apply_atomic_coordinates_symmetry_constraints(self) -> None: + """Apply symmetry constraints to atomic coordinates.""" + space_group_name: str = self.space_group.name_h_m.value + space_group_coord_code: Optional[int] = self.space_group.it_coordinate_system_code.value for atom in self.atom_sites: - dummy_atom = {"fract_x": atom.fract_x.value, + dummy_atom: Dict[str, float] = {"fract_x": atom.fract_x.value, "fract_y": atom.fract_y.value, "fract_z": atom.fract_z.value} - wl = atom.wyckoff_letter.value + wl: Optional[str] = atom.wyckoff_letter.value if not wl: #raise ValueError("Wyckoff letter is not defined for atom.") continue @@ -67,33 +71,43 @@ def _apply_atomic_coordinates_symmetry_constraints(self): atom.fract_y.value = dummy_atom['fract_y'] atom.fract_z.value = dummy_atom['fract_z'] - def _apply_atomic_displacement_symmetry_constraints(self): + def _apply_atomic_displacement_symmetry_constraints(self) -> None: + """Apply symmetry constraints to atomic displacement parameters.""" pass - def load_from_cif_file(self, cif_path: str): - """Load model data from a CIF file.""" + def load_from_cif_file(self, cif_path: str) -> None: + """ + Load model data from a CIF file. + + Args: + cif_path: Path to the CIF file. + """ # TODO: Implement CIF parsing here print(f"Loading SampleModel from CIF file: {cif_path}") # Example: self.id = extract_id_from_cif(cif_path) - def load_from_cif_string(self, cif_str: str): - """Load model data from a CIF string.""" + def load_from_cif_string(self, cif_str: str) -> None: + """ + Load model data from a CIF string. + + Args: + cif_str: CIF content as a string. + """ # TODO: Implement CIF parsing from a string print("Loading SampleModel from CIF string.") - def show_structure(self, plane='xy', grid_size=20): + def show_structure(self, plane: str = 'xy', grid_size: int = 20) -> None: """ Show an ASCII projection of the structure on a 2D plane. Args: - plane (str): 'xy', 'xz', or 'yz' plane to project. - grid_size (int): Size of the ASCII grid (default is 20). + plane: 'xy', 'xz', or 'yz' plane to project. + grid_size: Size of the ASCII grid (default is 20). """ - print(paragraph(f"Sample model 🧩 '{self.name}' structure view")) print("Not implemented yet.") - def show_params(self): + def show_params(self) -> None: """Display structural parameters (space group, unit cell, atomic sites).""" print(f"\nSampleModel ID: {self.name}") print(f"Space group: {self.space_group.name_h_m}") @@ -105,10 +119,10 @@ def as_cif(self) -> str: """ Export the sample model to CIF format. Returns: - str: CIF string representation of the sample model. + CIF string representation of the sample model. """ # Data block header - cif_lines = [f"data_{self.name}"] + cif_lines: List[str] = [f"data_{self.name}"] # Space Group cif_lines += ["", self.space_group.as_cif()] @@ -121,13 +135,14 @@ def as_cif(self) -> str: return "\n".join(cif_lines) - def show_as_cif(self): - cif_text = self.as_cif() - lines = cif_text.splitlines() - max_width = max(len(line) for line in lines) - padded_lines = [f"│ {line.ljust(max_width)} │" for line in lines] - top = f"╒{'═' * (max_width + 2)}╕" - bottom = f"╘{'═' * (max_width + 2)}╛" + def show_as_cif(self) -> None: + """Display the sample model in CIF format.""" + cif_text: str = self.as_cif() + lines: List[str] = cif_text.splitlines() + max_width: int = max(len(line) for line in lines) + padded_lines: List[str] = [f"│ {line.ljust(max_width)} │" for line in lines] + top: str = f"╒{'═' * (max_width + 2)}╕" + bottom: str = f"╘{'═' * (max_width + 2)}╛" print(paragraph(f"Sample model 🧩 '{self.name}' as cif")) print(top) @@ -140,65 +155,102 @@ class SampleModels(Collection): Collection manager for multiple SampleModel instances. """ - def __init__(self): + def __init__(self) -> None: super().__init__() # Initialize Collection self._models = self._items # Alias for legacy support - def add(self, model=None, name=None, cif_path=None, cif_str=None): + def add(self, + model: Optional[SampleModel] = None, + name: Optional[str] = None, + cif_path: Optional[str] = None, + cif_str: Optional[str] = None) -> None: """ Add a new sample model to the collection. Dispatches based on input type: pre-built model or parameters for new creation. + + Args: + model: An existing SampleModel instance. + name: Name for a new model if created from scratch. + cif_path: Path to a CIF file to create a model from. + cif_str: CIF content as string to create a model from. """ if model: self._add_prebuilt_sample_model(model) else: self._create_and_add_sample_model(name, cif_path, cif_str) - def remove(self, name): + def remove(self, name: str) -> None: """ Remove a sample model by its ID. + + Args: + name: ID of the model to remove. """ if name in self._models: del self._models[name] - def get_ids(self): + def get_ids(self) -> List[str]: """ Return a list of all model IDs in the collection. + + Returns: + List of model IDs. """ return list(self._models.keys()) - def show_names(self): - """ - List all model IDs in the collection. - """ + @property + def ids(self) -> List[str]: + """Property accessor for model IDs.""" + return self.get_ids() + + def show_names(self) -> None: + """List all model IDs in the collection.""" print(paragraph("Defined sample models" + " 🧩")) print(self.get_ids()) - def show_params(self): - """ - Show parameters of all sample models in the collection. - """ + def show_params(self) -> None: + """Show parameters of all sample models in the collection.""" for model in self._models.values(): model.show_params() def as_cif(self) -> str: """ Export all sample models to CIF format. + + Returns: + CIF string representation of all sample models. """ return "\n".join([model.as_cif() for model in self._models.values()]) - def _add_prebuilt_sample_model(self, model): + def _add_prebuilt_sample_model(self, model: SampleModel) -> None: """ Add a pre-built SampleModel instance. + + Args: + model: The SampleModel instance to add. + + Raises: + TypeError: If model is not a SampleModel instance. """ from easydiffraction.sample_models.sample_models import SampleModel # avoid circular import if not isinstance(model, SampleModel): raise TypeError("Expected an instance of SampleModel") self._models[model.name] = model - def _create_and_add_sample_model(self, name=None, cif_path=None, cif_str=None): + def _create_and_add_sample_model(self, + name: Optional[str] = None, + cif_path: Optional[str] = None, + cif_str: Optional[str] = None) -> None: """ Create a SampleModel instance and add it to the collection. + + Args: + name: Name for the new model. + cif_path: Path to a CIF file. + cif_str: CIF content as string. + + Raises: + ValueError: If neither name, cif_path, nor cif_str is provided. """ from easydiffraction.sample_models.sample_models import SampleModel # avoid circular import diff --git a/src/easydiffraction/summary.py b/src/easydiffraction/summary.py index 9470b7f7..b14fbf9b 100644 --- a/src/easydiffraction/summary.py +++ b/src/easydiffraction/summary.py @@ -1,5 +1,6 @@ from tabulate import tabulate from textwrap import wrap +from typing import Any, Dict, List from easydiffraction.utils.formatting import ( paragraph, @@ -15,19 +16,20 @@ class Summary: about the fitted model, experiments, and analysis results. """ - def __init__(self, project): + def __init__(self, project: Any) -> None: """ Initialize the summary with a reference to the project. - :param project: The Project instance this summary belongs to. + Args: + project: The Project instance this summary belongs to. """ - self.project = project + self.project: Any = project # ------------------------------------------ # Report Generation # ------------------------------------------ - def show_report(self): + def show_report(self) -> None: """ Show a report of the entire analysis process, including: - Project metadata @@ -54,23 +56,23 @@ def show_report(self): print(model.space_group.name_h_m.value) print(paragraph("Cell parameters")) - cell_data = [[k.replace('length_', '').replace('angle_', ''), f"{v:.4f}"] for k, v in model.cell.as_dict().items()] + cell_data: List[List[Any]] = [[k.replace('length_', '').replace('angle_', ''), f"{v:.4f}"] for k, v in model.cell.as_dict().items()] print(tabulate(cell_data, headers=["Parameter", "Value"], tablefmt="fancy_outline")) print(paragraph("Atom sites")) - atom_table = [] + atom_table: List[List[str]] = [] for site in model.atom_sites: - fract_x = site.fract_x.value - fract_y = site.fract_y.value - fract_z = site.fract_z.value - b_iso = site.b_iso.value - occ = site.occupancy.value + fract_x: float = site.fract_x.value + fract_y: float = site.fract_y.value + fract_z: float = site.fract_z.value + b_iso: float = site.b_iso.value + occ: float = site.occupancy.value atom_table.append([ site.label.value, site.type_symbol.value, f"{fract_x:.5f}", f"{fract_y:.5f}", f"{fract_z:.5f}", f"{occ:.5f}", f"{b_iso:.5f}" ]) - headers = ["Label", "Type", "fract_x", "fract_y", "fract_z", "Occupancy", "B_iso"] + headers: List[str] = ["Label", "Type", "fract_x", "fract_y", "fract_z", "Occupancy", "B_iso"] print(tabulate(atom_table, headers=headers, tablefmt="fancy_outline")) # ------------------------------------------ From e0e9c917250c4b813b0ffcab3e5cf681becccd16 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Wed, 23 Apr 2025 09:23:02 +0200 Subject: [PATCH 07/12] correct the order of arguments --- src/easydiffraction/analysis/minimizers/minimizer_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/easydiffraction/analysis/minimizers/minimizer_base.py b/src/easydiffraction/analysis/minimizers/minimizer_base.py index 46a8d028..106ea952 100644 --- a/src/easydiffraction/analysis/minimizers/minimizer_base.py +++ b/src/easydiffraction/analysis/minimizers/minimizer_base.py @@ -172,7 +172,7 @@ def _sync_result_to_parameters(self, def _finalize_fit(self, parameters: List[Any], raw_result: Any) -> FitResults: - self._sync_result_to_parameters(raw_result, parameters) + self._sync_result_to_parameters(parameters, raw_result) success = self._check_success(raw_result) self.result = FitResults( success=success, From eba023948841e535362ef1157e49b5923284bb6f Mon Sep 17 00:00:00 2001 From: rozyczko Date: Wed, 23 Apr 2025 14:02:48 +0200 Subject: [PATCH 08/12] run unit tests first --- .github/workflows/ci-testing.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci-testing.yaml b/.github/workflows/ci-testing.yaml index 563d061f..7fbfe05e 100644 --- a/.github/workflows/ci-testing.yaml +++ b/.github/workflows/ci-testing.yaml @@ -51,10 +51,10 @@ jobs: shell: bash run: python -m pip install -r requirements.txt - - name: Run Python functional tests + - name: Run Python unit tests shell: bash - run: PYTHONPATH=$(pwd)/src python -m pytest tests/functional_tests/ --color=yes -n auto + run: PYTHONPATH=$(pwd)/src python -m pytest tests/unit_tests/ --color=yes -n auto - - name: Run Python unit tests + - name: Run Python functional tests shell: bash - run: PYTHONPATH=$(pwd)/src python -m pytest tests/unit_tests/ --color=yes -n auto \ No newline at end of file + run: PYTHONPATH=$(pwd)/src python -m pytest tests/functional_tests/ --color=yes -n auto From 546a9698c231a2166d8356f39551d24b01628870 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Wed, 23 Apr 2025 15:24:34 +0200 Subject: [PATCH 09/12] moved Any to proper type --- src/easydiffraction/analysis/calculation.py | 8 +++-- .../analysis/calculators/calculator_base.py | 15 +++++----- .../calculators/calculator_crysfml.py | 15 ++++++---- .../analysis/calculators/calculator_cryspy.py | 18 ++++++----- .../analysis/calculators/calculator_pdffit.py | 10 +++++-- src/easydiffraction/analysis/minimization.py | 30 +++++++++++-------- .../analysis/reliability_factors.py | 5 ++-- src/easydiffraction/summary.py | 1 - 8 files changed, 60 insertions(+), 42 deletions(-) diff --git a/src/easydiffraction/analysis/calculation.py b/src/easydiffraction/analysis/calculation.py index 4d9688f0..5ffc584c 100644 --- a/src/easydiffraction/analysis/calculation.py +++ b/src/easydiffraction/analysis/calculation.py @@ -1,7 +1,9 @@ from typing import Any, Optional, List import numpy as np from .calculators.calculator_factory import CalculatorFactory - +from easydiffraction.sample_models.sample_models import SampleModels +from easydiffraction.experiments.experiments import Experiments +from easydiffraction.experiments.experiment import Experiment class DiffractionCalculator: """ @@ -29,7 +31,7 @@ def set_calculator(self, engine: str) -> None: """ self._calculator = self.calculator_factory.create_calculator(engine) - def calculate_structure_factors(self, sample_models: Any, experiments: Any) -> Optional[List[Any]]: + def calculate_structure_factors(self, sample_models: SampleModels, experiments: Experiments) -> Optional[List[Any]]: """ Calculate HKL intensities (structure factors) for sample models and experiments. @@ -42,7 +44,7 @@ def calculate_structure_factors(self, sample_models: Any, experiments: Any) -> O """ return self._calculator.calculate_structure_factors(sample_models, experiments) - def calculate_pattern(self, sample_models: Any, experiment: Any) -> np.ndarray: + def calculate_pattern(self, sample_models: SampleModels, experiment: Experiment) -> np.ndarray: """ Generate diffraction pattern based on sample models and experiment. diff --git a/src/easydiffraction/analysis/calculators/calculator_base.py b/src/easydiffraction/analysis/calculators/calculator_base.py index cdfb1a59..0cfbb6ef 100644 --- a/src/easydiffraction/analysis/calculators/calculator_base.py +++ b/src/easydiffraction/analysis/calculators/calculator_base.py @@ -3,7 +3,8 @@ from typing import List, Any from easydiffraction.core.singletons import ConstraintsHandler - +from easydiffraction.sample_models.sample_models import SampleModels +from easydiffraction.experiments.experiment import Experiment class CalculatorBase(ABC): """ @@ -21,15 +22,15 @@ def engine_imported(self) -> bool: pass @abstractmethod - def calculate_structure_factors(self, sample_model: Any, experiment: Any) -> None: + def calculate_structure_factors(self, sample_model: SampleModels, experiment: Experiment) -> None: """ Calculate structure factors for a single sample model and experiment. """ pass def calculate_pattern(self, - sample_models: Any, - experiment: Any, + sample_models: SampleModels, + experiment: Experiment, called_by_minimizer: bool = False) -> np.ndarray: """ Calculate the diffraction pattern for multiple sample models and a single experiment. @@ -82,8 +83,8 @@ def calculate_pattern(self, @abstractmethod def _calculate_single_model_pattern(self, - sample_model: Any, - experiment: Any, + sample_model: SampleModels, + experiment: Experiment, called_by_minimizer: bool) -> np.ndarray: """ Calculate the diffraction pattern for a single sample model and experiment. @@ -98,7 +99,7 @@ def _calculate_single_model_pattern(self, """ pass - def _get_valid_linked_phases(self, sample_models: Any, experiment: Any) -> List[Any]: + def _get_valid_linked_phases(self, sample_models: SampleModels, experiment: Experiment) -> List[Any]: """ Get valid linked phases from the experiment. diff --git a/src/easydiffraction/analysis/calculators/calculator_crysfml.py b/src/easydiffraction/analysis/calculators/calculator_crysfml.py index 502f8af5..34a4bb62 100644 --- a/src/easydiffraction/analysis/calculators/calculator_crysfml.py +++ b/src/easydiffraction/analysis/calculators/calculator_crysfml.py @@ -2,6 +2,9 @@ from typing import Any, Dict, List, Union from .calculator_base import CalculatorBase from easydiffraction.utils.formatting import warning +from easydiffraction.sample_models.sample_models import SampleModels +from easydiffraction.experiments.experiments import Experiments +from easydiffraction.experiments.experiment import Experiment try: from pycrysfml import cfml_py_utilities @@ -21,7 +24,7 @@ class CrysfmlCalculator(CalculatorBase): def name(self) -> str: return "crysfml" - def calculate_structure_factors(self, sample_models: Any, experiments: Any) -> None: + def calculate_structure_factors(self, sample_models: SampleModels, experiments: Experiments) -> None: """ Call Crysfml to calculate structure factors. @@ -33,8 +36,8 @@ def calculate_structure_factors(self, sample_models: Any, experiments: Any) -> N def _calculate_single_model_pattern( self, - sample_model: Any, - experiment: Any, + sample_model: SampleModels, + experiment: Experiment, called_by_minimizer: bool = False ) -> Union[np.ndarray, List[float]]: """ @@ -73,7 +76,7 @@ def _adjust_pattern_length(self, pattern: List[float], target_length: int) -> Li return pattern[:target_length] return pattern - def _crysfml_dict(self, sample_model: Any, experiment: Any) -> Dict[str, Any]: + def _crysfml_dict(self, sample_model: SampleModels, experiment: Experiment) -> Dict[str, Any]: """ Converts the sample model and experiment into a dictionary format for Crysfml. @@ -91,7 +94,7 @@ def _crysfml_dict(self, sample_model: Any, experiment: Any) -> Dict[str, Any]: "experiments": [experiment_dict] } - def _convert_sample_model_to_dict(self, sample_model: Any) -> Dict[str, Any]: + def _convert_sample_model_to_dict(self, sample_model: SampleModels) -> Dict[str, Any]: """ Converts a sample model into a dictionary format. @@ -129,7 +132,7 @@ def _convert_sample_model_to_dict(self, sample_model: Any) -> Dict[str, Any]: return sample_model_dict - def _convert_experiment_to_dict(self, experiment: Any) -> Dict[str, Any]: + def _convert_experiment_to_dict(self, experiment: Experiment) -> Dict[str, Any]: """ Converts an experiment into a dictionary format. diff --git a/src/easydiffraction/analysis/calculators/calculator_cryspy.py b/src/easydiffraction/analysis/calculators/calculator_cryspy.py index 92a72e71..e65cf6b5 100644 --- a/src/easydiffraction/analysis/calculators/calculator_cryspy.py +++ b/src/easydiffraction/analysis/calculators/calculator_cryspy.py @@ -4,6 +4,10 @@ from .calculator_base import CalculatorBase from easydiffraction.utils.formatting import warning +from easydiffraction.sample_models.sample_models import SampleModels +from easydiffraction.experiments.experiments import Experiments +from easydiffraction.experiments.experiment import Experiment + try: import cryspy from cryspy.procedure_rhochi.rhochi_by_dictionary import rhochi_calc_chi_sq_by_dictionary @@ -29,7 +33,7 @@ def __init__(self) -> None: super().__init__() self._cryspy_dicts: Dict[str, Dict[str, Any]] = {} - def calculate_structure_factors(self, sample_models: Any, experiments: Any) -> None: + def calculate_structure_factors(self, sample_models: SampleModels, experiments: Experiments) -> None: """ Raises a NotImplementedError as HKL calculation is not implemented. @@ -41,8 +45,8 @@ def calculate_structure_factors(self, sample_models: Any, experiments: Any) -> N def _calculate_single_model_pattern( self, - sample_model: Any, - experiment: Any, + sample_model: SampleModels, + experiment: Experiment, called_by_minimizer: bool = False ) -> Union[np.ndarray, List[float]]: """ @@ -102,7 +106,7 @@ def _calculate_single_model_pattern( return y_calc_total - def _recreate_cryspy_dict(self, sample_model: Any, experiment: Any) -> Dict[str, Any]: + def _recreate_cryspy_dict(self, sample_model: SampleModels, experiment: Experiment) -> Dict[str, Any]: """ Recreates the Cryspy dictionary for the given sample model and experiment. @@ -181,7 +185,7 @@ def _recreate_cryspy_dict(self, sample_model: Any, experiment: Any) -> Dict[str, return cryspy_dict - def _recreate_cryspy_obj(self, sample_model: Any, experiment: Any) -> Any: + def _recreate_cryspy_obj(self, sample_model: SampleModels, experiment: Experiment) -> Any: """ Recreates the Cryspy object for the given sample model and experiment. @@ -204,7 +208,7 @@ def _recreate_cryspy_obj(self, sample_model: Any, experiment: Any) -> Any: return cryspy_obj - def _convert_sample_model_to_cryspy_cif(self, sample_model: Any) -> str: + def _convert_sample_model_to_cryspy_cif(self, sample_model: SampleModels) -> str: """ Converts a sample model to a Cryspy CIF string. @@ -216,7 +220,7 @@ def _convert_sample_model_to_cryspy_cif(self, sample_model: Any) -> str: """ return sample_model.as_cif() - def _convert_experiment_to_cryspy_cif(self, experiment: Any, linked_phase: Any) -> str: + def _convert_experiment_to_cryspy_cif(self, experiment: Experiment, linked_phase: Any) -> str: """ Converts an experiment to a Cryspy CIF string. diff --git a/src/easydiffraction/analysis/calculators/calculator_pdffit.py b/src/easydiffraction/analysis/calculators/calculator_pdffit.py index 793abd89..b47081fc 100644 --- a/src/easydiffraction/analysis/calculators/calculator_pdffit.py +++ b/src/easydiffraction/analysis/calculators/calculator_pdffit.py @@ -2,6 +2,10 @@ from .calculator_base import CalculatorBase from easydiffraction.utils.formatting import warning +from easydiffraction.sample_models.sample_models import SampleModels +from easydiffraction.experiments.experiments import Experiments +from easydiffraction.experiments.experiment import Experiment + try: from diffpy.pdffit2 import pdffit except ImportError: @@ -20,7 +24,7 @@ class PdffitCalculator(CalculatorBase): def name(self) -> str: return "PdfFit" - def calculate_structure_factors(self, sample_models: Any, experiments: Any) -> List[Any]: + def calculate_structure_factors(self, sample_models: SampleModels, experiments: Experiments) -> List[Any]: """ PDF doesn't compute HKL but we keep the interface consistent. @@ -36,8 +40,8 @@ def calculate_structure_factors(self, sample_models: Any, experiments: Any) -> L def _calculate_single_model_pattern( self, - sample_model: Any, - experiment: Any, + sample_model: SampleModels, + experiment: Experiment, called_by_minimizer: bool = False ) -> Union[List[float], Any]: """ diff --git a/src/easydiffraction/analysis/minimization.py b/src/easydiffraction/analysis/minimization.py index 71cc7784..d02a85a5 100644 --- a/src/easydiffraction/analysis/minimization.py +++ b/src/easydiffraction/analysis/minimization.py @@ -4,6 +4,10 @@ from ..analysis.reliability_factors import get_reliability_inputs import numpy as np +from easydiffraction.sample_models.sample_models import SampleModels +from easydiffraction.experiments.experiments import Experiments +from easydiffraction.core.objects import Parameter + class DiffractionMinimizer: """ @@ -17,10 +21,10 @@ def __init__(self, selection: str = 'lmfit (leastsq)') -> None: self.results: Optional[FitResults] = None def fit(self, - sample_models: Any, - experiments: Any, + sample_models: SampleModels, + experiments: Experiments, calculator: Any, - weights: Optional[Any] = None) -> None: + weights: Optional[np.array] = None) -> None: """ Run the fitting process. @@ -31,7 +35,7 @@ def fit(self, weights: Optional weights for joint fitting. """ - params: List[Any] = sample_models.get_free_params() + experiments.get_free_params() + params = sample_models.get_free_params() + experiments.get_free_params() if not params: print("⚠️ No parameters selected for fitting.") @@ -56,8 +60,8 @@ def fit(self, self._process_fit_results(sample_models, experiments, calculator) def _process_fit_results(self, - sample_models: Any, - experiments: Any, + sample_models: SampleModels, + experiments: Experiments, calculator: Any) -> None: """ Collect reliability inputs and display results after fitting. @@ -76,8 +80,8 @@ def _process_fit_results(self, self.results.display_results(y_obs=y_obs, y_calc=y_calc, y_err=y_err, f_obs=f_obs, f_calc=f_calc) def _collect_free_parameters(self, - sample_models: Any, - experiments: Any) -> List[Any]: + sample_models: SampleModels, + experiments: Experiments) -> List[Parameter]: """ Collect free parameters from sample models and experiments. @@ -88,16 +92,16 @@ def _collect_free_parameters(self, Returns: List of free parameters. """ - free_params: List[Any] = sample_models.get_free_params() + experiments.get_free_params() + free_params: List[Parameter] = sample_models.get_free_params() + experiments.get_free_params() return free_params def _residual_function(self, engine_params: Dict[str, Any], - parameters: List[Any], - sample_models: Any, - experiments: Any, + parameters: List[Parameter], + sample_models: SampleModels, + experiments: Experiments, calculator: Any, - weights: Optional[Any] = None) -> np.ndarray: + weights: Optional[np.array] = None) -> np.ndarray: """ Residual function computes the difference between measured and calculated patterns. It updates the parameter values according to the optimizer-provided engine_params. diff --git a/src/easydiffraction/analysis/reliability_factors.py b/src/easydiffraction/analysis/reliability_factors.py index 6b23f008..cb72abad 100644 --- a/src/easydiffraction/analysis/reliability_factors.py +++ b/src/easydiffraction/analysis/reliability_factors.py @@ -1,6 +1,7 @@ import numpy as np from typing import Tuple, Any, Optional - +from easydiffraction.sample_models.sample_models import SampleModels +from easydiffraction.experiments.experiments import Experiments def calculate_r_factor(y_obs: np.ndarray, y_calc: np.ndarray) -> float: """ @@ -97,7 +98,7 @@ def calculate_reduced_chi_square(residuals: np.ndarray, num_parameters: int) -> return np.nan -def get_reliability_inputs(sample_models: Any, experiments: Any, calculator: Any) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: +def get_reliability_inputs(sample_models: SampleModels, experiments: Experiments, calculator: Any) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: """ Collect observed and calculated data points for reliability calculations. diff --git a/src/easydiffraction/summary.py b/src/easydiffraction/summary.py index b14fbf9b..f3dfc4f7 100644 --- a/src/easydiffraction/summary.py +++ b/src/easydiffraction/summary.py @@ -7,7 +7,6 @@ section ) - class Summary: """ Generates reports and exports results from the project. From 8e73ab0c5d201c3f55a8e071459fa24abe7c2101 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Wed, 23 Apr 2025 15:39:36 +0200 Subject: [PATCH 10/12] fix return type --- src/easydiffraction/analysis/analysis.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/easydiffraction/analysis/analysis.py b/src/easydiffraction/analysis/analysis.py index ae1f234a..5fa47c1a 100644 --- a/src/easydiffraction/analysis/analysis.py +++ b/src/easydiffraction/analysis/analysis.py @@ -1,4 +1,5 @@ import pandas as pd +import numpy as np from tabulate import tabulate from typing import List, Optional, Union, Any @@ -290,7 +291,7 @@ def show_current_fit_mode(self) -> None: print(paragraph("Current fit mode")) print(self.fit_mode) - def calculate_pattern(self, expt_name: str) -> Optional[pd.DataFrame]: + def calculate_pattern(self, expt_name: str) -> Optional[np.ndarray]: """ Calculate the diffraction pattern for a given experiment. From 6931f687c41c3a834a88d47e494f6f5f4f5a3388 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Thu, 24 Apr 2025 09:32:43 +0200 Subject: [PATCH 11/12] use correct type hint and avoid circular dependencies --- src/easydiffraction/analysis/analysis.py | 5 +++-- src/easydiffraction/summary.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/easydiffraction/analysis/analysis.py b/src/easydiffraction/analysis/analysis.py index 5fa47c1a..1f188cb4 100644 --- a/src/easydiffraction/analysis/analysis.py +++ b/src/easydiffraction/analysis/analysis.py @@ -1,7 +1,8 @@ +from __future__ import annotations import pandas as pd import numpy as np from tabulate import tabulate -from typing import List, Optional, Union, Any +from typing import List, Optional, Union from easydiffraction.utils.formatting import ( paragraph, @@ -32,7 +33,7 @@ class Analysis: _calculator = CalculatorFactory.create_calculator('cryspy') - def __init__(self, project: Any) -> None: + def __init__(self, project: Project) -> None: self.project = project self.aliases = ConstraintAliases() self.constraints = ConstraintExpressions() diff --git a/src/easydiffraction/summary.py b/src/easydiffraction/summary.py index f3dfc4f7..56ef0f16 100644 --- a/src/easydiffraction/summary.py +++ b/src/easydiffraction/summary.py @@ -1,3 +1,4 @@ +from __future__ import annotations from tabulate import tabulate from textwrap import wrap from typing import Any, Dict, List @@ -15,14 +16,14 @@ class Summary: about the fitted model, experiments, and analysis results. """ - def __init__(self, project: Any) -> None: + def __init__(self, project: Project) -> None: """ Initialize the summary with a reference to the project. Args: project: The Project instance this summary belongs to. """ - self.project: Any = project + self.project: Project = project # ------------------------------------------ # Report Generation From 4b0a9a418922b09ba5f1a41904ca0c5a8837d2f5 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Thu, 24 Apr 2025 11:23:55 +0200 Subject: [PATCH 12/12] a few more Any's disappeared --- .../analysis/calculators/calculator_crysfml.py | 7 ++++--- src/easydiffraction/analysis/minimization.py | 5 +++-- .../analysis/reliability_factors.py | 5 +++-- .../experiments/collections/datastore.py | 14 +++++++------- src/easydiffraction/sample_models/sample_models.py | 2 +- src/easydiffraction/summary.py | 4 ++-- 6 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/easydiffraction/analysis/calculators/calculator_crysfml.py b/src/easydiffraction/analysis/calculators/calculator_crysfml.py index 34a4bb62..100c5dab 100644 --- a/src/easydiffraction/analysis/calculators/calculator_crysfml.py +++ b/src/easydiffraction/analysis/calculators/calculator_crysfml.py @@ -3,8 +3,9 @@ from .calculator_base import CalculatorBase from easydiffraction.utils.formatting import warning from easydiffraction.sample_models.sample_models import SampleModels -from easydiffraction.experiments.experiments import Experiments +from easydiffraction.sample_models.sample_models import SampleModel from easydiffraction.experiments.experiment import Experiment +from easydiffraction.experiments.experiments import Experiments try: from pycrysfml import cfml_py_utilities @@ -76,7 +77,7 @@ def _adjust_pattern_length(self, pattern: List[float], target_length: int) -> Li return pattern[:target_length] return pattern - def _crysfml_dict(self, sample_model: SampleModels, experiment: Experiment) -> Dict[str, Any]: + def _crysfml_dict(self, sample_model: SampleModels, experiment: Experiment) -> Dict[str, Union[Experiment, SampleModel]]: """ Converts the sample model and experiment into a dictionary format for Crysfml. @@ -94,7 +95,7 @@ def _crysfml_dict(self, sample_model: SampleModels, experiment: Experiment) -> D "experiments": [experiment_dict] } - def _convert_sample_model_to_dict(self, sample_model: SampleModels) -> Dict[str, Any]: + def _convert_sample_model_to_dict(self, sample_model: SampleModels) -> Dict[str, SampleModel]: """ Converts a sample model into a dictionary format. diff --git a/src/easydiffraction/analysis/minimization.py b/src/easydiffraction/analysis/minimization.py index d02a85a5..e78128f5 100644 --- a/src/easydiffraction/analysis/minimization.py +++ b/src/easydiffraction/analysis/minimization.py @@ -7,6 +7,7 @@ from easydiffraction.sample_models.sample_models import SampleModels from easydiffraction.experiments.experiments import Experiments from easydiffraction.core.objects import Parameter +from easydiffraction.analysis.calculators.calculator_base import CalculatorBase class DiffractionMinimizer: @@ -62,7 +63,7 @@ def fit(self, def _process_fit_results(self, sample_models: SampleModels, experiments: Experiments, - calculator: Any) -> None: + calculator: CalculatorBase) -> None: """ Collect reliability inputs and display results after fitting. @@ -100,7 +101,7 @@ def _residual_function(self, parameters: List[Parameter], sample_models: SampleModels, experiments: Experiments, - calculator: Any, + calculator: CalculatorBase, weights: Optional[np.array] = None) -> np.ndarray: """ Residual function computes the difference between measured and calculated patterns. diff --git a/src/easydiffraction/analysis/reliability_factors.py b/src/easydiffraction/analysis/reliability_factors.py index cb72abad..b3dbe898 100644 --- a/src/easydiffraction/analysis/reliability_factors.py +++ b/src/easydiffraction/analysis/reliability_factors.py @@ -1,7 +1,8 @@ import numpy as np -from typing import Tuple, Any, Optional +from typing import Tuple, Optional from easydiffraction.sample_models.sample_models import SampleModels from easydiffraction.experiments.experiments import Experiments +from easydiffraction.analysis.calculators.calculator_base import CalculatorBase def calculate_r_factor(y_obs: np.ndarray, y_calc: np.ndarray) -> float: """ @@ -98,7 +99,7 @@ def calculate_reduced_chi_square(residuals: np.ndarray, num_parameters: int) -> return np.nan -def get_reliability_inputs(sample_models: SampleModels, experiments: Experiments, calculator: Any) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: +def get_reliability_inputs(sample_models: SampleModels, experiments: Experiments, calculator: CalculatorBase) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]: """ Collect observed and calculated data points for reliability calculations. diff --git a/src/easydiffraction/experiments/collections/datastore.py b/src/easydiffraction/experiments/collections/datastore.py index 1e49a354..2ff76639 100644 --- a/src/easydiffraction/experiments/collections/datastore.py +++ b/src/easydiffraction/experiments/collections/datastore.py @@ -1,15 +1,15 @@ -from typing import Optional, Any +from __future__ import annotations +from typing import Optional import numpy as np - class Pattern: """ Base pattern class for both powder and single crystal experiments. Stores x, measured intensities, uncertainties, background, and calculated intensities. """ - def __init__(self, experiment: Any) -> None: - self.experiment: Any = experiment + def __init__(self, experiment: Experiment) -> None: + self.experiment = experiment # Data arrays self.x: Optional[np.ndarray] = None @@ -33,7 +33,7 @@ class PowderPattern(Pattern): """ Specialized pattern for powder diffraction (can be extended in the future). """ - def __init__(self, experiment: Any) -> None: + def __init__(self, experiment: Experiment) -> None: super().__init__(experiment) # Additional powder-specific initialization if needed @@ -43,7 +43,7 @@ class Datastore: Stores pattern data (measured and calculated) for an experiment. """ - def __init__(self, sample_form: str, experiment: Any) -> None: + def __init__(self, sample_form: str, experiment: Experiment) -> None: self.sample_form: str = sample_form if sample_form == "powder": @@ -95,7 +95,7 @@ class DatastoreFactory: """ @staticmethod - def create(sample_form: str, experiment: Any) -> Datastore: + def create(sample_form: str, experiment: Experiment) -> Datastore: """ Create a datastore object depending on the sample form. diff --git a/src/easydiffraction/sample_models/sample_models.py b/src/easydiffraction/sample_models/sample_models.py index da1bc53b..b73b51a6 100644 --- a/src/easydiffraction/sample_models/sample_models.py +++ b/src/easydiffraction/sample_models/sample_models.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Any +from typing import Dict, List, Optional from easydiffraction.crystallography import crystallography as ecr from easydiffraction.core.objects import ( Collection, diff --git a/src/easydiffraction/summary.py b/src/easydiffraction/summary.py index 56ef0f16..2d06e7fe 100644 --- a/src/easydiffraction/summary.py +++ b/src/easydiffraction/summary.py @@ -1,7 +1,7 @@ from __future__ import annotations from tabulate import tabulate from textwrap import wrap -from typing import Any, Dict, List +from typing import List from easydiffraction.utils.formatting import ( paragraph, @@ -56,7 +56,7 @@ def show_report(self) -> None: print(model.space_group.name_h_m.value) print(paragraph("Cell parameters")) - cell_data: List[List[Any]] = [[k.replace('length_', '').replace('angle_', ''), f"{v:.4f}"] for k, v in model.cell.as_dict().items()] + cell_data = [[k.replace('length_', '').replace('angle_', ''), f"{v:.4f}"] for k, v in model.cell.as_dict().items()] print(tabulate(cell_data, headers=["Parameter", "Value"], tablefmt="fancy_outline")) print(paragraph("Atom sites"))