Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci-testing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ jobs:
shell: bash
run: python -m pip install -r requirements.txt

- name: Run Python unit tests
shell: bash
run: PYTHONPATH=$(pwd)/src python -m pytest tests/unit_tests/ --color=yes -n auto

- name: Run Python functional tests
shell: bash
run: PYTHONPATH=$(pwd)/src python -m pytest tests/functional_tests/ --color=yes -n auto
96 changes: 58 additions & 38 deletions src/easydiffraction/analysis/analysis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations
import pandas as pd
import numpy as np
from tabulate import tabulate
from typing import List, Optional, Union

from easydiffraction.utils.formatting import (
paragraph,
Expand Down Expand Up @@ -30,19 +33,25 @@
class Analysis:
_calculator = CalculatorFactory.create_calculator('cryspy')

def __init__(self, project):
def __init__(self, project: Project) -> None:
self.project = project
self.aliases = ConstraintAliases()
self.constraints = ConstraintExpressions()
self.constraints_handler = ConstraintsHandler.get()
self.calculator = Analysis._calculator # Default calculator shared by project
self._calculator_key = 'cryspy' # Added to track the current calculator
self._fit_mode = 'single'
self._calculator_key: str = 'cryspy' # Added to track the current calculator
self._fit_mode: str = 'single'
self.fitter = DiffractionMinimizer('lmfit (leastsq)')

def _get_params_as_dataframe(self, params):
def _get_params_as_dataframe(self, params: List[Union[Descriptor, Parameter]]) -> pd.DataFrame:
"""
Convert a list of parameters to a DataFrame.

Args:
params: List of Descriptor or Parameter objects.

Returns:
A pandas DataFrame containing parameter information.
"""
rows = []
for param in params:
Expand Down Expand Up @@ -75,9 +84,13 @@ def _get_params_as_dataframe(self, params):

return dataframe

def _show_params(self, dataframe, column_headers):
""":
def _show_params(self, dataframe: pd.DataFrame, column_headers: List[str]) -> None:
"""
Display parameters in a tabular format.

Args:
dataframe: The pandas DataFrame containing parameter information.
column_headers: List of column headers to display.
"""
dataframe = dataframe[column_headers]
indices = range(1, len(dataframe) + 1) # Force starting from 1
Expand All @@ -87,7 +100,7 @@ def _show_params(self, dataframe, column_headers):
tablefmt="fancy_outline",
showindex=indices))

def show_all_params(self):
def show_all_params(self) -> None:
sample_models_params = self.project.sample_models.get_all_params()
experiments_params = self.project.experiments.get_all_params()

Expand All @@ -110,7 +123,7 @@ def show_all_params(self):
experiments_dataframe = self._get_params_as_dataframe(experiments_params)
self._show_params(experiments_dataframe, column_headers=column_headers)

def show_fittable_params(self):
def show_fittable_params(self) -> None:
sample_models_params = self.project.sample_models.get_fittable_params()
experiments_params = self.project.experiments.get_fittable_params()

Expand All @@ -135,7 +148,7 @@ def show_fittable_params(self):
experiments_dataframe = self._get_params_as_dataframe(experiments_params)
self._show_params(experiments_dataframe, column_headers=column_headers)

def show_free_params(self):
def show_free_params(self) -> None:
sample_models_params = self.project.sample_models.get_free_params()
experiments_params = self.project.experiments.get_free_params()
free_params = sample_models_params + experiments_params
Expand All @@ -158,7 +171,7 @@ def show_free_params(self):
dataframe = self._get_params_as_dataframe(free_params)
self._show_params(dataframe, column_headers=column_headers)

def how_to_access_parameters(self, show_description=False):
def how_to_access_parameters(self, show_description: bool = False) -> None:
sample_models_params = self.project.sample_models.get_all_params()
experiments_params = self.project.experiments.get_all_params()
params = {'sample_models': sample_models_params,
Expand Down Expand Up @@ -204,21 +217,20 @@ def how_to_access_parameters(self, show_description=False):
tablefmt="fancy_outline",
showindex=indices))


def show_current_calculator(self):
def show_current_calculator(self) -> None:
print(paragraph("Current calculator"))
print(self.current_calculator)

@staticmethod
def show_supported_calculators():
def show_supported_calculators() -> None:
CalculatorFactory.show_supported_calculators()

@property
def current_calculator(self):
def current_calculator(self) -> str:
return self._calculator_key

@current_calculator.setter
def current_calculator(self, calculator_name):
def current_calculator(self, calculator_name: str) -> None:
calculator = CalculatorFactory.create_calculator(calculator_name)
if calculator is None:
return
Expand All @@ -227,30 +239,30 @@ def current_calculator(self, calculator_name):
print(paragraph("Current calculator changed to"))
print(self.current_calculator)

def show_current_minimizer(self):
def show_current_minimizer(self) -> None:
print(paragraph("Current minimizer"))
print(self.current_minimizer)

@staticmethod
def show_available_minimizers():
def show_available_minimizers() -> None:
MinimizerFactory.show_available_minimizers()

@property
def current_minimizer(self):
def current_minimizer(self) -> Optional[str]:
return self.fitter.selection if self.fitter else None

@current_minimizer.setter
def current_minimizer(self, selection):
def current_minimizer(self, selection: str) -> None:
self.fitter = DiffractionMinimizer(selection)
print(paragraph(f"Current minimizer changed to"))
print(self.current_minimizer)

@property
def fit_mode(self):
def fit_mode(self) -> str:
return self._fit_mode

@fit_mode.setter
def fit_mode(self, strategy):
def fit_mode(self, strategy: str) -> None:
if strategy not in ['single', 'joint']:
raise ValueError("Fit mode must be either 'single' or 'joint'")
self._fit_mode = strategy
Expand All @@ -263,7 +275,7 @@ def fit_mode(self, strategy):
print(paragraph("Current fit mode changed to"))
print(self._fit_mode)

def show_available_fit_modes(self):
def show_available_fit_modes(self) -> None:
strategies = [
{
"Strategy": "single",
Expand All @@ -276,18 +288,26 @@ def show_available_fit_modes(self):
print(paragraph("Available fit modes"))
print(tabulate(strategies, headers="keys", tablefmt="fancy_outline", showindex=False))

def show_current_fit_mode(self):
print(paragraph("Current ffit mode"))
def show_current_fit_mode(self) -> None:
print(paragraph("Current fit mode"))
print(self.fit_mode)

def calculate_pattern(self, expt_name):
# Pattern is calculated for the given experiment
def calculate_pattern(self, expt_name: str) -> Optional[np.ndarray]:
"""
Calculate the diffraction pattern for a given experiment.

Args:
expt_name: The name of the experiment.

Returns:
The calculated pattern as a pandas DataFrame.
"""
experiment = self.project.experiments[expt_name]
sample_models = self.project.sample_models
calculated_pattern = self.calculator.calculate_pattern(sample_models, experiment)
return calculated_pattern

def show_constraints(self):
def show_constraints(self) -> None:
constraints_dict = self.constraints._items

if not self.constraints._items:
Expand All @@ -312,7 +332,7 @@ def show_constraints(self):
tablefmt="fancy_outline",
showindex=False))

def _update_uid_map(self):
def _update_uid_map(self) -> None:
"""
Update the UID map for accessing parameters by UID.
This is needed for adding or removing constraints.
Expand All @@ -323,7 +343,7 @@ def _update_uid_map(self):

UidMapHandler.get().set_uid_map(params)

def apply_constraints(self):
def apply_constraints(self) -> None:
if not self.constraints._items:
print(warning(f"No constraints defined."))
return
Expand All @@ -337,7 +357,7 @@ def apply_constraints(self):
self.constraints_handler.set_expressions(self.constraints)
self.constraints_handler.apply(parameters=fittable_params)

def show_calc_chart(self, expt_name, x_min=None, x_max=None):
def show_calc_chart(self, expt_name: str, x_min: Optional[float] = None, x_max: Optional[float] = None) -> None:
self.calculate_pattern(expt_name)

experiment = self.project.experiments[expt_name]
Expand All @@ -354,11 +374,11 @@ def show_calc_chart(self, expt_name, x_min=None, x_max=None):
)

def show_meas_vs_calc_chart(self,
expt_name,
x_min=None,
x_max=None,
show_residual=False,
chart_height=DEFAULT_HEIGHT):
expt_name: str,
x_min: Optional[float] = None,
x_max: Optional[float] = None,
show_residual: bool = False,
chart_height: int = DEFAULT_HEIGHT) -> None:
experiment = self.project.experiments[expt_name]

self.calculate_pattern(expt_name)
Expand Down Expand Up @@ -387,7 +407,7 @@ def show_meas_vs_calc_chart(self,
labels=labels
)

def fit(self):
def fit(self) -> None:
sample_models = self.project.sample_models
if not sample_models:
print("No sample models found in the project. Cannot run fit.")
Expand Down Expand Up @@ -422,15 +442,15 @@ def fit(self):
# After fitting, get the results
self.fit_results = self.fitter.results

def as_cif(self):
def as_cif(self) -> str:
lines = []
lines.append(f"_analysis.calculator_engine {self.current_calculator}")
lines.append(f"_analysis.fitting_engine {self.current_minimizer}")
lines.append(f"_analysis.fit_mode {self.fit_mode}")

return "\n".join(lines)

def show_as_cif(self):
def show_as_cif(self) -> None:
cif_text = self.as_cif()
lines = cif_text.splitlines()
max_width = max(len(line) for line in lines)
Expand Down
22 changes: 13 additions & 9 deletions src/easydiffraction/analysis/calculation.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
from typing import Any, Optional, List
import numpy as np
from .calculators.calculator_factory import CalculatorFactory

from easydiffraction.sample_models.sample_models import SampleModels
from easydiffraction.experiments.experiments import Experiments
from easydiffraction.experiments.experiment import Experiment

class DiffractionCalculator:
"""
Invokes calculation engines for pattern generation.
"""

def __init__(self, engine='cryspy'):
def __init__(self, engine: str = 'cryspy') -> None:
"""
Initialize the DiffractionCalculator with a specified backend engine.

Args:
calculator_type (str): Type of the calculation engine to use.
Supported types: 'crysfml', 'cryspy', 'pdffit'.
Default is 'crysfml'.
engine: Type of the calculation engine to use.
Supported types: 'crysfml', 'cryspy', 'pdffit'.
Default is 'cryspy'.
"""
self.calculator_factory = CalculatorFactory()
self._calculator = self.calculator_factory.create_calculator(engine)

def set_calculator(self, engine):
def set_calculator(self, engine: str) -> None:
"""
Switch to a different calculator engine at runtime.

Args:
engine (str): New calculation engine type to use.
engine: New calculation engine type to use.
"""
self._calculator = self.calculator_factory.create_calculator(engine)

def calculate_structure_factors(self, sample_models, experiments):
def calculate_structure_factors(self, sample_models: SampleModels, experiments: Experiments) -> Optional[List[Any]]:
"""
Calculate HKL intensities (structure factors) for sample models and experiments.

Expand All @@ -40,7 +44,7 @@ def calculate_structure_factors(self, sample_models, experiments):
"""
return self._calculator.calculate_structure_factors(sample_models, experiments)

def calculate_pattern(self, sample_models, experiment):
def calculate_pattern(self, sample_models: SampleModels, experiment: Experiment) -> np.ndarray:
"""
Generate diffraction pattern based on sample models and experiment.

Expand Down
Loading