Skip to content

Commit

Permalink
Merge pull request #261 from fitbenchmarking/242_sorting
Browse files Browse the repository at this point in the history
Moving data sorting from plotting to data input

(Merging as admin as Tyrone is away, but his requested changes were fixed)
  • Loading branch information
AndrewLister-STFC committed Nov 5, 2019
2 parents 5e0b8c2 + 40437eb commit c36f0de
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 21 deletions.
4 changes: 3 additions & 1 deletion fitbenchmarking/fitbenchmark_one_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,12 @@ def benchmark(controller, minimizers):

if chi_sq < min_chi_sq:
min_chi_sq = chi_sq
index = controller.sorted_index
best_fit = plot_helper.data(name=minimizer,
x=controller.data_x,
y=controller.results,
E=controller.data_e)
E=controller.data_e,
sorted_index=index)

individual_result = \
misc.create_result_entry(problem=controller.problem,
Expand Down
3 changes: 3 additions & 0 deletions fitbenchmarking/fitting/controllers/base_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def _correct_data(self):
if self.data_e is not None:
self.data_e = self.data_e[mask]

# Stores the indices of the sorted data
self.sorted_index = np.argsort(self.data_x)

def prepare(self):
"""
Check that function and minimizer have been set.
Expand Down
25 changes: 8 additions & 17 deletions fitbenchmarking/fitting/plotting/plot_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ class data:
Holds all the data that is used in the plotting process.
"""

def __init__(self, name=None, x=[], y=[], E=[]):
def __init__(self, name=None, x=[], y=[], E=[], sorted_index=None):
"""
Creates a data object.
@param x :: the x data
@param y :: the y data
@param E :: the (y) errors
@param sorted_index :: the sorted indices from the x data
"""

if name is not None:
Expand All @@ -39,28 +40,19 @@ def __init__(self, name=None, x=[], y=[], E=[]):
self.E = np.zeros(len(self.x))
else:
self.E = copy.copy(E)

if sorted_index is not None:
self.x = self.x[sorted_index]
self.y = self.y[sorted_index]
self.z = self.E[sorted_index]

self.showError = False
self.markers = "x"
self.colour = "k"
self.linestyle = '--'
self.z_order = 1
self.linewidth = 1

def order_data(self):
"""
Ensures that the data is in ascending order in x.
Prevents line plots looping back on themselves.
"""

xData = self.x
yData = self.y
eData = self.E

index = np.argsort(xData)
xData = xData[index]
yData = yData[index]
eData = eData[index]


class plot(data):
"""
Expand Down Expand Up @@ -96,7 +88,6 @@ def make_scatter_plot(self, save=""):
self.set_plot_misc()
self.save_plot(save)


plt.close()

def set_plot_misc(self):
Expand Down
2 changes: 0 additions & 2 deletions fitbenchmarking/fitting/plotting/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def make_best_fit_plot(name, raw_data, best_fit, count, figures_dir):
best_fit.colour = 'lime'
best_fit.zorder = 2
best_fit.linewidth = 1.5
best_fit.order_data()
fig.add_data(best_fit)
fig.labels['y'] = "Arbitrary units"
fig.labels['x'] = "Time ($\mu s$)"
Expand Down Expand Up @@ -126,7 +125,6 @@ def make_starting_guess_plot(raw_data, problem, count, figures_dir):
xData = problem.data_x
yData = problem.eval_starting_params(0)
startData = data("Start Guess", xData, yData)
startData.order_data()
startData.colour = "red"
startData.markers = ''
startData.linestyle = "-"
Expand Down
1 change: 1 addition & 0 deletions fitbenchmarking/parsing/base_fitting_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# python3
from itertools import zip_longest as izip_longest


class BaseFittingProblem:
"""
Definition of a base class implementation of the fitting test problem,
Expand Down
2 changes: 1 addition & 1 deletion fitbenchmarking/parsing/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def parse_problem_file(prob_file):

prob_type = determine_problem_type(prob_file)
logger.info("Loading {0} formatted problem definition file {1} | Path: "
"{2}".format(prob_type,os.path.basename(prob_file),prob_file[prob_file.find('fitbenchmarking'):]))
"{2}".format(prob_type, os.path.basename(prob_file), prob_file[prob_file.find('fitbenchmarking'):]))

if prob_type == "NIST":

Expand Down

0 comments on commit c36f0de

Please sign in to comment.