forked from npshub/mantid
-
Notifications
You must be signed in to change notification settings - Fork 0
/
fit_information_test.py
215 lines (177 loc) · 9.28 KB
/
fit_information_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
# Mantid Repository : https://github.com/mantidproject/mantid
#
# Copyright © 2018 ISIS Rutherford Appleton Laboratory UKRI,
# NScD Oak Ridge National Laboratory, European Spallation Source,
# Institut Laue - Langevin & CSNS, Institute of High Energy Physics, CAS
# SPDX - License - Identifier: GPL - 3.0 +
import unittest
from mantid.api import AnalysisDataService, WorkspaceFactory
from mantid.kernel import FloatTimeSeriesProperty, StringPropertyWithValue
from unittest import mock
from Muon.GUI.Common.contexts.fitting_context import FitInformation
def create_test_workspace(ws_name=None,
time_series_logs=None,
string_value_logs=None):
"""
Create a test workspace.
:param ws_name: An optional name for the workspace
:param time_series_logs: A set of (name, (values,...))
:param string_value_logs: A set of (name, value) pairs
:return: The new workspace
"""
fake_ws = WorkspaceFactory.create('Workspace2D', 1, 1, 1)
run = fake_ws.run()
if time_series_logs is not None:
for name, values in time_series_logs:
tsp = FloatTimeSeriesProperty(name)
for item in values:
try:
time, value = item[0], item[1]
except TypeError:
time, value = "2000-05-01T12:00:00", item
tsp.addValue(time, value)
run.addProperty(name, tsp, replace=True)
if string_value_logs is not None:
for name, value in string_value_logs:
run.addProperty(
name, StringPropertyWithValue(name, value), replace=True)
ws_name = ws_name if ws_name is not None else 'fitting_context_model_test'
AnalysisDataService.Instance().addOrReplace(ws_name, fake_ws)
return fake_ws
class FitInformationTest(unittest.TestCase):
def tearDown(self):
AnalysisDataService.Instance().clear()
def test_equality_with_no_globals(self):
fit_info = FitInformation(mock.MagicMock(), 'MuonGuassOsc',
mock.MagicMock(), mock.MagicMock())
self.assertEqual(fit_info, fit_info)
def test_equality_with_globals(self):
fit_info = FitInformation(mock.MagicMock(), 'MuonGuassOsc',
mock.MagicMock(), mock.MagicMock(), ['A'])
self.assertEqual(fit_info, fit_info)
def test_inequality_with_globals(self):
fit_info1 = FitInformation(mock.MagicMock(), 'MuonGuassOsc',
mock.MagicMock(), ['A'])
fit_info2 = FitInformation(mock.MagicMock(), 'MuonGuassOsc',
mock.MagicMock(), ['B'])
self.assertNotEqual(fit_info1, fit_info2)
def test_empty_global_parameters_if_none_specified(self):
fit_information_object = FitInformation(mock.MagicMock(),
mock.MagicMock(),
mock.MagicMock(),
mock.MagicMock())
self.assertEqual([],
fit_information_object.parameters.global_parameters)
def test_global_parameters_are_captured(self):
fit_information_object = FitInformation(mock.MagicMock(),
mock.MagicMock(),
mock.MagicMock(),
mock.MagicMock(), ['A'])
self.assertEqual(['A'],
fit_information_object.parameters.global_parameters)
def test_parameters_are_readonly(self):
fit_info = FitInformation(mock.MagicMock(), mock.MagicMock(),
mock.MagicMock(), mock.MagicMock())
self.assertRaises(AttributeError, setattr, fit_info, "parameters",
mock.MagicMock())
def test_logs_from_workspace_without_logs_returns_emtpy_list(self):
fake_ws = create_test_workspace()
fit = FitInformation(mock.MagicMock(), 'func1', fake_ws.name(),
fake_ws.name())
allowed_logs = fit.log_names()
self.assertEqual(0, len(allowed_logs))
def test_logs_for_single_workspace_return_all_time_series_logs(self):
time_series_logs = (('ts_1', (1., )), ('ts_2', (3., )))
single_value_logs = (('sv_1', 'val1'), ('sv_2', 'val2'))
fake_ws = create_test_workspace(time_series_logs=time_series_logs)
fit = FitInformation(mock.MagicMock(), 'func1', fake_ws.name(),
fake_ws.name())
log_names = fit.log_names()
for name, _ in time_series_logs:
self.assertTrue(
name in log_names, msg="{} not found in log list".format(name))
for name, _ in single_value_logs:
self.assertFalse(
name in log_names, msg="{} found in log list".format(name))
def test_log_names_from_list_of_workspaces_gives_combined_set(self):
time_series_logs = (('ts_1', (1., )), ('ts_2', (3., )), ('ts_3', [2.]),
('ts_4', [3.]))
fake1 = create_test_workspace(
ws_name='fake1', time_series_logs=time_series_logs[:2])
fake2 = create_test_workspace(
ws_name='fake2', time_series_logs=time_series_logs[2:])
fit = FitInformation(mock.MagicMock(), 'func1',
[fake1.name(), fake2.name()], [fake1.name(), fake2.name()])
log_names = fit.log_names()
self.assertEqual(len(time_series_logs), len(log_names))
for name, _ in time_series_logs:
self.assertTrue(
name in log_names, msg="{} not found in log list".format(name))
def test_log_names_uses_filter_fn(self):
time_series_logs = (('ts_1', (1., )), ('ts_2', (3., )), ('ts_3', [2.]),
('ts_4', [3.]))
fake1 = create_test_workspace(
ws_name='fake1', time_series_logs=time_series_logs)
fit = FitInformation(mock.MagicMock(), 'func1', fake1.name(),
fake1.name())
log_names = fit.log_names(lambda log: log.name == 'ts_1')
self.assertEqual(1, len(log_names))
self.assertEqual(time_series_logs[0][0], log_names[0])
def test_has_log_returns_true_if_all_workspaces_have_the_log(self):
time_series_logs = (('ts_1', (1., )), ('ts_2', (3., )))
fake1 = create_test_workspace(
ws_name='fake1', time_series_logs=time_series_logs)
fake2 = create_test_workspace(
ws_name='fake2', time_series_logs=time_series_logs)
fit = FitInformation(mock.MagicMock(), 'func1',
[fake1.name(), fake2.name()], [fake1.name(), fake2.name()])
self.assertTrue(fit.has_log('ts_1'))
def test_has_log_returns_false_if_all_workspaces_do_not_have_log(self):
time_series_logs = [('ts_1', (1., ))]
fake1 = create_test_workspace(
ws_name='fake1', time_series_logs=time_series_logs)
fake2 = create_test_workspace(ws_name='fake2')
fit = FitInformation(mock.MagicMock(), 'func1',
[fake1.name(), fake2.name()], [fake1.name(), fake2.name()])
self.assertFalse(
fit.has_log('ts_1'),
msg='All input workspaces should have the requested log')
def test_string_log_value_from_fit_with_single_workspace(self):
single_value_logs = [('sv_1', '5')]
fake1 = create_test_workspace(
ws_name='fake1', string_value_logs=single_value_logs)
fit = FitInformation(mock.MagicMock(), 'func1', [fake1.name()],
[fake1.name()])
self.assertEqual(
float(single_value_logs[0][1]),
fit.log_value(single_value_logs[0][0]))
def test_time_series_log_value_from_fit_with_single_workspace_uses_time_average(
self):
time_series_logs = \
[('ts_1', (("2000-05-01T12:00:00", 5.),
("2000-05-01T12:00:10", 20.),
("2000-05-01T12:05:00", 30.)))]
fake1 = create_test_workspace('fake1', time_series_logs)
fit = FitInformation(mock.MagicMock(), 'func1', [fake1.name()], [fake1.name()], )
time_average = (10 * 5 + 290 * 20) / 300.
self.assertAlmostEqual(time_average, fit.log_value('ts_1'), places=6)
def test_time_series_log_value_from_fit_with_multiple_workspaces_uses_average_of_time_average(
self):
time_series_logs1 = \
[('ts_1', (("2000-05-01T12:00:00", 5.),
("2000-05-01T12:00:10", 20.),
("2000-05-01T12:05:00", 30.)))]
fake1 = create_test_workspace('fake1', time_series_logs1)
time_series_logs2 = \
[('ts_1', (("2000-05-01T12:00:30", 10.),
("2000-05-01T12:01:45", 30.),
("2000-05-01T12:05:00", 40.)))]
fake2 = create_test_workspace('fake2', time_series_logs2)
fit = FitInformation(mock.MagicMock(), 'func1',
[fake1.name(), fake2.name()], [fake1.name(), fake2.name()])
time_average1 = (10 * 5 + 290 * 20) / 300.
time_average2 = (75 * 10 + 195 * 30) / 270.
all_average = 0.5 * (time_average1 + time_average2)
self.assertAlmostEqual(all_average, fit.log_value('ts_1'), places=6)
if __name__ == '__main__':
unittest.main(buffer=False, verbosity=2)