Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ImageStack Class for loading, saving, and managing image stacks. #2165

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 131 additions & 113 deletions scripts/operations_tests/operations_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,27 @@ def __bool__(self):
return self.status == "pass"




class ImageStack:
def __init__(self, filepath: str):
self.filepath = filepath
self.data = self.load_image_stack()

def load_image_stack(self):
filename_group = FilenameGroup.from_file(self.filepath)
filename_group.find_all_files()
image_stack = loader.load(filename_group=filename_group)
return image_stack

def save(self, filepath: str):
np.savez(filepath, self.data)

@classmethod
def load_from_file(cls, filepath: str) -> ImageStack:
return cls(filepath)


def process_params(param):
"""
Handle parameter values that cannot be encoded natively in json
Expand All @@ -88,7 +109,7 @@ def process_params(param):
return param


def compare_mode():
def compare_mode(args):
for operation, test_case_info in TEST_CASES.items():
print(f"Running tests for {operation}:")
cases = test_case_info["cases"]
Expand All @@ -102,13 +123,14 @@ def compare_mode():
op_class = FILTERS[operation]
op_func = op_class.filter_func
test_case = TestCase(operation, test_name, sub_test_name, test_number, params, op_func)
run_test(test_case)
run_test(test_case, args) # Pass args here
print("\n")

print_compare_mode_results()
print_compare_mode_results(args)



def print_compare_mode_results():
def print_compare_mode_results(args):
print(f"{'=' * 40}RESULTS{'=' * 40}")
failures = 0
passes = 0
Expand All @@ -133,9 +155,10 @@ def print_compare_mode_results():
print(f"{'=' * 42}END{'=' * 42}")


def time_mode(runs):

def time_mode(args, runs):
durations = defaultdict(list)
image_stack = load_image_stack()
image_stack_manager = ImageStackManager(ImageStack.load_from_file(config_manager.load_sample))
for operation, test_case_info in TEST_CASES.items():
print(f"Running tests for {operation}:")
cases = test_case_info["cases"]
Expand All @@ -145,8 +168,8 @@ def time_mode(runs):
params = case["params"] | test_case_info["params"]
op_func = FILTERS[operation].filter_func
for _ in range(runs):
image_stack2 = image_stack.copy()
duration = time_operation(image_stack2, op_func, params)[0]
image_stack_manager.load_image_stack() # Load the image stack for each run
duration = image_stack_manager.time_operation(op_func, params)
durations[test_name].append(duration)

print_time_mode_results(durations)
Expand All @@ -168,114 +191,110 @@ def print_time_mode_results(durations):
avg,
quickest,
slowest,
stdev(times),
np.std(times),
GIT_TOKEN,
COMMIT_DATE,
)
writer.writerow(data)

print(f"{'=' * 42}END{'=' * 42}")


def time_operation(image_stack, op_func, params):
start = time.perf_counter()
image_stack = run_operation(image_stack, op_func, params)
duration = time.perf_counter() - start
return duration, image_stack


def run_test(test_case):
image_stack = load_image_stack()
test_case.duration, new_image_stack = time_operation(image_stack, test_case.op_func, test_case.params)
file_name = config_manager.save_dir / (test_case.test_name + ".npz")

if file_name.is_file():
baseline_image_stack = load_post_operation_image_stack(file_name)
compare_image_stacks(baseline_image_stack, new_image_stack.data, test_case)

if test_case.status == "pass":
print(".", end="")
elif test_case.status == "fail":
print("F", end="")
def run_test(test_case, args):
image_stack_manager = ImageStackManager(ImageStack(config_manager.load_sample))
image_stack_manager.run_test(test_case, args)

class ImageStackManager:
def __init__(self, image_stack: ImageStack):
self.image_stack = image_stack

def time_operation(self, op_func, params):
start = time.perf_counter()
self.image_stack.data = self.run_operation(op_func, params)
duration = time.perf_counter() - start
return duration

def run_operation(self, op_func, params):
op_func(self.image_stack.data, **params)
return self.image_stack.data

def compare_image_stacks(self, baseline_image_stack, new_image_stack, test_case, args):
if not (isinstance(baseline_image_stack, np.ndarray) and isinstance(new_image_stack, np.ndarray)):
test_case.status = "fail"
test_case.message = "new image stack is not an array"
elif baseline_image_stack.shape != new_image_stack.shape:
test_case.status = "fail"
test_case.message = "new image stack is different shape to the baseline"
elif baseline_image_stack.dtype != new_image_stack.dtype:
test_case.status = "fail"
test_case.message = "new image stack is different dtype to the baseline"
elif not np.array_equal(baseline_image_stack, new_image_stack):
test_case.status = "fail"
test_case.message = "arrays are not equal"
if args.gui:
self.gui_compare_image_stacks(baseline_image_stack, new_image_stack)
else:
print("?", end="")
test_case.status = "unknown"
else:
print("X", end="")
test_case.status = "new baseline"
save_image_stack(file_name, new_image_stack)

TEST_CASE_RESULTS.append(test_case)


def run_operation(image_stack, op_func, params):
op_func(image_stack, **params)
return image_stack


def save_image_stack(filepath, image_stack):
np.savez(filepath, image_stack.data)


def load_post_operation_image_stack(filepath):
return np.load(filepath)["arr_0"]


def load_image_stack():
filename_group = FilenameGroup.from_file(config_manager.load_sample)
filename_group.find_all_files()
image_stack = loader.load(filename_group=filename_group)
return image_stack


def compare_image_stacks(baseline_image_stack, new_image_stack, test_case):
if not (isinstance(baseline_image_stack, np.ndarray) and isinstance(new_image_stack, np.ndarray)):
test_case.status = "fail"
test_case.message = "new image stack is not an array"
elif baseline_image_stack.shape != new_image_stack.shape:
test_case.status = "fail"
test_case.message = "new image stack is different shape to the baseline"
elif baseline_image_stack.dtype != new_image_stack.dtype:
test_case.status = "fail"
test_case.message = "new image stack is different dtype to the baseline"
elif not np.array_equal(baseline_image_stack, new_image_stack):
test_case.status = "fail"
test_case.message = "arrays are not equal"
if args.gui:
gui_compare_image_stacks(baseline_image_stack, new_image_stack)
else:
test_case.status = "pass"
test_case.message = "arrays are equal"


def gui_compare_image_stacks(baseline_image_stack, new_image_stack):
from PyQt5.QtWidgets import QApplication, QWidget, QHBoxLayout
from mantidimaging.gui.widgets.mi_image_view.view import MIImageView
app = QApplication([])
win = QWidget()
win.resize(600, 900)
layout = QHBoxLayout()
win.setLayout(layout)

imvs = []
for name, data in (
("Baseline", baseline_image_stack),
("New", new_image_stack),
("Diff", new_image_stack - baseline_image_stack),
):
imv = MIImageView()
imv.name = name
imv.enable_nan_check(True)
imv.setImage(data)
layout.addWidget(imv)
imvs.append(imv)

imvs[0].sigTimeChanged.connect(imvs[1].setCurrentIndex)
imvs[0].sigTimeChanged.connect(imvs[2].setCurrentIndex)

win.show()
app.exec()

test_case.status = "pass"
test_case.message = "arrays are equal"

def run_test(self, test_case, args):
test_case.duration = self.time_operation(test_case.op_func, test_case.params)
file_name = config_manager.save_dir / (test_case.test_name + ".npz")

if file_name.is_file():
baseline_image_stack = np.load(file_name)["arr_0"]
self.compare_image_stacks(baseline_image_stack, self.image_stack.data, test_case, args)

if test_case.status == "pass":
print(".", end="")
elif test_case.status == "fail":
print("F", end="")
else:
print("?", end="")
test_case.status = "unknown"
else:
print("X", end="")
test_case.status = "new baseline"
self.image_stack.save(file_name)

def save_image_stack(self, filepath):
np.savez(filepath, self.image_stack.data)

def load_post_operation_image_stack(self, filepath):
return np.load(filepath)["arr_0"]

def load_image_stack(self):
filename_group = FilenameGroup.from_file(config_manager.load_sample)
filename_group.find_all_files()
image_stack = loader.load(filename_group=filename_group)
return image_stack

def gui_compare_image_stacks(self, baseline_image_stack, new_image_stack):
from PyQt5.QtWidgets import QApplication, QWidget, QHBoxLayout
from mantidimaging.gui.widgets.mi_image_view.view import MIImageView
app = QApplication([])
win = QWidget()
win.resize(600, 900)
layout = QHBoxLayout()
win.setLayout(layout)

imvs = []
for name, data in (
("Baseline", baseline_image_stack),
("New", new_image_stack),
("Diff", new_image_stack - baseline_image_stack),
):
imv = MIImageView()
imv.name = name
imv.enable_nan_check(True)
imv.setImage(data)
layout.addWidget(imv)
imvs.append(imv)

imvs[0].sigTimeChanged.connect(imvs[1].setCurrentIndex)
imvs[0].sigTimeChanged.connect(imvs[2].setCurrentIndex)

win.show()
app.exec()

def create_plots():
df = pd.read_csv("timings.csv", parse_dates=["commit_date"])
Expand Down Expand Up @@ -330,18 +349,17 @@ def main():
parser.add_argument("-k", dest="match", type=str, help="only run tests which match the given substring expression")
parser.add_argument("--gui", dest="gui", action="store_true", help="Show GUI comparison for differences")

global args
args = parser.parse_args()

if args.mode == "time":
time_mode(args.runs)
time_mode(args, args.runs)

if args.mode == "compare":
compare_mode()
elif args.mode == "compare":
compare_mode(args)

if args.graphs:
create_plots()


if __name__ == "__main__":
main()
main()
Loading