Skip to content

Commit

Permalink
Update ArithmeticFilter to use ps.run_compute_func
Browse files Browse the repository at this point in the history
  • Loading branch information
samtygier-stfc committed Jul 4, 2022
1 parent d04450c commit d703d22
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 27 deletions.
34 changes: 11 additions & 23 deletions mantidimaging/core/operations/arithmetic/arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (C) 2022 ISIS Rutherford Appleton Laboratory UKRI
# SPDX - License - Identifier: GPL-3.0-or-later
from functools import partial
from typing import Callable, Dict
from typing import Callable, Dict, List

import numpy as np
from PyQt5.QtWidgets import QFormLayout, QWidget, QDoubleSpinBox
Expand All @@ -14,19 +14,6 @@
from mantidimaging.core.parallel import shared as ps


def _arithmetic_func(data: np.ndarray, div_val: float, mult_val: float, add_val: float, sub_val: float):
"""
Process target function for the arithmetic operation.
:param data: The data array.
:param div_val: The division value.
:param mult_val: The multiplication value.
:param add_val: The addition value.
:param sub_val: The subtraction value.
"""
for i in range(len(data)):
data[i] = data[i] / div_val * mult_val + add_val - sub_val


class ArithmeticFilter(BaseFilter):
"""Add, subtract, multiply, or divide all grey values of an image with the given values.
Expand All @@ -37,8 +24,9 @@ class ArithmeticFilter(BaseFilter):
"""
filter_name = "Arithmetic"

@staticmethod
def filter_func(images: ImageStack,
@classmethod
def filter_func(cls,
images: ImageStack,
div_val: float = 1.0,
mult_val: float = 1.0,
add_val: float = 0.0,
Expand All @@ -57,9 +45,15 @@ def filter_func(images: ImageStack,
if div_val == 0 or mult_val == 0:
raise ValueError("Unable to proceed with operation because division/multiplication value is zero.")

_execute(images, div_val, mult_val, add_val, sub_val, progress)
params = {'div': div_val, 'mult': mult_val, 'add': add_val, 'sub': sub_val}
ps.run_compute_func(cls.compute_function, images.data.shape[0], [images.shared_array], params, progress)

return images

@staticmethod
def compute_function(i: int, arrays: List[np.ndarray], params: Dict[str, float]):
arrays[0][i] = arrays[0][i] / params['div'] * params['mult'] + params['add'] - params['sub']

@staticmethod
def register_gui(form: 'QFormLayout', on_change: Callable, view: 'BaseMainWindowView') -> Dict[str, 'QWidget']:
_, mult_input_widget = add_property_to_form('Multiply',
Expand Down Expand Up @@ -112,9 +106,3 @@ def execute_wrapper( # type: ignore
div_val=div_input_widget.value(),
add_val=add_input_widget.value(),
sub_val=sub_input_widget.value())


def _execute(images: ImageStack, div_val: float, mult_val: float, add_val: float, sub_val: float, progress):
arg_list = [div_val, mult_val, add_val, sub_val]
do_arithmetic = ps.create_partial(_arithmetic_func, fwd_function=ps.arithmetic, arg_list=arg_list)
ps.execute(do_arithmetic, [images.shared_array], images.data.shape[0], progress)
4 changes: 0 additions & 4 deletions mantidimaging/core/parallel/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ def return_to_second_at_i(func, data: Union[List[pu.SharedArray], List[pu.Shared
data[1].array[i] = func(data[0].array[i], **kwargs)


def arithmetic(func, data: Union[List[pu.SharedArray], List[pu.SharedArrayProxy]], i, arg_list):
func(data[0].array[i], *arg_list)


def create_partial(func, fwd_function, **kwargs):
"""
Create a partial using functools.partial, to forward the kwargs to the
Expand Down

0 comments on commit d703d22

Please sign in to comment.