diff --git a/README.md b/README.md index 780cde070..5455f8d60 100644 --- a/README.md +++ b/README.md @@ -28,10 +28,10 @@ For help of how to use the command line/terminal, click the hyperlink correspond The above step is done to ensure that the compatible version of docutils packages (version 0.12) is installed. -7. Finally, in this terminal, run `example_scripts/example_runScript.py`, located in the fitbenchmarking folder. This example script fit benchmarks Mantid using all the available minimizers. The resulting tables can be found in `example_scripts/results`. +7. Finally, in this terminal, run `example_scripts/example_runScript_mantid.py`, located in the fitbenchmarking folder. This example script fit benchmarks Mantid using all the available minimizers. The resulting tables can be found in `example_scripts/results`. ## FitBenchmarking Scipy -The `example_runScripts.py` file can be changed such that it benchmarks minimizers supported by scipy instead of mantid (details provided in the file itself). +The `example_runScripts.py` file is designed to benchmark minimizers supported by software/libraries that provide straightforward cross-platform Python install; as of know this mean SciPy (more details provided in the file itself). For this to work scipy version 0.17 or higher is needed (which includes needed [curve_fit](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html) support). **The Linux distributions we have tested against so far have all included scipy 0.17+ (0.17 is from Feb 2016).** @@ -54,6 +54,16 @@ Mantid on Windows is shipped with Python. The above steps can also be done from terminal, in which case please ensure that you are upgrading against Python installed with Mantid, which by default is located in `C:\MantidInstall\bin`. +## FitBenchmarking SasView +The `example_runScripts_SasView.py` file is designed to benchmark minimizers supported by SasView (Bumps). + +In order to do so, Bumps, sasmodels, lxml and sascalc need to be installed. Note that Bumps, sasmodels and lxml can be installed via `pip` commands. However, as of this writing, sascalc is not an independent package yet and, therefore, cannot be installed via `pip`. Thus, sascalc is now included in FitBenchmarking under the folder `fitbenchmarking/sas`. + +To install Bumps, sasmodels and lxml, run the following command on console: +1. `python -m pip install bumps` +2. `python -m pip install sasmodels` +3. `python -m pip install lxml` + ## Description The tool creates a table/tables that shows a comparison between the different minimizers available in a fitting software (e.g. scipy or mantid), based on their accuracy and/or runtimes. An example of a table is: diff --git a/benchmark_problems/1D_data/data_files/cyl_400_20.txt b/benchmark_problems/1D_data/data_files/cyl_400_20.txt new file mode 100644 index 000000000..e321e3401 --- /dev/null +++ b/benchmark_problems/1D_data/data_files/cyl_400_20.txt @@ -0,0 +1,22 @@ + +0 -1.#IND +0.025 125.852 +0.05 53.6662 +0.075 26.0733 +0.1 11.8935 +0.125 4.61714 +0.15 1.29983 +0.175 0.171347 +0.2 0.0417614 +0.225 0.172719 +0.25 0.247876 +0.275 0.20301 +0.3 0.104599 +0.325 0.0285595 +0.35 0.00213344 +0.375 0.0137511 +0.4 0.0312374 +0.425 0.0350328 +0.45 0.0243172 +0.475 0.00923067 +0.5 0.00121297 diff --git a/benchmark_problems/1D_data/data_files/cyl_400_40.txt b/benchmark_problems/1D_data/data_files/cyl_400_40.txt new file mode 100755 index 000000000..b533fa18e --- /dev/null +++ b/benchmark_problems/1D_data/data_files/cyl_400_40.txt @@ -0,0 +1,56 @@ + +0 -1.#IND +0.00925926 1246.59 +0.0185185 612.143 +0.0277778 361.142 +0.037037 211.601 +0.0462963 122.127 +0.0555556 65.2385 +0.0648148 30.8914 +0.0740741 12.4737 +0.0833333 3.51371 +0.0925926 0.721835 +0.101852 0.583607 +0.111111 1.31084 +0.12037 1.9432 +0.12963 1.94286 +0.138889 1.58912 +0.148148 0.987076 +0.157407 0.456678 +0.166667 0.147595 +0.175926 0.027441 +0.185185 0.0999575 +0.194444 0.198717 +0.203704 0.277667 +0.212963 0.288172 +0.222222 0.220056 +0.231481 0.139378 +0.240741 0.0541106 +0.25 0.0140158 +0.259259 0.0132187 +0.268519 0.0336301 +0.277778 0.0672911 +0.287037 0.0788983 +0.296296 0.0764438 +0.305556 0.0555445 +0.314815 0.0280548 +0.324074 0.0111798 +0.333333 0.00156156 +0.342593 0.00830883 +0.351852 0.0186266 +0.361111 0.0275426 +0.37037 0.03192 +0.37963 0.0255329 +0.388889 0.0175216 +0.398148 0.0073075 +0.407407 0.0016631 +0.416667 0.00224153 +0.425926 0.0051335 +0.435185 0.0112914 +0.444444 0.0138209 +0.453704 0.0137453 +0.462963 0.0106682 +0.472222 0.00532472 +0.481481 0.00230646 +0.490741 0.000335344 +0.5 0.00177224 diff --git a/benchmark_problems/1D_data/prob_def_1.txt b/benchmark_problems/1D_data/prob_def_1.txt new file mode 100644 index 000000000..19cd39202 --- /dev/null +++ b/benchmark_problems/1D_data/prob_def_1.txt @@ -0,0 +1,6 @@ +# An example data set for SasView 1D data +name = 'Problem Def 1' +input_file = 'cyl_400_20.txt' +function ='name=cylinder,radius=35.0,length=350.0,background=0.0,scale=1.0,sld=4.0,sld_solvent=1.0' +parameter_ranges = 'radius.range(1,50);length.range(1,500)' +description = '' diff --git a/benchmark_problems/1D_data/prob_def_2.txt b/benchmark_problems/1D_data/prob_def_2.txt new file mode 100644 index 000000000..b5d76e0b3 --- /dev/null +++ b/benchmark_problems/1D_data/prob_def_2.txt @@ -0,0 +1,6 @@ +# An example data set for SasView 1D data +name = 'Problem Def 2' +input_file = 'cyl_400_40.txt' +function ='name=cylinder,radius=35.0,length=350.0,background=0.0,scale=1.0,sld=4.0,sld_solvent=1.0' +parameter_ranges = 'radius.range(1,50);length.range(1,500)' +description = '' \ No newline at end of file diff --git a/example_scripts/SasView_example.py b/example_scripts/SasView_example.py new file mode 100644 index 000000000..64684dc13 --- /dev/null +++ b/example_scripts/SasView_example.py @@ -0,0 +1,87 @@ +from sasmodels.core import load_model +from sasmodels.bumps_model import Model, Experiment +from sasmodels.data import load_data, empty_data1D, Data1D + +from sasmodels.models.broad_peak import Iq + +from bumps.names import * +from bumps.fitters import fit +from bumps.formatnum import format_uncertainty + +import matplotlib.pyplot as plt + +import os + +current_path = os.path.realpath(__file__) +dir_path = os.path.dirname(current_path) +main_dir = os.path.dirname(dir_path) +oneD_data_dir = os.path.join(main_dir, 'benchmark_problems', '1D_data', 'data_files', 'cyl_400_20.txt') + +test_data = load_data(oneD_data_dir) +test_data.dy = 0.2*test_data.y + +# print(type(test_data)) + +data_1D = Data1D(x=test_data.x, y=test_data.y, dy=test_data.dy) + +print(type(data_1D.x)) + +kernel = load_model('cylinder') + +# model_test = load_model('sphere') + +# kernel = load_model('broad_peak') +# print(type(test_load)) +#We set some errors for demonstration + +# x_data = empty_data1D(test_data.x) +# print(x_data.x) +# print(type(x_data)) +# print(test_data.qmin) +# print(test_data.y) +# print(type(data_1D)) + +pars = dict(radius=35, + length=350, + background=0.0, + scale=1.0, + sld=4.0, + sld_solvent=1.0) + +model = Model(kernel, **pars) +# print(model.parameters()) +# model = Model(kernel) + +# SET THE FITTING PARAMETERS +model.radius.range(1, 50) +model.length.range(1, 500) + +# M = Experiment(data=data_1D, model=model) +M = Experiment(data=test_data, model=model) + +param_initial = M.parameters() +radius_initial = param_initial['radius'] + +problem = FitProblem(M) + +print("Initial chisq", problem.chisq_str()) +# problem.plot() +# problem.summarize() +# pylab.show() +# plt.show() +result = fit(problem, method='dream') + +# print(M.theory()) +# +# print(test_data.y) + +print("Final chisq", problem.chisq_str()) +for k, v, dv in zip(problem.labels(), result.x, result.dx): + print(k, ":", format_uncertainty(v, dv)) + + +# problem.plot() +# print(model.state()) +# print(problem.y) +# plt.show() +# print((M.parameters())) \ No newline at end of file diff --git a/example_scripts/example_runScripts.py b/example_scripts/example_runScripts.py index 1a3591600..e31922365 100644 --- a/example_scripts/example_runScripts.py +++ b/example_scripts/example_runScripts.py @@ -51,6 +51,7 @@ benchmark_probs_dir = os.path.join(fitbenchmarking_folder, 'benchmark_problems') + """ Modify results_dir to specify where the results of the fit should be saved If left as None, they will be saved in a "results" folder in the working dir diff --git a/example_scripts/example_runScripts_SasView.py b/example_scripts/example_runScripts_SasView.py new file mode 100644 index 000000000..31da5c2ac --- /dev/null +++ b/example_scripts/example_runScripts_SasView.py @@ -0,0 +1,126 @@ + + +from __future__ import (absolute_import, division, print_function) +import os +import sys + +# Avoid reaching the maximum recursion depth by setting recursion limit +# This is useful when running multiple data set benchmarking +# Otherwise recursion limit is reached and the interpreter throws an error +sys.setrecursionlimit(10000) + +# Insert path to where the scripts are located, relative to +# the example_scripts folder +current_path = os.path.dirname(os.path.realpath(__file__)) +fitbenchmarking_folder = os.path.abspath(os.path.join(current_path, os.pardir)) +scripts_folder = os.path.join(fitbenchmarking_folder, 'fitbenchmarking') +sys.path.insert(0, scripts_folder) +sys.path.insert(1, fitbenchmarking_folder) + +try: + import bumps +except: + print('******************************************\n' + 'Bumps is not yet installed on your computer\n' + 'To install, type the following command:\n' + 'python -m pip install bumps\n' + '******************************************') + sys.exit() + +try: + import sasmodels.data +except: + print('******************************************\n' + 'sasmodels is not yet installed on your computer\n' + 'To install, type the following command:\n' + 'python -m pip install sasmodels\n' + '******************************************') + sys.exit() + +try: + import sas +except: + print('******************************************\n' + 'sas is not yet installed on your computer\n' + 'To install, clone a version of SasView from https://github.com/SasView/sasview\n' + 'After that, copy a folder called "sas" inside the sub-folder sasview/src to the fitbenchmarking directory\n' + '******************************************') + sys.exit() + +from fitting_benchmarking import do_fitting_benchmark as fitBenchmarking +from results_output import save_results_tables as printTables + +# SPECIFY THE SOFTWARE/PACKAGE CONTAINING THE MINIMIZERS YOU WANT TO BENCHMARK +# software = 'mantid' +software = 'sasview' +software_options = {'software': software} + +# User defined minimizers +custom_minimizers = {"mantid": ["BFGS", "Simplex"], + "scipy": ["lm", "trf", "dogbox"], + "sasview": ["amoeba"]} +# custom_minimizers = None +# "amoeba", "lm", "newton", "de", "pt", "mp" + +# SPECIFY THE MINIMIZERS YOU WANT TO BENCHMARK, AND AS A MINIMUM FOR THE SOFTWARE YOU SPECIFIED ABOVE +if len(sys.argv) > 1: + # Read custom minimizer options from file + software_options['minimizer_options'] = current_path + sys.argv[1] +elif custom_minimizers: + # Custom minimizer options: + software_options['minimizer_options'] = custom_minimizers +else: + # Using default minimizers from + # fitbenchmarking/fitbenchmarking/minimizers_list_default.json + software_options['minimizer_options'] = None + + +# Benchmark problem directories +benchmark_probs_dir = os.path.join(fitbenchmarking_folder, + 'benchmark_problems') + +""" +Modify results_dir to specify where the results of the fit should be saved +If left as None, they will be saved in a "results" folder in the working dir +If the full path is not given results_dir is created relative to the working dir +""" +results_dir = None + +# Whether to use errors in the fitting process +use_errors = True + +# Parameters of how the final tables are colored +# e.g. lower that 1.1 -> light yellow, higher than 3 -> dark red +# Change these values to suit your needs +color_scale = [(1.1, 'ranking-top-1'), + (1.33, 'ranking-top-2'), + (1.75, 'ranking-med-3'), + (3, 'ranking-low-4'), + (float('nan'), 'ranking-low-5')] + +# ADD WHICH PROBLEM SETS TO TEST AGAINST HERE +# Do this, in this example file, by selecting sub-folders in benchmark_probs_dir +# "Muon_data" works for mantid minimizers +# problem_sets = ["Neutron_data", "NIST/average_difficulty"] +# problem_sets = ["CUTEst", "Muon_data", "Neutron_data", "NIST/average_difficulty", "NIST/high_difficulty", "NIST/low_difficulty"] +problem_sets = ["1D_data"] +for sub_dir in problem_sets: + # generate group label/name used for problem set + label = sub_dir.replace('/', '_') + + # Problem data directory + data_dir = os.path.join(benchmark_probs_dir, sub_dir) + + print('\nRunning the benchmarking on the {} problem set\n'.format(label)) + results_per_group, results_dir = fitBenchmarking(group_name=label, software_options=software_options, + data_dir=data_dir, + use_errors=use_errors, results_dir=results_dir) + + print('\nProducing output for the {} problem set\n'.format(label)) + for idx, group_results in enumerate(results_per_group): + # Display the runtime and accuracy results in a table + printTables(software_options, group_results, + group_name=label, use_errors=use_errors, + color_scale=color_scale, results_dir=results_dir) + + print('\nCompleted benchmarking for {} problem set\n'.format(sub_dir)) diff --git a/example_scripts/example_runScripts_expert.py b/example_scripts/example_runScripts_expert.py index 8f0b8b714..bb12c1d99 100644 --- a/example_scripts/example_runScripts_expert.py +++ b/example_scripts/example_runScripts_expert.py @@ -31,7 +31,7 @@ from resproc import visual_pages # SPECIFY THE SOFTWARE/PACKAGE CONTAINING THE MINIMIZERS YOU WANT TO BENCHMARK -software = ['mantid', 'scipy'] +software = ['scipy'] software_options = {'software': software} # User defined minimizers @@ -80,7 +80,7 @@ # Do this, in this example file, by selecting sub-folders in benchmark_probs_dir # "Muon_data" works for mantid minimizers # problem_sets = ["Neutron_data", "NIST/average_difficulty"] -problem_sets = ["Neutron_data"] +problem_sets = ["CUTEst"] for sub_dir in problem_sets: # generate group group_name/name used for problem set group_name = sub_dir.replace('/', '_') diff --git a/example_scripts/example_runScripts_mantid.py b/example_scripts/example_runScripts_mantid.py index 0b5be16f4..7338647f8 100644 --- a/example_scripts/example_runScripts_mantid.py +++ b/example_scripts/example_runScripts_mantid.py @@ -37,13 +37,15 @@ from results_output import save_results_tables as printTables # SPECIFY THE SOFTWARE/PACKAGE CONTAINING THE MINIMIZERS YOU WANT TO BENCHMARK -software = 'mantid' +# software = 'mantid' +software = "mantid" software_options = {'software': software} # User defined minimizers -# custom_minimizers = {"mantid": ["BFGS", "Simplex"], - # "scipy": ["lm", "trf", "dogbox"]} -custom_minimizers = None +custom_minimizers = {"mantid": ["Simplex"], + "scipy": ["lm", "trf", "dogbox"]} +# custom_minimizers = None +# "BFGS", # SPECIFY THE MINIMIZERS YOU WANT TO BENCHMARK, AND AS A MINIMUM FOR THE SOFTWARE YOU SPECIFIED ABOVE @@ -86,7 +88,7 @@ # Do this, in this example file, by selecting sub-folders in benchmark_probs_dir # "Muon_data" works for mantid minimizers # problem_sets = ["CUTEst", "Muon_data", "Neutron_data", "NIST/average_difficulty", "NIST/high_difficulty", "NIST/low_difficulty"] -problem_sets = ['Muon_data', 'CUTEst'] +problem_sets = ['1D_data'] for sub_dir in problem_sets: # generate group label/name used for problem set diff --git a/fitbenchmarking/fitbenchmark_one_problem.py b/fitbenchmarking/fitbenchmark_one_problem.py index 4f333b703..97b6e7a3b 100644 --- a/fitbenchmarking/fitbenchmark_one_problem.py +++ b/fitbenchmarking/fitbenchmark_one_problem.py @@ -58,8 +58,11 @@ def fitbm_one_prob(user_input, problem): for function in function_definitions: # Ad hoc exception for running the scipy script # scipy does not currently support the GEM problem - if 'GEM' in problem.name and user_input.software == 'scipy': - break + # if 'GEM' in problem.name and user_input.software == 'scipy': + # break + + # if 'GEM' in problem.name and user_input.software == 'sasview': + # break results_problem, best_fit = \ fit_one_function_def(user_input.software, problem, data_struct, @@ -103,5 +106,9 @@ def fit_one_function_def(software, problem, data_struct, function, minimizers, from fitting.scipy.main import benchmark return benchmark(problem, data_struct, function, minimizers, cost_function) + elif software == 'sasview': + from fitting.sasview.main import benchmark + return benchmark(problem, data_struct, function, + minimizers, cost_function) else: raise NameError("Sorry, that software is not supported.") diff --git a/fitbenchmarking/fitting/mantid/externals.py b/fitbenchmarking/fitting/mantid/externals.py index 760015c11..3c04e06f5 100644 --- a/fitbenchmarking/fitting/mantid/externals.py +++ b/fitbenchmarking/fitting/mantid/externals.py @@ -61,9 +61,15 @@ def set_ties(function_object, ties): @returns :: mantid function object with ties """ + for idx, ties_per_func in enumerate(ties): for tie in ties_per_func: - exec "function_object.tie({'f" + str(idx) + "." + tie + "})" + exec("param_dict = {'f" + str(idx) + "." + tie + "}") + param_str = 'f'+str(idx)+'.'+(tie.split("'"))[0] + # function_object.tie(param_dict) + function_object.fix(param_str) + # exec "function_object.tie({'f" + str(idx) + "." + tie + "})" + # exec "function_object.tie(f" + str(idx) + "." + tie + ")" return function_object diff --git a/fitbenchmarking/fitting/mantid/func_def.py b/fitbenchmarking/fitting/mantid/func_def.py index 756d1a797..d1cb16667 100644 --- a/fitbenchmarking/fitting/mantid/func_def.py +++ b/fitbenchmarking/fitting/mantid/func_def.py @@ -50,6 +50,8 @@ def function_definitions(problem): elif problem_type == 'NIST': nb_start_vals = len(problem.starting_values[0][1]) function_defs = parse_function_definitions(problem, nb_start_vals) + elif problem_type == 'SasView'.upper(): + function_defs = parse_sasview_function_definitions(problem) else: raise NameError('Currently data types supported are FitBenchmark' ' and nist, data type supplied was {}'.format(problem_type)) @@ -62,7 +64,7 @@ def parse_function_definitions(problem, nb_start_vals): Helper function that parses the NIST function definitions and transforms them into a mantid-readable format. - @param prob :: object holding the problem information + @param problem :: object holding the problem information @param nb_start_vals :: the number of starting points for a given function definition @@ -100,6 +102,47 @@ def function1D(self, xdata): return function_defs +def parse_sasview_function_definitions(problem): + """ + Helper function that parses the SasView function definitions and + transforms them into a mantid-readable format. + + @param problem :: object holding the problem information + + @returns :: the formatted function definition (str) + """ + + function_defs = [] + + start_val_str = '' + + param_names = [(param.split('='))[0] for param in problem.starting_values.split(',')] + param_values = [(param.split('='))[1] for param in problem.starting_values.split(',')] + for name, value in zip(param_names, param_values): + start_val_str += ('{0}={1},'.format(name, value)) + # Eliminate trailing comma + start_val_str = start_val_str[:-1] + function_defs.append("name=fitFunction,{}". + format(start_val_str)) + + class fitFunction(IFunction1D): + def init(self): + + for param in param_names: + self.declareParameter(param) + + def function1D(self, xdata): + + fit_param = '' + for param in param_names: + fit_param += param + '=' + str(self.getParameterValue(param)) +',' + fit_param = fit_param[:-1] + return problem.eval_f(xdata, fit_param) + + FunctionFactory.subscribe(fitFunction) + + return function_defs + def extract_problem_type(problem): """ This function gets the problem object and figures out the problem type diff --git a/fitbenchmarking/fitting/mantid/main.py b/fitbenchmarking/fitting/mantid/main.py index 9ce1b7c16..4f978433a 100644 --- a/fitbenchmarking/fitting/mantid/main.py +++ b/fitbenchmarking/fitting/mantid/main.py @@ -97,7 +97,7 @@ def fit(problem, wks_created, function, minimizer, StartX=problem.start_x, EndX=problem.end_x) t_end = time.clock() except (RuntimeError, ValueError) as err: - logger.error("Warning, fit failed. Going on. Error: " + str(err)) + logger.warning("Fit failed: " + str(err)) status, fit_wks, fin_function_def, runtime = \ parse_result(fit_result, t_start, t_end) @@ -152,7 +152,7 @@ def parse_result(fit_result, t_start, t_end): fit_wks = fit_result.OutputWorkspace fin_function_def = str(fit_result.Function) runtime = t_end - t_start - + print(fin_function_def) return status, fit_wks, fin_function_def, runtime diff --git a/fitbenchmarking/fitting/plotting/plots.py b/fitbenchmarking/fitting/plotting/plots.py index af2d2bd85..e913cc10c 100644 --- a/fitbenchmarking/fitting/plotting/plots.py +++ b/fitbenchmarking/fitting/plotting/plots.py @@ -183,6 +183,9 @@ def get_start_guess_data(software, data_struct, function, problem): return get_mantid_starting_guess_data(data_struct, function, problem) elif software == 'scipy': return get_scipy_starting_guess_data(data_struct, function) + elif software == 'sasview': + return get_sasview_starting_guess_data(data_struct, problem, function) + # return [0,0,0], [0,0,0] else: raise NameError("Sorry, that software is not supported.") @@ -230,3 +233,29 @@ def get_mantid_starting_guess_data(wks_created, function, problem): yData = tmp.readY(1) return xData, yData + +def get_sasview_starting_guess_data(data_struct, problem, function): + """ + + :param data_struct: + :param function: + :return: + """ + + # yData = function[0](data_struct.x) + + yData = problem.eval_f(data_struct.x, function[1]) + # from sasmodels.bumps_model import Experiment, Model + # + # kernel = function[0] + # + # exec("pars = dict(" + problem.starting_values + ")") + # + # model_wrapper = Model(kernel, **pars) + # + # for range in problem.starting_value_ranges.split(';'): + # exec('model_wrapper.'+range) + # + # M = Experiment(data=data_struct, model=model_wrapper) + + return data_struct.x, yData \ No newline at end of file diff --git a/fitbenchmarking/fitting/prerequisites.py b/fitbenchmarking/fitting/prerequisites.py index dbef95334..41a36135a 100644 --- a/fitbenchmarking/fitting/prerequisites.py +++ b/fitbenchmarking/fitting/prerequisites.py @@ -45,6 +45,8 @@ def prepare_software_prerequisites(software, problem, use_errors): return prepare_mantid(problem, use_errors) elif software == 'scipy': return prepare_scipy(problem, use_errors) + elif software == 'sasview': + return prepare_sasview(problem, use_errors) # elif software == 'your_software': # return prepare_your_software(problem, use_errors) else: @@ -95,3 +97,14 @@ def prepare_scipy(problem, use_errors): problem.start_x = - np.inf problem.end_x = np.inf return data, cost_function, function_definitions + +def prepare_sasview(problem, use_errors): + + from fitting.sasview.prepare_data import prepare_data + from fitting.sasview.func_def import function_definitions + + data_obj, cost_function = prepare_data(problem, use_errors) + + function_definitions = function_definitions(problem) + + return data_obj, cost_function, function_definitions \ No newline at end of file diff --git a/fitbenchmarking/fitting/sasview/__init__.py b/fitbenchmarking/fitting/sasview/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fitbenchmarking/fitting/sasview/func_def.py b/fitbenchmarking/fitting/sasview/func_def.py new file mode 100644 index 000000000..e76814dd6 --- /dev/null +++ b/fitbenchmarking/fitting/sasview/func_def.py @@ -0,0 +1,140 @@ +""" +Methods that prepare the function definitions to be used by the mantid +fitting software. +""" +# Copyright © 2016 ISIS Rutherford Appleton Laboratory, NScD +# Oak Ridge National Laboratory & European Spallation Source +# +# This file is part of Mantid. +# Mantid is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# Mantid is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +# File change history is stored at: +# . +# Code Documentation is available at: + +from __future__ import (absolute_import, division, print_function) + +from utils.logging_setup import logger +from sasmodels.core import load_model +import numpy as np +import re + + +def function_definitions(problem): + """ + Transforms the prob.equation field into a function that can be + understood by the mantid fitting software. + + @param problem :: object holding the problem information + + @returns :: a function definitions string with functions that + mantid understands + """ + + problem_type = extract_problem_type(problem) + + if problem_type == 'SasView'.upper(): + model_name = (problem.equation.split('='))[1] + kernel = load_model(model_name) + function_defs = [[kernel, problem.starting_values, problem.equation]] + elif problem_type == 'FitBenchmark'.upper(): + function_defs = problem.get_bumps_function() + elif problem_type == 'NIST': + function_defs = problem.get_function() + else: + raise NameError('Currently data types supported are FitBenchmark' + ' and nist, data type supplied was {}'.format(problem_type)) + + return function_defs + +def extract_problem_type(problem): + """ + This function gets the problem object and figures out the problem type + from the file name that the class that it has been sent from + + @param problem :: object holding the problem information + + @returns :: the type of the problem in capital letters (e.g. NIST) + """ + problem_file_name = problem.__class__.__module__ + problem_type = (problem_file_name.split('_')[1]).upper() + + return problem_type + +def get_fin_function_def(final_param_values, problem, init_func_def): + """ + + @param result :: the result object created by Bumps fitting + @param problem :: object holding the problem information + @param init_func_def :: the initial function definition string + + @returns :: the final function definition string + """ + + problem_type = extract_problem_type(problem) + + if not 'name=' in init_func_def: + final_param_values = list(final_param_values) + params = init_func_def.split("|")[1] + params = re.sub(r"[-+]?\d+.\d+", lambda m, rep=iter(final_param_values): + str(round(next(rep), 3)), params) + fin_function_def = init_func_def.split("|")[0] + " | " + params + elif problem_type == 'SasView'.upper(): + param_names = [(param.split('='))[0] for param in problem.starting_values.split(',')] + fin_function_def = problem.equation+',' + for name, value in zip(param_names, final_param_values): + fin_function_def += name+ '=' + str(value) + ',' + fin_function_def = fin_function_def[:-1] + else: + final_param_values = list(final_param_values) + all_attributes = re.findall(r"BinWidth=\d+[.]\d+", init_func_def) + if len(all_attributes) != 0: + init_func_def = [init_func_def.replace(attr, '+') for attr in all_attributes][0] + fin_function_def = re.sub(r"[-+]?\d+[.]\d+", lambda m, rep=iter(final_param_values): + str(round(next(rep), 3)), init_func_def) + if len(all_attributes) != 0: + fin_function_def = [fin_function_def.replace('+', attr) for attr in all_attributes] + + return fin_function_def + + +def get_init_function_def(function, problem): + """ + Get the initial function definition string. + + @param function :: array containing the function information + @param problem :: object holding the problem information + + @returns :: the initial function definition string + """ + + problem_type = extract_problem_type(problem) + + if not 'name=' in str(problem.equation): + params = function[0].__code__.co_varnames[1:] + param_string = '' + for idx in range(len(function[1])): + param_string += params[idx] + "= " + str(function[1][idx]) + ", " + param_string = param_string[:-2] + init_function_def = function[2] + " | " + param_string + elif problem_type == 'SasView'.upper(): + init_function_def = problem.equation + ',' + problem.starting_values + init_function_def = re.sub(r"(=)([-+]?\d+)([^.\d])", r"\g<1>\g<2>.0\g<3>", init_function_def) + else: + init_function_def = problem.equation + init_function_def = re.sub(r",(\s+)?ties=[(][A-Za-z0-9=.,\s+]+[)]", '', init_function_def) + init_function_def = re.sub(r"(=)([-+]?\d+)([^.\d])", r"\g<1>\g<2>.0\g<3>", init_function_def) + + return init_function_def + diff --git a/fitbenchmarking/fitting/sasview/main.py b/fitbenchmarking/fitting/sasview/main.py new file mode 100644 index 000000000..b574a2e6b --- /dev/null +++ b/fitbenchmarking/fitting/sasview/main.py @@ -0,0 +1,171 @@ +""" +Benchmark fitting functions for the scipy software. +""" +# Copyright © 2016 ISIS Rutherford Appleton Laboratory, NScD +# Oak Ridge National Laboratory & European Spallation Source +# +# This file is part of Mantid. +# Mantid is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# Mantid is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +# File change history is stored at: +# . +# Code Documentation is available at: +from __future__ import (absolute_import, division, print_function) + +from scipy.optimize import curve_fit +import numpy as np +import sys +import time +import re + +from fitting import misc +from sasmodels.bumps_model import Experiment, Model +from bumps.names import * +from bumps.fitters import fit as bumpsFit +from fitting.sasview.func_def import get_init_function_def, get_fin_function_def +from fitting.plotting import plot_helper +from utils.logging_setup import logger + +MAX_FLOAT = sys.float_info.max + + +def benchmark(problem, data, function, minimizers, cost_function): + """ + Fit benchmark one problem, with one function definition and all + the selected minimizers, using the SasView fitting software. + + @param problem :: a problem object containing information used in fitting + @param data :: object holding the problem data + @param function :: the fitted function (model for SasView problems) + @param minimizers :: array of minimizers used in fitting + @param cost_function :: the cost function used for fitting + + @returns :: nested array of result objects, per minimizer + and data object for the best fit + """ + min_chi_sq, best_fit = MAX_FLOAT, None + results_problem = [] + + for minimizer in minimizers: + + init_func_def = get_init_function_def(function, problem) + status, fitted_y, fin_func_def, runtime = \ + fit(problem, data, function, minimizer, init_func_def) + chi_sq, min_chi_sq, best_fit = \ + chisq(status, data, fitted_y, min_chi_sq, best_fit, minimizer) + + individual_result = \ + misc.create_result_entry(problem, status, chi_sq, runtime, minimizer, + init_func_def, fin_func_def) + + results_problem.append(individual_result) + + return results_problem, best_fit + +def fit(problem, data, function, minimizer, init_func_def): + """ + + @param problem :: a problem object containing information used in fitting + @param model :: the fitted function (model for SasView problems) + @param minimizer :: the minimizer used in the fitting + @param init_func_def :: the initial function definition + + @returns :: the status, either success or failure (str), the data + of the fit, the final function definition and the + runtime of the fitting software + """ + + t_start, t_end = None, None + model = function[0] + + if hasattr(model, '__call__'): + if isinstance(problem.starting_values, basestring): + function_without_ties = re.sub(r",(\s+)?ties=[(][A-Za-z0-9=.,\s+]+[)]", '', problem.equation) + function_list = (function_without_ties).split(';') + func_params_list = [(func.split(','))[1:] for func in function_list] + formatted_param_list = ['f'+str(func_params_list.index(func_params))+'.'+param.strip() for func_params in func_params_list for param in func_params] + param_names = [(param.split('='))[0] for param in formatted_param_list if not 'BinWidth' in param] + formatted_param_names = [param.replace('.', '_') for param in param_names] + else: + formatted_param_names = [param[0] for param in problem.starting_values] + + param_values = function[1] + param_string = '' + for name, value in zip(formatted_param_names, param_values): + if not name.endswith('BinWidth'): + param_string += "," + name + "=" + str(value) + + exec ('func_wrapper = Curve(model, x=data.x, y=data.y, dy=data.dy' + param_string + ')') + + for name, value in zip(formatted_param_names, param_values): + minVal = -np.inf + maxVal = np.inf + if not name.endswith('BinWidth'): + exec ('func_wrapper.' + name + '.range(' + str(minVal) + ',' + str(maxVal) + ')') + else: + exec ("params = dict(" + problem.starting_values + ")") in locals() + + model_wrapper = Model(model, **params) + for range in problem.starting_value_ranges.split(';'): + exec ('model_wrapper.' + range) + func_wrapper = Experiment(data=data, model=model_wrapper) + + fitProblem = FitProblem(func_wrapper) + + try: + t_start = time.clock() + result = bumpsFit(fitProblem, method=minimizer) + t_end = time.clock() + except (RuntimeError, ValueError) as err: + logger.warning("Fit failed: " + str(err)) + + status = 'success' if result.success is True else 'failed' + + fitted_y = func_wrapper.theory() + + final_param_values = result.x + + fin_func_def = get_fin_function_def(final_param_values, problem, init_func_def) + + runtime = t_end - t_start + + return status, fitted_y, fin_func_def, runtime + +def chisq(status, data, fitted_y, min_chi_sq, best_fit, minimizer_name): + """ + Calculates the chi squared and compares it to the minimum chi squared + found until now. If the current chi_squared is lower than the minimum, + the new values becomes the minimum and the data of the fit is stored + in the variable best_fit. + + @param status :: the status of the fit, either success or failure + @param fitted_y :: the y-data of the fit + @param min_chi_sq :: the minimum chi_squared value + @param best_fit :: object where the best fit data is stored + @param minimizer_name :: name of the minimizer used in storing the + best_fit data + + @returns :: The chi-squared values, the minimum chi-squared found + until now and the best fit data object + """ + if status != 'failed': + differences = fitted_y - data.y + chi_sq = misc.compute_chisq(differences) + if chi_sq < min_chi_sq and not chi_sq == np.nan: + best_fit = plot_helper.data(minimizer_name, data.x, fitted_y) + min_chi_sq = chi_sq + else: + chi_sq = np.nan + + return chi_sq, min_chi_sq, best_fit diff --git a/fitbenchmarking/fitting/sasview/prepare_data.py b/fitbenchmarking/fitting/sasview/prepare_data.py new file mode 100644 index 000000000..49a5c38be --- /dev/null +++ b/fitbenchmarking/fitting/sasview/prepare_data.py @@ -0,0 +1,91 @@ +""" +Functions that prepare the data to be in the right format. +""" +# Copyright © 2016 ISIS Rutherford Appleton Laboratory, NScD +# Oak Ridge National Laboratory & European Spallation Source +# +# This file is part of Mantid. +# Mantid is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 3 of the License, or +# (at your option) any later version. +# +# Mantid is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +# File change history is stored at: +# . +# Code Documentation is available at: +from __future__ import (absolute_import, division, print_function) + +import numpy as np +from sasmodels.data import empty_data1D, Data1D + +from utils.logging_setup import logger + +def prepare_data(problem, use_errors): + + problem = misc_preparations(problem) + + if use_errors: + data_obj = Data1D(x=problem.data_x, y=problem.data_y, dy=problem.data_e) + cost_function = 'least squares' + else: + data_obj = Data1D(x=problem.data_x, y=problem.data_y) + cost_function = 'unweighted least squares' + + return data_obj, cost_function + +def misc_preparations(problem): + """ + Helper function that does some miscellaneous preparation of the data. + It calculates the errors if they are not presented in problem file + itself by assuming a Poisson distribution. Additionally, it applies + constraints to the data if such constraints are provided. + + @return :: returns problem object with updated data + """ + + if problem.data_e is None: + problem.data_e = np.sqrt(abs(problem.data_y)) + + if problem.start_x is None and problem.end_x is None: + pass + elif problem.start_x is -np.inf and problem.end_x is np.inf: + pass + else: + problem = apply_x_data_range(problem) + + return problem + + +def apply_x_data_range(problem): + """ + Crop the data to fit within specified start_x and end_x values if these are provided otherwise + return unalternated problem object. + Scipy don't take start_x and end_x, meaning Scipy can on fit over the entire data array. + + @return :: Modified problem object where data have been cropped + """ + + if problem.start_x is None or problem.end_x is None: + return problem + + start_x_diff = problem.data_x - problem.start_x + end_x_diff = problem.data_x - problem.end_x + start_idx = np.where(start_x_diff >= 0, start_x_diff, np.inf).argmin() + end_idx = np.where(end_x_diff <= 0, end_x_diff, -np.inf).argmax() + + problem.data_x = np.array(problem.data_x)[start_idx:end_idx + 1] + problem.data_y = np.array(problem.data_y)[start_idx:end_idx + 1] + problem.data_e = np.array(problem.data_e)[start_idx:end_idx + 1] + problem.data_e[problem.data_e == 0] = 0.00000001 + return problem + + + diff --git a/fitbenchmarking/fitting/sasview/tests/__init__.py b/fitbenchmarking/fitting/sasview/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fitbenchmarking/fitting/sasview/tests/test_func_def.py b/fitbenchmarking/fitting/sasview/tests/test_func_def.py new file mode 100644 index 000000000..e69de29bb diff --git a/fitbenchmarking/fitting/sasview/tests/test_main.py b/fitbenchmarking/fitting/sasview/tests/test_main.py new file mode 100644 index 000000000..e69de29bb diff --git a/fitbenchmarking/fitting/sasview/tests/test_prepare_data.py b/fitbenchmarking/fitting/sasview/tests/test_prepare_data.py new file mode 100644 index 000000000..f4b708f17 --- /dev/null +++ b/fitbenchmarking/fitting/sasview/tests/test_prepare_data.py @@ -0,0 +1,151 @@ +from __future__ import (absolute_import, division, print_function) + +import unittest +import os +import numpy as np + +import sys +test_dir = os.path.dirname(os.path.realpath(__file__)) +parent_dir = os.path.dirname(os.path.normpath(test_dir)) +parent_dir = os.path.dirname(os.path.normpath(parent_dir)) +main_dir = os.path.dirname(os.path.normpath(parent_dir)) +sys.path.insert(0, main_dir) + +try: + import sasmodels.data +except: + print('******************************************\n' + 'sasmodels is not yet installed on your computer\n' + 'To install, type the following command:\n' + 'python -m pip install sasmodels\n' + '******************************************') + sys.exit() + +from fitting.sasview.prepare_data import prepare_data +from fitting.sasview.prepare_data import misc_preparations +from fitting.sasview.prepare_data import apply_x_data_range + +from parsing.parse_nist_data import FittingProblem +from mock_problem_files.get_problem_files import get_file + + +class SasViewTests(unittest.TestCase): + + def NIST_problem(self): + """ + Helper function. + Sets up the problem object for the nist problem file Misra1a.dat + """ + + data_pattern = np.array([[10.07, 77.6], + [14.73, 114.9], + [17.94, 141.1], + [23.93, 190.8], + [29.61, 239.9], + [35.18, 289.0], + [40.02, 332.8], + [44.82, 378.4], + [50.76, 434.8], + [55.05, 477.3], + [61.01, 536.8], + [66.40, 593.1], + [75.47, 689.1], + [81.78, 760.0]]) + + fname = get_file('NIST_Misra1a.dat') + prob = FittingProblem(fname) + prob.name = 'Misra1a' + prob.equation = 'b1*(1-exp(-b2*x))' + prob.starting_values = [['b1', [500.0, 250.0]], + ['b2', [0.0001, 0.0005]]] + prob.data_x = data_pattern[:, 1] + prob.data_y = data_pattern[:, 0] + + return prob + + def get_expected_data(self): + + problem = self.NIST_problem() + data_x = problem.data_x + data_y = problem.data_y + data_e = np.sqrt(abs(data_y)) + + return np.array([data_x, data_y, data_e]) + + def test_prepareData_use_errors_true(self): + + problem = self.NIST_problem() + use_errors = True + + data, cost_function = prepare_data(problem, use_errors) + expected_data = self.get_expected_data() + + np.testing.assert_array_equal(expected_data[0], data.x) + np.testing.assert_array_equal(expected_data[1], data.y) + np.testing.assert_array_equal(expected_data[2], data.dy) + self.assertEqual("least squares", cost_function) + + def test_prepareData_use_errors_false(self): + + problem = self.NIST_problem() + use_errors = False + + data, cost_function = prepare_data(problem, use_errors) + data_x, data_y, _ = self.get_expected_data() + expected_data = np.array([data_x, data_y]) + + np.testing.assert_array_equal(expected_data[0], data.x) + np.testing.assert_array_equal(expected_data[1], data.y) + self.assertEqual("unweighted least squares", cost_function) + + def test_miscPreparations_no_errors_no_limits(self): + + problem = self.NIST_problem() + + problem = \ + misc_preparations(problem) + expected_data = self.get_expected_data() + + np.testing.assert_equal(expected_data[0], problem.data_x) + np.testing.assert_equal(expected_data[1], problem.data_y) + np.testing.assert_equal(expected_data[2], problem.data_e) + + def test_miscPreparations_uneq_length_constrained(self): + + problem = self.NIST_problem() + problem.data_x = np.append(problem.data_x, 0) + problem.start_x = 115 + problem.end_x = 540 + + problem = \ + misc_preparations(problem) + expected_data = self.get_expected_data() + expected = [[None] * 3] * 10 + expected[0] = np.array(expected_data[0])[2:11] + expected[1] = np.array(expected_data[1])[2:11] + + np.testing.assert_equal(expected[0], problem.data_x) + np.testing.assert_equal(expected[1], problem.data_y) + + def test_applyXRange_return_data(self): + + problem = self.NIST_problem() + problem.start_x = 115 + problem.end_x = 540 + problem.data_e = np.sqrt(abs(problem.data_y)) + + problem = \ + apply_x_data_range(problem) + expected_data = self.get_expected_data() + expected = [ [ None ] * 3 ] * 10 + expected[0] = np.array(expected_data[0])[2:11] + expected[1] = np.array(expected_data[1])[2:11] + expected[2] = np.array(expected_data[2])[2:11] + + np.testing.assert_equal(expected[0], problem.data_x) + np.testing.assert_equal(expected[1], problem.data_y) + np.testing.assert_equal(expected[2], problem.data_e) + + +if __name__ == "__main__": + unittest.main() diff --git a/fitbenchmarking/fitting/scipy/func_def.py b/fitbenchmarking/fitting/scipy/func_def.py index 66651c205..497d768b9 100644 --- a/fitbenchmarking/fitting/scipy/func_def.py +++ b/fitbenchmarking/fitting/scipy/func_def.py @@ -34,7 +34,7 @@ def function_definitions(problem): """ problem_type = extract_problem_type(problem) - if problem_type == 'NIST' or problem_type == 'FitBenchmark'.upper(): + if problem_type == 'NIST' or problem_type == 'FitBenchmark'.upper() or problem_type == 'SasView'.upper(): return problem.get_function() else: RuntimeError("Your problem type is not supported yet!") @@ -51,37 +51,50 @@ def get_fin_function_def(init_function_def, func_callable, popt): @returns :: the final function definition string """ - if not 'name=' in str(func_callable): + if not 'name=' in init_function_def: popt = list(popt) params = init_function_def.split("|")[1] - params = re.sub(r"[-+]?\d+\.\d+", lambda m, rep=iter(popt): + params = re.sub(r"[-+]?\d+[.]\d+", lambda m, rep=iter(popt): str(round(next(rep), 3)), params) fin_function_def = init_function_def.split("|")[0] + " | " + params else: - fin_function_def = str(func_callable) + all_attributes = re.findall(r",[\s+]?ties=[(][A-Za-z0-9=.,\s+]+[)]", init_function_def) + if len(all_attributes) != 0: + init_function_def = [init_function_def.replace(attr, '+') for attr in all_attributes][0] + fin_function_def = re.sub(r"[-+]?\d+[.]\d+", lambda m, rep=iter(popt): + str(round(next(rep), 3)), init_function_def) + if len(all_attributes) != 0: + fin_function_def = [fin_function_def.replace('+', attr) for attr in all_attributes] return fin_function_def -def get_init_function_def(function, mantid_definition): +def get_init_function_def(function, problem): """ Get the initial function definition string. @param function :: array containing the function information - @param mantid_definition :: the string containing the function - definition in mantid format + @param equation :: the string containing the function + definition in mantid/sasview format @returns :: the initial function definition string """ - if not 'name=' in str(function[0]): + + problem_type = extract_problem_type(problem) + + if not 'name=' in str(problem.equation): params = function[0].__code__.co_varnames[1:] param_string = '' for idx in range(len(function[1])): param_string += params[idx] + "= " + str(function[1][idx]) + ", " param_string = param_string[:-2] init_function_def = function[2] + " | " + param_string + elif problem_type == 'SasView'.upper(): + init_function_def = problem.equation + ',' + problem.starting_values + init_function_def = re.sub(r"(=)([-+]?\d+)([^.\d])", r"\g<1>\g<2>.0\g<3>", init_function_def) else: - init_function_def = mantid_definition + init_function_def = problem.equation + init_function_def = re.sub(r"(=)([-+]?\d+)([^.\d])", r"\g<1>\g<2>.0\g<3>", init_function_def) return init_function_def diff --git a/fitbenchmarking/fitting/scipy/main.py b/fitbenchmarking/fitting/scipy/main.py index 61e9510b2..8a7fb5787 100644 --- a/fitbenchmarking/fitting/scipy/main.py +++ b/fitbenchmarking/fitting/scipy/main.py @@ -56,11 +56,12 @@ def benchmark(problem, data, function, minimizers, cost_function): for minimizer in minimizers: - init_function_def = get_init_function_def(function, problem.equation) + init_function_def = get_init_function_def(function, problem) status, fitted_y, fin_function_def, runtime = \ fit(data, function, minimizer, cost_function, init_function_def) chi_sq, min_chi_sq, best_fit = \ chisq(status, data, fitted_y, min_chi_sq, best_fit, minimizer) + individual_result = \ misc.create_result_entry(problem, status, chi_sq, runtime, minimizer, init_function_def, fin_function_def) @@ -96,7 +97,7 @@ def fit(data, function, minimizer, cost_function, init_function_def): minimizer, cost_function) t_end = time.clock() except(RuntimeError, ValueError) as err: - logger.error("Warning, fit failed. Going on. Error: " + str(err)) + logger.warning("Fit failed: " + str(err)) fin_def = None if not popt is None: @@ -147,6 +148,7 @@ def execute_fit(function, data, initial_params, minimizer, cost_function): @returns :: array of final variables after the fit was performed """ popt, pcov = None, None + try: if cost_function == 'least squares': popt, pcov = curve_fit(f=function.__call__, @@ -158,6 +160,7 @@ def execute_fit(function, data, initial_params, minimizer, cost_function): p0=initial_params, method=minimizer, maxfev=500) except(IndexError) as err: logger.error('Index out of bound. Going on.') + return popt @@ -181,8 +184,8 @@ def get_fittedy(function, data_x, popt): Gets the fitted y data corresponding to given x values. """ try: - fitted_y = function.__call__(data_x) - except: fitted_y = function(data_x, *popt) + except: + fitted_y = function.__call__(data_x) return fitted_y diff --git a/fitbenchmarking/fitting/scipy/prepare_data.py b/fitbenchmarking/fitting/scipy/prepare_data.py index 36834b8c8..c16bc84cb 100644 --- a/fitbenchmarking/fitting/scipy/prepare_data.py +++ b/fitbenchmarking/fitting/scipy/prepare_data.py @@ -77,7 +77,7 @@ def apply_x_data_range(problem): """ Crop the data to fit within specified start_x and end_x values if these are provided otherwise return unalternated problem object. - Scipy don't take start_x and end_x, meaning Scipy can on fit over the entire data array. + Scipy don't take start_x and end_x, meaning SasView can on fit over the entire data array. @return :: Modified problem object where data have been cropped """ diff --git a/fitbenchmarking/fitting/scipy/tests/test_func_def.py b/fitbenchmarking/fitting/scipy/tests/test_func_def.py index 965abfd93..ece8807a6 100644 --- a/fitbenchmarking/fitting/scipy/tests/test_func_def.py +++ b/fitbenchmarking/fitting/scipy/tests/test_func_def.py @@ -101,7 +101,7 @@ def test_get_init_function_def_return_NIST_init_func_def(self): prob = self.NIST_problem() - init_func_def = get_init_function_def((prob.get_function())[0],prob.equation) + init_func_def = get_init_function_def((prob.get_function())[0],prob) init_func_def_expected = "b1*(1-np.exp(-b2*x)) | b1= 500.0, b2= 0.0001" @@ -111,9 +111,9 @@ def test_get_init_function_def_return_neutron_init_func_def(self): prob = self.Neutron_problem() - init_func_def = get_init_function_def((prob.get_function())[0],prob.equation) + init_func_def = get_init_function_def((prob.get_function())[0],prob) - init_func_def_expected = "name=LinearBackground,A0=0,A1=0;name=BackToBackExponential,I=597.076,A=1,B=0.05,X0=24027.5,S=22.9096" + init_func_def_expected = "name=LinearBackground,A0=0.0,A1=0.0;name=BackToBackExponential,I=597.076,A=1.0,B=0.05,X0=24027.5,S=22.9096" self.assertEqual(init_func_def_expected, init_func_def) @@ -135,14 +135,14 @@ def test_get_fin_function_def_return_neutron_fin_func_def(self): prob = self.Neutron_problem() - init_func_def = "name=LinearBackground,A0=0,A1=0;name=BackToBackExponential,I=597.076,A=1,B=0.05,X0=24027.5,S=22.9096" + init_func_def = "name=LinearBackground,A0=0.0,A1=0.0;name=BackToBackExponential,I=597.076,A=1,B=0.05,X0=24027.5,S=22.9096" - popt = np.array([-2.28680098e+01, 9.80089245e-04, 7.10042119e+02, 3.58802084e+00, + popt = np.array([0.0,0.0,-2.28680098e+01, 9.80089245e-04, 7.10042119e+02, 3.58802084e+00, 3.21533386e-02, 2.40053562e+04, 1.65148875e+01]) fin_func_def = get_fin_function_def(init_func_def,(prob.get_function())[0][0],popt) - fin_func_def_expected = "name=LinearBackground,A0=0,A1=0;name=BackToBackExponential,I=597.076,A=1,B=0.05,X0=24027.5,S=22.9096" + fin_func_def_expected = "name=LinearBackground,A0=0.0,A1=0.0;name=BackToBackExponential,I=-22.868,A=1,B=0.001,X0=710.042,S=3.588" self.assertEqual(fin_func_def_expected, fin_func_def) diff --git a/fitbenchmarking/fitting/scipy/tests/test_prepare_data.py b/fitbenchmarking/fitting/scipy/tests/test_prepare_data.py index 6e2b5e3b3..039e21bef 100644 --- a/fitbenchmarking/fitting/scipy/tests/test_prepare_data.py +++ b/fitbenchmarking/fitting/scipy/tests/test_prepare_data.py @@ -66,7 +66,6 @@ def test_prepareData_use_errors_true(self): problem = self.NIST_problem() use_errors = True - print(type(problem.data_x)) data, cost_function = prepare_data(problem, use_errors) expected_data = self.get_expected_data() @@ -92,7 +91,7 @@ def test_miscPreparations_no_errors_no_limits(self): problem = \ misc_preparations(problem, problem.data_x, problem.data_y, problem.data_e) - # print(problem.data_x) + expected_data = self.get_expected_data() np.testing.assert_equal(expected_data[0], problem.data_x) @@ -128,7 +127,6 @@ def test_applyXRange_return_data(self): expected_data = self.get_expected_data() expected = [ [ None ] * 3 ] * 10 expected[0] = np.array(expected_data[0])[2:11] - # print(expected[0]) expected[1] = np.array(expected_data[1])[2:11] expected[2] = np.array(expected_data[2])[2:11] diff --git a/fitbenchmarking/minimizers_list_default.json b/fitbenchmarking/minimizers_list_default.json index 2b357d275..77f5d05d5 100644 --- a/fitbenchmarking/minimizers_list_default.json +++ b/fitbenchmarking/minimizers_list_default.json @@ -5,5 +5,6 @@ "Levenberg-Marquardt", "Levenberg-MarquardtMD", "Simplex","SteepestDescent", "Trust Region"], - "scipy" : ["lm", "trf", "dogbox"] + "scipy" : ["lm", "trf", "dogbox"], + "sasview" : ["amoeba", "lm", "newton", "de", "pt", "mp"] } diff --git a/fitbenchmarking/mock_problem_files/SV_cyl_400_20.txt b/fitbenchmarking/mock_problem_files/SV_cyl_400_20.txt new file mode 100644 index 000000000..e321e3401 --- /dev/null +++ b/fitbenchmarking/mock_problem_files/SV_cyl_400_20.txt @@ -0,0 +1,22 @@ + +0 -1.#IND +0.025 125.852 +0.05 53.6662 +0.075 26.0733 +0.1 11.8935 +0.125 4.61714 +0.15 1.29983 +0.175 0.171347 +0.2 0.0417614 +0.225 0.172719 +0.25 0.247876 +0.275 0.20301 +0.3 0.104599 +0.325 0.0285595 +0.35 0.00213344 +0.375 0.0137511 +0.4 0.0312374 +0.425 0.0350328 +0.45 0.0243172 +0.475 0.00923067 +0.5 0.00121297 diff --git a/fitbenchmarking/mock_problem_files/SV_prob_def_1.txt b/fitbenchmarking/mock_problem_files/SV_prob_def_1.txt new file mode 100644 index 000000000..693d32088 --- /dev/null +++ b/fitbenchmarking/mock_problem_files/SV_prob_def_1.txt @@ -0,0 +1,6 @@ +# An example data set for SasView 1D data +name = 'Problem Def 1' +input_file = 'SV_cyl_400_20.txt' +function ='name=cylinder,radius=35.0,length=350.0,background=0.0,scale=1.0,sld=4.0,sld_solvent=1.0' +parameter_ranges = 'radius.range(1,50);length.range(1,500)' +description = '' diff --git a/fitbenchmarking/parsing/base_fitting_problem.py b/fitbenchmarking/parsing/base_fitting_problem.py index 81a89a783..ab454de62 100644 --- a/fitbenchmarking/parsing/base_fitting_problem.py +++ b/fitbenchmarking/parsing/base_fitting_problem.py @@ -58,6 +58,7 @@ def __init__(self, fname): # Initialize contents of file. Here included to reduce I/O, i.e. # read file content once and then process as needed self._contents = None + self._starting_value_ranges = None def read_file(self): self._contents = open(self.fname, "r") @@ -145,6 +146,14 @@ def contents(self): def contents(self, value): self._contents = value + @property + def starting_value_ranges(self): + return self._starting_value_ranges + + @starting_value_ranges.setter + def starting_value_ranges(self, value): + self._starting_value_ranges = value + def __new__(cls, *args, **kwargs): if cls is BaseFittingProblem: raise TypeError("Base class {} may not be instantiated".format(cls)) diff --git a/fitbenchmarking/parsing/fitbenchmark_data_functions.py b/fitbenchmarking/parsing/fitbenchmark_data_functions.py index 475e15df8..a8039c1a9 100644 --- a/fitbenchmarking/parsing/fitbenchmark_data_functions.py +++ b/fitbenchmarking/parsing/fitbenchmark_data_functions.py @@ -25,7 +25,7 @@ from __future__ import (absolute_import, division, print_function) import numpy as np - +import re from fitting.mantid.externals import gen_func_obj, set_ties from utils.logging_setup import logger @@ -50,11 +50,37 @@ def fitbenchmark_func_definitions(functions_string): fit_function = make_fitbenchmark_fit_function(name, fit_function, params_set) fit_function = set_ties(fit_function, ties) - function_defs = [[fit_function, params]] return function_defs +def get_fit_function_without_kwargs(fit_function, functions_string): + """ + + :param fit_function: + :param functions_string: + :return: + """ + + functions_string = re.sub(r",(\s+)?ties=[(][A-Za-z0-9=.,\s+]+[)]", '', functions_string) + function_list = (functions_string).split(';') + func_params_list = [((func.split(','))[1:]) for func in function_list] + print(func_params_list) + formatted_param_list = ['f' + str(func_params_list.index(func_params)) + '.' + param.strip() for func_params in + func_params_list for param in func_params] + param_names = [(param.split('='))[0] for param in formatted_param_list] + param_values = [(param.split('='))[1] for param in formatted_param_list if + not (param.split('='))[0].endswith('BinWidth')] + new_param_names = [param.replace('.', '_') for param in param_names] + + param_names_string = '' + for param in new_param_names: + if not param.endswith('BinWidth'): + param_names_string += ',' + param + + exec ('def bumps_function(x' + param_names_string + '):\n return fit_function.__call__(x' + param_names_string + ')') in locals() + + return [[bumps_function, param_values]] def get_all_fitbenchmark_func_names(functions_string): """ @@ -179,6 +205,7 @@ def get_fitbenchmark_ties(param_set, ties): else: tie = param_set[start + 1:comma] ties_per_function.append(tie.replace("=", "': ")) + # ties_per_function.append(tie) if comma == -1: break start = comma + 1 diff --git a/fitbenchmarking/parsing/parse.py b/fitbenchmarking/parsing/parse.py index cb0e7bf50..695a3fbd6 100644 --- a/fitbenchmarking/parsing/parse.py +++ b/fitbenchmarking/parsing/parse.py @@ -29,6 +29,7 @@ from utils.logging_setup import logger from parsing import parse_nist_data + def parse_problem_file(prob_file): """ Helper function that loads the problem file and populates the fitting @@ -52,6 +53,10 @@ def parse_problem_file(prob_file): """ from parsing import parse_fitbenchmark_data problem = parse_fitbenchmark_data.FittingProblem(prob_file) + elif prob_type == 'SasView': + + from parsing import parse_sasview_data + problem = parse_sasview_data.FittingProblem(prob_file) check_problem_attributes(problem) @@ -84,7 +89,10 @@ def determine_problem_type(prob_file): prob_type = "NIST" elif "#" in fline: # Checking for a comment in the first line and from assume the format is: - prob_type = "FitBenchmark" + if "SasView" in fline: + prob_type = "SasView" + else: + prob_type = "FitBenchmark" else: raise RuntimeError("Data type supplied currently not supported") diff --git a/fitbenchmarking/parsing/parse_fitbenchmark_data.py b/fitbenchmarking/parsing/parse_fitbenchmark_data.py index 539802e56..386ed55c0 100644 --- a/fitbenchmarking/parsing/parse_fitbenchmark_data.py +++ b/fitbenchmarking/parsing/parse_fitbenchmark_data.py @@ -25,7 +25,7 @@ import numpy as np from utils.logging_setup import logger -from parsing.fitbenchmark_data_functions import fitbenchmark_func_definitions +from parsing.fitbenchmark_data_functions import fitbenchmark_func_definitions, get_fit_function_without_kwargs class FittingProblem(base_fitting_problem.BaseFittingProblem): @@ -59,7 +59,7 @@ def __init__(self, fname): #String containing the function name(s) and the starting parameter values for each function self._equation = entries['function'] - self._starting_values = None + # self._starting_values = (entries['function'].split(',', 1))[1] if 'fit_parameters' in entries: self._start_x = entries['fit_parameters']['StartX'] self._end_x = entries['fit_parameters']['EndX'] @@ -70,27 +70,45 @@ def eval_f(self, x, param_list): """ Function evaluation method - :param x: x data values - :param param_list: - :return: y data values + @param x :: x data values + @param param_list :: parameter values + @returns :: y data values evaluated from the function used in the problem """ function = (fitbenchmark_func_definitions(self._equation))[0][0] - param_statements = param_list.split(',') - param_name_and_value = [param.split('=') for param in param_statements] + param_values_string = '' + for param in param_list: + param_values_string += ',' + str(param) - for param in param_name_and_value: - function[param[0]] = float(param[1]) + y_values = eval('function(x'+param_values_string+')') - return function(x) + return y_values def get_function(self): + """ + + @returns :: function definition list containing the function and its starting parameter values + """ function = fitbenchmark_func_definitions(self._equation) return function + def get_bumps_function(self): + """ + Prepare a function definition list that is acceptable by Bumps fitting module. + The function to be used in Bumps fitting must not have *args or **kwargs in declaration + + @returns :: function definition list containing the function without + any *args or *kwargs and its starting parameter values + """ + + function = fitbenchmark_func_definitions(self._equation)[0][0] + + bumps_function_def = get_fit_function_without_kwargs(function, self._equation) + return bumps_function_def + def get_data_file(self, full_path_of_fitting_def_file, data_file_name): """ Find/create the (full) path to a data_file specified in a FitBenchmark definition file, where diff --git a/fitbenchmarking/parsing/parse_nist_data.py b/fitbenchmarking/parsing/parse_nist_data.py index a03afd60b..06214acba 100644 --- a/fitbenchmarking/parsing/parse_nist_data.py +++ b/fitbenchmarking/parsing/parse_nist_data.py @@ -104,8 +104,9 @@ def eval_f(self, x, param_list): Function evaluation method :param x: x data values - :param param_list: - :return: y data values + :param param_list: parameter value(s) + + :return: y data values evaluated from the function of the problem """ param_string = '' diff --git a/fitbenchmarking/parsing/parse_sasview_data.py b/fitbenchmarking/parsing/parse_sasview_data.py new file mode 100644 index 000000000..dfb775009 --- /dev/null +++ b/fitbenchmarking/parsing/parse_sasview_data.py @@ -0,0 +1,155 @@ +from __future__ import (absolute_import, division, print_function) + +import os +import numpy as np +import re +from parsing import base_fitting_problem +from sasmodels.data import load_data, empty_data1D +from sasmodels.core import load_model +from sasmodels.bumps_model import Experiment, Model + +from utils.logging_setup import logger + +class FittingProblem(base_fitting_problem.BaseFittingProblem): + """ + Definition of the SasView problem class, which provides the + methods for parsing a SasView formatted FitBenchmarking + problem definition file + + Types of data: + - strings: name, type, equation + - floats: start_x, end_x, ref_residual_sum_sq + - numpy arrays: data_x, data_y, data_e + - arrays: starting_values + """ + def __init__(self, fname): + + super(FittingProblem, self).__init__(fname) + super(FittingProblem, self).read_file() + + entries = self.get_data_problem_entries(self.contents) + data_file_path = self.get_data_file(self.fname, entries['input_file']) + + data_obj = load_data(data_file_path) + + self._data_x = data_obj.x + self._data_y = data_obj.y + + self._start_x, self._end_x = self.get_start_x_and_end_x(self._data_x) + + self._name = entries['name'] + self._equation = (entries['function'].split(',', 1))[0] + + self._starting_values = (entries['function'].split(',', 1))[1] + self._starting_value_ranges = entries['parameter_ranges'] + + super(FittingProblem, self).close_file() + + def eval_f(self, x, *param_list): + """ + Function Evaluation Method + + @param x :: x data values + @param *param_list :: parameter value(s) + + @ returns :: the y data values evaluated from the model + """ + data = empty_data1D(x) + model = load_model((self._equation.split('='))[1]) + + param_names = [(param.split('='))[0] for param in self.starting_values.split(',')] + if len(param_list) == 1: + if isinstance(param_list[0],basestring): + exec ("params = dict(" + param_list[0] + ")") + else: + param_string = '' + for name, value in zip(param_names, param_list): + param_string += name+'='+str(value)+',' + param_string = param_string[:-1] + exec ("params = dict(" + param_string + ")") + + model_wrapper = Model(model, **params) + for range in self.starting_value_ranges.split(';'): + exec ('model_wrapper.' + range) + func_wrapper = Experiment(data=data, model=model_wrapper) + + return func_wrapper.theory() + + def get_function(self): + """ + + @returns :: function definition list containing the model and its starting parameter values + """ + + param_values = [(param.split('='))[1] for param in self.starting_values.split(',')] + param_values = np.array([param_values],dtype=np.float64) + + function_defs = [] + + for param in param_values: + + function_defs.append([self.eval_f, param]) + + return function_defs + + + def get_data_file(self, full_path_of_fitting_def_file, data_file_name): + """ + Find/create the (full) path to a data_file specified in a FitBenchmark definition file, where + the data_file is search for in the directory of the definition file and subfolders of this + file + + @param full_path_of_fitting_def_file :: (full) path of a FitBenchmark definition file + @param data_file_name :: the name of the data file as specified in the FitBenchmark definition file + + @returns :: (full) path to a data file (str). Return None if not found + """ + data_file = None + # find or search for path for data_file_name + for root, dirs, files in os.walk(os.path.dirname(full_path_of_fitting_def_file)): + for name in files: + if data_file_name == name: + data_file = os.path.join(root, data_file_name) + + if data_file == None: + logger.error("Data file {} not found".format(data_file_name)) + + return data_file + + def get_data_problem_entries(self, fname): + """ + Get the problem entries from a fitbenchmark problem definition file. + + @param fname :: path to the fitbenchmark problem definition file + + @returns :: a dictionary with all the entires of the problem file + """ + + entries = {} + for line in fname: + # Discard comments + line = line.partition('#')[0] + line = line.rstrip() + if not line: + continue + + lhs, rhs = line.split("=", 1) + entries[lhs.strip()] = eval(rhs.strip()) + + return entries + + def get_start_x_and_end_x(self, x_data): + """ + + Get the start and end value of x from the list of x values. + + @param x_data :: list containing x values + @return :: the start and end values of the x data + """ + + sorted_x_data = sorted(x_data) + + start_x = sorted_x_data[0] + end_x = sorted_x_data[-1] + + return start_x, end_x diff --git a/fitbenchmarking/parsing/tests/test_parse_fitbenchmark_data.py b/fitbenchmarking/parsing/tests/test_parse_fitbenchmark_data.py index b9b7b425e..507438f56 100644 --- a/fitbenchmarking/parsing/tests/test_parse_fitbenchmark_data.py +++ b/fitbenchmarking/parsing/tests/test_parse_fitbenchmark_data.py @@ -68,9 +68,9 @@ def test_eval_f(self): problem = FittingProblem(fname) - y_values = problem.eval_f(problem.data_x[:10], 'f0.A1=100,f1.A=100') + y_values = problem.eval_f(problem.data_x[:10], [10,100,597.076,1.0,0.05,24027.5,22.9096]) - y_values_expected = np.array([600059.4, 600178.1, 600296.9, 600415.6, 600534.4, 600653.1, 600771.9, 600890.6, 601009.4, 601128.1]) + y_values_expected = np.array([600069.4, 600188.1, 600306.9, 600425.6, 600544.4, 600663.1, 600781.9, 600900.6, 601019.4, 601138.1]) np.testing.assert_array_equal(y_values_expected, y_values) diff --git a/fitbenchmarking/parsing/tests/test_parse_nist_data.py b/fitbenchmarking/parsing/tests/test_parse_nist_data.py index 50e02b3b1..08b1c6801 100644 --- a/fitbenchmarking/parsing/tests/test_parse_nist_data.py +++ b/fitbenchmarking/parsing/tests/test_parse_nist_data.py @@ -161,10 +161,10 @@ def test_getFunction_returns_correct_function(self): param_array = function[0][1:] - param_array_expeacted = [[500.0, 0.0001], 'b1*(1-np.exp(-b2*x))'] + param_array_expected = [[500.0, 0.0001], 'b1*(1-np.exp(-b2*x))'] np.testing.assert_allclose(y_values_expected, y_values, rtol=1e-5, atol=0) - self.assertListEqual(param_array_expeacted, param_array) + self.assertListEqual(param_array_expected, param_array) def test_ParseProblemFileNIST_returns_correct_problem_object(self): diff --git a/fitbenchmarking/parsing/tests/test_parse_sasview_data.py b/fitbenchmarking/parsing/tests/test_parse_sasview_data.py new file mode 100644 index 000000000..2cf93ed5e --- /dev/null +++ b/fitbenchmarking/parsing/tests/test_parse_sasview_data.py @@ -0,0 +1,191 @@ +from __future__ import (absolute_import, division, print_function) + +import unittest +import os +import numpy as np +import json + +# Delete four lines below when automated tests are enabled +import sys +test_dir = os.path.dirname(os.path.realpath(__file__)) +parent_dir = os.path.dirname(os.path.normpath(test_dir)) +main_dir = os.path.dirname(os.path.normpath(parent_dir)) +fb_dir = os.path.dirname(os.path.normpath(main_dir)) +sys.path.insert(0, main_dir) +sys.path.insert(1,fb_dir) + +try: + import sasmodels.data +except: + print('******************************************\n' + 'sasmodels is not yet installed on your computer\n' + 'To install, type the following command:\n' + 'python -m pip install sasmodels\n' + '******************************************') + sys.exit() + +from fitting.mantid.externals import store_main_problem_data +from parsing.parse import parse_problem_file +from parsing.parse import check_problem_attributes +from parsing.parse import determine_problem_type +from parsing.parse_sasview_data import FittingProblem +from mock_problem_files.get_problem_files import get_file + + +class ParseSasViewTests(unittest.TestCase): + + def get_bench_prob_dir(self): + + prob_path = get_file('SV_cyl_400_20.txt') + bench_prob_dir = os.path.dirname(prob_path) + + return bench_prob_dir + + def expected_SasView_problem_entries(self): + + entries = {} + entries['name'] = "Problem Def 1" + + entries['input_file'] = "SV_cyl_400_20.txt" + + entries['function'] = ("name=cylinder,radius=35.0,length=350.0,background=0.0,scale=1.0,sld=4.0,sld_solvent=1.0") + entries['parameter_ranges'] = ("radius.range(1,50);length.range(1,500)") + entries['description'] = '' + + return entries + + def expected_neutron_problem(self): + + bench_prob_dir = self.get_bench_prob_dir() + entries = self.expected_SasView_problem_entries() + fname = get_file('SV_prob_def_1.txt') + problem = FittingProblem(fname) + problem.name = entries['name'] + problem.equation = (entries['function'].split(',', 1))[0] + problem.starting_values = (entries['function'].split(',', 1))[1] + self._starting_value_ranges = entries['parameter_ranges'] + data_file = os.path.join(bench_prob_dir, entries['input_file']) + store_main_problem_data(data_file, problem) + + return problem + + def test_eval_f(self): + + fname = get_file('SV_prob_def_1.txt') + + problem = FittingProblem(fname) + + y_values = problem.eval_f(problem.data_x[:10], 'radius=35,length=350,background=0.0,scale=1.0,sld=4.0,sld_solvent=1.0') + + y_values_expected = np.array([3.34929172e+02, 9.61325700e+01, 1.92262557e+01, 1.21868330e+00, 7.96766671e-01, 1.27924690e+00, 4.70120551e-01, 1.96891429e-02, 1.54944490e-01, 1.69598579e-01]) + + np.testing.assert_allclose(y_values_expected, y_values, rtol=1e-5, atol=0) + + def test_getFunction_returns_correct_function(self): + + fname = get_file('SV_prob_def_1.txt') + + problem = FittingProblem(fname) + + function = problem.get_function() + + function_obj_str = function[0][0] + + y_values = function_obj_str(problem.data_x[:10], 'radius=35,length=350,background=0.0,scale=1.0,sld=4.0,sld_solvent=1.0') + + y_values_expected = np.array( + [3.34929172e+02, 9.61325700e+01, 1.92262557e+01, 1.21868330e+00, 7.96766671e-01, 1.27924690e+00, + 4.70120551e-01, 1.96891429e-02, 1.54944490e-01, 1.69598579e-01]) + + param_array = function[0][1] + + param_array_expected = np.array([35., 350., 0., 1., 4., 1.]) + + np.testing.assert_allclose(y_values_expected, y_values, rtol=1e-5, atol=0) + np.testing.assert_array_equal(param_array_expected, param_array) + + def test_ParseProblemFileSasView_returns_correct_problem_object(self): + + fname = get_file('SV_prob_def_1.txt') + + problem = parse_problem_file(fname) + problem_expected = self.expected_neutron_problem() + + self.assertEqual(problem_expected.name, problem.name) + self.assertEqual(problem_expected.equation, problem.equation) + self.assertEqual(problem_expected.starting_values, + problem.starting_values) + self.assertEqual(problem_expected.start_x, problem.start_x) + self.assertEqual(problem_expected.end_x, problem.end_x) + + def test_getDataFilesDir_return_data_files_path(self): + + fname = get_file('SV_prob_def_1.txt') + input_file = 'SV_cyl_400_20.txt' + + bench_prob_dir = self.get_bench_prob_dir() + prob = FittingProblem(fname) + data_file = prob.get_data_file(fname, input_file) + data_file_expected = os.path.join(bench_prob_dir, input_file) + + self.assertEqual(data_file_expected, data_file) + + def test_getSasViewDataProblemEntries_return_problem_entries(self): + + fname = get_file('SV_prob_def_1.txt') + prob = FittingProblem(fname) + with open(fname) as probf: + entries = prob.get_data_problem_entries(probf) + entries_expected = self.expected_SasView_problem_entries() + + self.assertEqual(entries_expected['name'], entries['name']) + self.assertEqual(entries_expected['input_file'], entries['input_file']) + self.assertEqual(entries_expected['function'], entries['function']) + self.assertEqual(entries_expected['parameter_ranges'], + entries['parameter_ranges']) + self.assertEqual(entries_expected['description'], + entries['description']) + + def test_storeMiscProbData(self): + fname = get_file('SV_prob_def_1.txt') + problem = FittingProblem(fname) + entries = self.expected_SasView_problem_entries() + + self.assertEqual(entries['name'], problem.name) + self.assertEqual((entries['function'].split(',', 1))[0], problem.equation) + self.assertEqual((entries['function'].split(',', 1))[1], problem.starting_values) + self.assertEqual(entries['parameter_ranges'], problem.starting_value_ranges) + # + def test_checkingAttributesAssertion(self): + fname = get_file('SV_prob_def_1.txt') + prob = FittingProblem(fname) + # with self.assertRaises(ValueError): + check_problem_attributes(prob) + + def test_checkingDetermineProblemType(self): + f = open("RandomData.txt", "w+") + for i in range(10): + f.write("This is line %d\r\n" % (i + 1)) + f.close() + with self.assertRaises(RuntimeError): + determine_problem_type("RandomData.txt") + os.remove("RandomData.txt") + + def test_get_start_x_and_end_x(self): + + fname = get_file('SV_prob_def_1.txt') + prob = FittingProblem(fname) + + x_data = prob.data_x + + expected_start_x = 0.025 + expected_end_x = 0.5 + + start_x, end_x = prob.get_start_x_and_end_x(x_data) + + self.assertEqual(expected_start_x, start_x) + self.assertEqual(expected_end_x, end_x) + + +if __name__ == "__main__": + unittest.main() diff --git a/sas/README.md b/sas/README.md new file mode 100644 index 000000000..4f77ab3c4 --- /dev/null +++ b/sas/README.md @@ -0,0 +1,4 @@ +# `sascalc` +`sascalc` is a part of [SasView](https://github.com/SasView/sasview). It contains modules that are used by an independent package `sasmodels`. As of this writing, `sascalc` is not yet an independent package that can be install using `pip`. + +As a result, `sascalc` is currently added here temporarily for the purpose of running SasView/Bumps fitting headless in FitBenchmarking. diff --git a/sas/__init__.py b/sas/__init__.py new file mode 100755 index 000000000..261f1cfb2 --- /dev/null +++ b/sas/__init__.py @@ -0,0 +1,57 @@ +__all__ = ['get_app_dir', 'get_user_dir', + 'get_local_config', 'get_custom_config'] + +_APP_DIR = None +def get_app_dir(): + """ + The directory where the sasview application is found. + + Returns the path to sasview if running in place or installed with setup. + If the application is frozen, returns the parent directory of the + application resources such as test files and images. + """ + global _APP_DIR + if not _APP_DIR: + from ._config import find_app_dir + _APP_DIR = find_app_dir() + return _APP_DIR + +_USER_DIR = None +def get_user_dir(): + """ + The directory where the per-user configuration is stored. + + Returns ~/.sasview, creating it if it does not already exist. + """ + global _USER_DIR + if not _USER_DIR: + from ._config import make_user_dir + _USER_DIR = make_user_dir() + return _USER_DIR + +def make_custom_config_path(): + from ._config import make_custom_config_path as _make_path + return _make_path(get_user_dir()) + +_CUSTOM_CONFIG = None +def get_custom_config(): + """ + Setup the custom config dir and cat file + """ + global _CUSTOM_CONFIG + if not _CUSTOM_CONFIG: + from ._config import setup_custom_config + _CUSTOM_CONFIG = setup_custom_config(get_app_dir(), get_user_dir()) + return _CUSTOM_CONFIG + + +_LOCAL_CONFIG = None +def get_local_config(): + """ + Loads the local config file. + """ + global _LOCAL_CONFIG + if not _LOCAL_CONFIG: + from ._config import load_local_config + _LOCAL_CONFIG = load_local_config(get_app_dir()) + return _LOCAL_CONFIG diff --git a/sas/_config.py b/sas/_config.py new file mode 100755 index 000000000..d857b52a6 --- /dev/null +++ b/sas/_config.py @@ -0,0 +1,121 @@ +# Setup and find Custom config dir +from __future__ import print_function + +import sys +import os +from os.path import exists, expanduser, dirname, realpath, join as joinpath +import logging +import shutil + +from sasmodels.custom import load_module_from_path + +logger = logging.getLogger(__name__) + +def dirn(path, n): + """ + Return the directory n up from the current path + """ + path = realpath(path) + for _ in range(n): + path = dirname(path) + return path + +def find_app_dir(): + """ + Locate the parent directory of the sasview resources. For the normal + application this will be the directory containing sasview.py. For the + frozen application this will be the path where the resources are installed. + """ + # We are starting out with the following info: + # __file__ = .../sas/__init__.pyc + # Check if the path .../sas/sasview exists, and use it as the + # app directory. This will only be the case if the app is not frozen. + path = joinpath(dirname(__file__), 'sasview') + if exists(path): + return path + + # If we are running frozen, then root is a parent directory + if sys.platform == 'darwin': + # Here is the path to the file on the mac: + # .../Sasview.app/Contents/Resources/lib/python2.7/site-packages.zip/sas/__init__.pyc + # We want the path to the Resources directory. + path = dirn(__file__, 5) + elif os.name == 'nt': + # Here is the path to the file on windows: + # ../Sasview/library.zip/sas/__init__.pyc + # We want the path to the Sasview directory. + path = dirn(__file__, 3) + else: + raise RuntimeError("Couldn't find the app directory") + return path + +def make_user_dir(): + """ + Create the user directory ~/.sasview if it doesn't already exist. + """ + path = joinpath(expanduser("~"),'.sasview') + if not exists(path): + os.mkdir(path) + return path + +def load_local_config(app_dir): + logger = logging.getLogger(__name__) + filename = 'local_config.py' + path = os.path.join(app_dir, filename) + try: + module = load_module_from_path('sas.local_config', path) + logger.info("GuiManager loaded %s", path) + return module + except Exception as exc: + logger.critical("Error loading %s: %s", path, exc) + sys.exit() + +def make_custom_config_path(user_dir): + """ + The location of the cusstom config file. + + Returns ~/.sasview/config/custom_config.py + """ + dirname = os.path.join(user_dir, 'config') + # If the directory doesn't exist, create it + if not os.path.exists(dirname): + os.makedirs(dirname) + path = os.path.join(dirname, "custom_config.py") + return path + +def setup_custom_config(app_dir, user_dir): + path = make_custom_config_path(user_dir) + #logger.info("custom config path %s", path) + if not os.path.isfile(path): + try: + # if the custom config file does not exist, copy the default from + # the app dir + shutil.copyfile(os.path.join(app_dir, "custom_config.py"), path) + except Exception: + logger.error("Could not copy default custom config.") + + custom_config = load_custom_config(path) + + #Adding SAS_OPENCL if it doesn't exist in the config file + # - to support backcompability + if not hasattr(custom_config, "SAS_OPENCL"): + custom_config.SAS_OPENCL = None + try: + open(path, "a+").write("SAS_OPENCL = \"None\"\n") + except Exception: + logger.error("Could not update custom config with SAS_OPENCL.") + + return custom_config + +def load_custom_config(path): + if os.path.exists(path): + try: + module = load_module_from_path('sas.custom_config', path) + logger.info("GuiManager loaded %s", path) + return module + except Exception as exc: + logger.error("Error loading %s: %s", path, exc) + + from sas.sasview import custom_config + logger.info("GuiManager custom_config defaults to sas.sasview.custom_config") + return custom_config diff --git a/sas/logger_config.py b/sas/logger_config.py new file mode 100755 index 000000000..22fe7bb22 --- /dev/null +++ b/sas/logger_config.py @@ -0,0 +1,84 @@ +from __future__ import print_function + +import logging +import logging.config +import os +import os.path + + +''' +Module that manages the global logging +''' + + +class SetupLogger(object): + ''' + Called at the beginning of run.py or sasview.py + ''' + + def __init__(self, logger_name): + self._find_config_file() + self.name = logger_name + + def config_production(self): + logger = logging.getLogger(self.name) + if not logger.root.handlers: + self._read_config_file() + logging.captureWarnings(True) + logger = logging.getLogger(self.name) + return logger + + def config_development(self): + ''' + ''' + self._read_config_file() + logger = logging.getLogger(self.name) + self._update_all_logs_to_debug(logger) + logging.captureWarnings(True) + return logger + + def _read_config_file(self): + if self.config_file is not None: + logging.config.fileConfig(self.config_file) + + def _update_all_logs_to_debug(self, logger): + ''' + This updates all loggers and respective handlers to DEBUG + ''' + for handler in logger.handlers or logger.parent.handlers: + handler.setLevel(logging.DEBUG) + for name, _ in logging.Logger.manager.loggerDict.items(): + logging.getLogger(name).setLevel(logging.DEBUG) + + def _find_config_file(self, filename="logging.ini"): + ''' + The config file is in: + Debug ./sasview/ + Packaging: sas/sasview/ + Packaging / production does not work well with absolute paths + thus the multiple paths below + ''' + places_to_look_for_conf_file = [ + os.path.join(os.path.abspath(os.path.dirname(__file__)), filename), + filename, + os.path.join("sas", "sasview", filename), + os.path.join(os.getcwd(), "sas", "sasview", filename), + ] + + # To avoid the exception in OSx + # NotImplementedError: resource_filename() only supported for .egg, not .zip + try: + import pkg_resources + places_to_look_for_conf_file.append( + pkg_resources.resource_filename(__name__, filename)) + except ImportError: + pass + except NotImplementedError: + pass + + for filepath in places_to_look_for_conf_file: + if os.path.exists(filepath): + self.config_file = filepath + return + print("ERROR: Logging.ini not found...") + self.config_file = None diff --git a/sas/logging.ini b/sas/logging.ini new file mode 100755 index 000000000..9a98225bd --- /dev/null +++ b/sas/logging.ini @@ -0,0 +1,73 @@ + +############################################################################### +################################### LOGGING ################################### +############################################################################### +# Main logger for SASView + +# SEE: https://docs.python.org/2/library/logging.html#logrecord-attributes +[formatters] +keys=simple,detailed + +[formatter_simple] +#format=%(asctime)s - %(name)s - %(levelname)s - %(message)s +#format=%(asctime)s - %(levelname)s : %(name)s:%(pathname)s:%(lineno)4d: %(message)s +format=%(asctime)s - %(levelname)s : %(name)s:%(lineno)4d: %(message)s +datefmt=%H:%M:%S + +[formatter_detailed] +#format=%(asctime)s : %(levelname)s : %(name)s: %(lineno)d: %(message)s +format=%(asctime)s : %(levelname)s : %(name)s (%(filename)s:%(lineno)s) :: %(message)s + +############################################################################### +# Handlers + +[handlers] +keys=console,log_file + +[handler_console] +class=logging.StreamHandler +formatter=simple +level=WARNING +args=tuple() + +[handler_log_file] +class=logging.FileHandler +level=DEBUG +formatter=detailed +args=(os.path.join(os.path.expanduser("~"),'sasview.log'),"a") + +############################################################################### +# Loggers + +[loggers] +keys=root,saspr,sasgui,sascalc,sasmodels + +[logger_root] +level=DEBUG +formatter=default +handlers=console,log_file + +[logger_sasmodels] +level=INFO +qualname=sas.models +handlers=console,log_file +propagate=0 + +[logger_saspr] +level=INFO +qualname=sas.pr +handlers=console,log_file +propagate=0 + +[logger_sasgui] +level=DEBUG +qualname=sas.sasgui +handlers=console,log_file +propagate=0 + +[logger_sascalc] +level=INFO +qualname=sas.sascalc +handlers=console,log_file +propagate=0 + diff --git a/sas/sascalc/__init__.py b/sas/sascalc/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/sas/sascalc/calculator/BaseComponent.py b/sas/sascalc/calculator/BaseComponent.py new file mode 100755 index 000000000..62ffe3ea5 --- /dev/null +++ b/sas/sascalc/calculator/BaseComponent.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python + +""" +Provide base functionality for all model components +""" + +# imports +import copy +from collections import OrderedDict + +import numpy as np +#TO DO: that about a way to make the parameter +#is self return if it is fittable or not + +class BaseComponent: + """ + Basic model component + + Since version 0.5.0, basic operations are no longer supported. + """ + + def __init__(self): + """ Initialization""" + + ## Name of the model + self.name = "BaseComponent" + + ## Parameters to be accessed by client + self.params = {} + self.details = {} + ## Dictionary used to store the dispersity/averaging + # parameters of dispersed/averaged parameters. + self.dispersion = {} + # string containing information about the model such as the equation + #of the given model, exception or possible use + self.description = '' + #list of parameter that can be fitted + self.fixed = [] + #list of non-fittable parameter + self.non_fittable = [] + ## parameters with orientation + self.orientation_params = [] + ## magnetic parameters + self.magnetic_params = [] + ## store dispersity reference + self._persistency_dict = {} + ## independent parameter name and unit [string] + self.input_name = "Q" + self.input_unit = "A^{-1}" + ## output name and unit [string] + self.output_name = "Intensity" + self.output_unit = "cm^{-1}" + + self.is_multiplicity_model = False + self.is_structure_factor = False + self.is_form_factor = False + + def __str__(self): + """ + :return: string representatio + """ + return self.name + + def is_fittable(self, par_name): + """ + Check if a given parameter is fittable or not + + :param par_name: the parameter name to check + + """ + return par_name.lower() in self.fixed + #For the future + #return self.params[str(par_name)].is_fittable() + + def run(self, x): + """ + run 1d + """ + return NotImplemented + + def runXY(self, x): + """ + run 2d + """ + return NotImplemented + + def calculate_ER(self): + """ + Calculate effective radius + """ + return NotImplemented + + def calculate_VR(self): + """ + Calculate volume fraction ratio + """ + return NotImplemented + + def evalDistribution(self, qdist): + """ + Evaluate a distribution of q-values. + + * For 1D, a numpy array is expected as input: :: + + evalDistribution(q) + + where q is a numpy array. + + + * For 2D, a list of numpy arrays are expected: [qx_prime,qy_prime], + where 1D arrays, :: + + qx_prime = [ qx[0], qx[1], qx[2], ....] + + and :: + + qy_prime = [ qy[0], qy[1], qy[2], ....] + + Then get :: + + q = np.sqrt(qx_prime^2+qy_prime^2) + + that is a qr in 1D array; :: + + q = [q[0], q[1], q[2], ....] + + .. note:: Due to 2D speed issue, no anisotropic scattering + is supported for python models, thus C-models should have + their own evalDistribution methods. + + The method is then called the following way: :: + + evalDistribution(q) + + where q is a numpy array. + + :param qdist: ndarray of scalar q-values or list [qx,qy] where qx,qy are 1D ndarrays + """ + if qdist.__class__.__name__ == 'list': + # Check whether we have a list of ndarrays [qx,qy] + if len(qdist)!=2 or \ + qdist[0].__class__.__name__ != 'ndarray' or \ + qdist[1].__class__.__name__ != 'ndarray': + msg = "evalDistribution expects a list of 2 ndarrays" + raise RuntimeError(msg) + + # Extract qx and qy for code clarity + qx = qdist[0] + qy = qdist[1] + + # calculate q_r component for 2D isotropic + q = np.sqrt(qx**2+qy**2) + # vectorize the model function runXY + v_model = np.vectorize(self.runXY, otypes=[float]) + # calculate the scattering + iq_array = v_model(q) + + return iq_array + + elif qdist.__class__.__name__ == 'ndarray': + # We have a simple 1D distribution of q-values + v_model = np.vectorize(self.runXY, otypes=[float]) + iq_array = v_model(qdist) + return iq_array + + else: + mesg = "evalDistribution is expecting an ndarray of scalar q-values" + mesg += " or a list [qx,qy] where qx,qy are 2D ndarrays." + raise RuntimeError(mesg) + + + + def clone(self): + """ Returns a new object identical to the current object """ + obj = copy.deepcopy(self) + return self._clone(obj) + + def _clone(self, obj): + """ + Internal utility function to copy the internal + data members to a fresh copy. + """ + obj.params = copy.deepcopy(self.params) + obj.details = copy.deepcopy(self.details) + obj.dispersion = copy.deepcopy(self.dispersion) + obj._persistency_dict = copy.deepcopy( self._persistency_dict) + return obj + + def set_dispersion(self, parameter, dispersion): + """ + model dispersions + """ + ##Not Implemented + return None + + def getProfile(self): + """ + Get SLD profile + + : return: (z, beta) where z is a list of depth of the transition points + beta is a list of the corresponding SLD values + """ + #Not Implemented + return None, None + + def setParam(self, name, value): + """ + Set the value of a model parameter + + :param name: name of the parameter + :param value: value of the parameter + + """ + # Look for dispersion parameters + toks = name.split('.') + if len(toks)==2: + for item in self.dispersion.keys(): + if item.lower()==toks[0].lower(): + for par in self.dispersion[item]: + if par.lower() == toks[1].lower(): + self.dispersion[item][par] = value + return + else: + # Look for standard parameter + for item in self.params.keys(): + if item.lower()==name.lower(): + self.params[item] = value + return + + raise ValueError("Model does not contain parameter %s" % name) + + def getParam(self, name): + """ + Set the value of a model parameter + :param name: name of the parameter + + """ + # Look for dispersion parameters + toks = name.split('.') + if len(toks)==2: + for item in self.dispersion.keys(): + if item.lower()==toks[0].lower(): + for par in self.dispersion[item]: + if par.lower() == toks[1].lower(): + return self.dispersion[item][par] + else: + # Look for standard parameter + for item in self.params.keys(): + if item.lower()==name.lower(): + return self.params[item] + + raise ValueError("Model does not contain parameter %s" % name) + + def getParamList(self): + """ + Return a list of all available parameters for the model + """ + list = _ordered_keys(self.params) + # WARNING: Extending the list with the dispersion parameters + list.extend(self.getDispParamList()) + return list + + def getDispParamList(self): + """ + Return a list of all available parameters for the model + """ + list = [] + for item in _ordered_keys(self.dispersion): + for p in _ordered_keys(self.dispersion[item]): + if p not in ['type']: + list.append('%s.%s' % (item.lower(), p.lower())) + + return list + + # Old-style methods that are no longer used + def setParamWithToken(self, name, value, token, member): + """ + set Param With Token + """ + return NotImplemented + def getParamWithToken(self, name, token, member): + """ + get Param With Token + """ + return NotImplemented + + def getParamListWithToken(self, token, member): + """ + get Param List With Token + """ + return NotImplemented + def __add__(self, other): + """ + add + """ + raise ValueError("Model operation are no longer supported") + def __sub__(self, other): + """ + sub + """ + raise ValueError("Model operation are no longer supported") + def __mul__(self, other): + """ + mul + """ + raise ValueError("Model operation are no longer supported") + def __div__(self, other): + """ + div + """ + raise ValueError("Model operation are no longer supported") + + +def _ordered_keys(d): + keys = list(d.keys()) + if not isinstance(d, OrderedDict): + keys.sort() + return keys diff --git a/sas/sascalc/calculator/__init__.py b/sas/sascalc/calculator/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/sas/sascalc/calculator/_sld2i.so b/sas/sascalc/calculator/_sld2i.so new file mode 100755 index 000000000..ef21b53f3 Binary files /dev/null and b/sas/sascalc/calculator/_sld2i.so differ diff --git a/sas/sascalc/calculator/c_extensions/libfunc.c b/sas/sascalc/calculator/c_extensions/libfunc.c new file mode 100755 index 000000000..ce6ccee06 --- /dev/null +++ b/sas/sascalc/calculator/c_extensions/libfunc.c @@ -0,0 +1,313 @@ +// by jcho + +#include + +#include "libfunc.h" + +#include + + + +//used in Si func + +int factorial(int i) { + + int k, j; + if (i<2){ + return 1; + } + + k=1; + + for(j=1;j= pi*6.2/4.0){ + double out_sin = 0.0; + double out_cos = 0.0; + out = pi/2.0; + + for (i=0; i0) { + if (m_max < 1.0e-32){ + uu = sqrt(sqrt(in_spin * out_spin)) * uu; + dd = sqrt(sqrt((1.0 - in_spin) * (1.0 - out_spin))) * dd; + } + } + else if (fabs(m_max)< 1.0e-32 && fabs(m_phi)< 1.0e-32 && fabs(m_theta)< 1.0e-32){ + uu = sqrt(sqrt(in_spin * out_spin)) * uu; + dd = sqrt(sqrt((1.0 - in_spin) * (1.0 - out_spin))) * dd; + } else { + + //These are needed because of the precision of inputs + if (in_spin < 0.0) in_spin = 0.0; + if (in_spin > 1.0) in_spin = 1.0; + if (out_spin < 0.0) out_spin = 0.0; + if (out_spin > 1.0) out_spin = 1.0; + + if (q_x == 0.0) q_angle = pi / 2.0; + else q_angle = atan(q_y/q_x); + if (q_y < 0.0 && q_x < 0.0) q_angle -= pi; + else if (q_y > 0.0 && q_x < 0.0) q_angle += pi; + + q_angle = pi/2.0 - q_angle; + if (q_angle > pi) q_angle -= 2.0 * pi; + else if (q_angle < -pi) q_angle += 2.0 * pi; + + if (fabs(q_x) < 1.0e-16 && fabs(q_y) < 1.0e-16){ + m_perp = 0.0; + } + else { + m_perp = m_max; + } + if (is_angle > 0){ + m_phi *= pi/180.0; + m_theta *= pi/180.0; + mx = m_perp * cos(m_theta) * cos(m_phi); + my = m_perp * sin(m_theta); + mz = -(m_perp * cos(m_theta) * sin(m_phi)); + } + else{ + mx = m_perp; + my = m_phi; + mz = m_theta; + } + //ToDo: simplify these steps + // m_perp1 -m_perp2 + m_perp_x = (mx) * cos(q_angle); + m_perp_x -= (my) * sin(q_angle); + m_perp_y = m_perp_x; + m_perp_x *= cos(-q_angle); + m_perp_y *= sin(-q_angle); + m_perp_z = mz; + + m_sigma_x = (m_perp_x * cos(-s_theta) - m_perp_y * sin(-s_theta)); + m_sigma_y = (m_perp_x * sin(-s_theta) + m_perp_y * cos(-s_theta)); + m_sigma_z = (m_perp_z); + + //Find b + uu -= m_sigma_x; + dd += m_sigma_x; + re_ud = m_sigma_y; + re_du = m_sigma_y; + im_ud = m_sigma_z; + im_du = -m_sigma_z; + + uu = sqrt(sqrt(in_spin * out_spin)) * uu; + dd = sqrt(sqrt((1.0 - in_spin) * (1.0 - out_spin))) * dd; + + re_ud = sqrt(sqrt(in_spin * (1.0 - out_spin))) * re_ud; + im_ud = sqrt(sqrt(in_spin * (1.0 - out_spin))) * im_ud; + re_du = sqrt(sqrt((1.0 - in_spin) * out_spin)) * re_du; + im_du = sqrt(sqrt((1.0 - in_spin) * out_spin)) * im_du; + } + p_sld->uu = uu; + p_sld->dd = dd; + p_sld->re_ud = re_ud; + p_sld->im_ud = im_ud; + p_sld->re_du = re_du; + p_sld->im_du = im_du; +} + + +/** Modifications below by kieranrcampbell@gmail.com + Institut Laue-Langevin, July 2012 +**/ + +/** +Wojtek's comment Mar 22 2016: The remaing code can mostly likely be deleated +Keeping it in order to check if it is not breaking anything +**/ + +/* +#define ITMAX 100 +#define EPS 3.0e-7 +#define FPMIN 1.0e-30 + +void gser(float *gamser, float a, float x, float *gln) { + int n; + float sum,del,ap; + + *gln = lgamma(a); + if(x <= 0.0) { + if (x < 0.0) printf("Error: x less than 0 in routine gser"); + *gamser = 0.0; + return; + } else { + ap = a; + del = sum = 1.0/a; + + for(n=1;n<=ITMAX;n++) { + ++ap; + del *= x/ap; + sum += del; + if(fabs(del) < fabs(sum)*EPS) { + *gamser = sum * exp(-x + a * log(x) - (*gln)); + return; + } + } + printf("a too large, ITMAX too small in routine gser"); + return; + + } + +} +*/ +/** + Implements the incomplete gamma function Q(a,x) evaluated by its continued fraction + representation +**/ +/* +void gcf(float *gammcf, float a, float x, float *gln) { + int i; + float an,b,c,d,del,h; + + *gln = lgamma(a); + b = x+1.0-a; + c = 1.0/FPMIN; + d = 1.0/b; + h=d; + for (i=1;i <= ITMAX; i++) { + an = -i*(i-a); + b += 2.0; + d = an*d + b; + if (fabs(d) < FPMIN) d = FPMIN; + c = b+an/c; + if (fabs(c) < FPMIN) c = FPMIN; + d = 1.0/d; + del = d*c; + h += del; + if (fabs(del-1.0) < EPS) break; + } + if (i > ITMAX) printf("a too large, ITMAX too small in gcf"); + *gammcf = exp(-x+a*log(x)-(*gln))*h; + return; +} +*/ +/** + Represents incomplete error function, P(a,x) +**/ +/* +float gammp(float a, float x) { + float gamser,gammcf,gln; + if(x < 0.0 || a <= 0.0) printf("Invalid arguments in routine gammp"); + if (x < (a+1.0)) { + gser(&gamser,a,x,&gln); + return gamser; + } else { + gcf(&gammcf,a,x,&gln); + return 1.0 - gammcf; + } +} +*/ +/** + Implementation of the error function, erf(x) +**/ +/* +float erff(float x) { + return x < 0.0 ? -gammp(0.5,x*x) : gammp(0.5,x*x); +} +*/ diff --git a/sas/sascalc/calculator/c_extensions/libfunc.h b/sas/sascalc/calculator/c_extensions/libfunc.h new file mode 100755 index 000000000..ef6687afb --- /dev/null +++ b/sas/sascalc/calculator/c_extensions/libfunc.h @@ -0,0 +1,31 @@ +#if !defined(o_h) +#define libfunc_h +typedef struct { + double uu; + double dd; + double re_ud; + double im_ud; + double re_du; + double im_du; +} polar_sld; + +int factorial(int i); + +double Si(double x); + +double sinc(double x); + +//double gamln(double x); + +void cal_msld(polar_sld*, int isangle, double qx, double qy, double bn, double m01, double mtheta1, + double mphi1, double spinfraci, double spinfracf, double spintheta); + +//void gser(float *gamser, float a, float x, float *gln); + +//void gcf(float *gammcf, float a, float x, float *gln); + +//float gammp(float a,float x); + +//float erff(float x); + +#endif diff --git a/sas/sascalc/calculator/c_extensions/librefl.c b/sas/sascalc/calculator/c_extensions/librefl.c new file mode 100755 index 000000000..42da9919c --- /dev/null +++ b/sas/sascalc/calculator/c_extensions/librefl.c @@ -0,0 +1,400 @@ +// The original code, of which work was not DANSE funded, +// was provided by J. Cho. +// And modified to fit sansmodels/sansview: JC + +#include +#include "librefl.h" +#include +#include +#if defined _MSC_VER || defined __TINYCC__ +#define NEED_ERF +#endif + + + +#if defined(NEED_ERF) +/* erf.c - public domain implementation of error function erf(3m) + +reference - Haruhiko Okumura: C-gengo niyoru saishin algorithm jiten + (New Algorithm handbook in C language) (Gijyutsu hyouron + sha, Tokyo, 1991) p.227 [in Japanese] */ + + +#ifdef __TINYCC__ +# ifdef isnan +# undef isnan +# endif +# ifdef isfinite +# undef isfinite +# endif +# define isnan(x) (x != x) +# define isfinite(x) (x != INFINITY && x != -INFINITY) +#elif defined _WIN32 +# include +# if !defined __MINGW32__ || defined __NO_ISOCEXT +# ifndef isnan +# define isnan(x) _isnan(x) +# endif +# ifndef isinf +# define isinf(x) (!_finite(x) && !_isnan(x)) +# endif +# ifndef isfinite +# define isfinite(x) _finite(x) +# endif +# endif +#endif + +static double q_gamma(double, double, double); + +/* Incomplete gamma function + 1 / Gamma(a) * Int_0^x exp(-t) t^(a-1) dt */ +static double p_gamma(double a, double x, double loggamma_a) +{ + int k; + double result, term, previous; + + if (x >= 1 + a) return 1 - q_gamma(a, x, loggamma_a); + if (x == 0) return 0; + result = term = exp(a * log(x) - x - loggamma_a) / a; + for (k = 1; k < 1000; k++) { + term *= x / (a + k); + previous = result; result += term; + if (result == previous) return result; + } + fprintf(stderr, "erf.c:%d:p_gamma() could not converge.", __LINE__); + return result; +} + +/* Incomplete gamma function + 1 / Gamma(a) * Int_x^inf exp(-t) t^(a-1) dt */ +static double q_gamma(double a, double x, double loggamma_a) +{ + int k; + double result, w, temp, previous; + double la = 1, lb = 1 + x - a; /* Laguerre polynomial */ + + if (x < 1 + a) return 1 - p_gamma(a, x, loggamma_a); + w = exp(a * log(x) - x - loggamma_a); + result = w / lb; + for (k = 2; k < 1000; k++) { + temp = ((k - 1 - a) * (lb - la) + (k + x) * lb) / k; + la = lb; lb = temp; + w *= (k - 1 - a) / k; + temp = w / (la * lb); + previous = result; result += temp; + if (result == previous) return result; + } + fprintf(stderr, "erf.c:%d:q_gamma() could not converge.", __LINE__); + return result; +} + +#define LOG_PI_OVER_2 0.572364942924700087071713675675 /* log_e(PI)/2 */ + +double erf(double x) +{ + if (!isfinite(x)) { + if (isnan(x)) return x; /* erf(NaN) = NaN */ + return (x>0 ? 1.0 : -1.0); /* erf(+-inf) = +-1.0 */ + } + if (x >= 0) return p_gamma(0.5, x * x, LOG_PI_OVER_2); + else return - p_gamma(0.5, x * x, LOG_PI_OVER_2); +} + +double erfc(double x) +{ + if (!isfinite(x)) { + if (isnan(x)) return x; /* erfc(NaN) = NaN */ + return (x>0 ? 0.0 : 2.0); /* erfc(+-inf) = 0.0, 2.0 */ + } + if (x >= 0) return q_gamma(0.5, x * x, LOG_PI_OVER_2); + else return 1 + p_gamma(0.5, x * x, LOG_PI_OVER_2); +} +#endif // NEED_ERF + +void cassign(Cplx *x, double real, double imag) +{ + x->re = real; + x->im = imag; +} + + +void cplx_add(Cplx *z, Cplx x, Cplx y) +{ + z->re = x.re + y.re; + z->im = x.im + y.im; +} + +void rcmult(Cplx *z, double x, Cplx y) +{ + z->re = x*y.re; + z->im = x*y.im; +} + +void cplx_sub(Cplx *z, Cplx x, Cplx y) +{ + z->re = x.re - y.re; + z->im = x.im - y.im; +} + + +void cplx_mult(Cplx *z, Cplx x, Cplx y) +{ + z->re = x.re*y.re - x.im*y.im; + z->im = x.re*y.im + x.im*y.re; +} + +void cplx_div(Cplx *z, Cplx x, Cplx y) +{ + z->re = (x.re*y.re + x.im*y.im)/(y.re*y.re + y.im*y.im); + z->im = (x.im*y.re - x.re*y.im)/(y.re*y.re + y.im*y.im); +} + +void cplx_exp(Cplx *z, Cplx b) +{ + double br,bi; + br=b.re; + bi=b.im; + z->re = exp(br)*cos(bi); + z->im = exp(br)*sin(bi); +} + + +void cplx_sqrt(Cplx *c, Cplx z) //see Schaum`s Math Handbook p. 22, 6.6 and 6.10 +{ + double zr,zi,x,y,r,w; + + zr=z.re; + zi=z.im; + + if (zr==0.0 && zi==0.0) + { + c->re=0.0; + c->im=0.0; + } else { + x=fabs(zr); + y=fabs(zi); + if (x>y) + { + r=y/x; + w=sqrt(x)*sqrt(0.5*(1.0+sqrt(1.0+r*r))); + } else { + r=x/y; + w=sqrt(y)*sqrt(0.5*(r+sqrt(1.0+r*r))); + } + if (zr >=0.0) + { + c->re=w; + c->im=zi/(2.0*w); + } else { + c->im=(zi >= 0) ? w : -w; + c->re=zi/(2.0*c->im); + } + } +} + +void cplx_cos(Cplx *z, Cplx b) +{ + // cos(b) = (e^bi + e^-bi)/2 + // = (e^b.im e^-i bi.re) + e^-b.im e^i b.re)/2 + // = (e^b.im cos(-b.re) + e^b.im sin(-b.re) i)/2 + (e^-b.im cos(b.re) + e^-b.im sin(b.re) i)/2 + // = e^b.im cos(b.re)/2 - e^b.im sin(b.re)/2 i + 1/e^b.im cos(b.re)/2 + 1/e^b.im sin(b.re)/2 i + // = (e^b.im + 1/e^b.im)/2 cos(b.re) + (-e^b.im + 1/e^b.im)/2 sin(b.re) i + // = cosh(b.im) cos(b.re) - sinh(b.im) sin(b.re) i + double exp_b_im = exp(b.im); + z->re = 0.5*(+exp_b_im + 1.0/exp_b_im) * cos(b.re); + z->im = -0.5*(exp_b_im - 1.0/exp_b_im) * sin(b.re); +} + +// normalized and modified erf +// | +// 1 + __ - - - - +// | _ +// | _ +// | __ +// 0 + - - - +// |-------------+------------+-- +// 0 center n_sub ---> +// ind +// +// n_sub = total no. of bins(or sublayers) +// ind = x position: 0 to max +// nu = max x to integration +double err_mod_func(double n_sub, double ind, double nu) +{ + double center, func; + if (nu == 0.0) + nu = 1e-14; + if (n_sub == 0.0) + n_sub = 1.0; + + + //ind = (n_sub-1.0)/2.0-1.0 +ind; + center = n_sub/2.0; + // transform it so that min(ind) = 0 + ind -= center; + // normalize by max limit + ind /= center; + // divide by sqrt(2) to get Gaussian func + nu /= sqrt(2.0); + ind *= nu; + // re-scale and normalize it so that max(erf)=1, min(erf)=0 + func = erf(ind)/erf(nu)/2.0; + // shift it by +0.5 in y-direction so that min(erf) = 0 + func += 0.5; + + return func; +} +double linearfunc(double n_sub, double ind, double nu) +{ + double bin_size, func; + if (n_sub == 0.0) + n_sub = 1.0; + + bin_size = 1.0/n_sub; //size of each sub-layer + // rescale + ind *= bin_size; + func = ind; + + return func; +} +// use the right hand side from the center of power func +double power_r(double n_sub, double ind, double nu) +{ + double bin_size,func; + if (nu == 0.0) + nu = 1e-14; + if (n_sub == 0.0) + n_sub = 1.0; + + bin_size = 1.0/n_sub; //size of each sub-layer + // rescale + ind *= bin_size; + func = pow(ind, nu); + + return func; +} +// use the left hand side from the center of power func +double power_l(double n_sub, double ind, double nu) +{ + double bin_size, func; + if (nu == 0.0) + nu = 1e-14; + if (n_sub == 0.0) + n_sub = 1.0; + + bin_size = 1.0/n_sub; //size of each sub-layer + // rescale + ind *= bin_size; + func = 1.0-pow((1.0-ind),nu); + + return func; +} +// use 1-exp func from x=0 to x=1 +double exp_r(double n_sub, double ind, double nu) +{ + double bin_size, func; + if (nu == 0.0) + nu = 1e-14; + if (n_sub == 0.0) + n_sub = 1.0; + + bin_size = 1.0/n_sub; //size of each sub-layer + // rescale + ind *= bin_size; + // modify func so that func(0) =0 and func(max)=1 + func = 1.0-exp(-nu*ind); + // normalize by its max + func /= (1.0-exp(-nu)); + + return func; +} + +// use the left hand side mirror image of exp func +double exp_l(double n_sub, double ind, double nu) +{ + double bin_size, func; + if (nu == 0.0) + nu = 1e-14; + if (n_sub == 0.0) + n_sub = 1.0; + + bin_size = 1.0/n_sub; //size of each sub-layer + // rescale + ind *= bin_size; + // modify func + func = exp(-nu*(1.0-ind))-exp(-nu); + // normalize by its max + func /= (1.0-exp(-nu)); + + return func; +} + +// To select function called +// At nu = 0 (singular point), call line function +double intersldfunc(int fun_type, double n_sub, double i, double nu, double sld_l, double sld_r) +{ + double sld_i, func; + // this condition protects an error from the singular point + if (nu == 0.0){ + nu = 1e-13; + } + // select func + switch(fun_type){ + case 1 : + func = power_r(n_sub, i, nu); + break; + case 2 : + func = power_l(n_sub, i, nu); + break; + case 3 : + func = exp_r(n_sub, i, nu); + break; + case 4 : + func = exp_l(n_sub, i, nu); + break; + case 5 : + func = linearfunc(n_sub, i, nu); + break; + default: + func = err_mod_func(n_sub, i, nu); + break; + } + // compute sld + if (sld_r>sld_l){ + sld_i = (sld_r-sld_l)*func+sld_l; //sld_cal(sld[i],sld[i+1],n_sub,dz,thick); + } + else if (sld_rsld_l){ + sld_i = (sld_r-sld_l)*func+sld_l; //sld_cal(sld[i],sld[i+1],n_sub,dz,thick); + } + else if (sld_r +#include +#include "sld2i.h" +#include "libfunc.h" +#include "librefl.h" +/** + * Constructor for GenI + * + * binning + * //@param qx: array of Qx values + * //@param qy: array of Qy values + * //@param qz: array of Qz values + * @param x: array of x values + * @param y: array of y values + * @param z: array of z values + * @param sldn: array of sld n + * @param mx: array of sld mx + * @param my: array of sld my + * @param mz: array of sld mz + * @param in_spin: ratio of up spin in Iin + * @param out_spin: ratio of up spin in Iout + * @param s_theta: angle (from x-axis) of the up spin in degree + */ +void initGenI(GenI* this, int is_avg, int npix, double* x, double* y, double* z, double* sldn, + double* mx, double* my, double* mz, double* voli, + double in_spin, double out_spin, + double s_theta) { + this->is_avg = is_avg; + this->n_pix = npix; + this->x_val = x; + this->y_val = y; + this->z_val = z; + this->sldn_val = sldn; + this->mx_val = mx; + this->my_val = my; + this->mz_val = mz; + this->vol_pix = voli; + this->inspin = in_spin; + this->outspin = out_spin; + this->stheta = s_theta; +} + +/** + * Compute 2D anisotropic + */ +void genicomXY(GenI* this, int npoints, double *qx, double *qy, double *I_out){ + //npoints is given negative for angular averaging + // Assumes that q doesn't have qz component and sld_n is all real + //double q = 0.0; + //double Pi = 4.0*atan(1.0); + polar_sld b_sld; + double qr = 0.0; + Cplx iqr; + Cplx ephase; + Cplx comp_sld; + + Cplx sumj_uu; + Cplx sumj_ud; + Cplx sumj_du; + Cplx sumj_dd; + Cplx temp_fi; + + double count = 0.0; + int i, j; + + cassign(&iqr, 0.0, 0.0); + cassign(&ephase, 0.0, 0.0); + cassign(&comp_sld, 0.0, 0.0); + + //Assume that pixel volumes are given in vol_pix in A^3 unit + //int x_size = 0; //in Ang + //int y_size = 0; //in Ang + //int z_size = 0; //in Ang + + // Loop over q-values and multiply apply matrix + + //printf("npoints: %d, npix: %d\n", npoints, this->n_pix); + for(i=0; in_pix; j++){ + if (this->sldn_val[j]!=0.0 + ||this->mx_val[j]!=0.0 + ||this->my_val[j]!=0.0 + ||this->mz_val[j]!=0.0) + { + // printf("i,j: %d,%d\n", i,j); + //anisotropic + cassign(&temp_fi, 0.0, 0.0); + cal_msld(&b_sld, 0, qx[i], qy[i], this->sldn_val[j], + this->mx_val[j], this->my_val[j], this->mz_val[j], + this->inspin, this->outspin, this->stheta); + qr = (qx[i]*this->x_val[j] + qy[i]*this->y_val[j]); + cassign(&iqr, 0.0, qr); + cplx_exp(&ephase, iqr); + + //Let's multiply pixel(atomic) volume here + rcmult(&ephase, this->vol_pix[j], ephase); + //up_up + if (this->inspin > 0.0 && this->outspin > 0.0){ + cassign(&comp_sld, b_sld.uu, 0.0); + cplx_mult(&temp_fi, comp_sld, ephase); + cplx_add(&sumj_uu, sumj_uu, temp_fi); + } + //down_down + if (this->inspin < 1.0 && this->outspin < 1.0){ + cassign(&comp_sld, b_sld.dd, 0.0); + cplx_mult(&temp_fi, comp_sld, ephase); + cplx_add(&sumj_dd, sumj_dd, temp_fi); + } + //up_down + if (this->inspin > 0.0 && this->outspin < 1.0){ + cassign(&comp_sld, b_sld.re_ud, b_sld.im_ud); + cplx_mult(&temp_fi, comp_sld, ephase); + cplx_add(&sumj_ud, sumj_ud, temp_fi); + } + //down_up + if (this->inspin < 1.0 && this->outspin > 0.0){ + cassign(&comp_sld, b_sld.re_du, b_sld.im_du); + cplx_mult(&temp_fi, comp_sld, ephase); + cplx_add(&sumj_du, sumj_du, temp_fi); + } + + if (i == 0){ + count += this->vol_pix[j]; + } + } + } + //printf("aa%d=%g %g %d\n", i, (sumj_uu.re*sumj_uu.re + sumj_uu.im*sumj_uu.im), (sumj_dd.re*sumj_dd.re + sumj_dd.im*sumj_dd.im), count); + + I_out[i] = (sumj_uu.re*sumj_uu.re + sumj_uu.im*sumj_uu.im); + I_out[i] += (sumj_ud.re*sumj_ud.re + sumj_ud.im*sumj_ud.im); + I_out[i] += (sumj_du.re*sumj_du.re + sumj_du.im*sumj_du.im); + I_out[i] += (sumj_dd.re*sumj_dd.re + sumj_dd.im*sumj_dd.im); + + I_out[i] *= (1.0E+8 / count); //in cm (unit) / number; //to be multiplied by vol_pix + } + //printf("count = %d %g %g %g %g\n", count, this->sldn_val[0],this->mx_val[0], this->my_val[0], this->mz_val[0]); +} +/** + * Compute 1D isotropic + * Isotropic: Assumes all slds are real (no magnetic) + * Also assumes there is no polarization: No dependency on spin + */ +void genicom(GenI* this, int npoints, double *q, double *I_out){ + //npoints is given negative for angular averaging + // Assumes that q doesn't have qz component and sld_n is all real + //double Pi = 4.0*atan(1.0); + double qr = 0.0; + double sumj; + double sld_j = 0.0; + double count = 0.0; + int i, j, k; + + //Assume that pixel volumes are given in vol_pix in A^3 unit + // Loop over q-values and multiply apply matrix + for(i=0; in_pix; j++){ + //Isotropic: Assumes all slds are real (no magnetic) + //Also assumes there is no polarization: No dependency on spin + if (this->is_avg == 1){ + // approximation for a spherical symmetric particle + qr = sqrt(this->x_val[j]*this->x_val[j]+this->y_val[j]*this->y_val[j]+this->z_val[j]*this->z_val[j])*q[i]; + if (qr > 0.0){ + qr = sin(qr) / qr; + sumj += this->sldn_val[j] * this->vol_pix[j] * qr; + } + else{ + sumj += this->sldn_val[j] * this->vol_pix[j]; + } + } + else{ + //full calculation + //pragma omp parallel for + for(k=0; kn_pix; k++){ + sld_j = this->sldn_val[j] * this->sldn_val[k] * this->vol_pix[j] * this->vol_pix[k]; + qr = (this->x_val[j]-this->x_val[k])*(this->x_val[j]-this->x_val[k])+ + (this->y_val[j]-this->y_val[k])*(this->y_val[j]-this->y_val[k])+ + (this->z_val[j]-this->z_val[k])*(this->z_val[j]-this->z_val[k]); + qr = sqrt(qr) * q[i]; + if (qr > 0.0){ + sumj += sld_j*sin(qr)/qr; + } + else{ + sumj += sld_j; + } + } + } + if (i == 0){ + count += this->vol_pix[j]; + } + } + I_out[i] = sumj; + if (this->is_avg == 1) { + I_out[i] *= sumj; + } + I_out[i] *= (1.0E+8 / count); //in cm (unit) / number; //to be multiplied by vol_pix + } + //printf("count = %d %g %g %g %g\n", count, sldn_val[0],mx_val[0], my_val[0], mz_val[0]); +} diff --git a/sas/sascalc/calculator/c_extensions/sld2i.h b/sas/sascalc/calculator/c_extensions/sld2i.h new file mode 100755 index 000000000..987744726 --- /dev/null +++ b/sas/sascalc/calculator/c_extensions/sld2i.h @@ -0,0 +1,37 @@ +/** +Computes the (magnetic) scattering form sld (n and m) profile + */ +#ifndef SLD2I_CLASS_H +#define SLD2I_CLASS_H + +/** + * Base class + */ +typedef struct { + // vectors + int is_avg; + int n_pix; + double* x_val; + double* y_val; + double* z_val; + double* sldn_val; + double* mx_val; + double* my_val; + double* mz_val; + double* vol_pix; + // spin ratios + double inspin; + double outspin; + double stheta; +} GenI; + +// Constructor +void initGenI(GenI*, int is_avg, int npix, double* x, double* y, double* z, + double* sldn, double* mx, double* my, double* mz, double* voli, + double in_spin, double out_spin, + double s_theta); +// compute function +void genicomXY(GenI*, int npoints, double* qx, double* qy, double *I_out); +void genicom(GenI*, int npoints, double* q, double *I_out); + +#endif diff --git a/sas/sascalc/calculator/c_extensions/sld2i_module.c b/sas/sascalc/calculator/c_extensions/sld2i_module.c new file mode 100755 index 000000000..052cf0a0d --- /dev/null +++ b/sas/sascalc/calculator/c_extensions/sld2i_module.c @@ -0,0 +1,213 @@ +/** + SLD2I module to perform point and I calculations + */ +#include + +//#define Py_LIMITED_API 0x03020000 +#include + +#include "sld2i.h" + +#if PY_MAJOR_VERSION < 3 +typedef void (*PyCapsule_Destructor)(PyObject *); +typedef void (*PyCObject_Destructor)(void *); +#define PyCapsule_New(pointer, name, destructor) (PyCObject_FromVoidPtr(pointer, (PyCObject_Destructor)destructor)) +#define PyCapsule_GetPointer(capsule, name) (PyCObject_AsVoidPtr(capsule)) +#endif + +// Vector binding glue +#if (PY_VERSION_HEX > 0x03000000) && !defined(Py_LIMITED_API) + // Assuming that a view into a writable vector points to a + // non-changing pointer for the duration of the C call, capture + // the view pointer and immediately free the view. + #define VECTOR(VEC_obj, VEC_buf, VEC_len) do { \ + Py_buffer VEC_view; \ + int VEC_err = PyObject_GetBuffer(VEC_obj, &VEC_view, PyBUF_WRITABLE|PyBUF_FORMAT); \ + if (VEC_err < 0 || sizeof(*VEC_buf) != VEC_view.itemsize) return NULL; \ + VEC_buf = (typeof(VEC_buf))VEC_view.buf; \ + VEC_len = VEC_view.len/sizeof(*VEC_buf); \ + PyBuffer_Release(&VEC_view); \ + } while (0) +#else + #define VECTOR(VEC_obj, VEC_buf, VEC_len) do { \ + int VEC_err = PyObject_AsWriteBuffer(VEC_obj, (void **)(&VEC_buf), &VEC_len); \ + if (VEC_err < 0) return NULL; \ + VEC_len /= sizeof(*VEC_buf); \ + } while (0) +#endif + +/** + * Delete a GenI object + */ +void +del_sld2i(PyObject *obj){ +#if PY_MAJOR_VERSION < 3 + GenI* sld2i = (GenI *)obj; +#else + GenI* sld2i = (GenI *)(PyCapsule_GetPointer(obj, "GenI")); +#endif + PyMem_Free((void *)sld2i); +} + +/** + * Create a GenI as a python object by supplying arrays + */ +PyObject * new_GenI(PyObject *self, PyObject *args) { + PyObject *x_val_obj; + PyObject *y_val_obj; + PyObject *z_val_obj; + PyObject *sldn_val_obj; + PyObject *mx_val_obj; + PyObject *my_val_obj; + PyObject *mz_val_obj; + PyObject *vol_pix_obj; + Py_ssize_t n_x, n_y, n_z, n_sld, n_mx, n_my, n_mz, n_vol_pix; + int is_avg; + double* x_val; + double* y_val; + double* z_val; + double* sldn_val; + double* mx_val; + double* my_val; + double* mz_val; + double* vol_pix; + double inspin; + double outspin; + double stheta; + PyObject *obj; + GenI* sld2i; + + //printf("new GenI\n"); + if (!PyArg_ParseTuple(args, "iOOOOOOOOddd", &is_avg, &x_val_obj, &y_val_obj, &z_val_obj, &sldn_val_obj, &mx_val_obj, &my_val_obj, &mz_val_obj, &vol_pix_obj, &inspin, &outspin, &stheta)) return NULL; + VECTOR(x_val_obj, x_val, n_x); + VECTOR(y_val_obj, y_val, n_y); + VECTOR(z_val_obj, z_val, n_z); + VECTOR(sldn_val_obj, sldn_val, n_sld); + VECTOR(mx_val_obj, mx_val, n_mx); + VECTOR(my_val_obj, my_val, n_my); + VECTOR(mz_val_obj, mz_val, n_mz); + VECTOR(vol_pix_obj, vol_pix, n_vol_pix); + sld2i = PyMem_Malloc(sizeof(GenI)); + //printf("sldi:%p\n", sld2i); + if (sld2i != NULL) { + initGenI(sld2i,is_avg,(int)n_x,x_val,y_val,z_val,sldn_val,mx_val,my_val,mz_val,vol_pix,inspin,outspin,stheta); + } + obj = PyCapsule_New(sld2i, "GenI", del_sld2i); + //printf("constructed %p\n", obj); + return obj; +} + +/** + * GenI the given input (2D) according to a given object + */ +PyObject * genicom_inputXY(PyObject *self, PyObject *args) { + PyObject *gen_obj; + PyObject *qx_obj; + PyObject *qy_obj; + PyObject *I_out_obj; + Py_ssize_t n_qx, n_qy, n_out; + double *qx; + double *qy; + double *I_out; + GenI* sld2i; + + //printf("in genicom_inputXY\n"); + if (!PyArg_ParseTuple(args, "OOOO", &gen_obj, &qx_obj, &qy_obj, &I_out_obj)) return NULL; + sld2i = (GenI *)PyCapsule_GetPointer(gen_obj, "GenI"); + VECTOR(qx_obj, qx, n_qx); + VECTOR(qy_obj, qy, n_qy); + VECTOR(I_out_obj, I_out, n_out); + //printf("qx, qy, I_out: %d %d %d, %d %d %d\n", qx, qy, I_out, n_qx, n_qy, n_out); + + // Sanity check + //if(n_q!=n_out) return Py_BuildValue("i",-1); + + genicomXY(sld2i, (int)n_qx, qx, qy, I_out); + //printf("done calc\n"); + //return PyCObject_FromVoidPtr(s, del_genicom); + return Py_BuildValue("i",1); +} + +/** + * GenI the given 1D input according to a given object + */ +PyObject * genicom_input(PyObject *self, PyObject *args) { + PyObject *gen_obj; + PyObject *q_obj; + PyObject *I_out_obj; + Py_ssize_t n_q, n_out; + double *q; + double *I_out; + GenI *sld2i; + + if (!PyArg_ParseTuple(args, "OOO", &gen_obj, &q_obj, &I_out_obj)) return NULL; + sld2i = (GenI *)PyCapsule_GetPointer(gen_obj, "GenI"); + VECTOR(q_obj, q, n_q); + VECTOR(I_out_obj, I_out, n_out); + + // Sanity check + //if (n_q!=n_out) return Py_BuildValue("i",-1); + + genicom(sld2i, (int)n_q, q, I_out); + return Py_BuildValue("i",1); +} + +/** + * Define module methods + */ +static PyMethodDef module_methods[] = { + {"new_GenI", (PyCFunction)new_GenI, METH_VARARGS, + "Create a new GenI object"}, + {"genicom",(PyCFunction)genicom_input, METH_VARARGS, + "genicom the given 1d input arrays"}, + {"genicomXY",(PyCFunction)genicom_inputXY, METH_VARARGS, + "genicomXY the given 2d input arrays"}, + {NULL} +}; + +#define MODULE_DOC "Sld2i C Library" +#define MODULE_NAME "_sld2i" +#define MODULE_INIT2 init_sld2i +#define MODULE_INIT3 PyInit__sld2i +#define MODULE_METHODS module_methods + +/* ==== boilerplate python 2/3 interface bootstrap ==== */ + + +#if defined(WIN32) && !defined(__MINGW32__) + #define DLL_EXPORT __declspec(dllexport) +#else + #define DLL_EXPORT +#endif + +#if PY_MAJOR_VERSION >= 3 + + DLL_EXPORT PyMODINIT_FUNC MODULE_INIT3(void) + { + static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + MODULE_NAME, /* m_name */ + MODULE_DOC, /* m_doc */ + -1, /* m_size */ + MODULE_METHODS, /* m_methods */ + NULL, /* m_reload */ + NULL, /* m_traverse */ + NULL, /* m_clear */ + NULL, /* m_free */ + }; + return PyModule_Create(&moduledef); + } + +#else /* !PY_MAJOR_VERSION >= 3 */ + + DLL_EXPORT PyMODINIT_FUNC MODULE_INIT2(void) + { + Py_InitModule4(MODULE_NAME, + MODULE_METHODS, + MODULE_DOC, + 0, + PYTHON_API_VERSION + ); + } + +#endif /* !PY_MAJOR_VERSION >= 3 */ diff --git a/sas/sascalc/calculator/instrument.py b/sas/sascalc/calculator/instrument.py new file mode 100755 index 000000000..d8b1d57f4 --- /dev/null +++ b/sas/sascalc/calculator/instrument.py @@ -0,0 +1,388 @@ +""" +This module is a small tool to allow user to +control instrumental parameters +""" +import numpy as np + +# defaults in cgs unit +_SAMPLE_A_SIZE = [1.27] +_SOURCE_A_SIZE = [3.81] +_SAMPLE_DISTANCE = [1627, 0] +_SAMPLE_OFFSET = [0, 0] +_SAMPLE_SIZE = [2.54] +_SAMPLE_THICKNESS = 0.2 +_D_DISTANCE = [1000, 0] +_D_SIZE = [128, 128] +_D_PIX_SIZE = [0.5, 0.5] + +_MIN = 0.0 +_MAX = 50.0 +_INTENSITY = 368428 +_WAVE_LENGTH = 6.0 +_WAVE_SPREAD = 0.125 +_MASS = 1.67492729E-24 # [gr] +_LAMBDA_ARRAY = [[0, 1e+16], [_INTENSITY, _INTENSITY]] + + +class Aperture(object): + """ + An object class that defines the aperture variables + """ + def __init__(self): + + # assumes that all aligned at the centers + # aperture_size [diameter] for pinhole, [dx, dy] for rectangular + self.sample_size = _SAMPLE_A_SIZE + self.source_size = _SOURCE_A_SIZE + self.sample_distance = _SAMPLE_DISTANCE + + def set_source_size(self, size=[]): + """ + Set the source aperture size + """ + if len(size) == 0: + self.source_size = 0.0 + else: + self.source_size = size + validate(size[0]) + + def set_sample_size(self, size=[]): + """ + Set the sample aperture size + """ + if len(size) == 0: + self.sample_size = 0.0 + else: + self.sample_size = size + validate(size[0]) + + def set_sample_distance(self, distance=[]): + """ + Set the sample aperture distance + """ + if len(distance) == 0: + self.sample_distance = 0.0 + else: + self.sample_distance = distance + validate(distance[0]) + + +class Sample(object): + """ + An object class that defines the sample variables + """ + def __init__(self): + + # assumes that all aligned at the centers + # source2sample or sample2detector distance + self.distance = _SAMPLE_OFFSET + self.size = _SAMPLE_SIZE + self.thickness = _SAMPLE_THICKNESS + + def set_size(self, size=[]): + """ + Set the sample size + """ + if len(size) == 0: + self.size = 0.0 + else: + self.size = size + validate(size[0]) + + def set_thickness(self, thickness=0.0): + """ + Set the sample thickness + """ + self.thickness = thickness + validate(thickness) + + def set_distance(self, distance=[]): + """ + Set the sample distance + """ + if len(distance) == 0: + self.distance = 0.0 + else: + self.distance = distance + if distance[0] != 0.0: + validate(distance[0]) + + +class Detector(object): + """ + An object class that defines the detector variables + """ + def __init__(self): + + # assumes that all aligned at the centers + # source2sample or sample2detector distance + self.distance = _D_DISTANCE + self.size = _D_SIZE + self.pix_size = _D_PIX_SIZE + + def set_size(self, size=[]): + """ + Set the detector size + """ + if len(size) == 0: + self.size = 0 + else: + self.size = size + validate(size[0]) + + def set_pix_size(self, size=[]): + """ + Set the detector pix_size + """ + if len(size) == 0: + self.pix_size = 0 + else: + self.pix_size = size + validate(size[0]) + + def set_distance(self, distance=[]): + """ + Set the detector distance + """ + if len(distance) == 0: + self.distance = 0 + else: + self.distance = distance + validate(distance[0]) + + +class Neutron(object): + """ + An object that defines the wavelength variables + """ + def __init__(self): + + # neutron mass in cgs unit + self.mass = _MASS + + # wavelength + self.wavelength = _WAVE_LENGTH + # wavelength spread (FWHM) + self.wavelength_spread = _WAVE_SPREAD + # wavelength spectrum + self.spectrum = self.get_default_spectrum() + # intensity in counts/sec + self.intensity = np.interp(self.wavelength, + self.spectrum[0], + self.spectrum[1], + 0.0, + 0.0) + # min max range of the spectrum + self.min = min(self.spectrum[0]) + self.max = max(self.spectrum[0]) + # wavelength band + self.band = [self.min, self.max] + + # default unit of the thickness + self.wavelength_unit = 'A' + + def set_full_band(self): + """ + set band to default value + """ + self.band = self.spectrum + + def set_spectrum(self, spectrum): + """ + Set spectrum + + :param spectrum: numpy array + """ + self.spectrum = spectrum + self.setup_spectrum() + + def setup_spectrum(self): + """ + To set the wavelength spectrum, and intensity, assumes + wavelength is already within the spectrum + """ + spectrum = self.spectrum + intensity = np.interp(self.wavelength, + spectrum[0], + spectrum[1], + 0.0, + 0.0) + self.set_intensity(intensity) + # min max range of the spectrum + self.min = min(self.spectrum[0]) + self.max = max(self.spectrum[0]) + # set default band + self.set_band([self.min, self.max]) + + def set_band(self, band=[]): + """ + To set the wavelength band + + :param band: array of [min, max] + """ + # check if the wavelength is in range + if min(band) < self.min or max(band) > self.max: + raise ValueError("band out of range") + self.band = band + + def set_intensity(self, intensity=368428): + """ + Sets the intensity in counts/sec + """ + self.intensity = intensity + validate(intensity) + + def set_wavelength(self, wavelength=_WAVE_LENGTH): + """ + Sets the wavelength + """ + # check if the wavelength is in range + if wavelength < min(self.band) or wavelength > max(self.band): + raise ValueError("wavelength out of range") + self.wavelength = wavelength + validate(wavelength) + self.intensity = np.interp(self.wavelength, + self.spectrum[0], + self.spectrum[1], + 0.0, + 0.0) + + def set_mass(self, mass=_MASS): + """ + Sets the wavelength + """ + self.mass = mass + validate(mass) + + def set_wavelength_spread(self, spread=_WAVE_SPREAD): + """ + Sets the wavelength spread + """ + self.wavelength_spread = spread + if spread != 0.0: + validate(spread) + + def get_intensity(self): + """ + To get the value of intensity + """ + return self.intensity + + def get_wavelength(self): + """ + To get the value of wavelength + """ + return self.wavelength + + def get_mass(self): + """ + To get the neutron mass + """ + return self.mass + + def get_wavelength_spread(self): + """ + To get the value of wavelength spread + """ + return self.wavelength_spread + + def get_ramdom_value(self): + """ + To get the value of wave length + """ + return self.wavelength + + def get_spectrum(self): + """ + To get the wavelength spectrum + """ + return self.spectrum + + def get_default_spectrum(self): + """ + get default spectrum + """ + return np.array(_LAMBDA_ARRAY) + + def get_band(self): + """ + To get the wavelength band + """ + return self.band + + def plot_spectrum(self): + """ + To plot the wavelength spactrum + : requirement: matplotlib.pyplot + """ + try: + import matplotlib.pyplot as plt + plt.plot(self.spectrum[0], self.spectrum[1], linewidth=2, color='r') + plt.legend(['Spectrum'], loc='best') + plt.show() + except: + raise RuntimeError("Can't import matplotlib required to plot...") + + +class TOF(Neutron): + """ + TOF: make list of wavelength and wave length spreads + """ + def __init__(self): + """ + Init + """ + Neutron.__init__(self) + #self.switch = switch + self.wavelength_list = [self.wavelength] + self.wavelength_spread_list = [self.wavelength_spread] + self.intensity_list = self.get_intensity_list() + + def get_intensity_list(self): + """ + get list of the intensity wrt wavelength_list + """ + out = np.interp(self.wavelength_list, + self.spectrum[0], + self.spectrum[1], + 0.0, + 0.0) + return out + + def get_wave_list(self): + """ + Get wavelength and wavelength_spread list + """ + return self.wavelength_list, self.wavelength_spread_list + + def set_wave_list(self, wavelength=[]): + """ + Set wavelength list + + :param wavelength: list of wavelengths + """ + self.wavelength_list = wavelength + + def set_wave_spread_list(self, wavelength_spread=[]): + """ + Set wavelength_spread list + + :param wavelength_spread: list of wavelength spreads + """ + self.wavelength_spread_list = wavelength_spread + + +def validate(value=None): + """ + Check if the value is folat > 0.0 + + :return value: True / False + """ + try: + val = float(value) + if val >= 0: + val = True + else: + val = False + except: + val = False diff --git a/sas/sascalc/calculator/kiessig_calculator.py b/sas/sascalc/calculator/kiessig_calculator.py new file mode 100755 index 000000000..026881b32 --- /dev/null +++ b/sas/sascalc/calculator/kiessig_calculator.py @@ -0,0 +1,64 @@ +""" +This module is a small tool to allow user to quickly +determine the size value in real space from the +fringe width in q space. +""" +from math import pi, fabs +_DQ_DEFAULT = 0.05 + + +class KiessigThicknessCalculator(object): + """ + compute thickness from the fringe width of data + """ + def __init__(self): + + # dq value + self.deltaq = _DQ_DEFAULT + # thickenss value + self.thickness = None + # unit of the thickness + self.thickness_unit = 'A' + + def set_deltaq(self, dq=None): + """ + Receive deltaQ value + + :param dq: q fringe width in 1/A unit + """ + # set dq + self.deltaq = dq + + def get_deltaq(self): + """ + return deltaQ value in 1/A unit + """ + # return dq + return self.deltaq + + def compute_thickness(self): + """ + Calculate thickness. + + :return: the thickness. + """ + # check if it is float + try: + dq = float(self.deltaq) + except: + return None + # check if delta_q is zero + if dq == 0.0 or dq is None: + return None + else: + # calculate thickness + thickness = 2*pi/fabs(dq) + # return thickness value + return thickness + + def get_thickness_unit(self): + """ + :return: the thickness unit. + """ + # unit of thickness + return self.thickness_unit diff --git a/sas/sascalc/calculator/resolution_calculator.py b/sas/sascalc/calculator/resolution_calculator.py new file mode 100755 index 000000000..5acf6cd2c --- /dev/null +++ b/sas/sascalc/calculator/resolution_calculator.py @@ -0,0 +1,1169 @@ +""" +This object is a small tool to allow user to quickly +determine the variance in q from the +instrumental parameters. +""" +import sys +from math import pi, sqrt +import math +import logging + +import numpy as np + +from .instrument import Sample +from .instrument import Detector +from .instrument import TOF as Neutron +from .instrument import Aperture + +logger = logging.getLogger(__name__) + +#Plank's constant in cgs unit +_PLANK_H = 6.62606896E-27 +#Gravitational acc. in cgs unit +_GRAVITY = 981.0 + + +class ResolutionCalculator(object): + """ + compute resolution in 2D + """ + def __init__(self): + + # wavelength + self.wave = Neutron() + # sample + self.sample = Sample() + # aperture + self.aperture = Aperture() + # detector + self.detector = Detector() + # 2d image of the resolution + self.image = [] + self.image_lam = [] + # resolutions + # lamda in r-direction + self.sigma_lamd = 0 + # x-dir (no lamda) + self.sigma_1 = 0 + #y-dir (no lamda) + self.sigma_2 = 0 + # 1D total + self.sigma_1d = 0 + self.gravity_phi = None + # q min and max + self.qx_min = -0.3 + self.qx_max = 0.3 + self.qy_min = -0.3 + self.qy_max = 0.3 + # q min and max of the detector + self.detector_qx_min = -0.3 + self.detector_qx_max = 0.3 + self.detector_qy_min = -0.3 + self.detector_qy_max = 0.3 + # possible max qrange + self.qxmin_limit = 0 + self.qxmax_limit = 0 + self.qymin_limit = 0 + self.qymax_limit = 0 + + # plots + self.plot = None + # instrumental params defaults + self.mass = 0 + self.intensity = 0 + self.wavelength = 0 + self.wavelength_spread = 0 + self.source_aperture_size = [] + self.source2sample_distance = [] + self.sample2sample_distance = [] + self.sample_aperture_size = [] + self.sample2detector_distance = [] + self.detector_pix_size = [] + self.detector_size = [] + self.get_all_instrument_params() + # max q range for all lambdas + self.qxrange = [] + self.qyrange = [] + + def compute_and_plot(self, qx_value, qy_value, qx_min, qx_max, + qy_min, qy_max, coord='cartesian'): + """ + Compute the resolution + : qx_value: x component of q + : qy_value: y component of q + """ + # make sure to update all the variables need. + # except lambda, dlambda, and intensity + self.get_all_instrument_params() + # wavelength etc. + lamda_list, dlamb_list = self.get_wave_list() + intens_list = [] + sig1_list = [] + sig2_list = [] + sigr_list = [] + sigma1d_list = [] + num_lamda = len(lamda_list) + for num in range(num_lamda): + lam = lamda_list[num] + # wavelength spread + dlam = dlamb_list[num] + intens = self.setup_tof(lam, dlam) + intens_list.append(intens) + # cehck if tof + if num_lamda > 1: + tof = True + else: + tof = False + # compute 2d resolution + _, _, sigma_1, sigma_2, sigma_r, sigma1d = \ + self.compute(lam, dlam, qx_value, qy_value, coord, tof) + # make image + image = self.get_image(qx_value, qy_value, sigma_1, sigma_2, + sigma_r, qx_min, qx_max, qy_min, qy_max, + coord, False) + if qx_min > self.qx_min: + qx_min = self.qx_min + if qx_max < self.qx_max: + qx_max = self.qx_max + if qy_min > self.qy_min: + qy_min = self.qy_min + if qy_max < self.qy_max: + qy_max = self.qy_max + + # set max qranges + self.qxrange = [qx_min, qx_max] + self.qyrange = [qy_min, qy_max] + sig1_list.append(sigma_1) + sig2_list.append(sigma_2) + sigr_list.append(sigma_r) + sigma1d_list.append(sigma1d) + # redraw image in global 2d q-space. + self.image_lam = [] + total_intensity = 0 + sigma_1 = 0 + sigma_r = 0 + sigma_2 = 0 + sigma1d = 0 + for ind in range(num_lamda): + lam = lamda_list[ind] + dlam = dlamb_list[ind] + intens = self.setup_tof(lam, dlam) + out = self.get_image(qx_value, qy_value, sig1_list[ind], + sig2_list[ind], sigr_list[ind], + qx_min, qx_max, qy_min, qy_max, coord) + # this is the case of q being outside the detector + #if numpy.all(out==0.0): + # continue + image = out + # set variance as sigmas + sigma_1 += sig1_list[ind] * sig1_list[ind] * self.intensity + sigma_r += sigr_list[ind] * sigr_list[ind] * self.intensity + sigma_2 += sig2_list[ind] * sig2_list[ind] * self.intensity + sigma1d += sigma1d_list[ind] * sigma1d_list[ind] * self.intensity + total_intensity += self.intensity + + if total_intensity != 0: + # average variance + image_out = image / total_intensity + sigma_1 = sigma_1 / total_intensity + sigma_r = sigma_r / total_intensity + sigma_2 = sigma_2 / total_intensity + sigma1d = sigma1d / total_intensity + # set sigmas + self.sigma_1 = sqrt(sigma_1) + self.sigma_lamd = sqrt(sigma_r) + self.sigma_2 = sqrt(sigma_2) + self.sigma_1d = sqrt(sigma1d) + # rescale + max_im_val = 1 + if max_im_val > 0: + image_out /= max_im_val + else: + image_out = image * 0.0 + # Don't calculate sigmas nor set self.sigmas! + sigma_1 = 0 + sigma_r = 0 + sigma_2 = 0 + sigma1d = 0 + if len(self.image) > 0: + self.image += image_out + else: + self.image = image_out + + # plot image + return self.plot_image(self.image) + + def setup_tof(self, wavelength, wavelength_spread): + """ + Setup all parameters in instrument + + : param ind: index of lambda, etc + """ + + # set wave.wavelength + self.set_wavelength(wavelength) + self.set_wavelength_spread(wavelength_spread) + self.intensity = self.wave.get_intensity() + + if wavelength == 0: + msg = "Can't compute the resolution: the wavelength is zero..." + raise RuntimeError(msg) + return self.intensity + + def compute(self, wavelength, wavelength_spread, qx_value, qy_value, + coord='cartesian', tof=False): + """ + Compute the Q resoltuion in || and + direction of 2D + : qx_value: x component of q + : qy_value: y component of q + """ + coord = 'cartesian' + lamb = wavelength + lamb_spread = wavelength_spread + # the shape of wavelength distribution + + if tof: + # rectangular + tof_factor = 2 + else: + # triangular + tof_factor = 1 + # Find polar values + qr_value, phi = self._get_polar_value(qx_value, qy_value) + # vacuum wave transfer + knot = 2*pi/lamb + # scattering angle theta; always true for plane detector + # aligned vertically to the ko direction + if qr_value > knot: + theta = pi/2 + else: + theta = math.asin(qr_value/knot) + # source aperture size + rone = self.source_aperture_size + # sample aperture size + rtwo = self.sample_aperture_size + # detector pixel size + rthree = self.detector_pix_size + # source to sample(aperture) distance + l_ssa = self.source2sample_distance[0] + # sample(aperture) to detector distance + l_sad = self.sample2detector_distance[0] + # sample (aperture) to sample distance + l_sas = self.sample2sample_distance[0] + # source to sample distance + l_one = l_ssa + l_sas + # sample to detector distance + l_two = l_sad - l_sas + + # Sample offset correction for l_one and Lp on variance calculation + l1_cor = (l_ssa * l_two) / (l_sas + l_two) + lp_cor = (l_ssa * l_two) / (l_one + l_two) + # the radial distance to the pixel from the center of the detector + radius = math.tan(theta) * l_two + #Lp = l_one*l_two/(l_one+l_two) + # default polar coordinate + comp1 = 'radial' + comp2 = 'phi' + # in the case of the cartesian coordinate + if coord == 'cartesian': + comp1 = 'x' + comp2 = 'y' + + # sigma in the radial/x direction + # for source aperture + sigma_1 = self.get_variance(rone, l1_cor, phi, comp1) + # for sample apperture + sigma_1 += self.get_variance(rtwo, lp_cor, phi, comp1) + # for detector pix + sigma_1 += self.get_variance(rthree, l_two, phi, comp1) + # for gravity term for 1d + sigma_1grav1d = self.get_variance_gravity(l_ssa, l_sad, lamb, + lamb_spread, phi, comp1, 'on') / tof_factor + # for wavelength spread + # reserve for 1d calculation + A_value = self._cal_A_value(lamb, l_ssa, l_sad) + sigma_wave_1, sigma_wave_1_1d = self.get_variance_wave(A_value, + radius, l_two, lamb_spread, + phi, 'radial', 'on') + sigma_wave_1 /= tof_factor + sigma_wave_1_1d /= tof_factor + # for 1d + variance_1d_1 = (sigma_1 + sigma_1grav1d) / 2 + sigma_wave_1_1d + # normalize + variance_1d_1 = knot * knot * variance_1d_1 / 12 + + # for 2d + #sigma_1 += sigma_wave_1 + # normalize + sigma_1 = knot * sqrt(sigma_1 / 12) + sigma_r = knot * sqrt(sigma_wave_1 / (tof_factor *12)) + # sigma in the phi/y direction + # for source apperture + sigma_2 = self.get_variance(rone, l1_cor, phi, comp2) + + # for sample apperture + sigma_2 += self.get_variance(rtwo, lp_cor, phi, comp2) + + # for detector pix + sigma_2 += self.get_variance(rthree, l_two, phi, comp2) + + # for gravity term for 1d + sigma_2grav1d = self.get_variance_gravity(l_ssa, l_sad, lamb, + lamb_spread, phi, comp2, 'on') / tof_factor + + # for wavelength spread + # reserve for 1d calculation + sigma_wave_2, sigma_wave_2_1d = self.get_variance_wave(A_value, + radius, l_two, lamb_spread, + phi, 'phi', 'on') + sigma_wave_2 /= tof_factor + sigma_wave_2_1d /= tof_factor + # for 1d + variance_1d_2 = (sigma_2 + sigma_2grav1d) / 2 + sigma_wave_2_1d + # normalize + variance_1d_2 = knot * knot * variance_1d_2 / 12 + + # for 2d + #sigma_2 = knot*sqrt(sigma_2/12) + #sigma_2 += sigma_wave_2 + # normalize + sigma_2 = knot * sqrt(sigma_2 / 12) + sigma1d = sqrt(variance_1d_1 + variance_1d_2) + # set sigmas + self.sigma_1 = sigma_1 + self.sigma_lamd = sigma_r + self.sigma_2 = sigma_2 + self.sigma_1d = sigma1d + return qr_value, phi, sigma_1, sigma_2, sigma_r, sigma1d + + def _within_detector_range(self, qx_value, qy_value): + """ + check if qvalues are within detector range + """ + # detector range + detector_qx_min = self.detector_qx_min + detector_qx_max = self.detector_qx_max + detector_qy_min = self.detector_qy_min + detector_qy_max = self.detector_qy_max + if self.qxmin_limit > detector_qx_min: + self.qxmin_limit = detector_qx_min + if self.qxmax_limit < detector_qx_max: + self.qxmax_limit = detector_qx_max + if self.qymin_limit > detector_qy_min: + self.qymin_limit = detector_qy_min + if self.qymax_limit < detector_qy_max: + self.qymax_limit = detector_qy_max + if qx_value < detector_qx_min or qx_value > detector_qx_max: + return False + if qy_value < detector_qy_min or qy_value > detector_qy_max: + return False + return True + + def get_image(self, qx_value, qy_value, sigma_1, sigma_2, sigma_r, + qx_min, qx_max, qy_min, qy_max, + coord='cartesian', full_cal=True): + """ + Get the resolution in polar coordinate ready to plot + : qx_value: qx_value value + : qy_value: qy_value value + : sigma_1: variance in r direction + : sigma_2: variance in phi direction + : coord: coordinate system of image, 'polar' or 'cartesian' + """ + # Get qx_max and qy_max... + self._get_detector_qxqy_pixels() + + qr_value, phi = self._get_polar_value(qx_value, qy_value) + + # Check whether the q value is within the detector range + if qx_min < self.qx_min: + self.qx_min = qx_min + #raise ValueError(msg) + if qx_max > self.qx_max: + self.qx_max = qx_max + #raise ValueError(msg) + if qy_min < self.qy_min: + self.qy_min = qy_min + #raise ValueError(msg) + if qy_max > self.qy_max: + self.qy_max = qy_max + #raise ValueError(msg) + if not full_cal: + return None + + # Make an empty graph in the detector scale + dx_size = (self.qx_max - self.qx_min) / (1000 - 1) + dy_size = (self.qy_max - self.qy_min) / (1000 - 1) + x_val = np.arange(self.qx_min, self.qx_max, dx_size) + y_val = np.arange(self.qy_max, self.qy_min, -dy_size) + q_1, q_2 = np.meshgrid(x_val, y_val) + #q_phi = numpy.arctan(q_1,q_2) + # check whether polar or cartesian + if coord == 'polar': + # Find polar values + qr_value, phi = self._get_polar_value(qx_value, qy_value) + q_1, q_2 = self._rotate_z(q_1, q_2, phi) + qc_1 = qr_value + qc_2 = 0.0 + # Calculate the 2D Gaussian distribution image + image = self._gaussian2d_polar(q_1, q_2, qc_1, qc_2, + sigma_1, sigma_2, sigma_r) + else: + # catesian coordinate + # qx_center + qc_1 = qx_value + # qy_center + qc_2 = qy_value + + # Calculate the 2D Gaussian distribution image + image = self._gaussian2d(q_1, q_2, qc_1, qc_2, + sigma_1, sigma_2, sigma_r) + # out side of detector + if not self._within_detector_range(qx_value, qy_value): + image *= 0.0 + self.intensity = 0.0 + #return self.image + + # Add it if there are more than one inputs. + if len(self.image_lam) > 0: + self.image_lam += image * self.intensity + else: + self.image_lam = image * self.intensity + + return self.image_lam + + def plot_image(self, image): + """ + Plot image using pyplot + : image: 2d resolution image + + : return plt: pylab object + """ + import matplotlib.pyplot as plt + + self.plot = plt + plt.xlabel('$\\rm{Q}_{x} [A^{-1}]$') + plt.ylabel('$\\rm{Q}_{y} [A^{-1}]$') + # Max value of the image + # max = numpy.max(image) + qx_min, qx_max, qy_min, qy_max = self.get_detector_qrange() + + # Image + im = plt.imshow(image, + extent=[qx_min, qx_max, qy_min, qy_max]) + + # bilinear interpolation to make it smoother + im.set_interpolation('bilinear') + + return plt + + def reset_image(self): + """ + Reset image to default (=[]) + """ + self.image = [] + + def get_variance(self, size=[], distance=0, phi=0, comp='radial'): + """ + Get the variance when the slit/pinhole size is given + : size: list that can be one(diameter for circular) or two components(lengths for rectangular) + : distance: [z, x] where z along the incident beam, x // qx_value + : comp: direction of the sigma; can be 'phi', 'y', 'x', and 'radial' + + : return variance: sigma^2 + """ + # check the length of size (list) + len_size = len(size) + + # define sigma component direction + if comp == 'radial': + phi_x = math.cos(phi) + phi_y = math.sin(phi) + elif comp == 'phi': + phi_x = math.sin(phi) + phi_y = math.cos(phi) + elif comp == 'x': + phi_x = 1 + phi_y = 0 + elif comp == 'y': + phi_x = 0 + phi_y = 1 + else: + phi_x = 0 + phi_y = 0 + # calculate each component + # for pinhole w/ radius = size[0]/2 + if len_size == 1: + x_comp = (0.5 * size[0]) * sqrt(3) + y_comp = 0 + # for rectangular slit + elif len_size == 2: + x_comp = size[0] * phi_x + y_comp = size[1] * phi_y + # otherwise + else: + raise ValueError(" Improper input...") + # get them squared + sigma = x_comp * x_comp + sigma += y_comp * y_comp + # normalize by distance + sigma /= (distance * distance) + + return sigma + + def get_variance_wave(self, A_value, radius, distance, spread, phi, + comp='radial', switch='on'): + """ + Get the variance when the wavelength spread is given + + : radius: the radial distance from the beam center to the pix of q + : distance: sample to detector distance + : spread: wavelength spread (ratio) + : comp: direction of the sigma; can be 'phi', 'y', 'x', and 'radial' + + : return variance: sigma^2 for 2d, sigma^2 for 1d [tuple] + """ + if switch.lower() == 'off': + return 0, 0 + # check the singular point + if distance == 0 or comp == 'phi': + return 0, 0 + else: + # calculate sigma^2 for 1d + sigma1d = 2 * math.pow(radius/distance*spread, 2) + if comp == 'x': + sigma1d *= (math.cos(phi)*math.cos(phi)) + elif comp == 'y': + sigma1d *= (math.sin(phi)*math.sin(phi)) + else: + sigma1d *= 1 + # sigma^2 for 2d + # shift the coordinate due to the gravitational shift + rad_x = radius * math.cos(phi) + rad_y = A_value - radius * math.sin(phi) + radius = math.sqrt(rad_x * rad_x + rad_y * rad_y) + # new phi + phi = math.atan2(-rad_y, rad_x) + self.gravity_phi = phi + # calculate sigma^2 + sigma = 2 * math.pow(radius/distance*spread, 2) + if comp == 'x': + sigma *= (math.cos(phi)*math.cos(phi)) + elif comp == 'y': + sigma *= (math.sin(phi)*math.sin(phi)) + else: + sigma *= 1 + + return sigma, sigma1d + + def get_variance_gravity(self, s_distance, d_distance, wavelength, spread, + phi, comp='radial', switch='on'): + """ + Get the variance from gravity when the wavelength spread is given + + : s_distance: source to sample distance + : d_distance: sample to detector distance + : wavelength: wavelength + : spread: wavelength spread (ratio) + : comp: direction of the sigma; can be 'phi', 'y', 'x', and 'radial' + + : return variance: sigma^2 + """ + if switch.lower() == 'off': + return 0 + if self.mass == 0.0: + return 0 + # check the singular point + if d_distance == 0 or comp == 'x': + return 0 + else: + a_value = self._cal_A_value(None, s_distance, d_distance) + # calculate sigma^2 + sigma = math.pow(a_value / d_distance, 2) + sigma *= math.pow(wavelength, 4) + sigma *= math.pow(spread, 2) + sigma *= 8 + return sigma + + def _cal_A_value(self, lamda, s_distance, d_distance): + """ + Calculate A value for gravity + + : s_distance: source to sample distance + : d_distance: sample to detector distance + """ + # neutron mass in cgs unit + self.mass = self.get_neutron_mass() + # plank constant in cgs unit + h_constant = _PLANK_H + # gravity in cgs unit + gravy = _GRAVITY + # m/h + m_over_h = self.mass / h_constant + # A value + a_value = d_distance * (s_distance + d_distance) + a_value *= math.pow(m_over_h / 2, 2) + a_value *= gravy + # unit correction (1/cm to 1/A) for A and d_distance below + a_value *= 1.0E-16 + # if lamda is give (broad meanning of A) return 2* lamda^2 * A + if lamda is not None: + a_value *= (4 * lamda * lamda) + return a_value + + def get_intensity(self): + """ + Get intensity + """ + return self.wave.intensity + + def get_wavelength(self): + """ + Get wavelength + """ + return self.wave.wavelength + + def get_default_spectrum(self): + """ + Get default_spectrum + """ + return self.wave.get_default_spectrum() + + def get_spectrum(self): + """ + Get _spectrum + """ + return self.wave.get_spectrum() + + def get_wavelength_spread(self): + """ + Get wavelength spread + """ + return self.wave.wavelength_spread + + def get_neutron_mass(self): + """ + Get Neutron mass + """ + return self.wave.mass + + def get_source_aperture_size(self): + """ + Get source aperture size + """ + return self.aperture.source_size + + def get_sample_aperture_size(self): + """ + Get sample aperture size + """ + return self.aperture.sample_size + + def get_detector_pix_size(self): + """ + Get detector pixel size + """ + return self.detector.pix_size + + def get_detector_size(self): + """ + Get detector size + """ + return self.detector.size + + def get_source2sample_distance(self): + """ + Get detector source2sample_distance + """ + return self.aperture.sample_distance + + def get_sample2sample_distance(self): + """ + Get detector sampleslitsample_distance + """ + return self.sample.distance + + def get_sample2detector_distance(self): + """ + Get detector sample2detector_distance + """ + return self.detector.distance + + def set_intensity(self, intensity): + """ + Set intensity + """ + self.wave.set_intensity(intensity) + + def set_wave(self, wavelength): + """ + Set wavelength list or wavelength + """ + if wavelength.__class__.__name__ == 'list': + self.wave.set_wave_list(wavelength) + elif wavelength.__class__.__name__ == 'float': + self.wave.set_wave_list([wavelength]) + #self.set_wavelength(wavelength) + else: + raise TypeError("invalid wavlength---should be list or float") + + def set_wave_spread(self, wavelength_spread): + """ + Set wavelength spread or wavelength spread + """ + if wavelength_spread.__class__.__name__ == 'list': + self.wave.set_wave_spread_list(wavelength_spread) + elif wavelength_spread.__class__.__name__ == 'float': + self.wave.set_wave_spread_list([wavelength_spread]) + else: + raise TypeError("invalid wavelength spread---should be list or float") + + def set_wavelength(self, wavelength): + """ + Set wavelength + """ + self.wavelength = wavelength + self.wave.set_wavelength(wavelength) + + def set_spectrum(self, spectrum): + """ + Set spectrum + """ + self.spectrum = spectrum + self.wave.set_spectrum(spectrum) + + def set_wavelength_spread(self, wavelength_spread): + """ + Set wavelength spread + """ + self.wavelength_spread = wavelength_spread + self.wave.set_wavelength_spread(wavelength_spread) + + def set_wave_list(self, wavelength_list, wavelengthspread_list): + """ + Set wavelength and its spread list + """ + self.wave.set_wave_list(wavelength_list) + self.wave.set_wave_spread_list(wavelengthspread_list) + + def get_wave_list(self): + """ + Set wavelength spread + """ + return self.wave.get_wave_list() + + def get_intensity_list(self): + """ + Set wavelength spread + """ + return self.wave.get_intensity_list() + + def set_source_aperture_size(self, size): + """ + Set source aperture size + + : param size: [dia_value] or [x_value, y_value] + """ + if len(size) < 1 or len(size) > 2: + raise RuntimeError("The length of the size must be one or two.") + self.aperture.set_source_size(size) + + def set_neutron_mass(self, mass): + """ + Set Neutron mass + """ + self.wave.set_mass(mass) + self.mass = mass + + def set_sample_aperture_size(self, size): + """ + Set sample aperture size + + : param size: [dia_value] or [xheight_value, yheight_value] + """ + if len(size) < 1 or len(size) > 2: + raise RuntimeError("The length of the size must be one or two.") + self.aperture.set_sample_size(size) + + def set_detector_pix_size(self, size): + """ + Set detector pixel size + """ + self.detector.set_pix_size(size) + + def set_detector_size(self, size): + """ + Set detector size in number of pixels + : param size: [pixel_nums] or [x_pix_num, yx_pix_num] + """ + self.detector.set_size(size) + + def set_source2sample_distance(self, distance): + """ + Set detector source2sample_distance + + : param distance: [distance, x_offset] + """ + if len(distance) < 1 or len(distance) > 2: + raise RuntimeError("The length of the size must be one or two.") + self.aperture.set_sample_distance(distance) + + def set_sample2sample_distance(self, distance): + """ + Set detector sample_slit2sample_distance + + : param distance: [distance, x_offset] + """ + if len(distance) < 1 or len(distance) > 2: + raise RuntimeError("The length of the size must be one or two.") + self.sample.set_distance(distance) + + def set_sample2detector_distance(self, distance): + """ + Set detector sample2detector_distance + + : param distance: [distance, x_offset] + """ + if len(distance) < 1 or len(distance) > 2: + raise RuntimeError("The length of the size must be one or two.") + self.detector.set_distance(distance) + + def get_all_instrument_params(self): + """ + Get all instrumental parameters + """ + self.mass = self.get_neutron_mass() + self.spectrum = self.get_spectrum() + self.source_aperture_size = self.get_source_aperture_size() + self.sample_aperture_size = self.get_sample_aperture_size() + self.detector_pix_size = self.get_detector_pix_size() + self.detector_size = self.get_detector_size() + self.source2sample_distance = self.get_source2sample_distance() + self.sample2sample_distance = self.get_sample2sample_distance() + self.sample2detector_distance = self.get_sample2detector_distance() + + def get_detector_qrange(self): + """ + get max detector q ranges + + : return: qx_min, qx_max, qy_min, qy_max tuple + """ + if len(self.qxrange) != 2 or len(self.qyrange) != 2: + return None + qx_min = self.qxrange[0] + qx_max = self.qxrange[1] + qy_min = self.qyrange[0] + qy_max = self.qyrange[1] + + return qx_min, qx_max, qy_min, qy_max + + def _rotate_z(self, x_value, y_value, theta=0.0): + """ + Rotate x-y cordinate around z-axis by theta + : x_value: numpy array of x values + : y_value: numpy array of y values + : theta: angle to rotate by in rad + + :return: x_prime, y-prime + """ + # rotate by theta + x_prime = x_value * math.cos(theta) + y_value * math.sin(theta) + y_prime = -x_value * math.sin(theta) + y_value * math.cos(theta) + + return x_prime, y_prime + + def _gaussian2d(self, x_val, y_val, x0_val, y0_val, + sigma_x, sigma_y, sigma_r): + """ + Calculate 2D Gaussian distribution + : x_val: x value + : y_val: y value + : x0_val: mean value in x-axis + : y0_val: mean value in y-axis + : sigma_x: variance in x-direction + : sigma_y: variance in y-direction + + : return: gaussian (value) + """ + # phi values at each points (not at the center) + x_value = x_val - x0_val + y_value = y_val - y0_val + phi_i = np.arctan2(y_val, x_val) + + # phi correction due to the gravity shift (in phi) + phi_0 = math.atan2(y0_val, x0_val) + phi_i = phi_i - phi_0 + self.gravity_phi + + sin_phi = np.sin(self.gravity_phi) + cos_phi = np.cos(self.gravity_phi) + + x_p = x_value * cos_phi + y_value * sin_phi + y_p = -x_value * sin_phi + y_value * cos_phi + + new_sig_x = sqrt(sigma_r * sigma_r / (sigma_x * sigma_x) + 1) + new_sig_y = sqrt(sigma_r * sigma_r / (sigma_y * sigma_y) + 1) + new_x = x_p * cos_phi / new_sig_x - y_p * sin_phi + new_x /= sigma_x + new_y = x_p * sin_phi / new_sig_y + y_p * cos_phi + new_y /= sigma_y + + nu_value = -0.5 * (new_x * new_x + new_y * new_y) + + gaussian = np.exp(nu_value) + # normalizing factor correction + gaussian /= gaussian.sum() + + return gaussian + + def _gaussian2d_polar(self, x_val, y_val, x0_val, y0_val, + sigma_x, sigma_y, sigma_r): + """ + Calculate 2D Gaussian distribution for polar coodinate + : x_val: x value + : y_val: y value + : x0_val: mean value in x-axis + : y0_val: mean value in y-axis + : sigma_x: variance in r-direction + : sigma_y: variance in phi-direction + : sigma_r: wavelength variance in r-direction + + : return: gaussian (value) + """ + sigma_x = sqrt(sigma_x * sigma_x + sigma_r * sigma_r) + # call gaussian1d + gaussian = self._gaussian1d(x_val, x0_val, sigma_x) + gaussian *= self._gaussian1d(y_val, y0_val, sigma_y) + + # normalizing factor correction + if sigma_x != 0 and sigma_y != 0: + gaussian *= sqrt(2 * pi) + return gaussian + + def _gaussian1d(self, value, mean, sigma): + """ + Calculate 1D Gaussian distribution + : value: value + : mean: mean value + : sigma: variance + + : return: gaussian (value) + """ + # default + gaussian = 1.0 + if sigma != 0: + # get exponent + nu_value = (value - mean) / sigma + nu_value *= nu_value + nu_value *= -0.5 + gaussian *= np.exp(nu_value) + gaussian /= sigma + # normalize + gaussian /= sqrt(2 * pi) + + return gaussian + + def _atan_phi(self, qy_value, qx_value): + """ + Find the angle phi of q on the detector plane for qx_value, qy_value given + : qx_value: x component of q + : qy_value: y component of q + + : return phi: the azimuthal angle of q on x-y plane + """ + phi = math.atan2(qy_value, qx_value) + return phi + + def _get_detector_qxqy_pixels(self): + """ + Get the pixel positions of the detector in the qx_value-qy_value space + """ + + # update all param values + self.get_all_instrument_params() + + # wavelength + wavelength = self.wave.wavelength + # Gavity correction + delta_y = self._get_beamcenter_drop() # in cm + + # detector_pix size + detector_pix_size = self.detector_pix_size + # Square or circular pixel + if len(detector_pix_size) == 1: + pix_x_size = detector_pix_size[0] + pix_y_size = detector_pix_size[0] + # rectangular pixel pixel + elif len(detector_pix_size) == 2: + pix_x_size = detector_pix_size[0] + pix_y_size = detector_pix_size[1] + else: + raise ValueError(" Input value format error...") + # Sample to detector distance = sample slit to detector + # minus sample offset + sample2detector_distance = self.sample2detector_distance[0] - \ + self.sample2sample_distance[0] + # detector offset in x-direction + detector_offset = 0 + try: + detector_offset = self.sample2detector_distance[1] + except: + logger.error(sys.exc_value) + + # detector size in [no of pix_x,no of pix_y] + detector_pix_nums_x = self.detector_size[0] + + # get pix_y if it exists, otherwse take it from [0] + try: + detector_pix_nums_y = self.detector_size[1] + except: + detector_pix_nums_y = self.detector_size[0] + + # detector offset in pix number + offset_x = detector_offset / pix_x_size + offset_y = delta_y / pix_y_size + + # beam center position in pix number (start from 0) + center_x, center_y = self._get_beamcenter_position(detector_pix_nums_x, + detector_pix_nums_y, + offset_x, offset_y) + # distance [cm] from the beam center on detector plane + detector_ind_x = np.arange(detector_pix_nums_x) + detector_ind_y = np.arange(detector_pix_nums_y) + + # shif 0.5 pixel so that pix position is at the center of the pixel + detector_ind_x = detector_ind_x + 0.5 + detector_ind_y = detector_ind_y + 0.5 + + # the relative postion from the beam center + detector_ind_x = detector_ind_x - center_x + detector_ind_y = detector_ind_y - center_y + + # unit correction in cm + detector_ind_x = detector_ind_x * pix_x_size + detector_ind_y = detector_ind_y * pix_y_size + + qx_value = np.zeros(len(detector_ind_x)) + qy_value = np.zeros(len(detector_ind_y)) + i = 0 + + for indx in detector_ind_x: + qx_value[i] = self._get_qx(indx, sample2detector_distance, wavelength) + i += 1 + i = 0 + for indy in detector_ind_y: + qy_value[i] = self._get_qx(indy, sample2detector_distance, wavelength) + i += 1 + + # qx_value and qy_value values in array + qx_value = qx_value.repeat(detector_pix_nums_y) + qx_value = qx_value.reshape(detector_pix_nums_x, detector_pix_nums_y) + qy_value = qy_value.repeat(detector_pix_nums_x) + qy_value = qy_value.reshape(detector_pix_nums_y, detector_pix_nums_x) + qy_value = qy_value.transpose() + + # p min and max values among the center of pixels + self.qx_min = np.min(qx_value) + self.qx_max = np.max(qx_value) + self.qy_min = np.min(qy_value) + self.qy_max = np.max(qy_value) + + # Appr. min and max values of the detector display limits + # i.e., edges of the last pixels. + self.qy_min += self._get_qx(-0.5 * pix_y_size, + sample2detector_distance, wavelength) + self.qy_max += self._get_qx(0.5 * pix_y_size, + sample2detector_distance, wavelength) + #if self.qx_min == self.qx_max: + self.qx_min += self._get_qx(-0.5 * pix_x_size, + sample2detector_distance, wavelength) + self.qx_max += self._get_qx(0.5 * pix_x_size, + sample2detector_distance, wavelength) + + # min and max values of detecter + self.detector_qx_min = self.qx_min + self.detector_qx_max = self.qx_max + self.detector_qy_min = self.qy_min + self.detector_qy_max = self.qy_max + + # try to set it as a Data2D otherwise pass (not required for now) + try: + from sas.sascalc.dataloader.data_info import Data2D + output = Data2D() + inten = np.zeros_like(qx_value) + output.data = inten + output.qx_data = qx_value + output.qy_data = qy_value + except: + logger.error(sys.exc_value) + + return output + + def _get_qx(self, dx_size, det_dist, wavelength): + """ + :param dx_size: x-distance from beam center [cm] + :param det_dist: sample to detector distance [cm] + + :return: q-value at the given position + """ + # Distance from beam center in the plane of detector + plane_dist = dx_size + # full scattering angle on the x-axis + theta = np.arctan(plane_dist / det_dist) + qx_value = (2.0 * pi / wavelength) * np.sin(theta) + return qx_value + + def _get_polar_value(self, qx_value, qy_value): + """ + Find qr_value and phi from qx_value and qy_value values + + : return qr_value, phi + """ + # find |q| on detector plane + qr_value = sqrt(qx_value*qx_value + qy_value*qy_value) + # find angle phi + phi = self._atan_phi(qy_value, qx_value) + + return qr_value, phi + + def _get_beamcenter_position(self, num_x, num_y, offset_x, offset_y): + """ + :param num_x: number of pixel in x-direction + :param num_y: number of pixel in y-direction + :param offset: detector offset in x-direction in pix number + + :return: pix number; pos_x, pos_y in pix index + """ + # beam center position + pos_x = num_x / 2 + pos_y = num_y / 2 + + # correction for offset + pos_x += offset_x + # correction for gravity that is always negative + pos_y -= offset_y + + return pos_x, pos_y + + def _get_beamcenter_drop(self): + """ + Get the beam center drop (delta y) in y diection due to gravity + + :return delta y: the beam center drop in cm + """ + # Check if mass == 0 (X-ray). + if self.mass == 0: + return 0 + # Covert unit from A to cm + unit_cm = 1e-08 + # Velocity of neutron in horizontal direction (~ actual velocity) + velocity = _PLANK_H / (self.mass * self.wave.wavelength * unit_cm) + # Compute delta y + delta_y = 0.5 + delta_y *= _GRAVITY + sampletodetector = self.sample2detector_distance[0] - \ + self.sample2sample_distance[0] + delta_y *= sampletodetector + delta_y *= (self.source2sample_distance[0] + self.sample2detector_distance[0]) + delta_y /= (velocity * velocity) + + return delta_y diff --git a/sas/sascalc/calculator/sas_gen.py b/sas/sascalc/calculator/sas_gen.py new file mode 100755 index 000000000..9f7e6a029 --- /dev/null +++ b/sas/sascalc/calculator/sas_gen.py @@ -0,0 +1,1128 @@ +# pylint: disable=invalid-name +""" +SAS generic computation and sld file readers +""" +from __future__ import print_function + +import os +import sys +import copy +import logging + +from periodictable import formula +from periodictable import nsf +import numpy as np + +from . import _sld2i +from .BaseComponent import BaseComponent + +logger = logging.getLogger(__name__) + +if sys.version_info[0] < 3: + def decode(s): + return s +else: + def decode(s): + return s.decode() if isinstance(s, bytes) else s + +MFACTOR_AM = 2.853E-12 +MFACTOR_MT = 2.3164E-9 +METER2ANG = 1.0E+10 +#Avogadro constant [1/mol] +NA = 6.02214129e+23 + +def mag2sld(mag, v_unit=None): + """ + Convert magnetization to magnatic SLD + sldm = Dm * mag where Dm = gamma * classical elec. radius/(2*Bohr magneton) + Dm ~ 2.853E-12 [A^(-2)] ==> Shouldn't be 2.90636E-12 [A^(-2)]??? + """ + if v_unit == "A/m": + factor = MFACTOR_AM + elif v_unit == "mT": + factor = MFACTOR_MT + else: + raise ValueError("Invalid valueunit") + sld_m = factor * mag + return sld_m + +def transform_center(pos_x, pos_y, pos_z): + """ + re-center + :return: posx, posy, posz [arrays] + """ + posx = pos_x - (min(pos_x) + max(pos_x)) / 2.0 + posy = pos_y - (min(pos_y) + max(pos_y)) / 2.0 + posz = pos_z - (min(pos_z) + max(pos_z)) / 2.0 + return posx, posy, posz + +class GenSAS(BaseComponent): + """ + Generic SAS computation Model based on sld (n & m) arrays + """ + def __init__(self): + """ + Init + :Params sld_data: MagSLD object + """ + # Initialize BaseComponent + BaseComponent.__init__(self) + self.sld_data = None + self.data_pos_unit = None + self.data_x = None + self.data_y = None + self.data_z = None + self.data_sldn = None + self.data_mx = None + self.data_my = None + self.data_mz = None + self.data_vol = None #[A^3] + self.is_avg = False + ## Name of the model + self.name = "GenSAS" + ## Define parameters + self.params = {} + self.params['scale'] = 1.0 + self.params['background'] = 0.0 + self.params['solvent_SLD'] = 0.0 + self.params['total_volume'] = 1.0 + self.params['Up_frac_in'] = 1.0 + self.params['Up_frac_out'] = 1.0 + self.params['Up_theta'] = 0.0 + self.description = 'GenSAS' + ## Parameter details [units, min, max] + self.details = {} + self.details['scale'] = ['', 0.0, np.inf] + self.details['background'] = ['[1/cm]', 0.0, np.inf] + self.details['solvent_SLD'] = ['1/A^(2)', -np.inf, np.inf] + self.details['total_volume'] = ['A^(3)', 0.0, np.inf] + self.details['Up_frac_in'] = ['[u/(u+d)]', 0.0, 1.0] + self.details['Up_frac_out'] = ['[u/(u+d)]', 0.0, 1.0] + self.details['Up_theta'] = ['[deg]', -np.inf, np.inf] + # fixed parameters + self.fixed = [] + + def set_pixel_volumes(self, volume): + """ + Set the volume of a pixel in (A^3) unit + :Param volume: pixel volume [float] + """ + if self.data_vol is None: + raise TypeError("data_vol is missing") + self.data_vol = volume + + def set_is_avg(self, is_avg=False): + """ + Sets is_avg: [bool] + """ + self.is_avg = is_avg + + def _gen(self, qx, qy): + """ + Evaluate the function + :Param x: array of x-values + :Param y: array of y-values + :Param i: array of initial i-value + :return: function value + """ + pos_x = self.data_x + pos_y = self.data_y + pos_z = self.data_z + if self.is_avg is None: + pos_x, pos_y, pos_z = transform_center(pos_x, pos_y, pos_z) + sldn = copy.deepcopy(self.data_sldn) + sldn -= self.params['solvent_SLD'] + # **** WARNING **** new_GenI holds pointers to numpy vectors + # be sure that they are contiguous double precision arrays and make + # sure the GC doesn't eat them before genicom is called. + # TODO: rewrite so that the parameters are passed directly to genicom + args = ( + (1 if self.is_avg else 0), + pos_x, pos_y, pos_z, + sldn, self.data_mx, self.data_my, + self.data_mz, self.data_vol, + self.params['Up_frac_in'], + self.params['Up_frac_out'], + self.params['Up_theta']) + model = _sld2i.new_GenI(*args) + if len(qy): + qx, qy = _vec(qx), _vec(qy) + I_out = np.empty_like(qx) + #print("npoints", qx.shape, "npixels", pos_x.shape) + _sld2i.genicomXY(model, qx, qy, I_out) + #print("I_out after", I_out) + else: + qx = _vec(qx) + I_out = np.empty_like(qx) + _sld2i.genicom(model, qx, I_out) + vol_correction = self.data_total_volume / self.params['total_volume'] + result = (self.params['scale'] * vol_correction * I_out + + self.params['background']) + return result + + def set_sld_data(self, sld_data=None): + """ + Sets sld_data + """ + self.sld_data = sld_data + self.data_pos_unit = sld_data.pos_unit + self.data_x = _vec(sld_data.pos_x) + self.data_y = _vec(sld_data.pos_y) + self.data_z = _vec(sld_data.pos_z) + self.data_sldn = _vec(sld_data.sld_n) + self.data_mx = _vec(sld_data.sld_mx) + self.data_my = _vec(sld_data.sld_my) + self.data_mz = _vec(sld_data.sld_mz) + self.data_vol = _vec(sld_data.vol_pix) + self.data_total_volume = sum(sld_data.vol_pix) + self.params['total_volume'] = sum(sld_data.vol_pix) + + def getProfile(self): + """ + Get SLD profile + : return: sld_data + """ + return self.sld_data + + def run(self, x=0.0): + """ + Evaluate the model + :param x: simple value + :return: (I value) + """ + if isinstance(x, list): + if len(x[1]) > 0: + msg = "Not a 1D." + raise ValueError(msg) + # 1D I is found at y =0 in the 2D pattern + out = self._gen(x[0], []) + return out + else: + msg = "Q must be given as list of qx's and qy's" + raise ValueError(msg) + + def runXY(self, x=0.0): + """ + Evaluate the model + :param x: simple value + :return: I value + :Use this runXY() for the computation + """ + if isinstance(x, list): + return self._gen(x[0], x[1]) + else: + msg = "Q must be given as list of qx's and qy's" + raise ValueError(msg) + + def evalDistribution(self, qdist): + """ + Evaluate a distribution of q-values. + + :param qdist: ndarray of scalar q-values (for 1D) or list [qx,qy] + where qx,qy are 1D ndarrays (for 2D). + """ + if isinstance(qdist, list): + return self.run(qdist) if len(qdist[1]) < 1 else self.runXY(qdist) + else: + mesg = "evalDistribution is expecting an ndarray of " + mesg += "a list [qx,qy] where qx,qy are arrays." + raise RuntimeError(mesg) + +def _vec(v): + return np.ascontiguousarray(v, 'd') + +class OMF2SLD(object): + """ + Convert OMFData to MAgData + """ + def __init__(self): + """ + Init + """ + self.pos_x = None + self.pos_y = None + self.pos_z = None + self.mx = None + self.my = None + self.mz = None + self.sld_n = None + self.vol_pix = None + self.output = None + self.omfdata = None + + def set_data(self, omfdata, shape='rectangular'): + """ + Set all data + """ + self.omfdata = omfdata + length = int(omfdata.xnodes * omfdata.ynodes * omfdata.znodes) + pos_x = np.arange(omfdata.xmin, + omfdata.xnodes*omfdata.xstepsize + omfdata.xmin, + omfdata.xstepsize) + pos_y = np.arange(omfdata.ymin, + omfdata.ynodes*omfdata.ystepsize + omfdata.ymin, + omfdata.ystepsize) + pos_z = np.arange(omfdata.zmin, + omfdata.znodes*omfdata.zstepsize + omfdata.zmin, + omfdata.zstepsize) + self.pos_x = np.tile(pos_x, int(omfdata.ynodes * omfdata.znodes)) + self.pos_y = pos_y.repeat(int(omfdata.xnodes)) + self.pos_y = np.tile(self.pos_y, int(omfdata.znodes)) + self.pos_z = pos_z.repeat(int(omfdata.xnodes * omfdata.ynodes)) + self.mx = omfdata.mx + self.my = omfdata.my + self.mz = omfdata.mz + self.sld_n = np.zeros(length) + + if omfdata.mx is None: + self.mx = np.zeros(length) + if omfdata.my is None: + self.my = np.zeros(length) + if omfdata.mz is None: + self.mz = np.zeros(length) + + self._check_data_length(length) + self.remove_null_points(False, False) + mask = np.ones(len(self.sld_n), dtype=bool) + if shape.lower() == 'ellipsoid': + try: + # Pixel (step) size included + x_c = max(self.pos_x) + min(self.pos_x) + y_c = max(self.pos_y) + min(self.pos_y) + z_c = max(self.pos_z) + min(self.pos_z) + x_d = max(self.pos_x) - min(self.pos_x) + y_d = max(self.pos_y) - min(self.pos_y) + z_d = max(self.pos_z) - min(self.pos_z) + x_r = (x_d + omfdata.xstepsize) / 2.0 + y_r = (y_d + omfdata.ystepsize) / 2.0 + z_r = (z_d + omfdata.zstepsize) / 2.0 + x_dir2 = ((self.pos_x - x_c / 2.0) / x_r) + x_dir2 *= x_dir2 + y_dir2 = ((self.pos_y - y_c / 2.0) / y_r) + y_dir2 *= y_dir2 + z_dir2 = ((self.pos_z - z_c / 2.0) / z_r) + z_dir2 *= z_dir2 + mask = (x_dir2 + y_dir2 + z_dir2) <= 1.0 + except Exception as exc: + logger.error(exc) + self.output = MagSLD(self.pos_x[mask], self.pos_y[mask], + self.pos_z[mask], self.sld_n[mask], + self.mx[mask], self.my[mask], self.mz[mask]) + self.output.set_pix_type('pixel') + self.output.set_pixel_symbols('pixel') + + def get_omfdata(self): + """ + Return all data + """ + return self.omfdata + + def get_output(self): + """ + Return output + """ + return self.output + + def _check_data_length(self, length): + """ + Check if the data lengths are consistent + :Params length: data length + """ + parts = (self.pos_x, self.pos_y, self.pos_z, self.mx, self.my, self.mz) + if any(len(v) != length for v in parts): + raise ValueError("Error: Inconsistent data length.") + + def remove_null_points(self, remove=False, recenter=False): + """ + Removes any mx, my, and mz = 0 points + """ + if remove: + is_nonzero = (np.fabs(self.mx) + np.fabs(self.my) + + np.fabs(self.mz)).nonzero() + if len(is_nonzero[0]) > 0: + self.pos_x = self.pos_x[is_nonzero] + self.pos_y = self.pos_y[is_nonzero] + self.pos_z = self.pos_z[is_nonzero] + self.sld_n = self.sld_n[is_nonzero] + self.mx = self.mx[is_nonzero] + self.my = self.my[is_nonzero] + self.mz = self.mz[is_nonzero] + if recenter: + self.pos_x -= (min(self.pos_x) + max(self.pos_x)) / 2.0 + self.pos_y -= (min(self.pos_y) + max(self.pos_y)) / 2.0 + self.pos_z -= (min(self.pos_z) + max(self.pos_z)) / 2.0 + + def get_magsld(self): + """ + return MagSLD + """ + return self.output + + +class OMFReader(object): + """ + Class to load omf/ascii files (3 columns w/header). + """ + ## File type + type_name = "OMF ASCII" + + ## Wildcards + type = ["OMF files (*.OMF, *.omf)|*.omf"] + ## List of allowed extensions + ext = ['.omf', '.OMF'] + + def read(self, path): + """ + Load data file + :param path: file path + :return: x, y, z, sld_n, sld_mx, sld_my, sld_mz + """ + desc = "" + mx = np.zeros(0) + my = np.zeros(0) + mz = np.zeros(0) + try: + input_f = open(path, 'rb') + buff = decode(input_f.read()) + lines = buff.split('\n') + input_f.close() + output = OMFData() + valueunit = None + for line in lines: + line = line.strip() + # Read data + if line and not line.startswith('#'): + try: + toks = line.split() + _mx = float(toks[0]) + _my = float(toks[1]) + _mz = float(toks[2]) + _mx = mag2sld(_mx, valueunit) + _my = mag2sld(_my, valueunit) + _mz = mag2sld(_mz, valueunit) + mx = np.append(mx, _mx) + my = np.append(my, _my) + mz = np.append(mz, _mz) + except Exception as exc: + # Skip non-data lines + logger.error(str(exc)+" when processing %r"%line) + #Reading Header; Segment count ignored + s_line = line.split(":", 1) + if s_line[0].lower().count("oommf") > 0: + oommf = s_line[1].lstrip() + if s_line[0].lower().count("title") > 0: + title = s_line[1].lstrip() + if s_line[0].lower().count("desc") > 0: + desc += s_line[1].lstrip() + desc += '\n' + if s_line[0].lower().count("meshtype") > 0: + meshtype = s_line[1].lstrip() + if s_line[0].lower().count("meshunit") > 0: + meshunit = s_line[1].lstrip() + if meshunit.count("m") < 1: + msg = "Error: \n" + msg += "We accept only m as meshunit" + raise ValueError(msg) + if s_line[0].lower().count("xbase") > 0: + xbase = s_line[1].lstrip() + if s_line[0].lower().count("ybase") > 0: + ybase = s_line[1].lstrip() + if s_line[0].lower().count("zbase") > 0: + zbase = s_line[1].lstrip() + if s_line[0].lower().count("xstepsize") > 0: + xstepsize = s_line[1].lstrip() + if s_line[0].lower().count("ystepsize") > 0: + ystepsize = s_line[1].lstrip() + if s_line[0].lower().count("zstepsize") > 0: + zstepsize = s_line[1].lstrip() + if s_line[0].lower().count("xnodes") > 0: + xnodes = s_line[1].lstrip() + if s_line[0].lower().count("ynodes") > 0: + ynodes = s_line[1].lstrip() + if s_line[0].lower().count("znodes") > 0: + znodes = s_line[1].lstrip() + if s_line[0].lower().count("xmin") > 0: + xmin = s_line[1].lstrip() + if s_line[0].lower().count("ymin") > 0: + ymin = s_line[1].lstrip() + if s_line[0].lower().count("zmin") > 0: + zmin = s_line[1].lstrip() + if s_line[0].lower().count("xmax") > 0: + xmax = s_line[1].lstrip() + if s_line[0].lower().count("ymax") > 0: + ymax = s_line[1].lstrip() + if s_line[0].lower().count("zmax") > 0: + zmax = s_line[1].lstrip() + if s_line[0].lower().count("valueunit") > 0: + valueunit = s_line[1].lstrip().rstrip() + if s_line[0].lower().count("valuemultiplier") > 0: + valuemultiplier = s_line[1].lstrip() + if s_line[0].lower().count("valuerangeminmag") > 0: + valuerangeminmag = s_line[1].lstrip() + if s_line[0].lower().count("valuerangemaxmag") > 0: + valuerangemaxmag = s_line[1].lstrip() + if s_line[0].lower().count("end") > 0: + output.filename = os.path.basename(path) + output.oommf = oommf + output.title = title + output.desc = desc + output.meshtype = meshtype + output.xbase = float(xbase) * METER2ANG + output.ybase = float(ybase) * METER2ANG + output.zbase = float(zbase) * METER2ANG + output.xstepsize = float(xstepsize) * METER2ANG + output.ystepsize = float(ystepsize) * METER2ANG + output.zstepsize = float(zstepsize) * METER2ANG + output.xnodes = float(xnodes) + output.ynodes = float(ynodes) + output.znodes = float(znodes) + output.xmin = float(xmin) * METER2ANG + output.ymin = float(ymin) * METER2ANG + output.zmin = float(zmin) * METER2ANG + output.xmax = float(xmax) * METER2ANG + output.ymax = float(ymax) * METER2ANG + output.zmax = float(zmax) * METER2ANG + output.valuemultiplier = valuemultiplier + output.valuerangeminmag = mag2sld(float(valuerangeminmag), \ + valueunit) + output.valuerangemaxmag = mag2sld(float(valuerangemaxmag), \ + valueunit) + output.set_m(mx, my, mz) + return output + except Exception: + msg = "%s is not supported: \n" % path + msg += "We accept only Text format OMF file." + raise RuntimeError(msg) + +class PDBReader(object): + """ + PDB reader class: limited for reading the lines starting with 'ATOM' + """ + type_name = "PDB" + ## Wildcards + type = ["pdb files (*.PDB, *.pdb)|*.pdb"] + ## List of allowed extensions + ext = ['.pdb', '.PDB'] + + def read(self, path): + """ + Load data file + + :param path: file path + :return: MagSLD + :raise RuntimeError: when the file can't be opened + """ + pos_x = np.zeros(0) + pos_y = np.zeros(0) + pos_z = np.zeros(0) + sld_n = np.zeros(0) + sld_mx = np.zeros(0) + sld_my = np.zeros(0) + sld_mz = np.zeros(0) + vol_pix = np.zeros(0) + pix_symbol = np.zeros(0) + x_line = [] + y_line = [] + z_line = [] + x_lines = [] + y_lines = [] + z_lines = [] + try: + input_f = open(path, 'rb') + buff = decode(input_f.read()) + lines = buff.split('\n') + input_f.close() + num = 0 + for line in lines: + try: + # check if line starts with "ATOM" + if line[0:6].strip().count('ATM') > 0 or \ + line[0:6].strip() == 'ATOM': + # define fields of interest + atom_name = line[12:16].strip() + try: + float(line[12]) + atom_name = atom_name[1].upper() + except Exception: + if len(atom_name) == 4: + atom_name = atom_name[0].upper() + elif line[12] != ' ': + atom_name = atom_name[0].upper() + \ + atom_name[1].lower() + else: + atom_name = atom_name[0].upper() + _pos_x = float(line[30:38].strip()) + _pos_y = float(line[38:46].strip()) + _pos_z = float(line[46:54].strip()) + pos_x = np.append(pos_x, _pos_x) + pos_y = np.append(pos_y, _pos_y) + pos_z = np.append(pos_z, _pos_z) + try: + val = nsf.neutron_sld(atom_name)[0] + # sld in Ang^-2 unit + val *= 1.0e-6 + sld_n = np.append(sld_n, val) + atom = formula(atom_name) + # cm to A units + vol = 1.0e+24 * atom.mass / atom.density / NA + vol_pix = np.append(vol_pix, vol) + except Exception: + logger.error("Error: set the sld of %s to zero"% atom_name) + sld_n = np.append(sld_n, 0.0) + sld_mx = np.append(sld_mx, 0) + sld_my = np.append(sld_my, 0) + sld_mz = np.append(sld_mz, 0) + pix_symbol = np.append(pix_symbol, atom_name) + elif line[0:6].strip().count('CONECT') > 0: + toks = line.split() + num = int(toks[1]) - 1 + val_list = [] + for val in toks[2:]: + try: + int_val = int(val) + except Exception: + break + if int_val == 0: + break + val_list.append(int_val) + #need val_list ordered + for val in val_list: + index = val - 1 + if (pos_x[index], pos_x[num]) in x_line and \ + (pos_y[index], pos_y[num]) in y_line and \ + (pos_z[index], pos_z[num]) in z_line: + continue + x_line.append((pos_x[num], pos_x[index])) + y_line.append((pos_y[num], pos_y[index])) + z_line.append((pos_z[num], pos_z[index])) + if len(x_line) > 0: + x_lines.append(x_line) + y_lines.append(y_line) + z_lines.append(z_line) + except Exception as exc: + logger.error(exc) + + output = MagSLD(pos_x, pos_y, pos_z, sld_n, sld_mx, sld_my, sld_mz) + output.set_conect_lines(x_line, y_line, z_line) + output.filename = os.path.basename(path) + output.set_pix_type('atom') + output.set_pixel_symbols(pix_symbol) + output.set_nodes() + output.set_pixel_volumes(vol_pix) + output.sld_unit = '1/A^(2)' + return output + except Exception: + raise RuntimeError("%s is not a sld file" % path) + + def write(self, path, data): + """ + Write + """ + print("Not implemented... ") + +class SLDReader(object): + """ + Class to load ascii files (7 columns). + """ + ## File type + type_name = "SLD ASCII" + ## Wildcards + type = ["sld files (*.SLD, *.sld)|*.sld", + "txt files (*.TXT, *.txt)|*.txt", + "all files (*.*)|*.*"] + ## List of allowed extensions + ext = ['.sld', '.SLD', '.txt', '.TXT', '.*'] + def read(self, path): + """ + Load data file + :param path: file path + :return MagSLD: x, y, z, sld_n, sld_mx, sld_my, sld_mz + :raise RuntimeError: when the file can't be opened + :raise ValueError: when the length of the data vectors are inconsistent + """ + try: + pos_x = np.zeros(0) + pos_y = np.zeros(0) + pos_z = np.zeros(0) + sld_n = np.zeros(0) + sld_mx = np.zeros(0) + sld_my = np.zeros(0) + sld_mz = np.zeros(0) + try: + # Use numpy to speed up loading + input_f = np.loadtxt(path, dtype='float', skiprows=1, + ndmin=1, unpack=True) + pos_x = np.array(input_f[0]) + pos_y = np.array(input_f[1]) + pos_z = np.array(input_f[2]) + sld_n = np.array(input_f[3]) + sld_mx = np.array(input_f[4]) + sld_my = np.array(input_f[5]) + sld_mz = np.array(input_f[6]) + ncols = len(input_f) + if ncols == 8: + vol_pix = np.array(input_f[7]) + elif ncols == 7: + vol_pix = None + except Exception: + # For older version of numpy + input_f = open(path, 'rb') + buff = decode(input_f.read()) + lines = buff.split('\n') + input_f.close() + for line in lines: + toks = line.split() + try: + _pos_x = float(toks[0]) + _pos_y = float(toks[1]) + _pos_z = float(toks[2]) + _sld_n = float(toks[3]) + _sld_mx = float(toks[4]) + _sld_my = float(toks[5]) + _sld_mz = float(toks[6]) + pos_x = np.append(pos_x, _pos_x) + pos_y = np.append(pos_y, _pos_y) + pos_z = np.append(pos_z, _pos_z) + sld_n = np.append(sld_n, _sld_n) + sld_mx = np.append(sld_mx, _sld_mx) + sld_my = np.append(sld_my, _sld_my) + sld_mz = np.append(sld_mz, _sld_mz) + try: + _vol_pix = float(toks[7]) + vol_pix = np.append(vol_pix, _vol_pix) + except Exception as exc: + vol_pix = None + except Exception as exc: + # Skip non-data lines + logger.error(exc) + output = MagSLD(pos_x, pos_y, pos_z, sld_n, + sld_mx, sld_my, sld_mz) + output.filename = os.path.basename(path) + output.set_pix_type('pixel') + output.set_pixel_symbols('pixel') + if vol_pix is not None: + output.set_pixel_volumes(vol_pix) + return output + except Exception: + raise RuntimeError("%s is not a sld file" % path) + + def write(self, path, data): + """ + Write sld file + :Param path: file path + :Param data: MagSLD data object + """ + if path is None: + raise ValueError("Missing the file path.") + if data is None: + raise ValueError("Missing the data to save.") + x_val = data.pos_x + y_val = data.pos_y + z_val = data.pos_z + vol_pix = data.vol_pix + length = len(x_val) + sld_n = data.sld_n + if sld_n is None: + sld_n = np.zeros(length) + sld_mx = data.sld_mx + if sld_mx is None: + sld_mx = np.zeros(length) + sld_my = np.zeros(length) + sld_mz = np.zeros(length) + else: + sld_my = data.sld_my + sld_mz = data.sld_mz + out = open(path, 'w') + # First Line: Column names + out.write("X Y Z SLDN SLDMx SLDMy SLDMz VOLUMEpix") + for ind in range(length): + out.write("\n%g %g %g %g %g %g %g %g" % \ + (x_val[ind], y_val[ind], z_val[ind], sld_n[ind], + sld_mx[ind], sld_my[ind], sld_mz[ind], vol_pix[ind])) + out.close() + + +class OMFData(object): + """ + OMF Data. + """ + _meshunit = "A" + _valueunit = "A^(-2)" + def __init__(self): + """ + Init for mag SLD + """ + self.filename = 'default' + self.oommf = '' + self.title = '' + self.desc = '' + self.meshtype = '' + self.meshunit = self._meshunit + self.valueunit = self._valueunit + self.xbase = 0.0 + self.ybase = 0.0 + self.zbase = 0.0 + self.xstepsize = 6.0 + self.ystepsize = 6.0 + self.zstepsize = 6.0 + self.xnodes = 10.0 + self.ynodes = 10.0 + self.znodes = 10.0 + self.xmin = 0.0 + self.ymin = 0.0 + self.zmin = 0.0 + self.xmax = 60.0 + self.ymax = 60.0 + self.zmax = 60.0 + self.mx = None + self.my = None + self.mz = None + self.valuemultiplier = 1. + self.valuerangeminmag = 0 + self.valuerangemaxmag = 0 + + def __str__(self): + """ + doc strings + """ + _str = "Type: %s\n" % self.__class__.__name__ + _str += "File: %s\n" % self.filename + _str += "OOMMF: %s\n" % self.oommf + _str += "Title: %s\n" % self.title + _str += "Desc: %s\n" % self.desc + _str += "meshtype: %s\n" % self.meshtype + _str += "meshunit: %s\n" % str(self.meshunit) + _str += "xbase: %s [%s]\n" % (str(self.xbase), self.meshunit) + _str += "ybase: %s [%s]\n" % (str(self.ybase), self.meshunit) + _str += "zbase: %s [%s]\n" % (str(self.zbase), self.meshunit) + _str += "xstepsize: %s [%s]\n" % (str(self.xstepsize), + self.meshunit) + _str += "ystepsize: %s [%s]\n" % (str(self.ystepsize), + self.meshunit) + _str += "zstepsize: %s [%s]\n" % (str(self.zstepsize), + self.meshunit) + _str += "xnodes: %s\n" % str(self.xnodes) + _str += "ynodes: %s\n" % str(self.ynodes) + _str += "znodes: %s\n" % str(self.znodes) + _str += "xmin: %s [%s]\n" % (str(self.xmin), self.meshunit) + _str += "ymin: %s [%s]\n" % (str(self.ymin), self.meshunit) + _str += "zmin: %s [%s]\n" % (str(self.zmin), self.meshunit) + _str += "xmax: %s [%s]\n" % (str(self.xmax), self.meshunit) + _str += "ymax: %s [%s]\n" % (str(self.ymax), self.meshunit) + _str += "zmax: %s [%s]\n" % (str(self.zmax), self.meshunit) + _str += "valueunit: %s\n" % self.valueunit + _str += "valuemultiplier: %s\n" % str(self.valuemultiplier) + _str += "ValueRangeMinMag:%s [%s]\n" % (str(self.valuerangeminmag), + self.valueunit) + _str += "ValueRangeMaxMag:%s [%s]\n" % (str(self.valuerangemaxmag), + self.valueunit) + return _str + + def set_m(self, mx, my, mz): + """ + Set the Mx, My, Mz values + """ + self.mx = mx + self.my = my + self.mz = mz + +class MagSLD(object): + """ + Magnetic SLD. + """ + pos_x = None + pos_y = None + pos_z = None + sld_n = None + sld_mx = None + sld_my = None + sld_mz = None + # Units + _pos_unit = 'A' + _sld_unit = '1/A^(2)' + _pix_type = 'pixel' + + def __init__(self, pos_x, pos_y, pos_z, sld_n=None, + sld_mx=None, sld_my=None, sld_mz=None, vol_pix=None): + """ + Init for mag SLD + :params : All should be numpy 1D array + """ + self.is_data = True + self.filename = '' + self.xstepsize = 6.0 + self.ystepsize = 6.0 + self.zstepsize = 6.0 + self.xnodes = 10.0 + self.ynodes = 10.0 + self.znodes = 10.0 + self.has_stepsize = False + self.has_conect = False + self.pos_unit = self._pos_unit + self.sld_unit = self._sld_unit + self.pix_type = 'pixel' + self.pos_x = pos_x + self.pos_y = pos_y + self.pos_z = pos_z + self.sld_n = sld_n + self.line_x = None + self.line_y = None + self.line_z = None + self.sld_mx = sld_mx + self.sld_my = sld_my + self.sld_mz = sld_mz + self.vol_pix = vol_pix + self.sld_m = None + self.sld_phi = None + self.sld_theta = None + self.pix_symbol = None + if sld_mx is not None and sld_my is not None and sld_mz is not None: + self.set_sldms(sld_mx, sld_my, sld_mz) + self.set_nodes() + + def __str__(self): + """ + doc strings + """ + _str = "Type: %s\n" % self.__class__.__name__ + _str += "File: %s\n" % self.filename + _str += "Axis_unit: %s\n" % self.pos_unit + _str += "SLD_unit: %s\n" % self.sld_unit + return _str + + def set_pix_type(self, pix_type): + """ + Set pixel type + :Param pix_type: string, 'pixel' or 'atom' + """ + self.pix_type = pix_type + + def set_sldn(self, sld_n): + """ + Sets neutron SLD + """ + if sld_n.__class__.__name__ == 'float': + if self.is_data: + # For data, put the value to only the pixels w non-zero M + is_nonzero = (np.fabs(self.sld_mx) + + np.fabs(self.sld_my) + + np.fabs(self.sld_mz)).nonzero() + self.sld_n = np.zeros(len(self.pos_x)) + if len(self.sld_n[is_nonzero]) > 0: + self.sld_n[is_nonzero] = sld_n + else: + self.sld_n.fill(sld_n) + else: + # For non-data, put the value to all the pixels + self.sld_n = np.ones(len(self.pos_x)) * sld_n + else: + self.sld_n = sld_n + + def set_sldms(self, sld_mx, sld_my, sld_mz): + r""" + Sets mx, my, mz and abs(m). + """ # Note: escaping + if sld_mx.__class__.__name__ == 'float': + self.sld_mx = np.ones(len(self.pos_x)) * sld_mx + else: + self.sld_mx = sld_mx + if sld_my.__class__.__name__ == 'float': + self.sld_my = np.ones(len(self.pos_x)) * sld_my + else: + self.sld_my = sld_my + if sld_mz.__class__.__name__ == 'float': + self.sld_mz = np.ones(len(self.pos_x)) * sld_mz + else: + self.sld_mz = sld_mz + + sld_m = np.sqrt(sld_mx * sld_mx + sld_my * sld_my + \ + sld_mz * sld_mz) + self.sld_m = sld_m + + def set_pixel_symbols(self, symbol='pixel'): + """ + Set pixel + :Params pixel: str; pixel or atomic symbol, or array of strings + """ + if self.sld_n is None: + return + if symbol.__class__.__name__ == 'str': + self.pix_symbol = np.repeat(symbol, len(self.sld_n)) + else: + self.pix_symbol = symbol + + def set_pixel_volumes(self, vol): + """ + Set pixel volumes + :Params pixel: str; pixel or atomic symbol, or array of strings + """ + if self.sld_n is None: + return + if vol.__class__.__name__ == 'ndarray': + self.vol_pix = vol + elif vol.__class__.__name__.count('float') > 0: + self.vol_pix = np.repeat(vol, len(self.sld_n)) + else: + self.vol_pix = None + + def get_sldn(self): + """ + Returns nuclear sld + """ + return self.sld_n + + def set_nodes(self): + """ + Set xnodes, ynodes, and znodes + """ + self.set_stepsize() + if self.pix_type == 'pixel': + try: + xdist = (max(self.pos_x) - min(self.pos_x)) / self.xstepsize + ydist = (max(self.pos_y) - min(self.pos_y)) / self.ystepsize + zdist = (max(self.pos_z) - min(self.pos_z)) / self.zstepsize + self.xnodes = int(xdist) + 1 + self.ynodes = int(ydist) + 1 + self.znodes = int(zdist) + 1 + except Exception: + self.xnodes = None + self.ynodes = None + self.znodes = None + else: + self.xnodes = None + self.ynodes = None + self.znodes = None + + def set_stepsize(self): + """ + Set xtepsize, ystepsize, and zstepsize + """ + if self.pix_type == 'pixel': + try: + xpos_pre = self.pos_x[0] + ypos_pre = self.pos_y[0] + zpos_pre = self.pos_z[0] + for x_pos in self.pos_x: + if xpos_pre != x_pos: + self.xstepsize = np.fabs(x_pos - xpos_pre) + break + for y_pos in self.pos_y: + if ypos_pre != y_pos: + self.ystepsize = np.fabs(y_pos - ypos_pre) + break + for z_pos in self.pos_z: + if zpos_pre != z_pos: + self.zstepsize = np.fabs(z_pos - zpos_pre) + break + #default pix volume + self.vol_pix = np.ones(len(self.pos_x)) + vol = self.xstepsize * self.ystepsize * self.zstepsize + self.set_pixel_volumes(vol) + self.has_stepsize = True + except Exception: + self.xstepsize = None + self.ystepsize = None + self.zstepsize = None + self.vol_pix = None + self.has_stepsize = False + else: + self.xstepsize = None + self.ystepsize = None + self.zstepsize = None + self.has_stepsize = True + return self.xstepsize, self.ystepsize, self.zstepsize + + def set_conect_lines(self, line_x, line_y, line_z): + """ + Set bonding line data if taken from pdb + """ + if line_x.__class__.__name__ != 'list' or len(line_x) < 1: + return + if line_y.__class__.__name__ != 'list' or len(line_y) < 1: + return + if line_z.__class__.__name__ != 'list' or len(line_z) < 1: + return + self.has_conect = True + self.line_x = line_x + self.line_y = line_y + self.line_z = line_z + +def _get_data_path(*path_parts): + from os.path import realpath, join as joinpath, dirname, abspath + # in sas/sascalc/calculator; want sas/sasview/test + return joinpath(dirname(realpath(__file__)), + '..', '..', 'sasview', 'test', *path_parts) + +def test_load(): + """ + Test code + """ + from mpl_toolkits.mplot3d import Axes3D + tfpath = _get_data_path("1d_data", "CoreXY_ShellZ.txt") + ofpath = _get_data_path("coordinate_data", "A_Raw_Example-1.omf") + if not os.path.isfile(tfpath) or not os.path.isfile(ofpath): + raise ValueError("file(s) not found: %r, %r"%(tfpath, ofpath)) + reader = SLDReader() + oreader = OMFReader() + output = reader.read(tfpath) + ooutput = oreader.read(ofpath) + foutput = OMF2SLD() + foutput.set_data(ooutput) + + import matplotlib.pyplot as plt + fig = plt.figure() + ax = Axes3D(fig) + ax.plot(output.pos_x, output.pos_y, output.pos_z, '.', c="g", + alpha=0.7, markeredgecolor='gray', rasterized=True) + gap = 7 + max_mx = max(output.sld_mx) + max_my = max(output.sld_my) + max_mz = max(output.sld_mz) + max_m = max(max_mx, max_my, max_mz) + x2 = output.pos_x+output.sld_mx/max_m * gap + y2 = output.pos_y+output.sld_my/max_m * gap + z2 = output.pos_z+output.sld_mz/max_m * gap + x_arrow = np.column_stack((output.pos_x, x2)) + y_arrow = np.column_stack((output.pos_y, y2)) + z_arrow = np.column_stack((output.pos_z, z2)) + unit_x2 = output.sld_mx / max_m + unit_y2 = output.sld_my / max_m + unit_z2 = output.sld_mz / max_m + color_x = np.fabs(unit_x2 * 0.8) + color_y = np.fabs(unit_y2 * 0.8) + color_z = np.fabs(unit_z2 * 0.8) + colors = np.column_stack((color_x, color_y, color_z)) + plt.show() + +def test_save(): + ofpath = _get_data_path("coordinate_data", "A_Raw_Example-1.omf") + if not os.path.isfile(ofpath): + raise ValueError("file(s) not found: %r"%(ofpath,)) + oreader = OMFReader() + omfdata = oreader.read(ofpath) + omf2sld = OMF2SLD() + omf2sld.set_data(omfdata) + writer = SLDReader() + writer.write("out.txt", omf2sld.output) + +def test(): + """ + Test code + """ + ofpath = _get_data_path("coordinate_data", "A_Raw_Example-1.omf") + if not os.path.isfile(ofpath): + raise ValueError("file(s) not found: %r"%(ofpath,)) + oreader = OMFReader() + omfdata = oreader.read(ofpath) + omf2sld = OMF2SLD() + omf2sld.set_data(omfdata) + model = GenSAS() + model.set_sld_data(omf2sld.output) + x = np.linspace(0, 0.1, 11)[1:] + return model.runXY([x, x]) + +if __name__ == "__main__": + #test_load() + #test_save() + #print(test()) + test() diff --git a/sas/sascalc/calculator/slit_length_calculator.py b/sas/sascalc/calculator/slit_length_calculator.py new file mode 100755 index 000000000..649fbffb2 --- /dev/null +++ b/sas/sascalc/calculator/slit_length_calculator.py @@ -0,0 +1,108 @@ +""" +This module is a small tool to allow user to quickly +determine the slit length value of data. +""" + + +class SlitlengthCalculator(object): + """ + compute slit length from SAXSess beam profile (1st col. Q , 2nd col. I , + and 3rd col. dI.: don't need the 3rd) + """ + def __init__(self): + + # x data + self.x = None + # y data + self.y = None + # default slit length + self.slit_length = 0.0 + + # The unit is unknown from SAXSess profile: + # It seems 1/nm but it could be not fixed, + # so users should be notified to determine the unit by themselves. + self.slit_length_unit = "unknown" + + def set_data(self, x=None, y=None): + """ + Receive two vector x, y and prepare the slit calculator for + computation. + + :param x: array + :param y: array + """ + self.x = x + self.y = y + + def calculate_slit_length(self): + """ + Calculate slit length. + + :return: the slit length calculated value. + """ + # None data do nothing + if self.y is None or self.x is None: + return + # set local variable + y = self.y + x = self.x + + # find max y + max_y = y.max() + + # initial values + y_sum = 0.0 + y_max = 0.0 + ind = 0 + + # sum 10 or more y values until getting max_y, + while True: + if ind >= 10 and y_max == max_y: + break + y_sum = y_sum + y[ind] + if y[ind] > y_max: + y_max = y[ind] + ind += 1 + + # find the average value/2 of the top values + y_half = y_sum/(2.0*ind) + + # defaults + y_half_d = 0.0 + ind = 0 + # find indices where it crosses y = y_half. + while True: + # no need to check when ind == 0 + ind += 1 + # y value and ind just after passed the spot of the half height + y_half_d = y[ind] + if y[ind] < y_half: + break + + # y value and ind just before passed the spot of the half height + y_half_u = y[ind - 1] + + # get corresponding x values + x_half_d = x[ind] + x_half_u = x[ind - 1] + + # calculate x at y = y_half using linear interpolation + if y_half_u == y_half_d: + x_half = (x_half_d + x_half_u)/2.0 + else: + x_half = ((x_half_u * (y_half - y_half_d) + + x_half_d * (y_half_u - y_half)) + / (y_half_u - y_half_d)) + + # Our slit length is half width, so just give half beam value + slit_length = x_half + + # set slit_length + self.slit_length = slit_length + return self.slit_length + + def get_slit_length_unit(self): + """ + :return: the slit length unit. + """ + return self.slit_length_unit diff --git a/sas/sascalc/corfunc/LICENSE.TXT b/sas/sascalc/corfunc/LICENSE.TXT new file mode 100755 index 000000000..f5d1a884f --- /dev/null +++ b/sas/sascalc/corfunc/LICENSE.TXT @@ -0,0 +1,23 @@ +From https://github.com/rprospero/corfunc-py + +The MIT License (MIT) + +Copyright (c) 2016 rprospero + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/sas/sascalc/corfunc/__init__.py b/sas/sascalc/corfunc/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/sas/sascalc/corfunc/corfunc_calculator.py b/sas/sascalc/corfunc/corfunc_calculator.py new file mode 100755 index 000000000..6ceca0b99 --- /dev/null +++ b/sas/sascalc/corfunc/corfunc_calculator.py @@ -0,0 +1,286 @@ +""" +This module implements corfunc +""" +import warnings +import numpy as np +from scipy.optimize import curve_fit +from scipy.interpolate import interp1d +from scipy.fftpack import dct +from scipy.signal import argrelextrema +from numpy.linalg import lstsq +from sas.sascalc.dataloader.data_info import Data1D +from sas.sascalc.corfunc.transform_thread import FourierThread +from sas.sascalc.corfunc.transform_thread import HilbertThread + +class CorfuncCalculator(object): + + class _Interpolator(object): + """ + Interpolates between curve f and curve g over the range start:stop and + caches the result of the function when it's called + + :param f: The first curve to interpolate + :param g: The second curve to interpolate + :param start: The value at which to start the interpolation + :param stop: The value at which to stop the interpolation + """ + def __init__(self, f, g, start, stop): + self.f = f + self.g = g + self.start = start + self.stop = stop + self._lastx = [] + self._lasty = [] + + def __call__(self, x): + # If input is a single number, evaluate the function at that number + # and return a single number + if type(x) == float or type(x) == int: + return self._smoothed_function(np.array([x]))[0] + # If input is a list, and is different to the last input, evaluate + # the function at each point. If the input is the same as last time + # the function was called, return the result that was calculated + # last time instead of explicity evaluating the function again. + elif self._lastx == [] or x.tolist() != self._lastx.tolist(): + self._lasty = self._smoothed_function(x) + self._lastx = x + return self._lasty + + def _smoothed_function(self,x): + ys = np.zeros(x.shape) + ys[x <= self.start] = self.f(x[x <= self.start]) + ys[x >= self.stop] = self.g(x[x >= self.stop]) + with warnings.catch_warnings(): + # Ignore divide by zero error + warnings.simplefilter('ignore') + h = 1/(1+(x-self.stop)**2/(self.start-x)**2) + mask = np.logical_and(x > self.start, x < self.stop) + ys[mask] = h[mask]*self.g(x[mask])+(1-h[mask])*self.f(x[mask]) + return ys + + + def __init__(self, data=None, lowerq=None, upperq=None, scale=1): + """ + Initialize the class. + + :param data: Data of the type DataLoader.Data1D + :param lowerq: The Q value to use as the boundary for + Guinier extrapolation + :param upperq: A tuple of the form (lower, upper). + Values between lower and upper will be used for Porod extrapolation + :param scale: Scaling factor for I(q) + """ + self._data = None + self.set_data(data, scale) + self.lowerq = lowerq + self.upperq = upperq + self.background = self.compute_background() + self._transform_thread = None + + def set_data(self, data, scale=1): + """ + Prepares the data for analysis + + :return: new_data = data * scale - background + """ + if data is None: + return + # Only process data of the class Data1D + if not issubclass(data.__class__, Data1D): + raise ValueError("Data must be of the type DataLoader.Data1D") + + # Prepare the data + new_data = Data1D(x=data.x, y=data.y) + new_data *= scale + + # Ensure the errors are set correctly + if new_data.dy is None or len(new_data.x) != len(new_data.dy) or \ + (min(new_data.dy) == 0 and max(new_data.dy) == 0): + new_data.dy = np.ones(len(new_data.x)) + + self._data = new_data + + def compute_background(self, upperq=None): + """ + Compute the background level from the Porod region of the data + """ + if self._data is None: return 0 + elif upperq is None and self.upperq is not None: upperq = self.upperq + elif upperq is None and self.upperq is None: return 0 + q = self._data.x + mask = np.logical_and(q > upperq[0], q < upperq[1]) + _, _, bg = self._fit_porod(q[mask], self._data.y[mask]) + + return bg + + def compute_extrapolation(self): + """ + Extrapolate and interpolate scattering data + + :return: The extrapolated data + """ + q = self._data.x + iq = self._data.y + + params, s2 = self._fit_data(q, iq) + # Extrapolate to 100*Qmax in experimental data + qs = np.arange(0, q[-1]*100, (q[1]-q[0])) + iqs = s2(qs) + + extrapolation = Data1D(qs, iqs) + + return params, extrapolation, s2 + + def compute_transform(self, extrapolation, trans_type, background=None, + completefn=None, updatefn=None): + """ + Transform an extrapolated scattering curve into a correlation function. + + :param extrapolation: The extrapolated data + :param background: The background value (if not provided, previously + calculated value will be used) + :param extrap_fn: A callable function representing the extraoplated data + :param completefn: The function to call when the transform calculation + is complete + :param updatefn: The function to call to update the GUI with the status + of the transform calculation + :return: The transformed data + """ + if self._transform_thread is not None: + if self._transform_thread.isrunning(): return + + if background is None: background = self.background + + if trans_type == 'fourier': + self._transform_thread = FourierThread(self._data, extrapolation, + background, completefn=completefn, + updatefn=updatefn) + elif trans_type == 'hilbert': + self._transform_thread = HilbertThread(self._data, extrapolation, + background, completefn=completefn, updatefn=updatefn) + else: + err = ("Incorrect transform type supplied, must be 'fourier'", + " or 'hilbert'") + raise ValueError(err) + + self._transform_thread.queue() + + def transform_isrunning(self): + if self._transform_thread is None: return False + return self._transform_thread.isrunning() + + def stop_transform(self): + if self._transform_thread.isrunning(): + self._transform_thread.stop() + + def extract_parameters(self, transformed_data): + """ + Extract the interesting measurements from a correlation function + + :param transformed_data: Fourier transformation of the extrapolated data + """ + # Calculate indexes of maxima and minima + x = transformed_data.x + y = transformed_data.y + maxs = argrelextrema(y, np.greater)[0] + mins = argrelextrema(y, np.less)[0] + + # If there are no maxima, return None + if len(maxs) == 0: + return None + + GammaMin = y[mins[0]] # The value at the first minimum + + ddy = (y[:-2]+y[2:]-2*y[1:-1])/(x[2:]-x[:-2])**2 # 2nd derivative of y + dy = (y[2:]-y[:-2])/(x[2:]-x[:-2]) # 1st derivative of y + # Find where the second derivative goes to zero + zeros = argrelextrema(np.abs(ddy), np.less)[0] + # locate the first inflection point + linear_point = zeros[0] + + # Try to calculate slope around linear_point using 80 data points + lower = linear_point - 40 + upper = linear_point + 40 + + # If too few data points to the left, use linear_point*2 data points + if lower < 0: + lower = 0 + upper = linear_point * 2 + # If too few to right, use 2*(dy.size - linear_point) data points + elif upper > len(dy): + upper = len(dy) + width = len(dy) - linear_point + lower = 2*linear_point - dy.size + + m = np.mean(dy[lower:upper]) # Linear slope + b = y[1:-1][linear_point]-m*x[1:-1][linear_point] # Linear intercept + + Lc = (GammaMin-b)/m # Hard block thickness + + # Find the data points where the graph is linear to within 1% + mask = np.where(np.abs((y-(m*x+b))/y) < 0.01)[0] + if len(mask) == 0: # Return garbage for bad fits + return { 'max': self._round_sig_figs(x[maxs[0]], 6) } + dtr = x[mask[0]] # Beginning of Linear Section + d0 = x[mask[-1]] # End of Linear Section + GammaMax = y[mask[-1]] + A = np.abs(GammaMin/GammaMax) # Normalized depth of minimum + + params = { + 'max': x[maxs[0]], + 'dtr': dtr, + 'Lc': Lc, + 'd0': d0, + 'A': A, + 'fill': Lc/x[maxs[0]] + } + + return params + + + def _porod(self, q, K, sigma, bg): + """Equation for the Porod region of the data""" + return bg + (K*q**(-4))*np.exp(-q**2*sigma**2) + + def _fit_guinier(self, q, iq): + """Fit the Guinier region of the curve""" + A = np.vstack([q**2, np.ones(q.shape)]).T + return lstsq(A, np.log(iq)) + + def _fit_porod(self, q, iq): + """Fit the Porod region of the curve""" + fitp = curve_fit(lambda q, k, sig, bg: self._porod(q, k, sig, bg)*q**2, + q, iq*q**2, bounds=([-np.inf, 0, -np.inf], [np.inf, np.inf, np.inf]))[0] + k, sigma, bg = fitp + return k, sigma, bg + + def _fit_data(self, q, iq): + """ + Given a data set, extrapolate out to large q with Porod and + to q=0 with Guinier + """ + mask = np.logical_and(q > self.upperq[0], q < self.upperq[1]) + + # Returns an array where the 1st and 2nd elements are the values of k + # and sigma for the best-fit Porod function + k, sigma, _ = self._fit_porod(q[mask], iq[mask]) + bg = self.background + + # Smooths between the best-fit porod function and the data to produce a + # better fitting curve + data = interp1d(q, iq) + s1 = self._Interpolator(data, + lambda x: self._porod(x, k, sigma, bg), self.upperq[0], q[-1]) + + mask = np.logical_and(q < self.lowerq, 0 < q) + + # Returns parameters for the best-fit Guinier function + g = self._fit_guinier(q[mask], iq[mask])[0] + + # Smooths between the best-fit Guinier function and the Porod curve + s2 = self._Interpolator((lambda x: (np.exp(g[1]+g[0]*x**2))), s1, q[0], + self.lowerq) + + params = {'A': g[1], 'B': g[0], 'K': k, 'sigma': sigma} + + return params, s2 diff --git a/sas/sascalc/corfunc/transform_thread.py b/sas/sascalc/corfunc/transform_thread.py new file mode 100755 index 000000000..5ec3ace81 --- /dev/null +++ b/sas/sascalc/corfunc/transform_thread.py @@ -0,0 +1,112 @@ +from sas.sascalc.data_util.calcthread import CalcThread +from sas.sascalc.dataloader.data_info import Data1D +from scipy.fftpack import dct +from scipy.integrate import trapz, cumtrapz +import numpy as np +from time import sleep + +class FourierThread(CalcThread): + def __init__(self, raw_data, extrapolated_data, bg, updatefn=None, + completefn=None): + CalcThread.__init__(self, updatefn=updatefn, completefn=completefn) + self.data = raw_data + self.background = bg + self.extrapolation = extrapolated_data + + def check_if_cancelled(self): + if self.isquit(): + self.update("Fourier transform cancelled.") + self.complete(transforms=None) + return True + return False + + def compute(self): + qs = self.extrapolation.x + iqs = self.extrapolation.y + q = self.data.x + background = self.background + + xs = np.pi*np.arange(len(qs),dtype=np.float32)/(q[1]-q[0])/len(qs) + + self.ready(delay=0.0) + self.update(msg="Fourier transform in progress.") + self.ready(delay=0.0) + + if self.check_if_cancelled(): return + try: + # ----- 1D Correlation Function ----- + gamma1 = dct((iqs-background)*qs**2) + Q = gamma1.max() + gamma1 /= Q + + if self.check_if_cancelled(): return + + # ----- 3D Correlation Function ----- + # gamma3(R) = 1/R int_{0}^{R} gamma1(x) dx + # trapz uses the trapezium rule to calculate the integral + mask = xs <= 1000.0 # Only calculate gamma3 up to x=1000 (as this is all that's plotted) + # gamma3 = [trapz(gamma1[:n], xs[:n])/xs[n-1] for n in range(2, len(xs[mask]) + 1)]j + # gamma3.insert(0, 1.0) # Gamma_3(0) is defined as 1 + n = len(xs[mask]) + gamma3 = cumtrapz(gamma1[:n], xs[:n])/xs[1:n] + gamma3 = np.hstack((1.0, gamma3)) # Gamma_3(0) is defined as 1 + + if self.check_if_cancelled(): return + + # ----- Interface Distribution function ----- + idf = dct(-qs**4 * (iqs-background)) + + if self.check_if_cancelled(): return + + # Manually calculate IDF(0.0), since scipy DCT tends to give us a + # very large negative value. + # IDF(x) = int_0^inf q^4 * I(q) * cos(q*x) * dq + # => IDF(0) = int_0^inf q^4 * I(q) * dq + idf[0] = trapz(-qs**4 * (iqs-background), qs) + idf /= Q # Normalise using scattering invariant + + except Exception as e: + import logging + logger = logging.getLogger(__name__) + logger.error(e) + + self.update(msg="Fourier transform failed.") + self.complete(transforms=None) + return + if self.isquit(): + return + self.update(msg="Fourier transform completed.") + + transform1 = Data1D(xs, gamma1) + transform3 = Data1D(xs[xs <= 1000], gamma3) + idf = Data1D(xs, idf) + + transforms = (transform1, transform3, idf) + + self.complete(transforms=transforms) + +class HilbertThread(CalcThread): + def __init__(self, raw_data, extrapolated_data, bg, updatefn=None, + completefn=None): + CalcThread.__init__(self, updatefn=updatefn, completefn=completefn) + self.data = raw_data + self.background = bg + self.extrapolation = extrapolated_data + + def compute(self): + qs = self.extrapolation.x + iqs = self.extrapolation.y + q = self.data.x + background = self.background + + self.ready(delay=0.0) + self.update(msg="Starting Hilbert transform.") + self.ready(delay=0.0) + if self.isquit(): + return + + # TODO: Implement hilbert transform + + self.update(msg="Hilbert transform completed.") + + self.complete(transforms=None) diff --git a/sas/sascalc/data_util/__init__.py b/sas/sascalc/data_util/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/sas/sascalc/data_util/calcthread.py b/sas/sascalc/data_util/calcthread.py new file mode 100755 index 000000000..bea2fe37d --- /dev/null +++ b/sas/sascalc/data_util/calcthread.py @@ -0,0 +1,330 @@ +# This program is public domain + +## \file +# \brief Abstract class for defining calculation threads. +# +from __future__ import print_function + +import traceback +import sys +import logging +try: + import _thread as thread +except ImportError: # CRUFT: python 2 support + import thread + +if sys.platform.count("darwin") > 0: + import time + stime = time.time() + + def clock(): + return time.time() - stime + + def sleep(t): + return time.sleep(t) +else: + from time import clock + from time import sleep + +logger = logging.getLogger(__name__) + + +class CalcThread: + """Threaded calculation class. Inherit from here and specialize + the compute() method to perform the appropriate operations for the + class. + + If you specialize the __init__ method be sure to call + CalcThread.__init__, passing it the keyword arguments for + yieldtime, worktime, update and complete. + + When defining the compute() method you need to include code which + allows the GUI to run. They are as follows: :: + + self.isquit() # call frequently to check for interrupts + self.update(kw=...) # call when the GUI could be updated + self.complete(kw=...) # call before exiting compute() + + The update() and complete() calls accept field=value keyword + arguments which are passed to the called function. complete() + should be called before exiting the GUI function. A KeyboardInterrupt + event is triggered if the GUI signals that the computation should + be halted. + + The following documentation should be included in the description + of the derived class. + + The user of this class will call the following: :: + + thread = Work(...,kw=...) # prepare the work thread. + thread.queue(...,kw=...) # queue a work unit + thread.requeue(...,kw=...) # replace work unit on the end of queue + thread.reset(...,kw=...) # reset the queue to the given work unit + thread.stop() # clear the queue and halt + thread.interrupt() # halt the current work unit but continue + thread.ready(delay=0.) # request an update signal after delay + thread.isrunning() # returns true if compute() is running + + Use queue() when all work must be done. Use requeue() when intermediate + work items don't need to be done (e.g., in response to a mouse move + event). Use reset() when the current item doesn't need to be completed + before the new event (e.g., in response to a mouse release event). Use + stop() to halt the current and pending computations (e.g., in response to + a stop button). + + The methods queue(), requeue() and reset() are proxies for the compute() + method in the subclass. Look there for a description of the arguments. + The compute() method can be called directly to run the computation in + the main thread, but it should not be called if isrunning() returns true. + + The constructor accepts additional keywords yieldtime=0.01 and + worktime=0.01 which determine the cooperative multitasking + behaviour. Yield time is the duration of the sleep period + required to give other processes a chance to run. Work time + is the duration between sleep periods. + + Notifying the GUI thread of work in progress and work complete + is done with updatefn=updatefn and completefn=completefn arguments + to the constructor. Details of the parameters to the functions + depend on the particular calculation class, but they will all + be passed as keyword arguments. Details of how the functions + should be implemented vary from framework to framework. + + For wx, something like the following is needed:: + + import wx, wx.lib.newevent + (CalcCompleteEvent, EVT_CALC_COMPLETE) = wx.lib.newevent.NewEvent() + + # methods in the main window class of your application + def __init__(): + ... + # Prepare the calculation in the GUI thread. + self.work = Work(completefn=self.CalcComplete) + self.Bind(EVT_CALC_COMPLETE, self.OnCalcComplete) + ... + # Bind work queue to a menu event. + self.Bind(wx.EVT_MENU, self.OnCalcStart, id=idCALCSTART) + ... + + def OnCalcStart(self,event): + # Start the work thread from the GUI thread. + self.work.queue(...work unit parameters...) + + def CalcComplete(self,**kwargs): + # Generate CalcComplete event in the calculation thread. + # kwargs contains field1, field2, etc. as defined by + # the Work thread class. + event = CalcCompleteEvent(**kwargs) + wx.PostEvent(self, event) + + def OnCalcComplete(self,event): + # Process CalcComplete event in GUI thread. + # Use values from event.field1, event.field2 etc. as + # defined by the Work thread class to show the results. + ... + """ + + def __init__(self, completefn=None, updatefn=None, + yieldtime=0.01, worktime=0.01, + exception_handler=None): + """Prepare the calculator""" + self.yieldtime = yieldtime + self.worktime = worktime + self.completefn = completefn + self.updatefn = updatefn + self.exception_handler = exception_handler + self._interrupting = False + self._running = False + self._queue = [] + self._lock = thread.allocate_lock() + self._delay = 1e6 + + def queue(self,*args,**kwargs): + """Add a work unit to the end of the queue. See the compute() + method for details of the arguments to the work unit.""" + self._lock.acquire() + self._queue.append((args, kwargs)) + # Cannot do start_new_thread call within the lock + self._lock.release() + if not self._running: + self._time_for_update = clock() + 1e6 + thread.start_new_thread(self._run, ()) + + def requeue(self, *args, **kwargs): + """Replace the work unit on the end of the queue. See the compute() + method for details of the arguments to the work unit.""" + self._lock.acquire() + self._queue = self._queue[:-1] + self._lock.release() + self.queue(*args, **kwargs) + + def reset(self, *args, **kwargs): + """Clear the queue and start a new work unit. See the compute() + method for details of the arguments to the work unit.""" + self.stop() + self.queue(*args, **kwargs) + + def stop(self): + """Clear the queue and stop the thread. New items may be + queued after stop. To stop just the current work item, and + continue the rest of the queue call the interrupt method""" + self._lock.acquire() + self._interrupting = True + self._queue = [] + self._lock.release() + + def interrupt(self): + """Stop the current work item. To clear the work queue as + well call the stop() method.""" + self._lock.acquire() + self._interrupting = True + self._lock.release() + + def isrunning(self): + return self._running + + def ready(self, delay=0.): + """Ready for another update after delay=t seconds. Call + this for threads which can show intermediate results from + long calculations.""" + self._delay = delay + self._lock.acquire() + self._time_for_update = clock() + delay + # print "setting _time_for_update to ",self._time_for_update + self._lock.release() + + def isquit(self): + """Check for interrupts. Should be called frequently to + provide user responsiveness. Also yields to other running + threads, which is required for good performance on OS X.""" + + # Only called from within the running thread so no need to lock + if self._running and self.yieldtime > 0 \ + and clock() > self._time_for_nap: + sleep(self.yieldtime) + self._time_for_nap = clock() + self.worktime + if self._interrupting: + raise KeyboardInterrupt + + def update(self, **kwargs): + """Update GUI with the lastest results from the current work unit.""" + if self.updatefn is not None and clock() > self._time_for_update: + self._lock.acquire() + self._time_for_update = clock() + self._delay + self._lock.release() + self._time_for_update += 1e6 # No more updates + + self.updatefn(**kwargs) + sleep(self.yieldtime) + if self._interrupting: + raise KeyboardInterrupt + else: + self.isquit() + return + + def complete(self, **kwargs): + """Update the GUI with the completed results from a work unit.""" + if self.completefn is not None: + self.completefn(**kwargs) + sleep(self.yieldtime) + return + + def compute(self, *args, **kwargs): + """Perform a work unit. The subclass will provide details of + the arguments.""" + raise NotImplemented("Calculation thread needs compute method") + + def exception(self): + """ + An exception occurred during computation, so call the exception handler + if there is one. If not, then log the exception and continue. + """ + # If we have an exception handler, let it try to handle the exception. + # If it fails fall through to log the failure to handle the exception + # (the original exception will be lost). If there is no exception + # handler, just log the exception in compute that we are responding to. + if self.exception_handler: + try: + self.exception_handler(*sys.exc_info()) + return + except Exception: + pass + logger.error(traceback.format_exc()) + #print 'CalcThread exception', + + def _run(self): + """Internal function to manage the thread.""" + # The code for condition wait in the threading package is + # implemented using polling. I'll accept for now that the + # authors of this code are clever enough that polling is + # difficult to avoid. Rather than polling, I will exit the + # thread when the queue is empty and start a new thread when + # there is more work to be done. + while 1: + self._lock.acquire() + self._time_for_nap = clock() + self.worktime + self._running = True + if self._queue == []: + break + self._interrupting = False + args, kwargs = self._queue[0] + self._queue = self._queue[1:] + self._lock.release() + try: + self.compute(*args, **kwargs) + except KeyboardInterrupt: + pass + except: + self.exception() + self._running = False + + +# ====================================================================== +# Demonstration of calcthread in action +class CalcDemo(CalcThread): + """Example of a calculation thread.""" + def compute(self, n): + total = 0. + for i in range(n): + self.update(i=i) + for j in range(n): + self.isquit() + total += j + self.complete(total=total) + + +class CalcCommandline: + """ + Test method + """ + def __init__(self, n=20000): + print(thread.get_ident()) + self.starttime = clock() + self.done = False + self.work = CalcDemo(completefn=self.complete, + updatefn=self.update, yieldtime=0.001) + self.work2 = CalcDemo(completefn=self.complete, + updatefn=self.update) + self.work3 = CalcDemo(completefn=self.complete, + updatefn=self.update) + self.work.queue(n) + self.work2.queue(n) + self.work3.queue(n) + print("Expect updates from Main every second and from thread every 2.5 seconds") + print("") + self.work.ready(.5) + while not self.done: + sleep(1) + print("Main thread %d at %.2f" % (thread.get_ident(), + clock() - self.starttime)) + + def update(self, i=0): + print("Update i=%d from thread %d at %.2f" % (i, thread.get_ident(), + clock() - self.starttime)) + self.work.ready(2.5) + + def complete(self, total=0.0): + print("Complete total=%g from thread %d at %.2f" % (total, + thread.get_ident(), + clock() - self.starttime)) + self.done = True diff --git a/sas/sascalc/data_util/err1d.py b/sas/sascalc/data_util/err1d.py new file mode 100755 index 000000000..7b7f649b9 --- /dev/null +++ b/sas/sascalc/data_util/err1d.py @@ -0,0 +1,155 @@ +# This program is public domain +""" +Error propogation algorithms for simple arithmetic + +Warning: like the underlying numpy library, the inplace operations +may return values of the wrong type if some of the arguments are +integers, so be sure to create them with floating point inputs. +""" +from __future__ import division # Get true division +import numpy as np + + +def div(X, varX, Y, varY): + """Division with error propagation""" + # Direct algorithm: + # Z = X/Y + # varZ = (varX/X**2 + varY/Y**2) * Z**2 + # = (varX + varY * Z**2) / Y**2 + # Indirect algorithm to minimize intermediates + Z = X/Y # truediv => Z is a float + varZ = Z**2 # Z is a float => varZ is a float + varZ *= varY + varZ += varX + T = Y**2 # Doesn't matter if T is float or int + varZ /= T + return Z, varZ + + +def mul(X, varX, Y, varY): + """Multiplication with error propagation""" + # Direct algorithm: + Z = X * Y + varZ = Y**2 * varX + X**2 * varY + # Indirect algorithm won't ensure floating point results + # varZ = Y**2 + # varZ *= varX + # Z = X**2 # Using Z to hold the temporary + # Z *= varY + # varZ += Z + # Z[:] = X + # Z *= Y + return Z, varZ + + +def sub(X, varX, Y, varY): + """Subtraction with error propagation""" + Z = X - Y + varZ = varX + varY + return Z, varZ + + +def add(X, varX, Y, varY): + """Addition with error propagation""" + Z = X + Y + varZ = varX + varY + return Z, varZ + + +def exp(X, varX): + """Exponentiation with error propagation""" + Z = np.exp(X) + varZ = varX * Z**2 + return Z, varZ + + +def log(X, varX): + """Logarithm with error propagation""" + Z = np.log(X) + varZ = varX / X**2 + return Z, varZ + +# Confirm this formula before using it +# def pow(X,varX, Y,varY): +# Z = X**Y +# varZ = (Y**2 * varX/X**2 + varY * np.log(X)**2) * Z**2 +# return Z,varZ +# + + +def pow(X, varX, n): + """X**n with error propagation""" + # Direct algorithm + # Z = X**n + # varZ = n*n * varX/X**2 * Z**2 + # Indirect algorithm to minimize intermediates + Z = X**n + varZ = varX / X + varZ /= X + varZ *= Z + varZ *= Z + varZ *= n**2 + return Z, varZ + + +def div_inplace(X, varX, Y, varY): + """In-place division with error propagation""" + # Z = X/Y + # varZ = (varX + varY * (X/Y)**2) / Y**2 = (varX + varY * Z**2) / Y**2 + X /= Y # X now has Z = X/Y + T = X**2 # create T with Z**2 + T *= varY # T now has varY * Z**2 + varX += T # varX now has varX + varY*Z**2 + del T # may want to use T[:] = Y for vectors + T = Y # reuse T for Y + T **= 2 # T now has Y**2 + varX /= T # varX now has varZ + return X, varX + + +def mul_inplace(X, varX, Y, varY): + """In-place multiplication with error propagation""" + # Z = X * Y + # varZ = Y**2 * varX + X**2 * varY + T = Y**2 # create T with Y**2 + varX *= T # varX now has Y**2 * varX + del T # may want to use T[:] = X for vectors + T = X # reuse T for X**2 * varY + T **=2 # T now has X**2 + T *= varY # T now has X**2 * varY + varX += T # varX now has varZ + X *= Y # X now has Z + return X, varX + + +def sub_inplace(X, varX, Y, varY): + """In-place subtraction with error propagation""" + # Z = X - Y + # varZ = varX + varY + X -= Y + varX += varY + return X, varX + + +def add_inplace(X, varX, Y, varY): + """In-place addition with error propagation""" + # Z = X + Y + # varZ = varX + varY + X += Y + varX += varY + return X, varX + + +def pow_inplace(X, varX, n): + """In-place X**n with error propagation""" + # Direct algorithm + # Z = X**n + # varZ = abs(n) * varX/X**2 * Z**2 + # Indirect algorithm to minimize intermediates + varX /= X + varX /= X # varX now has varX/X**2 + X **= n # X now has Z = X**n + varX *= X + varX *= X # varX now has varX/X**2 * Z**2 + varX *= n**2 # varX now has varZ + return X, varX diff --git a/sas/sascalc/data_util/formatnum.py b/sas/sascalc/data_util/formatnum.py new file mode 100755 index 000000000..debbeccb3 --- /dev/null +++ b/sas/sascalc/data_util/formatnum.py @@ -0,0 +1,441 @@ +# This program is public domain +# Author: Paul Kienzle +""" +Format values and uncertainties nicely for printing. + +:func:`format_uncertainty_pm` produces the expanded format v +/- err. + +:func:`format_uncertainty_compact` produces the compact format v(##), +where the number in parenthesis is the uncertainty in the last two digits of v. + +:func:`format_uncertainty` uses the compact format by default, but this +can be changed to use the expanded +/- format by setting +format_uncertainty.compact to False. + +The formatted string uses only the number of digits warranted by +the uncertainty in the measurement. + +If the uncertainty is 0 or not otherwise provided, the simple +%g floating point format option is used. + +Infinite and indefinite numbers are represented as inf and NaN. + +Example:: + + >>> v,dv = 757.2356,0.01032 + >>> print format_uncertainty_pm(v,dv) + 757.236 +/- 0.010 + >>> print format_uncertainty_compact(v,dv) + 757.236(10) + >>> print format_uncertainty(v,dv) + 757.236(10) + >>> format_uncertainty.compact = False + >>> print format_uncertainty(v,dv) + 757.236 +/- 0.010 + +UncertaintyFormatter() returns a private formatter with its own +formatter.compact flag. +""" +from __future__ import division, print_function + +import math +import numpy as np +__all__ = ['format_uncertainty', 'format_uncertainty_pm', + 'format_uncertainty_compact'] + +# Coordinating scales across a set of numbers is not supported. For easy +# comparison a set of numbers should be shown in the same scale. One could +# force this from the outside by adding scale parameter (either 10**n, n, or +# a string representing the desired SI prefix) and having a separate routine +# which computes the scale given a set of values. + +# Coordinating scales with units offers its own problems. Again, the user +# may want to force particular units. This can be done by outside of the +# formatting routines by scaling the numbers to the appropriate units then +# forcing them to print with scale 10**0. If this is a common operation, +# however, it may want to happen inside. + +# The value e is currently formatted into the number. Alternatively this +# scale factor could be returned so that the user can choose the appropriate +# SI prefix when printing the units. This gets tricky when talking about +# composite units such as 2.3e-3 m**2 -> 2300 mm**2, and with volumes +# such as 1 g/cm**3 -> 1 kg/L. + + +def format_uncertainty_pm(value, uncertainty): + """ + Given *value* v and *uncertainty* dv, return a string v +/- dv. + """ + return _format_uncertainty(value, uncertainty, compact=False) + + +def format_uncertainty_compact(value, uncertainty): + """ + Given *value* v and *uncertainty* dv, return the compact + representation v(##), where ## are the first two digits of + the uncertainty. + """ + return _format_uncertainty(value, uncertainty, compact=True) + + +class UncertaintyFormatter: + """ + Value and uncertainty formatter. + + The *formatter* instance will use either the expanded v +/- dv form + or the compact v(##) form depending on whether *formatter.compact* is + True or False. The default is True. + """ + compact = True + + def __call__(self, value, uncertainty): + """ + Given *value* and *uncertainty*, return a string representation. + """ + return _format_uncertainty(value, uncertainty, self.compact) +format_uncertainty = UncertaintyFormatter() + + +def _format_uncertainty(value, uncertainty, compact): + """ + Implementation of both the compact and the +/- formats. + """ + # Handle indefinite value + if np.isinf(value): + return "inf" if value > 0 else "-inf" + if np.isnan(value): + return "NaN" + + # Handle indefinite uncertainty + if uncertainty is None or uncertainty <= 0 or np.isnan(uncertainty): + return "%g" % value + if np.isinf(uncertainty): + if compact: + return "%.2g(inf)" % value + else: + return "%.2g +/- inf" % value + + # Handle zero and negative values + sign = "-" if value < 0 else "" + value = abs(value) + + # Determine scale of value and error + err_place = int(math.floor(math.log10(uncertainty))) + if value == 0: + val_place = err_place - 1 + else: + val_place = int(math.floor(math.log10(value))) + + if err_place > val_place: + # Degenerate case: error bigger than value + # The mantissa is 0.#(##)e#, 0.0#(##)e# or 0.00#(##)e# + val_place = err_place + 2 + elif err_place == val_place: + # Degenerate case: error and value the same order of magnitude + # The value is ##(##)e#, #.#(##)e# or 0.##(##)e# + val_place = err_place + 1 + elif err_place <= 1 and val_place >= -3: + # Normal case: nice numbers and errors + # The value is ###.###(##) + val_place = 0 + else: + # Extreme cases: zeros before value or after error + # The value is ###.###(##)e#, ##.####(##)e# or #.#####(##)e# + pass + + # Force engineering notation, with exponent a multiple of 3 + val_place = int(math.floor(val_place / 3.)) * 3 + + # Format the result + digits_after_decimal = abs(val_place - err_place + 1) + val_str = "%.*f" % (digits_after_decimal, value / 10.**val_place) + exp_str = "e%d" % val_place if val_place != 0 else "" + if compact: + err_str = "(%2d)" % int(uncertainty / 10.**(err_place - 1) + 0.5) + result = "".join((sign, val_str, err_str, exp_str)) + else: + err_str = "%.*f" % (digits_after_decimal, uncertainty / 10.**val_place) + result = "".join((sign, val_str, exp_str + " +/- ", err_str, exp_str)) + return result + + +def test_compact(): + # Oops... renamed function after writing tests + value_str = format_uncertainty_compact + + # val_place > err_place + assert value_str(1235670,766000) == "1.24(77)e6" + assert value_str(123567.,76600) == "124(77)e3" + assert value_str(12356.7,7660) == "12.4(77)e3" + assert value_str(1235.67,766) == "1.24(77)e3" + assert value_str(123.567,76.6) == "124(77)" + assert value_str(12.3567,7.66) == "12.4(77)" + assert value_str(1.23567,.766) == "1.24(77)" + assert value_str(.123567,.0766) == "0.124(77)" + assert value_str(.0123567,.00766) == "0.0124(77)" + assert value_str(.00123567,.000766) == "0.00124(77)" + assert value_str(.000123567,.0000766) == "124(77)e-6" + assert value_str(.0000123567,.00000766) == "12.4(77)e-6" + assert value_str(.00000123567,.000000766) == "1.24(77)e-6" + assert value_str(.000000123567,.0000000766) == "124(77)e-9" + assert value_str(.00000123567,.0000000766) == "1.236(77)e-6" + assert value_str(.0000123567,.0000000766) == "12.357(77)e-6" + assert value_str(.000123567,.0000000766) == "123.567(77)e-6" + assert value_str(.00123567,.000000766) == "0.00123567(77)" + assert value_str(.0123567,.00000766) == "0.0123567(77)" + assert value_str(.123567,.0000766) == "0.123567(77)" + assert value_str(1.23567,.000766) == "1.23567(77)" + assert value_str(12.3567,.00766) == "12.3567(77)" + assert value_str(123.567,.0764) == "123.567(76)" + assert value_str(1235.67,.764) == "1235.67(76)" + assert value_str(12356.7,7.64) == "12356.7(76)" + assert value_str(123567,76.4) == "123567(76)" + assert value_str(1235670,764) == "1.23567(76)e6" + assert value_str(12356700,764) == "12.35670(76)e6" + assert value_str(123567000,764) == "123.56700(76)e6" + assert value_str(123567000,7640) == "123.5670(76)e6" + assert value_str(1235670000,76400) == "1.235670(76)e9" + + # val_place == err_place + assert value_str(123567,764000) == "0.12(76)e6" + assert value_str(12356.7,76400) == "12(76)e3" + assert value_str(1235.67,7640) == "1.2(76)e3" + assert value_str(123.567,764) == "0.12(76)e3" + assert value_str(12.3567,76.4) == "12(76)" + assert value_str(1.23567,7.64) == "1.2(76)" + assert value_str(.123567,.764) == "0.12(76)" + assert value_str(.0123567,.0764) == "12(76)e-3" + assert value_str(.00123567,.00764) == "1.2(76)e-3" + assert value_str(.000123567,.000764) == "0.12(76)e-3" + + # val_place == err_place-1 + assert value_str(123567,7640000) == "0.1(76)e6" + assert value_str(12356.7,764000) == "0.01(76)e6" + assert value_str(1235.67,76400) == "0.001(76)e6" + assert value_str(123.567,7640) == "0.1(76)e3" + assert value_str(12.3567,764) == "0.01(76)e3" + assert value_str(1.23567,76.4) == "0.001(76)e3" + assert value_str(.123567,7.64) == "0.1(76)" + assert value_str(.0123567,.764) == "0.01(76)" + assert value_str(.00123567,.0764) == "0.001(76)" + assert value_str(.000123567,.00764) == "0.1(76)e-3" + + # val_place == err_place-2 + assert value_str(12356700,7640000000) == "0.0(76)e9" + assert value_str(1235670,764000000) == "0.00(76)e9" + assert value_str(123567,76400000) == "0.000(76)e9" + assert value_str(12356,7640000) == "0.0(76)e6" + assert value_str(1235,764000) == "0.00(76)e6" + assert value_str(123,76400) == "0.000(76)e6" + assert value_str(12,7640) == "0.0(76)e3" + assert value_str(1,764) == "0.00(76)e3" + assert value_str(0.1,76.4) == "0.000(76)e3" + assert value_str(0.01,7.64) == "0.0(76)" + assert value_str(0.001,0.764) == "0.00(76)" + assert value_str(0.0001,0.0764) == "0.000(76)" + assert value_str(0.00001,0.00764) == "0.0(76)e-3" + + # val_place == err_place-3 + assert value_str(12356700,76400000000) == "0.000(76)e12" + assert value_str(1235670,7640000000) == "0.0(76)e9" + assert value_str(123567,764000000) == "0.00(76)e9" + assert value_str(12356,76400000) == "0.000(76)e9" + assert value_str(1235,7640000) == "0.0(76)e6" + assert value_str(123,764000) == "0.00(76)e6" + assert value_str(12,76400) == "0.000(76)e6" + assert value_str(1,7640) == "0.0(76)e3" + assert value_str(0.1,764) == "0.00(76)e3" + assert value_str(0.01,76.4) == "0.000(76)e3" + assert value_str(0.001,7.64) == "0.0(76)" + assert value_str(0.0001,0.764) == "0.00(76)" + assert value_str(0.00001,0.0764) == "0.000(76)" + assert value_str(0.000001,0.00764) == "0.0(76)e-3" + + # Zero values + assert value_str(0,7640000) == "0.0(76)e6" + assert value_str(0, 764000) == "0.00(76)e6" + assert value_str(0, 76400) == "0.000(76)e6" + assert value_str(0, 7640) == "0.0(76)e3" + assert value_str(0, 764) == "0.00(76)e3" + assert value_str(0, 76.4) == "0.000(76)e3" + assert value_str(0, 7.64) == "0.0(76)" + assert value_str(0, 0.764) == "0.00(76)" + assert value_str(0, 0.0764) == "0.000(76)" + assert value_str(0, 0.00764) == "0.0(76)e-3" + assert value_str(0, 0.000764) == "0.00(76)e-3" + assert value_str(0, 0.0000764) == "0.000(76)e-3" + + # negative values + assert value_str(-1235670,765000) == "-1.24(77)e6" + assert value_str(-1.23567,.766) == "-1.24(77)" + assert value_str(-.00000123567,.0000000766) == "-1.236(77)e-6" + assert value_str(-12356.7,7.64) == "-12356.7(76)" + assert value_str(-123.567,764) == "-0.12(76)e3" + assert value_str(-1235.67,76400) == "-0.001(76)e6" + assert value_str(-.000123567,.00764) == "-0.1(76)e-3" + assert value_str(-12356,7640000) == "-0.0(76)e6" + assert value_str(-12,76400) == "-0.000(76)e6" + assert value_str(-0.0001,0.764) == "-0.00(76)" + + # non-finite values + assert value_str(-np.inf,None) == "-inf" + assert value_str(np.inf,None) == "inf" + assert value_str(np.NaN,None) == "NaN" + + # bad or missing uncertainty + assert value_str(-1.23567,np.NaN) == "-1.23567" + assert value_str(-1.23567,-np.inf) == "-1.23567" + assert value_str(-1.23567,-0.1) == "-1.23567" + assert value_str(-1.23567,0) == "-1.23567" + assert value_str(-1.23567,None) == "-1.23567" + assert value_str(-1.23567,np.inf) == "-1.2(inf)" + +def test_pm(): + # Oops... renamed function after writing tests + value_str = format_uncertainty_pm + + # val_place > err_place + assert value_str(1235670,766000) == "1.24e6 +/- 0.77e6" + assert value_str(123567., 76600) == "124e3 +/- 77e3" + assert value_str(12356.7, 7660) == "12.4e3 +/- 7.7e3" + assert value_str(1235.67, 766) == "1.24e3 +/- 0.77e3" + assert value_str(123.567, 76.6) == "124 +/- 77" + assert value_str(12.3567, 7.66) == "12.4 +/- 7.7" + assert value_str(1.23567, .766) == "1.24 +/- 0.77" + assert value_str(.123567, .0766) == "0.124 +/- 0.077" + assert value_str(.0123567, .00766) == "0.0124 +/- 0.0077" + assert value_str(.00123567, .000766) == "0.00124 +/- 0.00077" + assert value_str(.000123567, .0000766) == "124e-6 +/- 77e-6" + assert value_str(.0000123567, .00000766) == "12.4e-6 +/- 7.7e-6" + assert value_str(.00000123567, .000000766) == "1.24e-6 +/- 0.77e-6" + assert value_str(.000000123567,.0000000766) == "124e-9 +/- 77e-9" + assert value_str(.00000123567, .0000000766) == "1.236e-6 +/- 0.077e-6" + assert value_str(.0000123567, .0000000766) == "12.357e-6 +/- 0.077e-6" + assert value_str(.000123567, .0000000766) == "123.567e-6 +/- 0.077e-6" + assert value_str(.00123567, .000000766) == "0.00123567 +/- 0.00000077" + assert value_str(.0123567, .00000766) == "0.0123567 +/- 0.0000077" + assert value_str(.123567, .0000766) == "0.123567 +/- 0.000077" + assert value_str(1.23567, .000766) == "1.23567 +/- 0.00077" + assert value_str(12.3567, .00766) == "12.3567 +/- 0.0077" + assert value_str(123.567, .0764) == "123.567 +/- 0.076" + assert value_str(1235.67, .764) == "1235.67 +/- 0.76" + assert value_str(12356.7, 7.64) == "12356.7 +/- 7.6" + assert value_str(123567, 76.4) == "123567 +/- 76" + assert value_str(1235670, 764) == "1.23567e6 +/- 0.00076e6" + assert value_str(12356700, 764) == "12.35670e6 +/- 0.00076e6" + assert value_str(123567000, 764) == "123.56700e6 +/- 0.00076e6" + assert value_str(123567000,7640) == "123.5670e6 +/- 0.0076e6" + assert value_str(1235670000,76400) == "1.235670e9 +/- 0.000076e9" + + # val_place == err_place + assert value_str(123567,764000) == "0.12e6 +/- 0.76e6" + assert value_str(12356.7,76400) == "12e3 +/- 76e3" + assert value_str(1235.67,7640) == "1.2e3 +/- 7.6e3" + assert value_str(123.567,764) == "0.12e3 +/- 0.76e3" + assert value_str(12.3567,76.4) == "12 +/- 76" + assert value_str(1.23567,7.64) == "1.2 +/- 7.6" + assert value_str(.123567,.764) == "0.12 +/- 0.76" + assert value_str(.0123567,.0764) == "12e-3 +/- 76e-3" + assert value_str(.00123567,.00764) == "1.2e-3 +/- 7.6e-3" + assert value_str(.000123567,.000764) == "0.12e-3 +/- 0.76e-3" + + # val_place == err_place-1 + assert value_str(123567,7640000) == "0.1e6 +/- 7.6e6" + assert value_str(12356.7,764000) == "0.01e6 +/- 0.76e6" + assert value_str(1235.67,76400) == "0.001e6 +/- 0.076e6" + assert value_str(123.567,7640) == "0.1e3 +/- 7.6e3" + assert value_str(12.3567,764) == "0.01e3 +/- 0.76e3" + assert value_str(1.23567,76.4) == "0.001e3 +/- 0.076e3" + assert value_str(.123567,7.64) == "0.1 +/- 7.6" + assert value_str(.0123567,.764) == "0.01 +/- 0.76" + assert value_str(.00123567,.0764) == "0.001 +/- 0.076" + assert value_str(.000123567,.00764) == "0.1e-3 +/- 7.6e-3" + + # val_place == err_place-2 + assert value_str(12356700,7640000000) == "0.0e9 +/- 7.6e9" + assert value_str(1235670,764000000) == "0.00e9 +/- 0.76e9" + assert value_str(123567,76400000) == "0.000e9 +/- 0.076e9" + assert value_str(12356,7640000) == "0.0e6 +/- 7.6e6" + assert value_str(1235,764000) == "0.00e6 +/- 0.76e6" + assert value_str(123,76400) == "0.000e6 +/- 0.076e6" + assert value_str(12,7640) == "0.0e3 +/- 7.6e3" + assert value_str(1,764) == "0.00e3 +/- 0.76e3" + assert value_str(0.1,76.4) == "0.000e3 +/- 0.076e3" + assert value_str(0.01,7.64) == "0.0 +/- 7.6" + assert value_str(0.001,0.764) == "0.00 +/- 0.76" + assert value_str(0.0001,0.0764) == "0.000 +/- 0.076" + assert value_str(0.00001,0.00764) == "0.0e-3 +/- 7.6e-3" + + # val_place == err_place-3 + assert value_str(12356700,76400000000) == "0.000e12 +/- 0.076e12" + assert value_str(1235670,7640000000) == "0.0e9 +/- 7.6e9" + assert value_str(123567,764000000) == "0.00e9 +/- 0.76e9" + assert value_str(12356,76400000) == "0.000e9 +/- 0.076e9" + assert value_str(1235,7640000) == "0.0e6 +/- 7.6e6" + assert value_str(123,764000) == "0.00e6 +/- 0.76e6" + assert value_str(12,76400) == "0.000e6 +/- 0.076e6" + assert value_str(1,7640) == "0.0e3 +/- 7.6e3" + assert value_str(0.1,764) == "0.00e3 +/- 0.76e3" + assert value_str(0.01,76.4) == "0.000e3 +/- 0.076e3" + assert value_str(0.001,7.64) == "0.0 +/- 7.6" + assert value_str(0.0001,0.764) == "0.00 +/- 0.76" + assert value_str(0.00001,0.0764) == "0.000 +/- 0.076" + assert value_str(0.000001,0.00764) == "0.0e-3 +/- 7.6e-3" + + # Zero values + assert value_str(0,7640000) == "0.0e6 +/- 7.6e6" + assert value_str(0, 764000) == "0.00e6 +/- 0.76e6" + assert value_str(0, 76400) == "0.000e6 +/- 0.076e6" + assert value_str(0, 7640) == "0.0e3 +/- 7.6e3" + assert value_str(0, 764) == "0.00e3 +/- 0.76e3" + assert value_str(0, 76.4) == "0.000e3 +/- 0.076e3" + assert value_str(0, 7.64) == "0.0 +/- 7.6" + assert value_str(0, 0.764) == "0.00 +/- 0.76" + assert value_str(0, 0.0764) == "0.000 +/- 0.076" + assert value_str(0, 0.00764) == "0.0e-3 +/- 7.6e-3" + assert value_str(0, 0.000764) == "0.00e-3 +/- 0.76e-3" + assert value_str(0, 0.0000764) == "0.000e-3 +/- 0.076e-3" + + # negative values + assert value_str(-1235670,766000) == "-1.24e6 +/- 0.77e6" + assert value_str(-1.23567,.766) == "-1.24 +/- 0.77" + assert value_str(-.00000123567,.0000000766) == "-1.236e-6 +/- 0.077e-6" + assert value_str(-12356.7,7.64) == "-12356.7 +/- 7.6" + assert value_str(-123.567,764) == "-0.12e3 +/- 0.76e3" + assert value_str(-1235.67,76400) == "-0.001e6 +/- 0.076e6" + assert value_str(-.000123567,.00764) == "-0.1e-3 +/- 7.6e-3" + assert value_str(-12356,7640000) == "-0.0e6 +/- 7.6e6" + assert value_str(-12,76400) == "-0.000e6 +/- 0.076e6" + assert value_str(-0.0001,0.764) == "-0.00 +/- 0.76" + + # non-finite values + assert value_str(-np.inf,None) == "-inf" + assert value_str(np.inf,None) == "inf" + assert value_str(np.NaN,None) == "NaN" + + # bad or missing uncertainty + assert value_str(-1.23567,np.NaN) == "-1.23567" + assert value_str(-1.23567,-np.inf) == "-1.23567" + assert value_str(-1.23567,-0.1) == "-1.23567" + assert value_str(-1.23567,0) == "-1.23567" + assert value_str(-1.23567,None) == "-1.23567" + assert value_str(-1.23567,np.inf) == "-1.2 +/- inf" + +def test_default(): + # Check that the default is the compact format + assert format_uncertainty(-1.23567,0.766) == "-1.24(77)" + +def main(): + """ + Run all tests. + + This is equivalent to "nosetests --with-doctest" + """ + test_compact() + test_pm() + test_default() + + import doctest + doctest.testmod() + +if __name__ == "__main__": main() diff --git a/sas/sascalc/data_util/nxsunit.py b/sas/sascalc/data_util/nxsunit.py new file mode 100755 index 000000000..9fa80b6f5 --- /dev/null +++ b/sas/sascalc/data_util/nxsunit.py @@ -0,0 +1,216 @@ +# This program is public domain +# Author: Paul Kienzle +""" +Define unit conversion support for NeXus style units. + +The unit format is somewhat complicated. There are variant spellings +and incorrect capitalization to worry about, as well as forms such as +"mili*metre" and "1e-7 seconds". + +This is a minimal implementation of units including only what I happen to +need now. It does not support the complete dimensional analysis provided +by the package udunits on which NeXus is based, or even the units used +in the NeXus definition files. + +Unlike other units packages, this package does not carry the units along with +the value but merely provides a conversion function for transforming values. + +Usage example:: + + import nxsunit + u = nxsunit.Converter('mili*metre') # Units stored in mm + v = u(3000,'m') # Convert the value 3000 mm into meters + +NeXus example:: + + # Load sample orientation in radians regardless of how it is stored. + # 1. Open the path + file.openpath('/entry1/sample/sample_orientation') + # 2. scan the attributes, retrieving 'units' + units = [for attr,value in file.attrs() if attr == 'units'] + # 3. set up the converter (assumes that units actually exists) + u = nxsunit.Converter(units[0]) + # 4. read the data and convert to the correct units + v = u(file.read(),'radians') + +This is a standalone module, not relying on either DANSE or NeXus, and +can be used for other unit conversion tasks. + +Note: minutes are used for angle and seconds are used for time. We +cannot tell what the correct interpretation is without knowing something +about the fields themselves. If this becomes an issue, we will need to +allow the application to set the dimension for the unit rather than +inferring the dimension from an example unit. +""" + +# TODO: Add udunits to NAPI rather than reimplementing it in python +# TODO: Alternatively, parse the udunits database directly +# UDUnits: +# http://www.unidata.ucar.edu/software/udunits/udunits-1/udunits.txt + +from __future__ import division +import math + +__all__ = ['Converter'] + +# Limited form of units for returning objects of a specific type. +# Maybe want to do full units handling with e.g., pyre's +# unit class. For now lets keep it simple. Note that +def _build_metric_units(unit,abbr): + """ + Construct standard SI names for the given unit. + Builds e.g., + s, ns + second, nanosecond, nano*second + seconds, nanoseconds + Includes prefixes for femto through peta. + + Ack! Allows, e.g., Coulomb and coulomb even though Coulomb is not + a unit because some NeXus files store it that way! + + Returns a dictionary of names and scales. + """ + prefix = dict(peta=1e15,tera=1e12,giga=1e9,mega=1e6,kilo=1e3, + deci=1e-1,centi=1e-2,milli=1e-3,mili=1e-3,micro=1e-6, + nano=1e-9,pico=1e-12,femto=1e-15) + short_prefix = dict(P=1e15,T=1e12,G=1e9,M=1e6,k=1e3, + d=1e-1,c=1e-2,m=1e-3,u=1e-6, + n=1e-9,p=1e-12,f=1e-15) + map = {abbr:1} + map.update([(P+abbr,scale) for (P,scale) in short_prefix.items()]) + for name in [unit,unit.capitalize()]: + map.update({name:1,name+'s':1}) + map.update([(P+name,scale) for (P,scale) in prefix.items()]) + map.update([(P+'*'+name,scale) for (P,scale) in prefix.items()]) + map.update([(P+name+'s',scale) for (P,scale) in prefix.items()]) + return map + +def _build_plural_units(**kw): + """ + Construct names for the given units. Builds singular and plural form. + """ + map = {} + map.update([(name,scale) for name,scale in kw.items()]) + map.update([(name+'s',scale) for name,scale in kw.items()]) + return map + +def _caret_optional(s): + """ + Strip '^' from unit names. + + * WARNING * this will incorrect transform 10^3 to 103. + """ + s.update((k.replace('^',''),v) + for k, v in list(s.items()) + if '^' in k) + +def _build_all_units(): + distance = _build_metric_units('meter','m') + distance.update(_build_metric_units('metre','m')) + distance.update(_build_plural_units(micron=1e-6, Angstrom=1e-10)) + distance.update({'A':1e-10, 'Ang':1e-10}) + + # Note: minutes are used for angle + time = _build_metric_units('second','s') + time.update(_build_plural_units(hour=3600,day=24*3600,week=7*24*3600)) + + # Note: seconds are used for time + angle = _build_plural_units(degree=1, minute=1/60., + arcminute=1/60., arcsecond=1/3600., radian=180/math.pi) + angle.update(deg=1, arcmin=1/60., arcsec=1/3600., rad=180/math.pi) + + frequency = _build_metric_units('hertz','Hz') + frequency.update(_build_metric_units('Hertz','Hz')) + frequency.update(_build_plural_units(rpm=1/60.)) + + # Note: degrees are used for angle + # Note: temperature needs an offset as well as a scale + temperature = _build_metric_units('kelvin','K') + temperature.update(_build_metric_units('Kelvin','K')) + temperature.update(_build_metric_units('Celcius', 'C')) + temperature.update(_build_metric_units('celcius', 'C')) + + charge = _build_metric_units('coulomb','C') + charge.update({'microAmp*hour':0.0036}) + + sld = { '10^-6 Angstrom^-2': 1e-6, 'Angstrom^-2': 1 } + Q = { 'invA': 1, 'invAng': 1, 'invAngstroms': 1, '1/A': 1, + '1/Angstrom': 1, '1/angstrom': 1, 'A^{-1}': 1, 'cm^{-1}': 1e-8, + '10^-3 Angstrom^-1': 1e-3, '1/cm': 1e-8, '1/m': 1e-10, + 'nm^{-1}': 1, 'nm^-1': 0.1, '1/nm': 0.1, 'n_m^-1': 0.1 } + + _caret_optional(sld) + _caret_optional(Q) + + dims = [distance, time, angle, frequency, temperature, charge, sld, Q] + return dims + +class Converter(object): + """ + Unit converter for NeXus style units. + """ + # Define the units, using both American and European spelling. + scalemap = None + scalebase = 1 + dims = _build_all_units() + + # Note: a.u. stands for arbitrary units, which should return the default + # units for that particular dimension. + # Note: don't have support for dimensionless units. + unknown = {None:1, '???':1, '': 1, 'a.u.': 1, 'Counts': 1, 'counts': 1} + + def __init__(self, name): + self.base = name + for map in self.dims: + if name in map: + self.scalemap = map + self.scalebase = self.scalemap[name] + return + if name in self.unknown: + return # default scalemap and scalebase correspond to unknown + else: + raise KeyError("Unknown unit %s"%name) + + def scale(self, units=""): + if units == "" or self.scalemap is None: return 1 + return self.scalebase/self.scalemap[units] + + def __call__(self, value, units=""): + # Note: calculating a*1 rather than simply returning a would produce + # an unnecessary copy of the array, which in the case of the raw + # counts array would be bad. Sometimes copying and other times + # not copying is also bad, but copy on modify semantics isn't + # supported. + if units == "" or self.scalemap is None: return value + try: + return value * (self.scalebase/self.scalemap[units]) + except KeyError: + possible_units = ", ".join(str(k) for k in self.scalemap.keys()) + raise KeyError("%s not in %s"%(units,possible_units)) + +def _check(expect,get): + if expect != get: + raise ValueError("Expected %s but got %s"%(expect, get)) + #print expect,"==",get + +def test(): + _check(1,Converter('n_m^-1')(10,'invA')) # 10 nm^-1 = 1 inv Angstroms + _check(2,Converter('mm')(2000,'m')) # 2000 mm -> 2 m + _check(2.011e10,Converter('1/A')(2.011,"1/m")) # 2.011 1/A -> 2.011 * 10^10 1/m + _check(0.003,Converter('microseconds')(3,units='ms')) # 3 us -> 0.003 ms + _check(45,Converter('nanokelvin')(45)) # 45 nK -> 45 nK + _check(0.5,Converter('seconds')(1800,units='hours')) # 1800 s -> 0.5 hr + _check(123,Converter('a.u.')(123,units='mm')) # arbitrary units always returns the same value + _check(123,Converter('a.u.')(123,units='s')) # arbitrary units always returns the same value + _check(123,Converter('a.u.')(123,units='')) # arbitrary units always returns the same value + try: + Converter('help') + except KeyError: + pass + else: + raise Exception("unknown unit did not raise an error") + + # TODO: more tests + +if __name__ == "__main__": + test() diff --git a/sas/sascalc/data_util/odict.py b/sas/sascalc/data_util/odict.py new file mode 100755 index 000000000..697114879 --- /dev/null +++ b/sas/sascalc/data_util/odict.py @@ -0,0 +1,1399 @@ +# odict.py +# An Ordered Dictionary object +# Copyright (C) 2005 Nicola Larosa, Michael Foord +# E-mail: nico AT tekNico DOT net, fuzzyman AT voidspace DOT org DOT uk + +# This software is licensed under the terms of the BSD license. +# http://www.voidspace.org.uk/python/license.shtml +# Basically you're free to copy, modify, distribute and relicense it, +# So long as you keep a copy of the license with it. + +# Documentation at http://www.voidspace.org.uk/python/odict.html +# For information about bugfixes, updates and support, please join the +# Pythonutils mailing list: +# http://groups.google.com/group/pythonutils/ +# Comments, suggestions and bug reports welcome. + +"""A dict that keeps keys in insertion order""" +from __future__ import generators + +__author__ = ('Nicola Larosa ,' + 'Michael Foord ') + +__docformat__ = "restructuredtext en" + +__revision__ = '$Id: odict.py 58 2008-09-02 14:09:54Z farrowch $' + +__version__ = '0.2.2' + +__all__ = ['OrderedDict', 'SequenceOrderedDict'] + +import sys +INTP_VER = sys.version_info[:2] +if INTP_VER < (2, 2): + raise RuntimeError("Python v.2.2 or later required") + +import types, warnings + +class OrderedDict(dict): + """ + A class of dictionary that keeps the insertion order of keys. + + All appropriate methods return keys, items, or values in an ordered way. + + All normal dictionary methods are available. Update and comparison is + restricted to other OrderedDict objects. + + Various sequence methods are available, including the ability to explicitly + mutate the key ordering. + + __contains__ tests: + + >>> d = OrderedDict(((1, 3),)) + >>> 1 in d + 1 + >>> 4 in d + 0 + + __getitem__ tests: + + >>> OrderedDict(((1, 3), (3, 2), (2, 1)))[2] + 1 + >>> OrderedDict(((1, 3), (3, 2), (2, 1)))[4] + Traceback (most recent call last): + KeyError: 4 + + __len__ tests: + + >>> len(OrderedDict()) + 0 + >>> len(OrderedDict(((1, 3), (3, 2), (2, 1)))) + 3 + + get tests: + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.get(1) + 3 + >>> d.get(4) is None + 1 + >>> d.get(4, 5) + 5 + >>> d + OrderedDict([(1, 3), (3, 2), (2, 1)]) + + has_key tests: + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.has_key(1) + 1 + >>> d.has_key(4) + 0 + """ + + def __init__(self, init_val=(), strict=False): + """ + Create a new ordered dictionary. Cannot init from a normal dict, + nor from kwargs, since items order is undefined in those cases. + + If the ``strict`` keyword argument is ``True`` (``False`` is the + default) then when doing slice assignment - the ``OrderedDict`` you are + assigning from *must not* contain any keys in the remaining dict. + + >>> OrderedDict() + OrderedDict([]) + >>> OrderedDict({1: 1}) + Traceback (most recent call last): + TypeError: undefined order, cannot get items from dict + >>> OrderedDict({1: 1}.items()) + OrderedDict([(1, 1)]) + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d + OrderedDict([(1, 3), (3, 2), (2, 1)]) + >>> OrderedDict(d) + OrderedDict([(1, 3), (3, 2), (2, 1)]) + """ + self.strict = strict + dict.__init__(self) + if isinstance(init_val, OrderedDict): + self._sequence = init_val.keys() + dict.update(self, init_val) + elif isinstance(init_val, dict): + # we lose compatibility with other ordered dict types this way + raise TypeError('undefined order, cannot get items from dict') + else: + self._sequence = [] + self.update(init_val) + +### Special methods ### + + def __delitem__(self, key): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> del d[3] + >>> d + OrderedDict([(1, 3), (2, 1)]) + >>> del d[3] + Traceback (most recent call last): + KeyError: 3 + >>> d[3] = 2 + >>> d + OrderedDict([(1, 3), (2, 1), (3, 2)]) + >>> del d[0:1] + >>> d + OrderedDict([(2, 1), (3, 2)]) + """ + if isinstance(key, types.SliceType): + # FIXME: efficiency? + keys = self._sequence[key] + for entry in keys: + dict.__delitem__(self, entry) + del self._sequence[key] + else: + # do the dict.__delitem__ *first* as it raises + # the more appropriate error + dict.__delitem__(self, key) + self._sequence.remove(key) + + def __eq__(self, other): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d == OrderedDict(d) + True + >>> d == OrderedDict(((1, 3), (2, 1), (3, 2))) + False + >>> d == OrderedDict(((1, 0), (3, 2), (2, 1))) + False + >>> d == OrderedDict(((0, 3), (3, 2), (2, 1))) + False + >>> d == dict(d) + False + >>> d == False + False + """ + if isinstance(other, OrderedDict): + # FIXME: efficiency? + # Generate both item lists for each compare + return (self.items() == other.items()) + else: + return False + + def __lt__(self, other): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> c = OrderedDict(((0, 3), (3, 2), (2, 1))) + >>> c < d + True + >>> d < c + False + >>> d < dict(c) + Traceback (most recent call last): + TypeError: Can only compare with other OrderedDicts + """ + if not isinstance(other, OrderedDict): + raise TypeError('Can only compare with other OrderedDicts') + # FIXME: efficiency? + # Generate both item lists for each compare + return (self.items() < other.items()) + + def __le__(self, other): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> c = OrderedDict(((0, 3), (3, 2), (2, 1))) + >>> e = OrderedDict(d) + >>> c <= d + True + >>> d <= c + False + >>> d <= dict(c) + Traceback (most recent call last): + TypeError: Can only compare with other OrderedDicts + >>> d <= e + True + """ + if not isinstance(other, OrderedDict): + raise TypeError('Can only compare with other OrderedDicts') + # FIXME: efficiency? + # Generate both item lists for each compare + return (self.items() <= other.items()) + + def __ne__(self, other): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d != OrderedDict(d) + False + >>> d != OrderedDict(((1, 3), (2, 1), (3, 2))) + True + >>> d != OrderedDict(((1, 0), (3, 2), (2, 1))) + True + >>> d == OrderedDict(((0, 3), (3, 2), (2, 1))) + False + >>> d != dict(d) + True + >>> d != False + True + """ + if isinstance(other, OrderedDict): + # FIXME: efficiency? + # Generate both item lists for each compare + return not (self.items() == other.items()) + else: + return True + + def __gt__(self, other): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> c = OrderedDict(((0, 3), (3, 2), (2, 1))) + >>> d > c + True + >>> c > d + False + >>> d > dict(c) + Traceback (most recent call last): + TypeError: Can only compare with other OrderedDicts + """ + if not isinstance(other, OrderedDict): + raise TypeError('Can only compare with other OrderedDicts') + # FIXME: efficiency? + # Generate both item lists for each compare + return (self.items() > other.items()) + + def __ge__(self, other): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> c = OrderedDict(((0, 3), (3, 2), (2, 1))) + >>> e = OrderedDict(d) + >>> c >= d + False + >>> d >= c + True + >>> d >= dict(c) + Traceback (most recent call last): + TypeError: Can only compare with other OrderedDicts + >>> e >= d + True + """ + if not isinstance(other, OrderedDict): + raise TypeError('Can only compare with other OrderedDicts') + # FIXME: efficiency? + # Generate both item lists for each compare + return (self.items() >= other.items()) + + def __repr__(self): + """ + Used for __repr__ and __str__ + + >>> r1 = repr(OrderedDict((('a', 'b'), ('c', 'd'), ('e', 'f')))) + >>> r1 + "OrderedDict([('a', 'b'), ('c', 'd'), ('e', 'f')])" + >>> r2 = repr(OrderedDict((('a', 'b'), ('e', 'f'), ('c', 'd')))) + >>> r2 + "OrderedDict([('a', 'b'), ('e', 'f'), ('c', 'd')])" + >>> r1 == str(OrderedDict((('a', 'b'), ('c', 'd'), ('e', 'f')))) + True + >>> r2 == str(OrderedDict((('a', 'b'), ('e', 'f'), ('c', 'd')))) + True + """ + return '%s([%s])' % (self.__class__.__name__, ', '.join( + ['(%r, %r)' % (key, self[key]) for key in self._sequence])) + + def __setitem__(self, key, val): + """ + Allows slice assignment, so long as the slice is an OrderedDict + >>> d = OrderedDict() + >>> d['a'] = 'b' + >>> d['b'] = 'a' + >>> d[3] = 12 + >>> d + OrderedDict([('a', 'b'), ('b', 'a'), (3, 12)]) + >>> d[:] = OrderedDict(((1, 2), (2, 3), (3, 4))) + >>> d + OrderedDict([(1, 2), (2, 3), (3, 4)]) + >>> d[::2] = OrderedDict(((7, 8), (9, 10))) + >>> d + OrderedDict([(7, 8), (2, 3), (9, 10)]) + >>> d = OrderedDict(((0, 1), (1, 2), (2, 3), (3, 4))) + >>> d[1:3] = OrderedDict(((1, 2), (5, 6), (7, 8))) + >>> d + OrderedDict([(0, 1), (1, 2), (5, 6), (7, 8), (3, 4)]) + >>> d = OrderedDict(((0, 1), (1, 2), (2, 3), (3, 4)), strict=True) + >>> d[1:3] = OrderedDict(((1, 2), (5, 6), (7, 8))) + >>> d + OrderedDict([(0, 1), (1, 2), (5, 6), (7, 8), (3, 4)]) + + >>> a = OrderedDict(((0, 1), (1, 2), (2, 3)), strict=True) + >>> a[3] = 4 + >>> a + OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a[::1] = OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a + OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a[:2] = OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)]) + Traceback (most recent call last): + ValueError: slice assignment must be from unique keys + >>> a = OrderedDict(((0, 1), (1, 2), (2, 3))) + >>> a[3] = 4 + >>> a + OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a[::1] = OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a + OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a[:2] = OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a + OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a[::-1] = OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a + OrderedDict([(3, 4), (2, 3), (1, 2), (0, 1)]) + + >>> d = OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> d[:1] = 3 + Traceback (most recent call last): + TypeError: slice assignment requires an OrderedDict + + >>> d = OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> d[:1] = OrderedDict([(9, 8)]) + >>> d + OrderedDict([(9, 8), (1, 2), (2, 3), (3, 4)]) + """ + if isinstance(key, types.SliceType): + if not isinstance(val, OrderedDict): + # FIXME: allow a list of tuples? + raise TypeError('slice assignment requires an OrderedDict') + keys = self._sequence[key] + # NOTE: Could use ``range(*key.indices(len(self._sequence)))`` + indexes = range(len(self._sequence))[key] + if key.step is None: + # NOTE: new slice may not be the same size as the one being + # overwritten ! + # NOTE: What is the algorithm for an impossible slice? + # e.g. d[5:3] + pos = key.start or 0 + del self[key] + newkeys = val.keys() + for k in newkeys: + if k in self: + if self.strict: + raise ValueError('slice assignment must be from ' + 'unique keys') + else: + # NOTE: This removes duplicate keys *first* + # so start position might have changed? + del self[k] + self._sequence = (self._sequence[:pos] + newkeys + + self._sequence[pos:]) + dict.update(self, val) + else: + # extended slice - length of new slice must be the same + # as the one being replaced + if len(keys) != len(val): + raise ValueError('attempt to assign sequence of size %s ' + 'to extended slice of size %s' % (len(val), len(keys))) + # FIXME: efficiency? + del self[key] + item_list = zip(indexes, val.items()) + # smallest indexes first - higher indexes not guaranteed to + # exist + item_list.sort() + for pos, (newkey, newval) in item_list: + if self.strict and newkey in self: + raise ValueError('slice assignment must be from unique' + ' keys') + self.insert(pos, newkey, newval) + else: + if key not in self: + self._sequence.append(key) + dict.__setitem__(self, key, val) + + def __getitem__(self, key): + """ + Allows slicing. Returns an OrderedDict if you slice. + >>> b = OrderedDict([(7, 0), (6, 1), (5, 2), (4, 3), (3, 4), (2, 5), (1, 6)]) + >>> b[::-1] + OrderedDict([(1, 6), (2, 5), (3, 4), (4, 3), (5, 2), (6, 1), (7, 0)]) + >>> b[2:5] + OrderedDict([(5, 2), (4, 3), (3, 4)]) + >>> type(b[2:4]) + + """ + if isinstance(key, types.SliceType): + # FIXME: does this raise the error we want? + keys = self._sequence[key] + # FIXME: efficiency? + return OrderedDict([(entry, self[entry]) for entry in keys]) + else: + return dict.__getitem__(self, key) + + __str__ = __repr__ + + def __setattr__(self, name, value): + """ + Implemented so that accesses to ``sequence`` raise a warning and are + diverted to the new ``setkeys`` method. + """ + if name == 'sequence': + warnings.warn('Use of the sequence attribute is deprecated.' + ' Use the keys method instead.', DeprecationWarning) + # NOTE: doesn't return anything + self.setkeys(value) + else: + # FIXME: do we want to allow arbitrary setting of attributes? + # Or do we want to manage it? + object.__setattr__(self, name, value) + + def __getattr__(self, name): + """ + Implemented so that access to ``sequence`` raises a warning. + + >>> d = OrderedDict() + >>> d.sequence + [] + """ + if name == 'sequence': + warnings.warn('Use of the sequence attribute is deprecated.' + ' Use the keys method instead.', DeprecationWarning) + # NOTE: Still (currently) returns a direct reference. Need to + # because code that uses sequence will expect to be able to + # mutate it in place. + return self._sequence + else: + # raise the appropriate error + raise AttributeError("OrderedDict has no '%s' attribute" % name) + + def __deepcopy__(self, memo): + """ + To allow deepcopy to work with OrderedDict. + + >>> from copy import deepcopy + >>> a = OrderedDict([(1, 1), (2, 2), (3, 3)]) + >>> a['test'] = {} + >>> b = deepcopy(a) + >>> b == a + True + >>> b is a + False + >>> a['test'] is b['test'] + False + """ + from copy import deepcopy + return self.__class__(deepcopy(self.items(), memo), self.strict) + + +### Read-only methods ### + + def copy(self): + """ + >>> OrderedDict(((1, 3), (3, 2), (2, 1))).copy() + OrderedDict([(1, 3), (3, 2), (2, 1)]) + """ + return OrderedDict(self) + + def items(self): + """ + ``items`` returns a list of tuples representing all the + ``(key, value)`` pairs in the dictionary. + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.items() + [(1, 3), (3, 2), (2, 1)] + >>> d.clear() + >>> d.items() + [] + """ + return zip(self._sequence, self.values()) + + def keys(self): + """ + Return a list of keys in the ``OrderedDict``. + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.keys() + [1, 3, 2] + """ + return self._sequence[:] + + def values(self, values=None): + """ + Return a list of all the values in the OrderedDict. + + Optionally you can pass in a list of values, which will replace the + current list. The value list must be the same len as the OrderedDict. + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.values() + [3, 2, 1] + """ + return [self[key] for key in self._sequence] + + def iteritems(self): + """ + >>> ii = OrderedDict(((1, 3), (3, 2), (2, 1))).iteritems() + >>> ii.next() + (1, 3) + >>> ii.next() + (3, 2) + >>> ii.next() + (2, 1) + >>> ii.next() + Traceback (most recent call last): + StopIteration + """ + def make_iter(self=self): + keys = self.iterkeys() + while True: + key = keys.next() + yield (key, self[key]) + return make_iter() + + def iterkeys(self): + """ + >>> ii = OrderedDict(((1, 3), (3, 2), (2, 1))).iterkeys() + >>> ii.next() + 1 + >>> ii.next() + 3 + >>> ii.next() + 2 + >>> ii.next() + Traceback (most recent call last): + StopIteration + """ + return iter(self._sequence) + + __iter__ = iterkeys + + def itervalues(self): + """ + >>> iv = OrderedDict(((1, 3), (3, 2), (2, 1))).itervalues() + >>> iv.next() + 3 + >>> iv.next() + 2 + >>> iv.next() + 1 + >>> iv.next() + Traceback (most recent call last): + StopIteration + """ + def make_iter(self=self): + keys = self.iterkeys() + while True: + yield self[keys.next()] + return make_iter() + +### Read-write methods ### + + def clear(self): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.clear() + >>> d + OrderedDict([]) + """ + dict.clear(self) + self._sequence = [] + + def pop(self, key, *args): + """ + No dict.pop in Python 2.2, gotta reimplement it + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.pop(3) + 2 + >>> d + OrderedDict([(1, 3), (2, 1)]) + >>> d.pop(4) + Traceback (most recent call last): + KeyError: 4 + >>> d.pop(4, 0) + 0 + >>> d.pop(4, 0, 1) + Traceback (most recent call last): + TypeError: pop expected at most 2 arguments, got 3 + """ + if len(args) > 1: + raise TypeError('pop expected at most 2 arguments, got %s' % + (len(args) + 1)) + if key in self: + val = self[key] + del self[key] + else: + try: + val = args[0] + except IndexError: + raise KeyError(key) + return val + + def popitem(self, i=-1): + """ + Delete and return an item specified by index, not a random one as in + dict. The index is -1 by default (the last item). + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.popitem() + (2, 1) + >>> d + OrderedDict([(1, 3), (3, 2)]) + >>> d.popitem(0) + (1, 3) + >>> OrderedDict().popitem() + Traceback (most recent call last): + KeyError: 'popitem(): dictionary is empty' + >>> d.popitem(2) + Traceback (most recent call last): + IndexError: popitem(): index 2 not valid + """ + if not self._sequence: + raise KeyError('popitem(): dictionary is empty') + try: + key = self._sequence[i] + except IndexError: + raise IndexError('popitem(): index %s not valid' % i) + return (key, self.pop(key)) + + def setdefault(self, key, defval = None): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.setdefault(1) + 3 + >>> d.setdefault(4) is None + True + >>> d + OrderedDict([(1, 3), (3, 2), (2, 1), (4, None)]) + >>> d.setdefault(5, 0) + 0 + >>> d + OrderedDict([(1, 3), (3, 2), (2, 1), (4, None), (5, 0)]) + """ + if key in self: + return self[key] + else: + self[key] = defval + return defval + + def update(self, from_od): + """ + Update from another OrderedDict or sequence of (key, value) pairs + + >>> d = OrderedDict(((1, 0), (0, 1))) + >>> d.update(OrderedDict(((1, 3), (3, 2), (2, 1)))) + >>> d + OrderedDict([(1, 3), (0, 1), (3, 2), (2, 1)]) + >>> d.update({4: 4}) + Traceback (most recent call last): + TypeError: undefined order, cannot get items from dict + >>> d.update((4, 4)) + Traceback (most recent call last): + TypeError: cannot convert dictionary update sequence element "4" to a 2-item sequence + """ + if isinstance(from_od, OrderedDict): + for key, val in from_od.items(): + self[key] = val + elif isinstance(from_od, dict): + # we lose compatibility with other ordered dict types this way + raise TypeError('undefined order, cannot get items from dict') + else: + # FIXME: efficiency? + # sequence of 2-item sequences, or error + for item in from_od: + try: + key, val = item + except TypeError: + raise TypeError('cannot convert dictionary update' + ' sequence element "%s" to a 2-item sequence' % item) + self[key] = val + + def rename(self, old_key, new_key): + """ + Rename the key for a given value, without modifying sequence order. + + For the case where new_key already exists this raise an exception, + since if new_key exists, it is ambiguous as to what happens to the + associated values, and the position of new_key in the sequence. + + >>> od = OrderedDict() + >>> od['a'] = 1 + >>> od['b'] = 2 + >>> od.items() + [('a', 1), ('b', 2)] + >>> od.rename('b', 'c') + >>> od.items() + [('a', 1), ('c', 2)] + >>> od.rename('c', 'a') + Traceback (most recent call last): + ValueError: New key already exists: 'a' + >>> od.rename('d', 'b') + Traceback (most recent call last): + KeyError: 'd' + """ + if new_key == old_key: + # no-op + return + if new_key in self: + raise ValueError("New key already exists: %r" % new_key) + # rename sequence entry + value = self[old_key] + old_idx = self._sequence.index(old_key) + self._sequence[old_idx] = new_key + # rename internal dict entry + dict.__delitem__(self, old_key) + dict.__setitem__(self, new_key, value) + + def setitems(self, items): + """ + This method allows you to set the items in the dict. + + It takes a list of tuples - of the same sort returned by the ``items`` + method. + + >>> d = OrderedDict() + >>> d.setitems(((3, 1), (2, 3), (1, 2))) + >>> d + OrderedDict([(3, 1), (2, 3), (1, 2)]) + """ + self.clear() + # FIXME: this allows you to pass in an OrderedDict as well :-) + self.update(items) + + def setkeys(self, keys): + """ + ``setkeys`` all ows you to pass in a new list of keys which will + replace the current set. This must contain the same set of keys, but + need not be in the same order. + + If you pass in new keys that don't match, a ``KeyError`` will be + raised. + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.keys() + [1, 3, 2] + >>> d.setkeys((1, 2, 3)) + >>> d + OrderedDict([(1, 3), (2, 1), (3, 2)]) + >>> d.setkeys(['a', 'b', 'c']) + Traceback (most recent call last): + KeyError: 'Keylist is not the same as current keylist.' + """ + # FIXME: Efficiency? (use set for Python 2.4 :-) + # NOTE: list(keys) rather than keys[:] because keys[:] returns + # a tuple, if keys is a tuple. + kcopy = list(keys) + kcopy.sort() + self._sequence.sort() + if kcopy != self._sequence: + raise KeyError('Keylist is not the same as current keylist.') + # NOTE: This makes the _sequence attribute a new object, instead + # of changing it in place. + # FIXME: efficiency? + self._sequence = list(keys) + + def setvalues(self, values): + """ + You can pass in a list of values, which will replace the + current list. The value list must be the same len as the OrderedDict. + + (Or a ``ValueError`` is raised.) + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.setvalues((1, 2, 3)) + >>> d + OrderedDict([(1, 1), (3, 2), (2, 3)]) + >>> d.setvalues([6]) + Traceback (most recent call last): + ValueError: Value list is not the same length as the OrderedDict. + """ + if len(values) != len(self): + # FIXME: correct error to raise? + raise ValueError('Value list is not the same length as the ' + 'OrderedDict.') + self.update(zip(self, values)) + +### Sequence Methods ### + + def index(self, key): + """ + Return the position of the specified key in the OrderedDict. + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.index(3) + 1 + >>> d.index(4) + Traceback (most recent call last): + ValueError: list.index(x): x not in list + """ + return self._sequence.index(key) + + def insert(self, index, key, value): + """ + Takes ``index``, ``key``, and ``value`` as arguments. + + Sets ``key`` to ``value``, so that ``key`` is at position ``index`` in + the OrderedDict. + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.insert(0, 4, 0) + >>> d + OrderedDict([(4, 0), (1, 3), (3, 2), (2, 1)]) + >>> d.insert(0, 2, 1) + >>> d + OrderedDict([(2, 1), (4, 0), (1, 3), (3, 2)]) + >>> d.insert(8, 8, 1) + >>> d + OrderedDict([(2, 1), (4, 0), (1, 3), (3, 2), (8, 1)]) + """ + if key in self: + # FIXME: efficiency? + del self[key] + self._sequence.insert(index, key) + dict.__setitem__(self, key, value) + + def reverse(self): + """ + Reverse the order of the OrderedDict. + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.reverse() + >>> d + OrderedDict([(2, 1), (3, 2), (1, 3)]) + """ + self._sequence.reverse() + + def sort(self, *args, **kwargs): + """ + Sort the key order in the OrderedDict. + + This method takes the same arguments as the ``list.sort`` method on + your version of Python. + + >>> d = OrderedDict(((4, 1), (2, 2), (3, 3), (1, 4))) + >>> d.sort() + >>> d + OrderedDict([(1, 4), (2, 2), (3, 3), (4, 1)]) + """ + self._sequence.sort(*args, **kwargs) + +class Keys(object): + # FIXME: should this object be a subclass of list? + """ + Custom object for accessing the keys of an OrderedDict. + + Can be called like the normal ``OrderedDict.keys`` method, but also + supports indexing and sequence methods. + """ + + def __init__(self, main): + self._main = main + + def __call__(self): + """Pretend to be the keys method.""" + return self._main._keys() + + def __getitem__(self, index): + """Fetch the key at position i.""" + # NOTE: this automatically supports slicing :-) + return self._main._sequence[index] + + def __setitem__(self, index, name): + """ + You cannot assign to keys, but you can do slice assignment to re-order + them. + + You can only do slice assignment if the new set of keys is a reordering + of the original set. + """ + if isinstance(index, types.SliceType): + # FIXME: efficiency? + # check length is the same + indexes = range(len(self._main._sequence))[index] + if len(indexes) != len(name): + raise ValueError('attempt to assign sequence of size %s ' + 'to slice of size %s' % (len(name), len(indexes))) + # check they are the same keys + # FIXME: Use set + old_keys = self._main._sequence[index] + new_keys = list(name) + old_keys.sort() + new_keys.sort() + if old_keys != new_keys: + raise KeyError('Keylist is not the same as current keylist.') + orig_vals = [self._main[k] for k in name] + del self._main[index] + vals = zip(indexes, name, orig_vals) + vals.sort() + for i, k, v in vals: + if self._main.strict and k in self._main: + raise ValueError('slice assignment must be from ' + 'unique keys') + self._main.insert(i, k, v) + else: + raise ValueError('Cannot assign to keys') + + ### following methods pinched from UserList and adapted ### + def __repr__(self): return repr(self._main._sequence) + + # FIXME: do we need to check if we are comparing with another ``Keys`` + # object? (like the __cast method of UserList) + def __lt__(self, other): return self._main._sequence < other + def __le__(self, other): return self._main._sequence <= other + def __eq__(self, other): return self._main._sequence == other + def __ne__(self, other): return self._main._sequence != other + def __gt__(self, other): return self._main._sequence > other + def __ge__(self, other): return self._main._sequence >= other + # FIXME: do we need __cmp__ as well as rich comparisons? + def __cmp__(self, other): return cmp(self._main._sequence, other) + + def __contains__(self, item): return item in self._main._sequence + def __len__(self): return len(self._main._sequence) + def __iter__(self): return self._main.iterkeys() + def count(self, item): return self._main._sequence.count(item) + def index(self, item, *args): return self._main._sequence.index(item, *args) + def reverse(self): self._main._sequence.reverse() + def sort(self, *args, **kwds): self._main._sequence.sort(*args, **kwds) + def __mul__(self, n): return self._main._sequence*n + __rmul__ = __mul__ + def __add__(self, other): return self._main._sequence + other + def __radd__(self, other): return other + self._main._sequence + + ## following methods not implemented for keys ## + def __delitem__(self, i): raise TypeError('Can\'t delete items from keys') + def __iadd__(self, other): raise TypeError('Can\'t add in place to keys') + def __imul__(self, n): raise TypeError('Can\'t multiply keys in place') + def append(self, item): raise TypeError('Can\'t append items to keys') + def insert(self, i, item): raise TypeError('Can\'t insert items into keys') + def pop(self, i=-1): raise TypeError('Can\'t pop items from keys') + def remove(self, item): raise TypeError('Can\'t remove items from keys') + def extend(self, other): raise TypeError('Can\'t extend keys') + +class Items(object): + """ + Custom object for accessing the items of an OrderedDict. + + Can be called like the normal ``OrderedDict.items`` method, but also + supports indexing and sequence methods. + """ + + def __init__(self, main): + self._main = main + + def __call__(self): + """Pretend to be the items method.""" + return self._main._items() + + def __getitem__(self, index): + """Fetch the item at position i.""" + if isinstance(index, types.SliceType): + # fetching a slice returns an OrderedDict + return self._main[index].items() + key = self._main._sequence[index] + return (key, self._main[key]) + + def __setitem__(self, index, item): + """Set item at position i to item.""" + if isinstance(index, types.SliceType): + # NOTE: item must be an iterable (list of tuples) + self._main[index] = OrderedDict(item) + else: + # FIXME: Does this raise a sensible error? + orig = self._main.keys[index] + key, value = item + if self._main.strict and key in self and (key != orig): + raise ValueError('slice assignment must be from ' + 'unique keys') + # delete the current one + del self._main[self._main._sequence[index]] + self._main.insert(index, key, value) + + def __delitem__(self, i): + """Delete the item at position i.""" + key = self._main._sequence[i] + if isinstance(i, types.SliceType): + for k in key: + # FIXME: efficiency? + del self._main[k] + else: + del self._main[key] + + ### following methods pinched from UserList and adapted ### + def __repr__(self): return repr(self._main.items()) + + # FIXME: do we need to check if we are comparing with another ``Items`` + # object? (like the __cast method of UserList) + def __lt__(self, other): return self._main.items() < other + def __le__(self, other): return self._main.items() <= other + def __eq__(self, other): return self._main.items() == other + def __ne__(self, other): return self._main.items() != other + def __gt__(self, other): return self._main.items() > other + def __ge__(self, other): return self._main.items() >= other + def __cmp__(self, other): return cmp(self._main.items(), other) + + def __contains__(self, item): return item in self._main.items() + def __len__(self): return len(self._main._sequence) # easier :-) + def __iter__(self): return self._main.iteritems() + def count(self, item): return self._main.items().count(item) + def index(self, item, *args): return self._main.items().index(item, *args) + def reverse(self): self._main.reverse() + def sort(self, *args, **kwds): self._main.sort(*args, **kwds) + def __mul__(self, n): return self._main.items()*n + __rmul__ = __mul__ + def __add__(self, other): return self._main.items() + other + def __radd__(self, other): return other + self._main.items() + + def append(self, item): + """Add an item to the end.""" + # FIXME: this is only append if the key isn't already present + key, value = item + self._main[key] = value + + def insert(self, i, item): + key, value = item + self._main.insert(i, key, value) + + def pop(self, i=-1): + key = self._main._sequence[i] + return (key, self._main.pop(key)) + + def remove(self, item): + key, value = item + try: + assert value == self._main[key] + except (KeyError, AssertionError): + raise ValueError('ValueError: list.remove(x): x not in list') + else: + del self._main[key] + + def extend(self, other): + # FIXME: is only a true extend if none of the keys already present + for item in other: + key, value = item + self._main[key] = value + + def __iadd__(self, other): + self.extend(other) + + ## following methods not implemented for items ## + + def __imul__(self, n): raise TypeError('Can\'t multiply items in place') + +class Values(object): + """ + Custom object for accessing the values of an OrderedDict. + + Can be called like the normal ``OrderedDict.values`` method, but also + supports indexing and sequence methods. + """ + + def __init__(self, main): + self._main = main + + def __call__(self): + """Pretend to be the values method.""" + return self._main._values() + + def __getitem__(self, index): + """Fetch the value at position i.""" + if isinstance(index, types.SliceType): + return [self._main[key] for key in self._main._sequence[index]] + else: + return self._main[self._main._sequence[index]] + + def __setitem__(self, index, value): + """ + Set the value at position i to value. + + You can only do slice assignment to values if you supply a sequence of + equal length to the slice you are replacing. + """ + if isinstance(index, types.SliceType): + keys = self._main._sequence[index] + if len(keys) != len(value): + raise ValueError('attempt to assign sequence of size %s ' + 'to slice of size %s' % (len(name), len(keys))) + # FIXME: efficiency? Would be better to calculate the indexes + # directly from the slice object + # NOTE: the new keys can collide with existing keys (or even + # contain duplicates) - these will overwrite + for key, val in zip(keys, value): + self._main[key] = val + else: + self._main[self._main._sequence[index]] = value + + ### following methods pinched from UserList and adapted ### + def __repr__(self): return repr(self._main.values()) + + # FIXME: do we need to check if we are comparing with another ``Values`` + # object? (like the __cast method of UserList) + def __lt__(self, other): return self._main.values() < other + def __le__(self, other): return self._main.values() <= other + def __eq__(self, other): return self._main.values() == other + def __ne__(self, other): return self._main.values() != other + def __gt__(self, other): return self._main.values() > other + def __ge__(self, other): return self._main.values() >= other + def __cmp__(self, other): return cmp(self._main.values(), other) + + def __contains__(self, item): return item in self._main.values() + def __len__(self): return len(self._main._sequence) # easier :-) + def __iter__(self): return self._main.itervalues() + def count(self, item): return self._main.values().count(item) + def index(self, item, *args): return self._main.values().index(item, *args) + + def reverse(self): + """Reverse the values""" + vals = self._main.values() + vals.reverse() + # FIXME: efficiency + self[:] = vals + + def sort(self, *args, **kwds): + """Sort the values.""" + vals = self._main.values() + vals.sort(*args, **kwds) + self[:] = vals + + def __mul__(self, n): return self._main.values()*n + __rmul__ = __mul__ + def __add__(self, other): return self._main.values() + other + def __radd__(self, other): return other + self._main.values() + + ## following methods not implemented for values ## + def __delitem__(self, i): raise TypeError('Can\'t delete items from values') + def __iadd__(self, other): raise TypeError('Can\'t add in place to values') + def __imul__(self, n): raise TypeError('Can\'t multiply values in place') + def append(self, item): raise TypeError('Can\'t append items to values') + def insert(self, i, item): raise TypeError('Can\'t insert items into values') + def pop(self, i=-1): raise TypeError('Can\'t pop items from values') + def remove(self, item): raise TypeError('Can\'t remove items from values') + def extend(self, other): raise TypeError('Can\'t extend values') + +class SequenceOrderedDict(OrderedDict): + """ + Experimental version of OrderedDict that has a custom object for ``keys``, + ``values``, and ``items``. + + These are callable sequence objects that work as methods, or can be + manipulated directly as sequences. + + Test for ``keys``, ``items`` and ``values``. + + >>> d = SequenceOrderedDict(((1, 2), (2, 3), (3, 4))) + >>> d + SequenceOrderedDict([(1, 2), (2, 3), (3, 4)]) + >>> d.keys + [1, 2, 3] + >>> d.keys() + [1, 2, 3] + >>> d.setkeys((3, 2, 1)) + >>> d + SequenceOrderedDict([(3, 4), (2, 3), (1, 2)]) + >>> d.setkeys((1, 2, 3)) + >>> d.keys[0] + 1 + >>> d.keys[:] + [1, 2, 3] + >>> d.keys[-1] + 3 + >>> d.keys[-2] + 2 + >>> d.keys[0:2] = [2, 1] + >>> d + SequenceOrderedDict([(2, 3), (1, 2), (3, 4)]) + >>> d.keys.reverse() + >>> d.keys + [3, 1, 2] + >>> d.keys = [1, 2, 3] + >>> d + SequenceOrderedDict([(1, 2), (2, 3), (3, 4)]) + >>> d.keys = [3, 1, 2] + >>> d + SequenceOrderedDict([(3, 4), (1, 2), (2, 3)]) + >>> a = SequenceOrderedDict() + >>> b = SequenceOrderedDict() + >>> a.keys == b.keys + 1 + >>> a['a'] = 3 + >>> a.keys == b.keys + 0 + >>> b['a'] = 3 + >>> a.keys == b.keys + 1 + >>> b['b'] = 3 + >>> a.keys == b.keys + 0 + >>> a.keys > b.keys + 0 + >>> a.keys < b.keys + 1 + >>> 'a' in a.keys + 1 + >>> len(b.keys) + 2 + >>> 'c' in d.keys + 0 + >>> 1 in d.keys + 1 + >>> [v for v in d.keys] + [3, 1, 2] + >>> d.keys.sort() + >>> d.keys + [1, 2, 3] + >>> d = SequenceOrderedDict(((1, 2), (2, 3), (3, 4)), strict=True) + >>> d.keys[::-1] = [1, 2, 3] + >>> d + SequenceOrderedDict([(3, 4), (2, 3), (1, 2)]) + >>> d.keys[:2] + [3, 2] + >>> d.keys[:2] = [1, 3] + Traceback (most recent call last): + KeyError: 'Keylist is not the same as current keylist.' + + >>> d = SequenceOrderedDict(((1, 2), (2, 3), (3, 4))) + >>> d + SequenceOrderedDict([(1, 2), (2, 3), (3, 4)]) + >>> d.values + [2, 3, 4] + >>> d.values() + [2, 3, 4] + >>> d.setvalues((4, 3, 2)) + >>> d + SequenceOrderedDict([(1, 4), (2, 3), (3, 2)]) + >>> d.values[::-1] + [2, 3, 4] + >>> d.values[0] + 4 + >>> d.values[-2] + 3 + >>> del d.values[0] + Traceback (most recent call last): + TypeError: Can't delete items from values + >>> d.values[::2] = [2, 4] + >>> d + SequenceOrderedDict([(1, 2), (2, 3), (3, 4)]) + >>> 7 in d.values + 0 + >>> len(d.values) + 3 + >>> [val for val in d.values] + [2, 3, 4] + >>> d.values[-1] = 2 + >>> d.values.count(2) + 2 + >>> d.values.index(2) + 0 + >>> d.values[-1] = 7 + >>> d.values + [2, 3, 7] + >>> d.values.reverse() + >>> d.values + [7, 3, 2] + >>> d.values.sort() + >>> d.values + [2, 3, 7] + >>> d.values.append('anything') + Traceback (most recent call last): + TypeError: Can't append items to values + >>> d.values = (1, 2, 3) + >>> d + SequenceOrderedDict([(1, 1), (2, 2), (3, 3)]) + + >>> d = SequenceOrderedDict(((1, 2), (2, 3), (3, 4))) + >>> d + SequenceOrderedDict([(1, 2), (2, 3), (3, 4)]) + >>> d.items() + [(1, 2), (2, 3), (3, 4)] + >>> d.setitems([(3, 4), (2 ,3), (1, 2)]) + >>> d + SequenceOrderedDict([(3, 4), (2, 3), (1, 2)]) + >>> d.items[0] + (3, 4) + >>> d.items[:-1] + [(3, 4), (2, 3)] + >>> d.items[1] = (6, 3) + >>> d.items + [(3, 4), (6, 3), (1, 2)] + >>> d.items[1:2] = [(9, 9)] + >>> d + SequenceOrderedDict([(3, 4), (9, 9), (1, 2)]) + >>> del d.items[1:2] + >>> d + SequenceOrderedDict([(3, 4), (1, 2)]) + >>> (3, 4) in d.items + 1 + >>> (4, 3) in d.items + 0 + >>> len(d.items) + 2 + >>> [v for v in d.items] + [(3, 4), (1, 2)] + >>> d.items.count((3, 4)) + 1 + >>> d.items.index((1, 2)) + 1 + >>> d.items.index((2, 1)) + Traceback (most recent call last): + ValueError: list.index(x): x not in list + >>> d.items.reverse() + >>> d.items + [(1, 2), (3, 4)] + >>> d.items.reverse() + >>> d.items.sort() + >>> d.items + [(1, 2), (3, 4)] + >>> d.items.append((5, 6)) + >>> d.items + [(1, 2), (3, 4), (5, 6)] + >>> d.items.insert(0, (0, 0)) + >>> d.items + [(0, 0), (1, 2), (3, 4), (5, 6)] + >>> d.items.insert(-1, (7, 8)) + >>> d.items + [(0, 0), (1, 2), (3, 4), (7, 8), (5, 6)] + >>> d.items.pop() + (5, 6) + >>> d.items + [(0, 0), (1, 2), (3, 4), (7, 8)] + >>> d.items.remove((1, 2)) + >>> d.items + [(0, 0), (3, 4), (7, 8)] + >>> d.items.extend([(1, 2), (5, 6)]) + >>> d.items + [(0, 0), (3, 4), (7, 8), (1, 2), (5, 6)] + """ + + def __init__(self, init_val=(), strict=True): + OrderedDict.__init__(self, init_val, strict=strict) + self._keys = self.keys + self._values = self.values + self._items = self.items + self.keys = Keys(self) + self.values = Values(self) + self.items = Items(self) + self._att_dict = { + 'keys': self.setkeys, + 'items': self.setitems, + 'values': self.setvalues, + } + + def __setattr__(self, name, value): + """Protect keys, items, and values.""" + if not '_att_dict' in self.__dict__: + object.__setattr__(self, name, value) + else: + try: + fun = self._att_dict[name] + except KeyError: + OrderedDict.__setattr__(self, name, value) + else: + fun(value) + +if __name__ == '__main__': + if INTP_VER < (2, 3): + raise RuntimeError("Tests require Python v.2.3 or later") + # turn off warnings for tests + warnings.filterwarnings('ignore') + # run the code tests in doctest format + import doctest + m = sys.modules.get('__main__') + globs = m.__dict__.copy() + globs.update({ + 'INTP_VER': INTP_VER, + }) + doctest.testmod(m, globs=globs) + diff --git a/sas/sascalc/data_util/ordereddict.py b/sas/sascalc/data_util/ordereddict.py new file mode 100755 index 000000000..c7eda05fb --- /dev/null +++ b/sas/sascalc/data_util/ordereddict.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python +"""Backport from python2.7 to python <= 2.6.""" + +from itertools import repeat as _repeat, chain as _chain, starmap as _starmap + +try: + from itertools import izip_longest as _zip_longest +except ImportError: + + from itertools import izip + + def _zip_longest(*args, **kwds): + # izip_longest('ABCD', 'xy', fillvalue='-') --> Ax By C- D- + fillvalue = kwds.get('fillvalue') + def sentinel(counter = ([fillvalue]*(len(args)-1)).pop): + yield counter() # yields the fillvalue, or raises IndexError + fillers = _repeat(fillvalue) + iters = [_chain(it, sentinel(), fillers) for it in args] + try: + for tup in izip(*iters): + yield tup + except IndexError: + pass + +class OrderedDict(dict): + + def __init__(self, *args, **kwds): + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + if not hasattr(self, '_keys'): + self._keys = [] + self.update(*args, **kwds) + + def clear(self): + del self._keys[:] + dict.clear(self) + + def __setitem__(self, key, value): + if key not in self: + self._keys.append(key) + dict.__setitem__(self, key, value) + + def __delitem__(self, key): + dict.__delitem__(self, key) + self._keys.remove(key) + + def __iter__(self): + return iter(self._keys) + + def __reversed__(self): + return reversed(self._keys) + + def popitem(self): + if not self: + raise KeyError('dictionary is empty') + key = self._keys.pop() + value = dict.pop(self, key) + return key, value + + def __reduce__(self): + items = [[k, self[k]] for k in self] + inst_dict = vars(self).copy() + inst_dict.pop('_keys', None) + return (self.__class__, (items,), inst_dict) + + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return default + + def update(self, other=(), **kwds): + if hasattr(other, "keys"): + for key in other.keys(): + self[key] = other[key] + else: + for key, value in other: + self[key] = value + for key, value in kwds.items(): + self[key] = value + + __marker = object() + + def pop(self, key, default=__marker): + try: + value = self[key] + except KeyError: + if default is self.__marker: + raise + return default + else: + del self[key] + return value + + def keys(self): + return list(self) + + def values(self): + return [self[key] for key in self] + + def items(self): + return [(key, self[key]) for key in self] + + def __repr__(self): + if not self: + return '%s()' % (self.__class__.__name__,) + return '%s(%r)' % (self.__class__.__name__, list(self.items())) + + def copy(self): + return self.__class__(self) + + @classmethod + def fromkeys(cls, iterable, value=None): + d = cls() + for key in iterable: + d[key] = value + return d + + def __eq__(self, other): + if isinstance(other, OrderedDict): + return all(p==q for p, q in _zip_longest(self.items(), other.items())) + return dict.__eq__(self, other) + + + +# End class OrderedDict + diff --git a/sas/sascalc/data_util/ordereddicttest.py b/sas/sascalc/data_util/ordereddicttest.py new file mode 100755 index 000000000..f732f7906 --- /dev/null +++ b/sas/sascalc/data_util/ordereddicttest.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python + +from random import shuffle +import copy +import inspect +import pickle +import unittest + + +from ordereddict import OrderedDict + + +class TestOrderedDict(unittest.TestCase): + + def test_init(self): + self.assertRaises(TypeError, OrderedDict, ([('a', 1), ('b', 2)], None)) + # too many args + pairs = [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5)] + self.assertEqual(sorted(OrderedDict(dict(pairs)).items()), pairs) # dict input + self.assertEqual(sorted(OrderedDict(**dict(pairs)).items()), pairs) # kwds input + self.assertEqual(list(OrderedDict(pairs).items()), pairs) # pairs input + self.assertEqual(list(OrderedDict([('a', 1), ('b', 2), ('c', 9), ('d', 4)], + c=3, e=5).items()), pairs) # mixed input + + # make sure no positional args conflict with possible kwdargs + self.assertEqual(inspect.getargspec(OrderedDict.__dict__['__init__'])[0], + ['self']) + + # Make sure that direct calls to __init__ do not clear previous contents + d = OrderedDict([('a', 1), ('b', 2), ('c', 3), ('d', 44), ('e', 55)]) + d.__init__([('e', 5), ('f', 6)], g=7, d=4) + self.assertEqual(list(d.items()), + [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6), ('g', 7)]) + + def test_update(self): + self.assertRaises(TypeError, OrderedDict().update, [('a', 1), ('b', + 2)], None) # too many args + pairs = [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5)] + od = OrderedDict() + od.update(dict(pairs)) + self.assertEqual(sorted(od.items()), pairs) # dict input + od = OrderedDict() + od.update(**dict(pairs)) + self.assertEqual(sorted(od.items()), pairs) # kwds input + od = OrderedDict() + od.update(pairs) + self.assertEqual(list(od.items()), pairs) # pairs input + od = OrderedDict() + od.update([('a', 1), ('b', 2), ('c', 9), ('d', 4)], c=3, e=5) + self.assertEqual(list(od.items()), pairs) # mixed input + + # Make sure that direct calls to update do not clear previous contents + # add that updates items are not moved to the end + d = OrderedDict([('a', 1), ('b', 2), ('c', 3), ('d', 44), ('e', 55)]) + d.update([('e', 5), ('f', 6)], g=7, d=4) + self.assertEqual(list(d.items()), + [('a', 1), ('b', 2), ('c', 3), ('d', 4), ('e', 5), ('f', 6), ('g', 7)]) + + def test_clear(self): + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + shuffle(pairs) + od = OrderedDict(pairs) + self.assertEqual(len(od), len(pairs)) + od.clear() + self.assertEqual(len(od), 0) + + def test_delitem(self): + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + od = OrderedDict(pairs) + del od['a'] + self.assert_('a' not in od) + self.assertRaises(KeyError, od.__delitem__, 'a') + self.assertEqual(list(od.items()), pairs[:2] + pairs[3:]) + + def test_setitem(self): + od = OrderedDict([('d', 1), ('b', 2), ('c', 3), ('a', 4), ('e', 5)]) + od['c'] = 10 # existing element + od['f'] = 20 # new element + self.assertEqual(list(od.items()), + [('d', 1), ('b', 2), ('c', 10), ('a', 4), ('e', 5), ('f', 20)]) + + def test_iterators(self): + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + shuffle(pairs) + od = OrderedDict(pairs) + self.assertEqual(list(od), [t[0] for t in pairs]) + self.assertEqual(list(od.keys()), [t[0] for t in pairs]) + self.assertEqual(list(od.values()), [t[1] for t in pairs]) + self.assertEqual(list(od.items()), pairs) + self.assertEqual(list(reversed(od)), + [t[0] for t in reversed(pairs)]) + + def test_popitem(self): + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + shuffle(pairs) + od = OrderedDict(pairs) + while pairs: + self.assertEqual(od.popitem(), pairs.pop()) + self.assertRaises(KeyError, od.popitem) + self.assertEqual(len(od), 0) + + def test_pop(self): + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + shuffle(pairs) + od = OrderedDict(pairs) + shuffle(pairs) + while pairs: + k, v = pairs.pop() + self.assertEqual(od.pop(k), v) + self.assertRaises(KeyError, od.pop, 'xyz') + self.assertEqual(len(od), 0) + self.assertEqual(od.pop(k, 12345), 12345) + + def test_equality(self): + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + shuffle(pairs) + od1 = OrderedDict(pairs) + od2 = OrderedDict(pairs) + self.assertEqual(od1, od2) # same order implies equality + pairs = pairs[2:] + pairs[:2] + od2 = OrderedDict(pairs) + self.assertNotEqual(od1, od2) # different order implies inequality + # comparison to regular dict is not order sensitive + self.assertEqual(od1, dict(od2)) + self.assertEqual(dict(od2), od1) + # different length implied inequality + self.assertNotEqual(od1, OrderedDict(pairs[:-1])) + + def test_copying(self): + # Check that ordered dicts are copyable, deepcopyable, picklable, + # and have a repr/eval round-trip + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + od = OrderedDict(pairs) + update_test = OrderedDict() + update_test.update(od) + for i, dup in enumerate([ + od.copy(), + copy.copy(od), + copy.deepcopy(od), + pickle.loads(pickle.dumps(od, 0)), + pickle.loads(pickle.dumps(od, 1)), + pickle.loads(pickle.dumps(od, 2)), + pickle.loads(pickle.dumps(od, -1)), + eval(repr(od)), + update_test, + OrderedDict(od), + ]): + self.assert_(dup is not od) + self.assertEquals(dup, od) + self.assertEquals(list(dup.items()), list(od.items())) + self.assertEquals(len(dup), len(od)) + self.assertEquals(type(dup), type(od)) + + def test_repr(self): + od = OrderedDict([('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)]) + self.assertEqual(repr(od), + "OrderedDict([('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)])") + self.assertEqual(eval(repr(od)), od) + self.assertEqual(repr(OrderedDict()), "OrderedDict()") + + def test_setdefault(self): + pairs = [('c', 1), ('b', 2), ('a', 3), ('d', 4), ('e', 5), ('f', 6)] + shuffle(pairs) + od = OrderedDict(pairs) + pair_order = list(od.items()) + self.assertEqual(od.setdefault('a', 10), 3) + # make sure order didn't change + self.assertEqual(list(od.items()), pair_order) + self.assertEqual(od.setdefault('x', 10), 10) + # make sure 'x' is added to the end + self.assertEqual(list(od.items())[-1], ('x', 10)) + + def test_reinsert(self): + # Given insert a, insert b, delete a, re-insert a, + # verify that a is now later than b. + od = OrderedDict() + od['a'] = 1 + od['b'] = 2 + del od['a'] + od['a'] = 1 + self.assertEqual(list(od.items()), [('b', 2), ('a', 1)]) + +if __name__ == "__main__": + + unittest.main() + diff --git a/sas/sascalc/data_util/pathutils.py b/sas/sascalc/data_util/pathutils.py new file mode 100755 index 000000000..5a1385a2b --- /dev/null +++ b/sas/sascalc/data_util/pathutils.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +""" +Utilities for path manipulation. Not to be confused with the pathutils module +from the pythonutils package (http://groups.google.com/group/pythonutils). +""" + +# NOTE: If enough of _that_ pathutils functionality is required, we can switch +# this module for that one. + +# TODO: Make algorithm more robust and complete; consider using abspath. + +__all__ = ['relpath'] + +from os.path import join +from os.path import sep + +def relpath(p1, p2): + """Compute the relative path of p1 with respect to p2.""" + + def commonpath(L1, L2, common=[]): + if len(L1) < 1: return (common, L1, L2) + if len(L2) < 1: return (common, L1, L2) + if L1[0] != L2[0]: return (common, L1, L2) + return commonpath(L1[1:], L2[1:], common=common+[L1[0]]) + + # if the strings are equal, then return "." + if p1 == p2: return "." + (common,L1,L2) = commonpath(p2.split(sep), p1.split(sep)) + # if there is nothing in common, then return an empty string + if not common: return "" + # otherwise, replace the common pieces with "../" (or "..\") + p = [(".."+sep) * len(L1)] + L2 + return join(*p) + +def test(): + p1 = sep.join(["a","b","c","d"]) + p2 = sep.join(["a","b","c1","d1"]) + p3 = sep.join(["a","b","c","d","e"]) + p4 = sep.join(["a","b","c","d1","e"]) + p5 = sep.join(["w","x","y","z"]) + + assert relpath(p1, p1) == "." + assert relpath(p2, p1) == sep.join(["..", "..", "c1", "d1"]) + assert relpath(p3, p1) == "e" + assert relpath(p4, p1) == sep.join(["..", "d1", "e"]) + assert relpath(p5, p1) == "" + +if __name__ == '__main__': + test() diff --git a/sas/sascalc/data_util/registry.py b/sas/sascalc/data_util/registry.py new file mode 100755 index 000000000..f20077f56 --- /dev/null +++ b/sas/sascalc/data_util/registry.py @@ -0,0 +1,150 @@ +""" +File extension registry. + +This provides routines for opening files based on extension, +and registers the built-in file extensions. +""" +from __future__ import print_function + +from sas.sascalc.dataloader.loader_exceptions import NoKnownLoaderException + + +class ExtensionRegistry(object): + """ + Associate a file loader with an extension. + + Note that there may be multiple loaders for the same extension. + + Example: :: + + registry = ExtensionRegistry() + + # Add an association by setting an element + registry['.zip'] = unzip + + # Multiple extensions for one loader + registry['.tgz'] = untar + registry['.tar.gz'] = untar + + # Generic extensions to use after trying more specific extensions; + # these will be checked after the more specific extensions fail. + registry['.gz'] = gunzip + + # Multiple loaders for one extension + registry['.cx'] = cx1 + registry['.cx'] = cx2 + registry['.cx'] = cx3 + + # Show registered extensions + print registry.extensions() + + # Can also register a format name for explicit control from caller + registry['cx3'] = cx3 + print registry.formats() + + # Retrieve loaders for a file name + registry.lookup('hello.cx') -> [cx3,cx2,cx1] + + # Run loader on a filename + registry.load('hello.cx') -> + try: + return cx3('hello.cx') + except: + try: + return cx2('hello.cx') + except: + return cx1('hello.cx') + + # Load in a specific format ignoring extension + registry.load('hello.cx',format='cx3') -> + return cx3('hello.cx') + """ + def __init__(self, **kw): + self.loaders = {} + + def __setitem__(self, ext, loader): + if ext not in self.loaders: + self.loaders[ext] = [] + self.loaders[ext].insert(0,loader) + + def __getitem__(self, ext): + return self.loaders[ext] + + def __contains__(self, ext): + return ext in self.loaders + + def formats(self): + """ + Return a sorted list of the registered formats. + """ + names = [a for a in self.loaders.keys() if not a.startswith('.')] + names.sort() + return names + + def extensions(self): + """ + Return a sorted list of registered extensions. + """ + exts = [a for a in self.loaders.keys() if a.startswith('.')] + exts.sort() + return exts + + def lookup(self, path): + """ + Return the loader associated with the file type of path. + + :param path: Data file path + :raises ValueError: When no loaders are found for the file. + :return: List of available readers for the file extension + """ + # Find matching extensions + extlist = [ext for ext in self.extensions() if path.endswith(ext)] + # Sort matching extensions by decreasing order of length + extlist.sort(key=len) + # Combine loaders for matching extensions into one big list + loaders = [] + for L in [self.loaders[ext] for ext in extlist]: + loaders.extend(L) + # Remove duplicates if they exist + if len(loaders) != len(set(loaders)): + result = [] + for L in loaders: + if L not in result: result.append(L) + loaders = L + # Raise an error if there are no matching extensions + if len(loaders) == 0: + raise ValueError("Unknown file type for "+path) + return loaders + + def load(self, path, format=None): + """ + Call the loader for the file type of path. + + :raises ValueError: if no loader is available. + :raises KeyError: if format is not available. + + May raise a loader-defined exception if loader fails. + """ + loaders = [] + if format is None: + try: + loaders = self.lookup(path) + except ValueError as e: + pass + else: + try: + loaders = self.loaders[format] + except KeyError as e: + pass + last_exc = None + for fn in loaders: + try: + return fn(path) + except Exception as e: + last_exc = e + pass # give other loaders a chance to succeed + # If we get here it is because all loaders failed + if last_exc is not None and len(loaders) != 0: + # If file has associated loader(s) and they;ve failed + raise last_exc + raise NoKnownLoaderException(e.message) # raise generic exception diff --git a/sas/sascalc/data_util/release_notes.txt b/sas/sascalc/data_util/release_notes.txt new file mode 100755 index 000000000..d85e9ddb5 --- /dev/null +++ b/sas/sascalc/data_util/release_notes.txt @@ -0,0 +1,60 @@ +Release Notes +============= + +Package name: data_util 0.1.5 + +1- Version 0.1.5 + - Release date: 28/4/2010 + - Added OrderedDict (backport from 2.7) + - Improved NeXus unit support + - Number/uncertainty formats fixes + + Version 0.1.2 + - Release date: 7/4/2009 + - Added calcthread + + Version 0.1.1 + - Release date: 4/21/2009 + + Version 0.1 + - Release date: 8/1/2008 + - Contains useful data handling utilities from reflectometry group. + +2- Downloading and Installing + + 2.1- System Requirements: + - Python version >= 2.4 should be running on the system + + 2.2- Installing: + - Get the code from svn://danse.us/common/releases/data_util-0.1.1 + - The following modules are required: + * numpy + +3- Known Issues + + 3.1- All systems: + - None + + 3.2- Windows: + - None + + 3.3- Linux: + - None + +4- Troubleshooting + + - None + +5- Frequently Asked Questions + + - None + +6- Other Resources + + - None + + + + + + diff --git a/sas/sascalc/data_util/uncertainty.py b/sas/sascalc/data_util/uncertainty.py new file mode 100755 index 000000000..13e059253 --- /dev/null +++ b/sas/sascalc/data_util/uncertainty.py @@ -0,0 +1,317 @@ +r""" +Uncertainty propagation class for arithmetic, log and exp. + +Based on scalars or numpy vectors, this class allows you to store and +manipulate values+uncertainties, with propagation of gaussian error for +addition, subtraction, multiplication, division, power, exp and log. + +Storage properties are determined by the numbers used to set the value +and uncertainty. Be sure to use floating point uncertainty vectors +for inplace operations since numpy does not do automatic type conversion. +Normal operations can use mixed integer and floating point. In place +operations such as *a \*= b* create at most one extra copy for each operation. +By contrast, *c = a\*b* uses four intermediate vectors, so shouldn't be used +for huge arrays. +""" + +from __future__ import division + +import numpy as np + +from .import err1d +from .formatnum import format_uncertainty + +__all__ = ['Uncertainty'] + +# TODO: rename to Measurement and add support for units? +# TODO: C implementation of *,/,**? +class Uncertainty(object): + # Make standard deviation available + def _getdx(self): return np.sqrt(self.variance) + def _setdx(self,dx): + # Direct operation + # variance = dx**2 + # Indirect operation to avoid temporaries + self.variance[:] = dx + self.variance **= 2 + dx = property(_getdx,_setdx,doc="standard deviation") + + # Constructor + def __init__(self, x, variance=None): + self.x, self.variance = x, variance + + # Numpy array slicing operations + def __len__(self): + return len(self.x) + def __getitem__(self,key): + return Uncertainty(self.x[key],self.variance[key]) + def __setitem__(self,key,value): + self.x[key] = value.x + self.variance[key] = value.variance + def __delitem__(self, key): + del self.x[key] + del self.variance[key] + #def __iter__(self): pass # Not sure we need iter + + # Normal operations: may be of mixed type + def __add__(self, other): + if isinstance(other,Uncertainty): + return Uncertainty(*err1d.add(self.x,self.variance,other.x,other.variance)) + else: + return Uncertainty(self.x+other, self.variance+0) # Force copy + def __sub__(self, other): + if isinstance(other,Uncertainty): + return Uncertainty(*err1d.sub(self.x,self.variance,other.x,other.variance)) + else: + return Uncertainty(self.x-other, self.variance+0) # Force copy + def __mul__(self, other): + if isinstance(other,Uncertainty): + return Uncertainty(*err1d.mul(self.x,self.variance,other.x,other.variance)) + else: + return Uncertainty(self.x*other, self.variance*other**2) + def __truediv__(self, other): + if isinstance(other,Uncertainty): + return Uncertainty(*err1d.div(self.x,self.variance,other.x,other.variance)) + else: + return Uncertainty(self.x/other, self.variance/other**2) + def __pow__(self, other): + if isinstance(other,Uncertainty): + # Haven't calcuated variance in (a+/-da) ** (b+/-db) + return NotImplemented + else: + return Uncertainty(*err1d.pow(self.x,self.variance,other)) + + # Reverse operations + def __radd__(self, other): + return Uncertainty(self.x+other, self.variance+0) # Force copy + def __rsub__(self, other): + return Uncertainty(other-self.x, self.variance+0) + def __rmul__(self, other): + return Uncertainty(self.x*other, self.variance*other**2) + def __rtruediv__(self, other): + x,variance = err1d.pow(self.x,self.variance,-1) + return Uncertainty(x*other,variance*other**2) + def __rpow__(self, other): return NotImplemented + + # In-place operations: may be of mixed type + def __iadd__(self, other): + if isinstance(other,Uncertainty): + self.x,self.variance \ + = err1d.add_inplace(self.x,self.variance,other.x,other.variance) + else: + self.x+=other + return self + def __isub__(self, other): + if isinstance(other,Uncertainty): + self.x,self.variance \ + = err1d.sub_inplace(self.x,self.variance,other.x,other.variance) + else: + self.x-=other + return self + def __imul__(self, other): + if isinstance(other,Uncertainty): + self.x, self.variance \ + = err1d.mul_inplace(self.x,self.variance,other.x,other.variance) + else: + self.x *= other + self.variance *= other**2 + return self + def __itruediv__(self, other): + if isinstance(other,Uncertainty): + self.x,self.variance \ + = err1d.div_inplace(self.x,self.variance,other.x,other.variance) + else: + self.x /= other + self.variance /= other**2 + return self + def __ipow__(self, other): + if isinstance(other,Uncertainty): + # Haven't calcuated variance in (a+/-da) ** (b+/-db) + return NotImplemented + else: + self.x,self.variance = err1d.pow_inplace(self.x, self.variance, other) + return self + + # Use true division instead of integer division + def __div__(self, other): return self.__truediv__(other) + def __rdiv__(self, other): return self.__rtruediv__(other) + def __idiv__(self, other): return self.__itruediv__(other) + + + # Unary ops + def __neg__(self): + return Uncertainty(-self.x,self.variance) + def __pos__(self): + return self + def __abs__(self): + return Uncertainty(np.abs(self.x),self.variance) + + def __str__(self): + #return str(self.x)+" +/- "+str(np.sqrt(self.variance)) + if np.isscalar(self.x): + return format_uncertainty(self.x,np.sqrt(self.variance)) + else: + return [format_uncertainty(v,dv) + for v,dv in zip(self.x,np.sqrt(self.variance))] + def __repr__(self): + return "Uncertainty(%s,%s)"%(str(self.x),str(self.variance)) + + # Not implemented + def __floordiv__(self, other): return NotImplemented + def __mod__(self, other): return NotImplemented + def __divmod__(self, other): return NotImplemented + def __mod__(self, other): return NotImplemented + def __lshift__(self, other): return NotImplemented + def __rshift__(self, other): return NotImplemented + def __and__(self, other): return NotImplemented + def __xor__(self, other): return NotImplemented + def __or__(self, other): return NotImplemented + + def __rfloordiv__(self, other): return NotImplemented + def __rmod__(self, other): return NotImplemented + def __rdivmod__(self, other): return NotImplemented + def __rmod__(self, other): return NotImplemented + def __rlshift__(self, other): return NotImplemented + def __rrshift__(self, other): return NotImplemented + def __rand__(self, other): return NotImplemented + def __rxor__(self, other): return NotImplemented + def __ror__(self, other): return NotImplemented + + def __ifloordiv__(self, other): return NotImplemented + def __imod__(self, other): return NotImplemented + def __idivmod__(self, other): return NotImplemented + def __imod__(self, other): return NotImplemented + def __ilshift__(self, other): return NotImplemented + def __irshift__(self, other): return NotImplemented + def __iand__(self, other): return NotImplemented + def __ixor__(self, other): return NotImplemented + def __ior__(self, other): return NotImplemented + + def __invert__(self): return NotImplmented # For ~x + def __complex__(self): return NotImplmented + def __int__(self): return NotImplmented + def __long__(self): return NotImplmented + def __float__(self): return NotImplmented + def __oct__(self): return NotImplmented + def __hex__(self): return NotImplmented + def __index__(self): return NotImplmented + def __coerce__(self): return NotImplmented + + def log(self): + return Uncertainty(*err1d.log(self.x,self.variance)) + + def exp(self): + return Uncertainty(*err1d.exp(self.x,self.variance)) + +def log(val): return self.log() +def exp(val): return self.exp() + +def test(): + a = Uncertainty(5,3) + b = Uncertainty(4,2) + + # Scalar operations + z = a+4 + assert z.x == 5+4 and z.variance == 3 + z = a-4 + assert z.x == 5-4 and z.variance == 3 + z = a*4 + assert z.x == 5*4 and z.variance == 3*4**2 + z = a/4 + assert z.x == 5./4 and z.variance == 3./4**2 + + # Reverse scalar operations + z = 4+a + assert z.x == 4+5 and z.variance == 3 + z = 4-a + assert z.x == 4-5 and z.variance == 3 + z = 4*a + assert z.x == 4*5 and z.variance == 3*4**2 + z = 4/a + assert z.x == 4./5 and abs(z.variance - 3./5**4 * 4**2) < 1e-15 + + # Power operations + z = a**2 + assert z.x == 5**2 and z.variance == 4*3*5**2 + z = a**1 + assert z.x == 5**1 and z.variance == 3 + z = a**0 + assert z.x == 5**0 and z.variance == 0 + z = a**-1 + assert z.x == 5**-1 and abs(z.variance - 3./5**4) < 1e-15 + + # Binary operations + z = a+b + assert z.x == 5+4 and z.variance == 3+2 + z = a-b + assert z.x == 5-4 and z.variance == 3+2 + z = a*b + assert z.x == 5*4 and z.variance == (5**2*2 + 4**2*3) + z = a/b + assert z.x == 5./4 and abs(z.variance - (3./5**2 + 2./4**2)*(5./4)**2) < 1e-15 + + # ===== Inplace operations ===== + # Scalar operations + y = a+0; y += 4 + z = a+4 + assert y.x == z.x and abs(y.variance-z.variance) < 1e-15 + y = a+0; y -= 4 + z = a-4 + assert y.x == z.x and abs(y.variance-z.variance) < 1e-15 + y = a+0; y *= 4 + z = a*4 + assert y.x == z.x and abs(y.variance-z.variance) < 1e-15 + y = a+0; y /= 4 + z = a/4 + assert y.x == z.x and abs(y.variance-z.variance) < 1e-15 + + # Power operations + y = a+0; y **= 4 + z = a**4 + assert y.x == z.x and abs(y.variance-z.variance) < 1e-15 + + # Binary operations + y = a+0; y += b + z = a+b + assert y.x == z.x and abs(y.variance-z.variance) < 1e-15 + y = a+0; y -= b + z = a-b + assert y.x == z.x and abs(y.variance-z.variance) < 1e-15 + y = a+0; y *= b + z = a*b + assert y.x == z.x and abs(y.variance-z.variance) < 1e-15 + y = a+0; y /= b + z = a/b + assert y.x == z.x and abs(y.variance-z.variance) < 1e-15 + + + # =============== vector operations ================ + # Slicing + z = Uncertainty(np.array([1,2,3,4,5]),np.array([2,1,2,3,2])) + assert z[2].x == 3 and z[2].variance == 2 + assert (z[2:4].x == [3,4]).all() + assert (z[2:4].variance == [2,3]).all() + z[2:4] = Uncertainty(np.array([8,7]),np.array([4,5])) + assert z[2].x == 8 and z[2].variance == 4 + A = Uncertainty(np.array([a.x]*2),np.array([a.variance]*2)) + B = Uncertainty(np.array([b.x]*2),np.array([b.variance]*2)) + + # TODO complete tests of copy and inplace operations for vectors and slices. + + # Binary operations + z = A+B + assert (z.x == 5+4).all() and (z.variance == 3+2).all() + z = A-B + assert (z.x == 5-4).all() and (z.variance == 3+2).all() + z = A*B + assert (z.x == 5*4).all() and (z.variance == (5**2*2 + 4**2*3)).all() + z = A/B + assert (z.x == 5./4).all() + assert (abs(z.variance - (3./5**2 + 2./4**2)*(5./4)**2) < 1e-15).all() + + # printing; note that sqrt(3) ~ 1.7 + assert str(Uncertainty(5,3)) == "5.0(17)" + assert str(Uncertainty(15,3)) == "15.0(17)" + assert str(Uncertainty(151.23356,0.324185**2)) == "151.23(32)" + +if __name__ == "__main__": test() diff --git a/sas/sascalc/dataloader/__init__.py b/sas/sascalc/dataloader/__init__.py new file mode 100755 index 000000000..6f797d39e --- /dev/null +++ b/sas/sascalc/dataloader/__init__.py @@ -0,0 +1,3 @@ +from .data_info import * +from .manipulations import * +from .readers import * diff --git a/sas/sascalc/dataloader/data_info.py b/sas/sascalc/dataloader/data_info.py new file mode 100755 index 000000000..e74596fec --- /dev/null +++ b/sas/sascalc/dataloader/data_info.py @@ -0,0 +1,1219 @@ +""" + Module that contains classes to hold information read from + reduced data files. + + A good description of the data members can be found in + the CanSAS 1D XML data format: + + http://www.smallangles.net/wgwiki/index.php/cansas1d_documentation +""" +##################################################################### +#This software was developed by the University of Tennessee as part of the +#Distributed Data Analysis of Neutron Scattering Experiments (DANSE) +#project funded by the US National Science Foundation. +#See the license text in license.txt +#copyright 2008, University of Tennessee +###################################################################### + +from __future__ import print_function + +#TODO: Keep track of data manipulation in the 'process' data structure. +#TODO: This module should be independent of plottables. We should write +# an adapter class for plottables when needed. + +#from sas.guitools.plottables import Data1D as plottable_1D +from sas.sascalc.data_util.uncertainty import Uncertainty +import numpy as np +import math + +class plottable_1D(object): + """ + Data1D is a place holder for 1D plottables. + """ + # The presence of these should be mutually + # exclusive with the presence of Qdev (dx) + x = None + y = None + dx = None + dy = None + ## Slit smearing length + dxl = None + ## Slit smearing width + dxw = None + ## SESANS specific params (wavelengths for spin echo length calculation) + lam = None + dlam = None + + # Units + _xaxis = '' + _xunit = '' + _yaxis = '' + _yunit = '' + + def __init__(self, x, y, dx=None, dy=None, dxl=None, dxw=None, lam=None, dlam=None): + self.x = np.asarray(x) + self.y = np.asarray(y) + if dx is not None: + self.dx = np.asarray(dx) + if dy is not None: + self.dy = np.asarray(dy) + if dxl is not None: + self.dxl = np.asarray(dxl) + if dxw is not None: + self.dxw = np.asarray(dxw) + if lam is not None: + self.lam = np.asarray(lam) + if dlam is not None: + self.dlam = np.asarray(dlam) + + def xaxis(self, label, unit): + """ + set the x axis label and unit + """ + self._xaxis = label + self._xunit = unit + + def yaxis(self, label, unit): + """ + set the y axis label and unit + """ + self._yaxis = label + self._yunit = unit + + +class plottable_2D(object): + """ + Data2D is a place holder for 2D plottables. + """ + xmin = None + xmax = None + ymin = None + ymax = None + data = None + qx_data = None + qy_data = None + q_data = None + err_data = None + dqx_data = None + dqy_data = None + mask = None + + # Units + _xaxis = '' + _xunit = '' + _yaxis = '' + _yunit = '' + _zaxis = '' + _zunit = '' + + def __init__(self, data=None, err_data=None, qx_data=None, + qy_data=None, q_data=None, mask=None, + dqx_data=None, dqy_data=None): + self.data = np.asarray(data) + self.qx_data = np.asarray(qx_data) + self.qy_data = np.asarray(qy_data) + self.q_data = np.asarray(q_data) + self.mask = np.asarray(mask) + self.err_data = np.asarray(err_data) + if dqx_data is not None: + self.dqx_data = np.asarray(dqx_data) + if dqy_data is not None: + self.dqy_data = np.asarray(dqy_data) + + def xaxis(self, label, unit): + """ + set the x axis label and unit + """ + self._xaxis = label + self._xunit = unit + + def yaxis(self, label, unit): + """ + set the y axis label and unit + """ + self._yaxis = label + self._yunit = unit + + def zaxis(self, label, unit): + """ + set the z axis label and unit + """ + self._zaxis = label + self._zunit = unit + + +class Vector(object): + """ + Vector class to hold multi-dimensional objects + """ + ## x component + x = None + ## y component + y = None + ## z component + z = None + + def __init__(self, x=None, y=None, z=None): + """ + Initialization. Components that are not + set a set to None by default. + + :param x: x component + :param y: y component + :param z: z component + """ + self.x = x + self.y = y + self.z = z + + def __str__(self): + msg = "x = %s\ty = %s\tz = %s" % (str(self.x), str(self.y), str(self.z)) + return msg + + +class Detector(object): + """ + Class to hold detector information + """ + ## Name of the instrument [string] + name = None + ## Sample to detector distance [float] [mm] + distance = None + distance_unit = 'mm' + ## Offset of this detector position in X, Y, + #(and Z if necessary) [Vector] [mm] + offset = None + offset_unit = 'm' + ## Orientation (rotation) of this detector in roll, + # pitch, and yaw [Vector] [degrees] + orientation = None + orientation_unit = 'degree' + ## Center of the beam on the detector in X and Y + #(and Z if necessary) [Vector] [mm] + beam_center = None + beam_center_unit = 'mm' + ## Pixel size in X, Y, (and Z if necessary) [Vector] [mm] + pixel_size = None + pixel_size_unit = 'mm' + ## Slit length of the instrument for this detector.[float] [mm] + slit_length = None + slit_length_unit = 'mm' + + def __init__(self): + """ + Initialize class attribute that are objects... + """ + self.offset = Vector() + self.orientation = Vector() + self.beam_center = Vector() + self.pixel_size = Vector() + + def __str__(self): + _str = "Detector:\n" + _str += " Name: %s\n" % self.name + _str += " Distance: %s [%s]\n" % \ + (str(self.distance), str(self.distance_unit)) + _str += " Offset: %s [%s]\n" % \ + (str(self.offset), str(self.offset_unit)) + _str += " Orientation: %s [%s]\n" % \ + (str(self.orientation), str(self.orientation_unit)) + _str += " Beam center: %s [%s]\n" % \ + (str(self.beam_center), str(self.beam_center_unit)) + _str += " Pixel size: %s [%s]\n" % \ + (str(self.pixel_size), str(self.pixel_size_unit)) + _str += " Slit length: %s [%s]\n" % \ + (str(self.slit_length), str(self.slit_length_unit)) + return _str + + +class Aperture(object): + ## Name + name = None + ## Type + type = None + ## Size name + size_name = None + ## Aperture size [Vector] + size = None + size_unit = 'mm' + ## Aperture distance [float] + distance = None + distance_unit = 'mm' + + def __init__(self): + self.size = Vector() + + +class Collimation(object): + """ + Class to hold collimation information + """ + ## Name + name = None + ## Length [float] [mm] + length = None + length_unit = 'mm' + ## Aperture + aperture = None + + def __init__(self): + self.aperture = [] + + def __str__(self): + _str = "Collimation:\n" + _str += " Length: %s [%s]\n" % \ + (str(self.length), str(self.length_unit)) + for item in self.aperture: + _str += " Aperture size:%s [%s]\n" % \ + (str(item.size), str(item.size_unit)) + _str += " Aperture_dist:%s [%s]\n" % \ + (str(item.distance), str(item.distance_unit)) + return _str + + +class Source(object): + """ + Class to hold source information + """ + ## Name + name = None + ## Radiation type [string] + radiation = None + ## Beam size name + beam_size_name = None + ## Beam size [Vector] [mm] + beam_size = None + beam_size_unit = 'mm' + ## Beam shape [string] + beam_shape = None + ## Wavelength [float] [Angstrom] + wavelength = None + wavelength_unit = 'A' + ## Minimum wavelength [float] [Angstrom] + wavelength_min = None + wavelength_min_unit = 'nm' + ## Maximum wavelength [float] [Angstrom] + wavelength_max = None + wavelength_max_unit = 'nm' + ## Wavelength spread [float] [Angstrom] + wavelength_spread = None + wavelength_spread_unit = 'percent' + + def __init__(self): + self.beam_size = Vector() + + def __str__(self): + _str = "Source:\n" + _str += " Radiation: %s\n" % str(self.radiation) + _str += " Shape: %s\n" % str(self.beam_shape) + _str += " Wavelength: %s [%s]\n" % \ + (str(self.wavelength), str(self.wavelength_unit)) + _str += " Waveln_min: %s [%s]\n" % \ + (str(self.wavelength_min), str(self.wavelength_min_unit)) + _str += " Waveln_max: %s [%s]\n" % \ + (str(self.wavelength_max), str(self.wavelength_max_unit)) + _str += " Waveln_spread:%s [%s]\n" % \ + (str(self.wavelength_spread), str(self.wavelength_spread_unit)) + _str += " Beam_size: %s [%s]\n" % \ + (str(self.beam_size), str(self.beam_size_unit)) + return _str + + +""" +Definitions of radiation types +""" +NEUTRON = 'neutron' +XRAY = 'x-ray' +MUON = 'muon' +ELECTRON = 'electron' + + +class Sample(object): + """ + Class to hold the sample description + """ + ## Short name for sample + name = '' + ## ID + ID = '' + ## Thickness [float] [mm] + thickness = None + thickness_unit = 'mm' + ## Transmission [float] [fraction] + transmission = None + ## Temperature [float] [No Default] + temperature = None + temperature_unit = None + ## Position [Vector] [mm] + position = None + position_unit = 'mm' + ## Orientation [Vector] [degrees] + orientation = None + orientation_unit = 'degree' + ## Details + details = None + ## SESANS zacceptance + zacceptance = (0,"") + yacceptance = (0,"") + + def __init__(self): + self.position = Vector() + self.orientation = Vector() + self.details = [] + + def __str__(self): + _str = "Sample:\n" + _str += " ID: %s\n" % str(self.ID) + _str += " Transmission: %s\n" % str(self.transmission) + _str += " Thickness: %s [%s]\n" % \ + (str(self.thickness), str(self.thickness_unit)) + _str += " Temperature: %s [%s]\n" % \ + (str(self.temperature), str(self.temperature_unit)) + _str += " Position: %s [%s]\n" % \ + (str(self.position), str(self.position_unit)) + _str += " Orientation: %s [%s]\n" % \ + (str(self.orientation), str(self.orientation_unit)) + + _str += " Details:\n" + for item in self.details: + _str += " %s\n" % item + + return _str + + +class Process(object): + """ + Class that holds information about the processes + performed on the data. + """ + name = '' + date = '' + description = '' + term = None + notes = None + + def __init__(self): + self.term = [] + self.notes = [] + + def is_empty(self): + """ + Return True if the object is empty + """ + return len(self.name) == 0 and len(self.date) == 0 and len(self.description) == 0 \ + and len(self.term) == 0 and len(self.notes) == 0 + + def single_line_desc(self): + """ + Return a single line string representing the process + """ + return "%s %s %s" % (self.name, self.date, self.description) + + def __str__(self): + _str = "Process:\n" + _str += " Name: %s\n" % self.name + _str += " Date: %s\n" % self.date + _str += " Description: %s\n" % self.description + for item in self.term: + _str += " Term: %s\n" % item + for item in self.notes: + _str += " Note: %s\n" % item + return _str + + +class TransmissionSpectrum(object): + """ + Class that holds information about transmission spectrum + for white beams and spallation sources. + """ + name = '' + timestamp = '' + ## Wavelength (float) [A] + wavelength = None + wavelength_unit = 'A' + ## Transmission (float) [unit less] + transmission = None + transmission_unit = '' + ## Transmission Deviation (float) [unit less] + transmission_deviation = None + transmission_deviation_unit = '' + + def __init__(self): + self.wavelength = [] + self.transmission = [] + self.transmission_deviation = [] + + def __str__(self): + _str = "Transmission Spectrum:\n" + _str += " Name: \t{0}\n".format(self.name) + _str += " Timestamp: \t{0}\n".format(self.timestamp) + _str += " Wavelength unit: \t{0}\n".format(self.wavelength_unit) + _str += " Transmission unit:\t{0}\n".format(self.transmission_unit) + _str += " Trans. Dev. unit: \t{0}\n".format(\ + self.transmission_deviation_unit) + length_list = [len(self.wavelength), len(self.transmission), \ + len(self.transmission_deviation)] + _str += " Number of Pts: \t{0}\n".format(max(length_list)) + return _str + + +class DataInfo(object): + """ + Class to hold the data read from a file. + It includes four blocks of data for the + instrument description, the sample description, + the data itself and any other meta data. + """ + ## Title + title = '' + ## Run number + run = None + ## Run name + run_name = None + ## File name + filename = '' + ## Notes + notes = None + ## Processes (Action on the data) + process = None + ## Instrument name + instrument = '' + ## Detector information + detector = None + ## Sample information + sample = None + ## Source information + source = None + ## Collimation information + collimation = None + ## Transmission Spectrum INfo + trans_spectrum = None + ## Additional meta-data + meta_data = None + ## Loading errors + errors = None + ## SESANS data check + isSesans = None + + + def __init__(self): + """ + Initialization + """ + ## Title + self.title = '' + ## Run number + self.run = [] + self.run_name = {} + ## File name + self.filename = '' + ## Notes + self.notes = [] + ## Processes (Action on the data) + self.process = [] + ## Instrument name + self.instrument = '' + ## Detector information + self.detector = [] + ## Sample information + self.sample = Sample() + ## Source information + self.source = Source() + ## Collimation information + self.collimation = [] + ## Transmission Spectrum + self.trans_spectrum = [] + ## Additional meta-data + self.meta_data = {} + ## Loading errors + self.errors = [] + ## SESANS data check + self.isSesans = False + + def append_empty_process(self): + """ + """ + self.process.append(Process()) + + def add_notes(self, message=""): + """ + Add notes to datainfo + """ + self.notes.append(message) + + def __str__(self): + """ + Nice printout + """ + _str = "File: %s\n" % self.filename + _str += "Title: %s\n" % self.title + _str += "Run: %s\n" % str(self.run) + _str += "SESANS: %s\n" % str(self.isSesans) + _str += "Instrument: %s\n" % str(self.instrument) + _str += "%s\n" % str(self.sample) + _str += "%s\n" % str(self.source) + for item in self.detector: + _str += "%s\n" % str(item) + for item in self.collimation: + _str += "%s\n" % str(item) + for item in self.process: + _str += "%s\n" % str(item) + for item in self.notes: + _str += "%s\n" % str(item) + for item in self.trans_spectrum: + _str += "%s\n" % str(item) + return _str + + # Private method to perform operation. Not implemented for DataInfo, + # but should be implemented for each data class inherited from DataInfo + # that holds actual data (ex.: Data1D) + def _perform_operation(self, other, operation): + """ + Private method to perform operation. Not implemented for DataInfo, + but should be implemented for each data class inherited from DataInfo + that holds actual data (ex.: Data1D) + """ + return NotImplemented + + def _perform_union(self, other): + """ + Private method to perform union operation. Not implemented for DataInfo, + but should be implemented for each data class inherited from DataInfo + that holds actual data (ex.: Data1D) + """ + return NotImplemented + + def __add__(self, other): + """ + Add two data sets + + :param other: data set to add to the current one + :return: new data set + :raise ValueError: raised when two data sets are incompatible + """ + def operation(a, b): + return a + b + return self._perform_operation(other, operation) + + def __radd__(self, other): + """ + Add two data sets + + :param other: data set to add to the current one + :return: new data set + :raise ValueError: raised when two data sets are incompatible + """ + def operation(a, b): + return b + a + return self._perform_operation(other, operation) + + def __sub__(self, other): + """ + Subtract two data sets + + :param other: data set to subtract from the current one + :return: new data set + :raise ValueError: raised when two data sets are incompatible + """ + def operation(a, b): + return a - b + return self._perform_operation(other, operation) + + def __rsub__(self, other): + """ + Subtract two data sets + + :param other: data set to subtract from the current one + :return: new data set + :raise ValueError: raised when two data sets are incompatible + """ + def operation(a, b): + return b - a + return self._perform_operation(other, operation) + + def __mul__(self, other): + """ + Multiply two data sets + + :param other: data set to subtract from the current one + :return: new data set + :raise ValueError: raised when two data sets are incompatible + """ + def operation(a, b): + return a * b + return self._perform_operation(other, operation) + + def __rmul__(self, other): + """ + Multiply two data sets + + :param other: data set to subtract from the current one + :return: new data set + :raise ValueError: raised when two data sets are incompatible + """ + def operation(a, b): + return b * a + return self._perform_operation(other, operation) + + def __div__(self, other): + """ + Divided a data set by another + + :param other: data set that the current one is divided by + :return: new data set + :raise ValueError: raised when two data sets are incompatible + """ + def operation(a, b): + return a/b + return self._perform_operation(other, operation) + + def __rdiv__(self, other): + """ + Divided a data set by another + + :param other: data set that the current one is divided by + :return: new data set + :raise ValueError: raised when two data sets are incompatible + """ + def operation(a, b): + return b/a + return self._perform_operation(other, operation) + + def __or__(self, other): + """ + Union a data set with another + + :param other: data set to be unified + :return: new data set + :raise ValueError: raised when two data sets are incompatible + """ + return self._perform_union(other) + + def __ror__(self, other): + """ + Union a data set with another + + :param other: data set to be unified + :return: new data set + :raise ValueError: raised when two data sets are incompatible + """ + return self._perform_union(other) + +class Data1D(plottable_1D, DataInfo): + """ + 1D data class + """ + def __init__(self, x=None, y=None, dx=None, dy=None, lam=None, dlam=None, isSesans=None): + DataInfo.__init__(self) + plottable_1D.__init__(self, x, y, dx, dy,None, None, lam, dlam) + self.isSesans = isSesans + try: + if self.isSesans: # the data is SESANS + self.x_unit = 'A' + self.y_unit = 'pol' + elif not self.isSesans: # the data is SANS + self.x_unit = '1/A' + self.y_unit = '1/cm' + except: # the data is not recognized/supported, and the user is notified + raise TypeError('data not recognized, check documentation for supported 1D data formats') + + def __str__(self): + """ + Nice printout + """ + _str = "%s\n" % DataInfo.__str__(self) + _str += "Data:\n" + _str += " Type: %s\n" % self.__class__.__name__ + _str += " X-axis: %s\t[%s]\n" % (self._xaxis, self._xunit) + _str += " Y-axis: %s\t[%s]\n" % (self._yaxis, self._yunit) + _str += " Length: %g\n" % len(self.x) + return _str + + def is_slit_smeared(self): + """ + Check whether the data has slit smearing information + :return: True is slit smearing info is present, False otherwise + """ + def _check(v): + if (v.__class__ == list or v.__class__ == np.ndarray) \ + and len(v) > 0 and min(v) > 0: + return True + return False + return _check(self.dxl) or _check(self.dxw) + + def clone_without_data(self, length=0, clone=None): + """ + Clone the current object, without copying the data (which + will be filled out by a subsequent operation). + The data arrays will be initialized to zero. + + :param length: length of the data array to be initialized + :param clone: if provided, the data will be copied to clone + """ + from copy import deepcopy + + if clone is None or not issubclass(clone.__class__, Data1D): + x = np.zeros(length) + dx = np.zeros(length) + y = np.zeros(length) + dy = np.zeros(length) + lam = np.zeros(length) + dlam = np.zeros(length) + clone = Data1D(x, y, lam=lam, dx=dx, dy=dy, dlam=dlam) + + clone.title = self.title + clone.run = self.run + clone.filename = self.filename + clone.instrument = self.instrument + clone.notes = deepcopy(self.notes) + clone.process = deepcopy(self.process) + clone.detector = deepcopy(self.detector) + clone.sample = deepcopy(self.sample) + clone.source = deepcopy(self.source) + clone.collimation = deepcopy(self.collimation) + clone.trans_spectrum = deepcopy(self.trans_spectrum) + clone.meta_data = deepcopy(self.meta_data) + clone.errors = deepcopy(self.errors) + + return clone + + def _validity_check(self, other): + """ + Checks that the data lengths are compatible. + Checks that the x vectors are compatible. + Returns errors vectors equal to original + errors vectors if they were present or vectors + of zeros when none was found. + + :param other: other data set for operation + :return: dy for self, dy for other [numpy arrays] + :raise ValueError: when lengths are not compatible + """ + dy_other = None + if isinstance(other, Data1D): + # Check that data lengths are the same + if len(self.x) != len(other.x) or \ + len(self.y) != len(other.y): + msg = "Unable to perform operation: data length are not equal" + raise ValueError(msg) + # Here we could also extrapolate between data points + TOLERANCE = 0.01 + for i in range(len(self.x)): + if math.fabs((self.x[i] - other.x[i])/self.x[i]) > TOLERANCE: + msg = "Incompatible data sets: x-values do not match" + raise ValueError(msg) + + # Check that the other data set has errors, otherwise + # create zero vector + dy_other = other.dy + if other.dy is None or (len(other.dy) != len(other.y)): + dy_other = np.zeros(len(other.y)) + + # Check that we have errors, otherwise create zero vector + dy = self.dy + if self.dy is None or (len(self.dy) != len(self.y)): + dy = np.zeros(len(self.y)) + + return dy, dy_other + + def _perform_operation(self, other, operation): + """ + """ + # First, check the data compatibility + dy, dy_other = self._validity_check(other) + result = self.clone_without_data(len(self.x)) + if self.dxw is None: + result.dxw = None + else: + result.dxw = np.zeros(len(self.x)) + if self.dxl is None: + result.dxl = None + else: + result.dxl = np.zeros(len(self.x)) + + for i in range(len(self.x)): + result.x[i] = self.x[i] + if self.dx is not None and len(self.x) == len(self.dx): + result.dx[i] = self.dx[i] + if self.dxw is not None and len(self.x) == len(self.dxw): + result.dxw[i] = self.dxw[i] + if self.dxl is not None and len(self.x) == len(self.dxl): + result.dxl[i] = self.dxl[i] + + a = Uncertainty(self.y[i], dy[i]**2) + if isinstance(other, Data1D): + b = Uncertainty(other.y[i], dy_other[i]**2) + if other.dx is not None: + result.dx[i] *= self.dx[i] + result.dx[i] += (other.dx[i]**2) + result.dx[i] /= 2 + result.dx[i] = math.sqrt(result.dx[i]) + if result.dxl is not None and other.dxl is not None: + result.dxl[i] *= self.dxl[i] + result.dxl[i] += (other.dxl[i]**2) + result.dxl[i] /= 2 + result.dxl[i] = math.sqrt(result.dxl[i]) + else: + b = other + + output = operation(a, b) + result.y[i] = output.x + result.dy[i] = math.sqrt(math.fabs(output.variance)) + return result + + def _validity_check_union(self, other): + """ + Checks that the data lengths are compatible. + Checks that the x vectors are compatible. + Returns errors vectors equal to original + errors vectors if they were present or vectors + of zeros when none was found. + + :param other: other data set for operation + :return: bool + :raise ValueError: when data types are not compatible + """ + if not isinstance(other, Data1D): + msg = "Unable to perform operation: different types of data set" + raise ValueError(msg) + return True + + def _perform_union(self, other): + """ + """ + # First, check the data compatibility + self._validity_check_union(other) + result = self.clone_without_data(len(self.x) + len(other.x)) + if self.dy is None or other.dy is None: + result.dy = None + else: + result.dy = np.zeros(len(self.x) + len(other.x)) + if self.dx is None or other.dx is None: + result.dx = None + else: + result.dx = np.zeros(len(self.x) + len(other.x)) + if self.dxw is None or other.dxw is None: + result.dxw = None + else: + result.dxw = np.zeros(len(self.x) + len(other.x)) + if self.dxl is None or other.dxl is None: + result.dxl = None + else: + result.dxl = np.zeros(len(self.x) + len(other.x)) + + result.x = np.append(self.x, other.x) + #argsorting + ind = np.argsort(result.x) + result.x = result.x[ind] + result.y = np.append(self.y, other.y) + result.y = result.y[ind] + if result.dy is not None: + result.dy = np.append(self.dy, other.dy) + result.dy = result.dy[ind] + if result.dx is not None: + result.dx = np.append(self.dx, other.dx) + result.dx = result.dx[ind] + if result.dxw is not None: + result.dxw = np.append(self.dxw, other.dxw) + result.dxw = result.dxw[ind] + if result.dxl is not None: + result.dxl = np.append(self.dxl, other.dxl) + result.dxl = result.dxl[ind] + return result + + +class Data2D(plottable_2D, DataInfo): + """ + 2D data class + """ + ## Units for Q-values + Q_unit = '1/A' + ## Units for I(Q) values + I_unit = '1/cm' + ## Vector of Q-values at the center of each bin in x + x_bins = None + ## Vector of Q-values at the center of each bin in y + y_bins = None + ## No 2D SESANS data as of yet. Always set it to False + isSesans = False + + def __init__(self, data=None, err_data=None, qx_data=None, + qy_data=None, q_data=None, mask=None, + dqx_data=None, dqy_data=None): + DataInfo.__init__(self) + plottable_2D.__init__(self, data, err_data, qx_data, + qy_data, q_data, mask, dqx_data, dqy_data) + self.y_bins = [] + self.x_bins = [] + + if len(self.detector) > 0: + raise RuntimeError("Data2D: Detector bank already filled at init") + + def __str__(self): + _str = "%s\n" % DataInfo.__str__(self) + _str += "Data:\n" + _str += " Type: %s\n" % self.__class__.__name__ + _str += " X-axis: %s\t[%s]\n" % (self._xaxis, self._xunit) + _str += " Y-axis: %s\t[%s]\n" % (self._yaxis, self._yunit) + _str += " Z-axis: %s\t[%s]\n" % (self._zaxis, self._zunit) + _str += " Length: %g \n" % (len(self.data)) + _str += " Shape: (%d, %d)\n" % (len(self.y_bins), len(self.x_bins)) + return _str + + def clone_without_data(self, length=0, clone=None): + """ + Clone the current object, without copying the data (which + will be filled out by a subsequent operation). + The data arrays will be initialized to zero. + + :param length: length of the data array to be initialized + :param clone: if provided, the data will be copied to clone + """ + from copy import deepcopy + + if clone is None or not issubclass(clone.__class__, Data2D): + data = np.zeros(length) + err_data = np.zeros(length) + qx_data = np.zeros(length) + qy_data = np.zeros(length) + q_data = np.zeros(length) + mask = np.zeros(length) + dqx_data = None + dqy_data = None + clone = Data2D(data=data, err_data=err_data, + qx_data=qx_data, qy_data=qy_data, + q_data=q_data, mask=mask) + + clone._xaxis = self._xaxis + clone._yaxis = self._yaxis + clone._zaxis = self._zaxis + clone._xunit = self._xunit + clone._yunit = self._yunit + clone._zunit = self._zunit + clone.x_bins = self.x_bins + clone.y_bins = self.y_bins + + clone.title = self.title + clone.run = self.run + clone.filename = self.filename + clone.instrument = self.instrument + clone.notes = deepcopy(self.notes) + clone.process = deepcopy(self.process) + clone.detector = deepcopy(self.detector) + clone.sample = deepcopy(self.sample) + clone.source = deepcopy(self.source) + clone.collimation = deepcopy(self.collimation) + clone.trans_spectrum = deepcopy(self.trans_spectrum) + clone.meta_data = deepcopy(self.meta_data) + clone.errors = deepcopy(self.errors) + + return clone + + def _validity_check(self, other): + """ + Checks that the data lengths are compatible. + Checks that the x vectors are compatible. + Returns errors vectors equal to original + errors vectors if they were present or vectors + of zeros when none was found. + + :param other: other data set for operation + :return: dy for self, dy for other [numpy arrays] + :raise ValueError: when lengths are not compatible + """ + err_other = None + TOLERANCE = 0.01 + if isinstance(other, Data2D): + # Check that data lengths are the same + if len(self.data) != len(other.data) or \ + len(self.qx_data) != len(other.qx_data) or \ + len(self.qy_data) != len(other.qy_data): + msg = "Unable to perform operation: data length are not equal" + raise ValueError(msg) + for ind in range(len(self.data)): + if math.fabs((self.qx_data[ind] - other.qx_data[ind])/self.qx_data[ind]) > TOLERANCE: + msg = "Incompatible data sets: qx-values do not match: %s %s" % (self.qx_data[ind], other.qx_data[ind]) + raise ValueError(msg) + if math.fabs((self.qy_data[ind] - other.qy_data[ind])/self.qy_data[ind]) > TOLERANCE: + msg = "Incompatible data sets: qy-values do not match: %s %s" % (self.qy_data[ind], other.qy_data[ind]) + raise ValueError(msg) + + # Check that the scales match + err_other = other.err_data + if other.err_data is None or \ + (len(other.err_data) != len(other.data)): + err_other = np.zeros(len(other.data)) + + # Check that we have errors, otherwise create zero vector + err = self.err_data + if self.err_data is None or \ + (len(self.err_data) != len(self.data)): + err = np.zeros(len(other.data)) + return err, err_other + + def _perform_operation(self, other, operation): + """ + Perform 2D operations between data sets + + :param other: other data set + :param operation: function defining the operation + """ + # First, check the data compatibility + dy, dy_other = self._validity_check(other) + result = self.clone_without_data(np.size(self.data)) + if self.dqx_data is None or self.dqy_data is None: + result.dqx_data = None + result.dqy_data = None + else: + result.dqx_data = np.zeros(len(self.data)) + result.dqy_data = np.zeros(len(self.data)) + for i in range(np.size(self.data)): + result.data[i] = self.data[i] + if self.err_data is not None and \ + np.size(self.data) == np.size(self.err_data): + result.err_data[i] = self.err_data[i] + if self.dqx_data is not None: + result.dqx_data[i] = self.dqx_data[i] + if self.dqy_data is not None: + result.dqy_data[i] = self.dqy_data[i] + result.qx_data[i] = self.qx_data[i] + result.qy_data[i] = self.qy_data[i] + result.q_data[i] = self.q_data[i] + result.mask[i] = self.mask[i] + + a = Uncertainty(self.data[i], dy[i]**2) + if isinstance(other, Data2D): + b = Uncertainty(other.data[i], dy_other[i]**2) + if other.dqx_data is not None and \ + result.dqx_data is not None: + result.dqx_data[i] *= self.dqx_data[i] + result.dqx_data[i] += (other.dqx_data[i]**2) + result.dqx_data[i] /= 2 + result.dqx_data[i] = math.sqrt(result.dqx_data[i]) + if other.dqy_data is not None and \ + result.dqy_data is not None: + result.dqy_data[i] *= self.dqy_data[i] + result.dqy_data[i] += (other.dqy_data[i]**2) + result.dqy_data[i] /= 2 + result.dqy_data[i] = math.sqrt(result.dqy_data[i]) + else: + b = other + output = operation(a, b) + result.data[i] = output.x + result.err_data[i] = math.sqrt(math.fabs(output.variance)) + return result + + def _validity_check_union(self, other): + """ + Checks that the data lengths are compatible. + Checks that the x vectors are compatible. + Returns errors vectors equal to original + errors vectors if they were present or vectors + of zeros when none was found. + + :param other: other data set for operation + :return: bool + :raise ValueError: when data types are not compatible + """ + if not isinstance(other, Data2D): + msg = "Unable to perform operation: different types of data set" + raise ValueError(msg) + return True + + def _perform_union(self, other): + """ + Perform 2D operations between data sets + + :param other: other data set + :param operation: function defining the operation + """ + # First, check the data compatibility + self._validity_check_union(other) + result = self.clone_without_data(np.size(self.data) + \ + np.size(other.data)) + result.xmin = self.xmin + result.xmax = self.xmax + result.ymin = self.ymin + result.ymax = self.ymax + if self.dqx_data is None or self.dqy_data is None or \ + other.dqx_data is None or other.dqy_data is None: + result.dqx_data = None + result.dqy_data = None + else: + result.dqx_data = np.zeros(len(self.data) + \ + np.size(other.data)) + result.dqy_data = np.zeros(len(self.data) + \ + np.size(other.data)) + + result.data = np.append(self.data, other.data) + result.qx_data = np.append(self.qx_data, other.qx_data) + result.qy_data = np.append(self.qy_data, other.qy_data) + result.q_data = np.append(self.q_data, other.q_data) + result.mask = np.append(self.mask, other.mask) + if result.err_data is not None: + result.err_data = np.append(self.err_data, other.err_data) + if self.dqx_data is not None: + result.dqx_data = np.append(self.dqx_data, other.dqx_data) + if self.dqy_data is not None: + result.dqy_data = np.append(self.dqy_data, other.dqy_data) + + return result + + +def combine_data_info_with_plottable(data, datainfo): + """ + A function that combines the DataInfo data in self.current_datainto with a + plottable_1D or 2D data object. + + :param data: A plottable_1D or plottable_2D data object + :return: A fully specified Data1D or Data2D object + """ + + final_dataset = None + if isinstance(data, plottable_1D): + final_dataset = Data1D(data.x, data.y, isSesans=datainfo.isSesans) + final_dataset.dx = data.dx + final_dataset.dy = data.dy + final_dataset.dxl = data.dxl + final_dataset.dxw = data.dxw + final_dataset.x_unit = data._xunit + final_dataset.y_unit = data._yunit + final_dataset.xaxis(data._xaxis, data._xunit) + final_dataset.yaxis(data._yaxis, data._yunit) + elif isinstance(data, plottable_2D): + final_dataset = Data2D(data.data, data.err_data, data.qx_data, + data.qy_data, data.q_data, data.mask, + data.dqx_data, data.dqy_data) + final_dataset.xaxis(data._xaxis, data._xunit) + final_dataset.yaxis(data._yaxis, data._yunit) + final_dataset.zaxis(data._zaxis, data._zunit) + else: + return_string = ("Should Never Happen: _combine_data_info_with_plottabl" + "e input is not a plottable1d or plottable2d data " + "object") + return return_string + + if hasattr(data, "xmax"): + final_dataset.xmax = data.xmax + if hasattr(data, "ymax"): + final_dataset.ymax = data.ymax + if hasattr(data, "xmin"): + final_dataset.xmin = data.xmin + if hasattr(data, "ymin"): + final_dataset.ymin = data.ymin + final_dataset.isSesans = datainfo.isSesans + final_dataset.title = datainfo.title + final_dataset.run = datainfo.run + final_dataset.run_name = datainfo.run_name + final_dataset.filename = datainfo.filename + final_dataset.notes = datainfo.notes + final_dataset.process = datainfo.process + final_dataset.instrument = datainfo.instrument + final_dataset.detector = datainfo.detector + final_dataset.sample = datainfo.sample + final_dataset.source = datainfo.source + final_dataset.collimation = datainfo.collimation + final_dataset.trans_spectrum = datainfo.trans_spectrum + final_dataset.meta_data = datainfo.meta_data + final_dataset.errors = datainfo.errors + return final_dataset diff --git a/sas/sascalc/dataloader/file_reader_base_class.py b/sas/sascalc/dataloader/file_reader_base_class.py new file mode 100755 index 000000000..700f7ff78 --- /dev/null +++ b/sas/sascalc/dataloader/file_reader_base_class.py @@ -0,0 +1,500 @@ +""" +This is the base file reader class most file readers should inherit from. +All generic functionality required for a file loader/reader is built into this +class +""" + +import os +import sys +import math +import logging +from abc import abstractmethod + +import numpy as np +from .loader_exceptions import NoKnownLoaderException, FileContentsException,\ + DataReaderException, DefaultReaderException +from .data_info import Data1D, Data2D, DataInfo, plottable_1D, plottable_2D,\ + combine_data_info_with_plottable +from sas.sascalc.data_util.nxsunit import Converter + +logger = logging.getLogger(__name__) + +if sys.version_info[0] < 3: + def decode(s): + return s +else: + def decode(s): + return s.decode() if isinstance(s, bytes) else s + +# Data 1D fields for iterative purposes +FIELDS_1D = ('x', 'y', 'dx', 'dy', 'dxl', 'dxw') +# Data 2D fields for iterative purposes +FIELDS_2D = ('data', 'qx_data', 'qy_data', 'q_data', 'err_data', + 'dqx_data', 'dqy_data', 'mask') +DEPRECATION_MESSAGE = ("\rThe extension of this file suggests the data set migh" + "t not be fully reduced. Support for the reader associat" + "ed with this file type has been removed. An attempt to " + "load the file was made, but, should it be successful, " + "SasView cannot guarantee the accuracy of the data.") + + +class FileReader(object): + # String to describe the type of data this reader can load + type_name = "ASCII" + # Wildcards to display + type = ["Text files (*.txt|*.TXT)"] + # List of allowed extensions + ext = ['.txt'] + # Deprecated extensions + deprecated_extensions = ['.asc'] + # Bypass extension check and try to load anyway + allow_all = False + # Able to import the unit converter + has_converter = True + # Default value of zero + _ZERO = 1e-16 + + def __init__(self): + # List of Data1D and Data2D objects to be sent back to data_loader + self.output = [] + # Current plottable_(1D/2D) object being loaded in + self.current_dataset = None + # Current DataInfo object being loaded in + self.current_datainfo = None + # File path sent to reader + self.filepath = None + # Open file handle + self.f_open = None + + def read(self, filepath): + """ + Basic file reader + + :param filepath: The full or relative path to a file to be loaded + """ + self.filepath = filepath + if os.path.isfile(filepath): + basename, extension = os.path.splitext(os.path.basename(filepath)) + self.extension = extension.lower() + # If the file type is not allowed, return nothing + if self.extension in self.ext or self.allow_all: + # Try to load the file, but raise an error if unable to. + try: + self.f_open = open(filepath, 'rb') + self.get_file_contents() + + except DataReaderException as e: + self.handle_error_message(e.message) + except OSError as e: + # If the file cannot be opened + msg = "Unable to open file: {}\n".format(filepath) + msg += e.message + self.handle_error_message(msg) + finally: + # Close the file handle if it is open + if not self.f_open.closed: + self.f_open.close() + if any(filepath.lower().endswith(ext) for ext in + self.deprecated_extensions): + self.handle_error_message(DEPRECATION_MESSAGE) + if len(self.output) > 0: + # Sort the data that's been loaded + self.convert_data_units() + self.sort_data() + else: + msg = "Unable to find file at: {}\n".format(filepath) + msg += "Please check your file path and try again." + self.handle_error_message(msg) + + # Return a list of parsed entries that data_loader can manage + final_data = self.output + self.reset_state() + return final_data + + def reset_state(self): + """ + Resets the class state to a base case when loading a new data file so previous + data files do not appear a second time + """ + self.current_datainfo = None + self.current_dataset = None + self.filepath = None + self.ind = None + self.output = [] + + def nextline(self): + """ + Returns the next line in the file as a string. + """ + #return self.f_open.readline() + return decode(self.f_open.readline()) + + def nextlines(self): + """ + Returns the next line in the file as a string. + """ + for line in self.f_open: + #yield line + yield decode(line) + + def readall(self): + """ + Returns the entire file as a string. + """ + return decode(self.f_open.read()) + + def handle_error_message(self, msg): + """ + Generic error handler to add an error to the current datainfo to + propagate the error up the error chain. + :param msg: Error message + """ + if len(self.output) > 0: + self.output[-1].errors.append(msg) + elif isinstance(self.current_datainfo, DataInfo): + self.current_datainfo.errors.append(msg) + else: + logger.warning(msg) + raise NoKnownLoaderException(msg) + + def send_to_output(self): + """ + Helper that automatically combines the info and set and then appends it + to output + """ + data_obj = combine_data_info_with_plottable(self.current_dataset, + self.current_datainfo) + self.output.append(data_obj) + + def sort_data(self): + """ + Sort 1D data along the X axis for consistency + """ + for data in self.output: + if isinstance(data, Data1D): + # Normalize the units for + data.x_unit = self.format_unit(data.x_unit) + data._xunit = data.x_unit + data.y_unit = self.format_unit(data.y_unit) + data._yunit = data.y_unit + # Sort data by increasing x and remove 1st point + ind = np.lexsort((data.y, data.x)) + data.x = self._reorder_1d_array(data.x, ind) + data.y = self._reorder_1d_array(data.y, ind) + if data.dx is not None: + if len(data.dx) == 0: + data.dx = None + continue + data.dx = self._reorder_1d_array(data.dx, ind) + if data.dxl is not None: + data.dxl = self._reorder_1d_array(data.dxl, ind) + if data.dxw is not None: + data.dxw = self._reorder_1d_array(data.dxw, ind) + if data.dy is not None: + if len(data.dy) == 0: + data.dy = None + continue + data.dy = self._reorder_1d_array(data.dy, ind) + if data.lam is not None: + data.lam = self._reorder_1d_array(data.lam, ind) + if data.dlam is not None: + data.dlam = self._reorder_1d_array(data.dlam, ind) + data = self._remove_nans_in_data(data) + if len(data.x) > 0: + data.xmin = np.min(data.x) + data.xmax = np.max(data.x) + data.ymin = np.min(data.y) + data.ymax = np.max(data.y) + elif isinstance(data, Data2D): + # Normalize the units for + data.Q_unit = self.format_unit(data.Q_unit) + data.I_unit = self.format_unit(data.I_unit) + data._xunit = data.Q_unit + data._yunit = data.Q_unit + data._zunit = data.I_unit + data.data = data.data.astype(np.float64) + data.qx_data = data.qx_data.astype(np.float64) + data.xmin = np.min(data.qx_data) + data.xmax = np.max(data.qx_data) + data.qy_data = data.qy_data.astype(np.float64) + data.ymin = np.min(data.qy_data) + data.ymax = np.max(data.qy_data) + data.q_data = np.sqrt(data.qx_data * data.qx_data + + data.qy_data * data.qy_data) + if data.err_data is not None: + data.err_data = data.err_data.astype(np.float64) + if data.dqx_data is not None: + data.dqx_data = data.dqx_data.astype(np.float64) + if data.dqy_data is not None: + data.dqy_data = data.dqy_data.astype(np.float64) + if data.mask is not None: + data.mask = data.mask.astype(dtype=bool) + + if len(data.data.shape) == 2: + n_rows, n_cols = data.data.shape + data.y_bins = data.qy_data[0::int(n_cols)] + data.x_bins = data.qx_data[:int(n_cols)] + data.data = data.data.flatten() + data = self._remove_nans_in_data(data) + if len(data.data) > 0: + data.xmin = np.min(data.qx_data) + data.xmax = np.max(data.qx_data) + data.ymin = np.min(data.qy_data) + data.ymax = np.max(data.qy_data) + + @staticmethod + def _reorder_1d_array(array, ind): + """ + Reorders a 1D array based on the indices passed as ind + :param array: Array to be reordered + :param ind: Indices used to reorder array + :return: reordered array + """ + array = np.asarray(array, dtype=np.float64) + return array[ind] + + @staticmethod + def _remove_nans_in_data(data): + """ + Remove data points where nan is loaded + :param data: 1D or 2D data object + :return: data with nan points removed + """ + if isinstance(data, Data1D): + fields = FIELDS_1D + elif isinstance(data, Data2D): + fields = FIELDS_2D + else: + return data + # Make array of good points - all others will be removed + good = np.isfinite(getattr(data, fields[0])) + for name in fields[1:]: + array = getattr(data, name) + if array is not None: + # Update good points only if not already changed + good &= np.isfinite(array) + if not np.all(good): + for name in fields: + array = getattr(data, name) + if array is not None: + setattr(data, name, array[good]) + return data + + @staticmethod + def set_default_1d_units(data): + """ + Set the x and y axes to the default 1D units + :param data: 1D data set + :return: + """ + data.xaxis(r"\rm{Q}", '1/A') + data.yaxis(r"\rm{Intensity}", "1/cm") + return data + + @staticmethod + def set_default_2d_units(data): + """ + Set the x and y axes to the default 2D units + :param data: 2D data set + :return: + """ + data.xaxis("\\rm{Q_{x}}", '1/A') + data.yaxis("\\rm{Q_{y}}", '1/A') + data.zaxis("\\rm{Intensity}", "1/cm") + return data + + def convert_data_units(self, default_q_unit="1/A"): + """ + Converts al; data to the sasview default of units of A^{-1} for Q and + cm^{-1} for I. + :param default_q_unit: The default Q unit used by Sasview + """ + convert_q = True + new_output = [] + for data in self.output: + if data.isSesans: + new_output.append(data) + continue + try: + file_x_unit = data._xunit + data_conv_x = Converter(file_x_unit) + except KeyError: + logger.info("Unrecognized Q units in data file. No data " + "conversion attempted") + convert_q = False + try: + + if isinstance(data, Data1D): + if convert_q: + data.x = data_conv_x(data.x, units=default_q_unit) + data._xunit = default_q_unit + data.x_unit = default_q_unit + if data.dx is not None: + data.dx = data_conv_x(data.dx, + units=default_q_unit) + if data.dxl is not None: + data.dxl = data_conv_x(data.dxl, + units=default_q_unit) + if data.dxw is not None: + data.dxw = data_conv_x(data.dxw, + units=default_q_unit) + elif isinstance(data, Data2D): + if convert_q: + data.qx_data = data_conv_x(data.qx_data, + units=default_q_unit) + if data.dqx_data is not None: + data.dqx_data = data_conv_x(data.dqx_data, + units=default_q_unit) + try: + file_y_unit = data._yunit + data_conv_y = Converter(file_y_unit) + data.qy_data = data_conv_y(data.qy_data, + units=default_q_unit) + if data.dqy_data is not None: + data.dqy_data = data_conv_y(data.dqy_data, + units=default_q_unit) + except KeyError: + logger.info("Unrecognized Qy units in data file. No" + " data conversion attempted") + except KeyError: + message = "Unable to convert Q units from {0} to 1/A." + message.format(default_q_unit) + data.errors.append(message) + new_output.append(data) + self.output = new_output + + def format_unit(self, unit=None): + """ + Format units a common way + :param unit: + :return: + """ + if unit: + split = unit.split("/") + if len(split) == 1: + return unit + elif split[0] == '1': + return "{0}^".format(split[1]) + "{-1}" + else: + return "{0}*{1}^".format(split[0], split[1]) + "{-1}" + + def set_all_to_none(self): + """ + Set all mutable values to None for error handling purposes + """ + self.current_dataset = None + self.current_datainfo = None + self.output = [] + + def data_cleanup(self): + """ + Clean up the data sets and refresh everything + :return: None + """ + self.remove_empty_q_values() + self.send_to_output() # Combine datasets with DataInfo + self.current_datainfo = DataInfo() # Reset DataInfo + + def remove_empty_q_values(self): + """ + Remove any point where Q == 0 + """ + if isinstance(self.current_dataset, plottable_1D): + # Booleans for resolutions + has_error_dx = self.current_dataset.dx is not None + has_error_dxl = self.current_dataset.dxl is not None + has_error_dxw = self.current_dataset.dxw is not None + has_error_dy = self.current_dataset.dy is not None + # Create arrays of zeros for non-existent resolutions + if has_error_dxw and not has_error_dxl: + array_size = self.current_dataset.dxw.size - 1 + self.current_dataset.dxl = np.append(self.current_dataset.dxl, + np.zeros([array_size])) + has_error_dxl = True + elif has_error_dxl and not has_error_dxw: + array_size = self.current_dataset.dxl.size - 1 + self.current_dataset.dxw = np.append(self.current_dataset.dxw, + np.zeros([array_size])) + has_error_dxw = True + elif not has_error_dxl and not has_error_dxw and not has_error_dx: + array_size = self.current_dataset.x.size - 1 + self.current_dataset.dx = np.append(self.current_dataset.dx, + np.zeros([array_size])) + has_error_dx = True + if not has_error_dy: + array_size = self.current_dataset.y.size - 1 + self.current_dataset.dy = np.append(self.current_dataset.dy, + np.zeros([array_size])) + has_error_dy = True + + # Remove points where q = 0 + x = self.current_dataset.x + self.current_dataset.x = self.current_dataset.x[x != 0] + self.current_dataset.y = self.current_dataset.y[x != 0] + if has_error_dy: + self.current_dataset.dy = self.current_dataset.dy[x != 0] + if has_error_dx: + self.current_dataset.dx = self.current_dataset.dx[x != 0] + if has_error_dxl: + self.current_dataset.dxl = self.current_dataset.dxl[x != 0] + if has_error_dxw: + self.current_dataset.dxw = self.current_dataset.dxw[x != 0] + elif isinstance(self.current_dataset, plottable_2D): + has_error_dqx = self.current_dataset.dqx_data is not None + has_error_dqy = self.current_dataset.dqy_data is not None + has_error_dy = self.current_dataset.err_data is not None + has_mask = self.current_dataset.mask is not None + x = self.current_dataset.qx_data + self.current_dataset.data = self.current_dataset.data[x != 0] + self.current_dataset.qx_data = self.current_dataset.qx_data[x != 0] + self.current_dataset.qy_data = self.current_dataset.qy_data[x != 0] + self.current_dataset.q_data = np.sqrt( + np.square(self.current_dataset.qx_data) + np.square( + self.current_dataset.qy_data)) + if has_error_dy: + self.current_dataset.err_data = self.current_dataset.err_data[ + x != 0] + if has_error_dqx: + self.current_dataset.dqx_data = self.current_dataset.dqx_data[ + x != 0] + if has_error_dqy: + self.current_dataset.dqy_data = self.current_dataset.dqy_data[ + x != 0] + if has_mask: + self.current_dataset.mask = self.current_dataset.mask[x != 0] + + def reset_data_list(self, no_lines=0): + """ + Reset the plottable_1D object + """ + # Initialize data sets with arrays the maximum possible size + x = np.zeros(no_lines) + y = np.zeros(no_lines) + dx = np.zeros(no_lines) + dy = np.zeros(no_lines) + self.current_dataset = plottable_1D(x, y, dx, dy) + + @staticmethod + def splitline(line): + """ + Splits a line into pieces based on common delimiters + :param line: A single line of text + :return: list of values + """ + # Initial try for CSV (split on ,) + toks = line.split(',') + # Now try SCSV (split on ;) + if len(toks) < 2: + toks = line.split(';') + # Now go for whitespace + if len(toks) < 2: + toks = line.split() + return toks + + @abstractmethod + def get_file_contents(self): + """ + Reader specific class to access the contents of the file + All reader classes that inherit from FileReader must implement + """ + pass diff --git a/sas/sascalc/dataloader/loader.py b/sas/sascalc/dataloader/loader.py new file mode 100755 index 000000000..b3bd63d65 --- /dev/null +++ b/sas/sascalc/dataloader/loader.py @@ -0,0 +1,441 @@ +""" + File handler to support different file extensions. + Uses reflectometer registry utility. + + The default readers are found in the 'readers' sub-module + and registered by default at initialization time. + + To add a new default reader, one must register it in + the register_readers method found in readers/__init__.py. + + A utility method (find_plugins) is available to inspect + a directory (for instance, a user plug-in directory) and + look for new readers/writers. +""" +##################################################################### +# This software was developed by the University of Tennessee as part of the +# Distributed Data Analysis of Neutron Scattering Experiments (DANSE) +# project funded by the US National Science Foundation. +# See the license text in license.txt +# copyright 2008, University of Tennessee +###################################################################### + +import os +import sys +import logging +import time +from zipfile import ZipFile + +from sas.sascalc.data_util.registry import ExtensionRegistry + +# Default readers are defined in the readers sub-module +from . import readers +from .loader_exceptions import NoKnownLoaderException, FileContentsException,\ + DefaultReaderException +from .readers import ascii_reader +from .readers import cansas_reader +from .readers import cansas_reader_HDF5 + +logger = logging.getLogger(__name__) + + +class Registry(ExtensionRegistry): + """ + Registry class for file format extensions. + Readers and writers are supported. + """ + def __init__(self): + super(Registry, self).__init__() + + # Writers + self.writers = {} + + # List of wildcards + self.wildcards = ['All (*.*)|*.*'] + + # Creation time, for testing + self._created = time.time() + + # Register default readers + readers.read_associations(self) + + def load(self, path, format=None): + """ + Call the loader for the file type of path. + + :param path: file path + :param format: explicit extension, to force the use + of a particular reader + + Defaults to the ascii (multi-column), cansas XML, and cansas NeXuS + readers if no reader was registered for the file's extension. + """ + # Gets set to a string if the file has an associated reader that fails + msg_from_reader = None + try: + return super(Registry, self).load(path, format=format) + #except Exception: raise # for debugging, don't use fallback loader + except NoKnownLoaderException as nkl_e: + pass # Try the ASCII reader + except FileContentsException as fc_exc: + # File has an associated reader but it failed. + # Save the error message to display later, but try the 3 default loaders + msg_from_reader = fc_exc.message + except Exception: + pass + + # File has no associated reader, or the associated reader failed. + # Try the ASCII reader + try: + ascii_loader = ascii_reader.Reader() + return ascii_loader.read(path) + except NoKnownLoaderException: + pass # Try the Cansas XML reader + except DefaultReaderException: + pass # Loader specific error to try the cansas XML reader + except FileContentsException as e: + if msg_from_reader is None: + raise RuntimeError(e.message) + + # ASCII reader failed - try CanSAS xML reader + try: + cansas_loader = cansas_reader.Reader() + return cansas_loader.read(path) + except NoKnownLoaderException: + pass # Try the NXcanSAS reader + except DefaultReaderException: + pass # Loader specific error to try the NXcanSAS reader + except FileContentsException as e: + if msg_from_reader is None: + raise RuntimeError(e.message) + except Exception: + pass + + # CanSAS XML reader failed - try NXcanSAS reader + try: + cansas_nexus_loader = cansas_reader_HDF5.Reader() + return cansas_nexus_loader.read(path) + except DefaultReaderException as e: + logging.error("No default loader can load the data") + # No known reader available. Give up and throw an error + if msg_from_reader is None: + msg = "\nUnknown data format: {}.\nThe file is not a ".format(path) + msg += "known format that can be loaded by SasView.\n" + raise NoKnownLoaderException(msg) + else: + # Associated reader and default readers all failed. + # Show error message from associated reader + raise RuntimeError(msg_from_reader) + except FileContentsException as e: + err_msg = msg_from_reader if msg_from_reader is not None else e.message + raise RuntimeError(err_msg) + + def find_plugins(self, dir): + """ + Find readers in a given directory. This method + can be used to inspect user plug-in directories to + find new readers/writers. + + :param dir: directory to search into + :return: number of readers found + """ + readers_found = 0 + temp_path = os.path.abspath(dir) + if not os.path.isdir(temp_path): + temp_path = os.path.join(os.getcwd(), dir) + if not os.path.isdir(temp_path): + temp_path = os.path.join(os.path.dirname(__file__), dir) + if not os.path.isdir(temp_path): + temp_path = os.path.join(os.path.dirname(sys.path[0]), dir) + + dir = temp_path + # Check whether the directory exists + if not os.path.isdir(dir): + msg = "DataLoader couldn't locate DataLoader plugin folder." + msg += """ "%s" does not exist""" % dir + logger.warning(msg) + return readers_found + + for item in os.listdir(dir): + full_path = os.path.join(dir, item) + if os.path.isfile(full_path): + + # Process python files + if item.endswith('.py'): + toks = os.path.splitext(os.path.basename(item)) + try: + sys.path.insert(0, os.path.abspath(dir)) + module = __import__(toks[0], globals(), locals()) + if self._identify_plugin(module): + readers_found += 1 + except: + msg = "Loader: Error importing " + msg += "%s\n %s" % (item, sys.exc_value) + logger.error(msg) + + # Process zip files + elif item.endswith('.zip'): + try: + # Find the modules in the zip file + zfile = ZipFile(item) + nlist = zfile.namelist() + + sys.path.insert(0, item) + for mfile in nlist: + try: + # Change OS path to python path + fullname = mfile.replace('/', '.') + fullname = os.path.splitext(fullname)[0] + module = __import__(fullname, globals(), + locals(), [""]) + if self._identify_plugin(module): + readers_found += 1 + except: + msg = "Loader: Error importing" + msg += " %s\n %s" % (mfile, sys.exc_value) + logger.error(msg) + + except: + msg = "Loader: Error importing " + msg += " %s\n %s" % (item, sys.exc_value) + logger.error(msg) + + return readers_found + + def associate_file_type(self, ext, module): + """ + Look into a module to find whether it contains a + Reader class. If so, APPEND it to readers and (potentially) + to the list of writers for the given extension + + :param ext: file extension [string] + :param module: module object + """ + reader_found = False + + if hasattr(module, "Reader"): + try: + # Find supported extensions + loader = module.Reader() + if ext not in self.loaders: + self.loaders[ext] = [] + # Append the new reader to the list + self.loaders[ext].append(loader.read) + + reader_found = True + + # Keep track of wildcards + type_name = module.__name__ + if hasattr(loader, 'type_name'): + type_name = loader.type_name + + wcard = "%s files (*%s)|*%s" % (type_name, ext.lower(), + ext.lower()) + if wcard not in self.wildcards: + self.wildcards.append(wcard) + + # Check whether writing is supported + if hasattr(loader, 'write'): + if ext not in self.writers: + self.writers[ext] = [] + # Append the new writer to the list + self.writers[ext].append(loader.write) + + except: + msg = "Loader: Error accessing" + msg += " Reader in %s\n %s" % (module.__name__, sys.exc_value) + logger.error(msg) + return reader_found + + def associate_file_reader(self, ext, loader): + """ + Append a reader object to readers + + :param ext: file extension [string] + :param module: reader object + """ + reader_found = False + + try: + # Find supported extensions + if ext not in self.loaders: + self.loaders[ext] = [] + # Append the new reader to the list + self.loaders[ext].append(loader.read) + + reader_found = True + + # Keep track of wildcards + if hasattr(loader, 'type_name'): + type_name = loader.type_name + + wcard = "%s files (*%s)|*%s" % (type_name, ext.lower(), + ext.lower()) + if wcard not in self.wildcards: + self.wildcards.append(wcard) + + except: + msg = "Loader: Error accessing Reader " + msg += "in %s\n %s" % (loader.__name__, sys.exc_value) + logger.error(msg) + return reader_found + + def _identify_plugin(self, module): + """ + Look into a module to find whether it contains a + Reader class. If so, add it to readers and (potentially) + to the list of writers. + :param module: module object + + """ + reader_found = False + + if hasattr(module, "Reader"): + try: + # Find supported extensions + loader = module.Reader() + for ext in loader.ext: + if ext not in self.loaders: + self.loaders[ext] = [] + # When finding a reader at run time, + # treat this reader as the new default + self.loaders[ext].insert(0, loader.read) + + reader_found = True + + # Keep track of wildcards + type_name = module.__name__ + if hasattr(loader, 'type_name'): + type_name = loader.type_name + wcard = "%s files (*%s)|*%s" % (type_name, ext.lower(), + ext.lower()) + if wcard not in self.wildcards: + self.wildcards.append(wcard) + + # Check whether writing is supported + if hasattr(loader, 'write'): + for ext in loader.ext: + if ext not in self.writers: + self.writers[ext] = [] + self.writers[ext].insert(0, loader.write) + + except: + msg = "Loader: Error accessing Reader" + msg += " in %s\n %s" % (module.__name__, sys.exc_value) + logger.error(msg) + return reader_found + + def lookup_writers(self, path): + """ + :return: the loader associated with the file type of path. + :Raises ValueError: if file type is not known. + """ + # Find matching extensions + extlist = [ext for ext in self.extensions() if path.endswith(ext)] + # Sort matching extensions by decreasing order of length + extlist.sort(key=len) + # Combine loaders for matching extensions into one big list + writers = [] + for L in [self.writers[ext] for ext in extlist]: + writers.extend(L) + # Remove duplicates if they exist + if len(writers) != len(set(writers)): + result = [] + for L in writers: + if L not in result: + result.append(L) + writers = L + # Raise an error if there are no matching extensions + if len(writers) == 0: + raise ValueError("Unknown file type for " + path) + # All done + return writers + + def save(self, path, data, format=None): + """ + Call the writer for the file type of path. + + Raises ValueError if no writer is available. + Raises KeyError if format is not available. + May raise a writer-defined exception if writer fails. + """ + if format is None: + writers = self.lookup_writers(path) + else: + writers = self.writers[format] + for fn in writers: + try: + return fn(path, data) + except Exception as exc: + msg = "Saving file {} using the {} writer failed.\n".format( + path, type(fn).__name__) + msg += str(exc) + logger.exception(msg) # give other loaders a chance to succeed + + +class Loader(object): + """ + Utility class to use the Registry as a singleton. + """ + ## Registry instance + __registry = Registry() + + def associate_file_type(self, ext, module): + """ + Look into a module to find whether it contains a + Reader class. If so, append it to readers and (potentially) + to the list of writers for the given extension + + :param ext: file extension [string] + :param module: module object + """ + return self.__registry.associate_file_type(ext, module) + + def associate_file_reader(self, ext, loader): + """ + Append a reader object to readers + + :param ext: file extension [string] + :param module: reader object + """ + return self.__registry.associate_file_reader(ext, loader) + + def load(self, file, format=None): + """ + Load a file + + :param file: file name (path) + :param format: specified format to use (optional) + :return: DataInfo object + """ + return self.__registry.load(file, format) + + def save(self, file, data, format): + """ + Save a DataInfo object to file + :param file: file name (path) + :param data: DataInfo object + :param format: format to write the data in + """ + return self.__registry.save(file, data, format) + + def _get_registry_creation_time(self): + """ + Internal method used to test the uniqueness + of the registry object + """ + return self.__registry._created + + def find_plugins(self, directory): + """ + Find plugins in a given directory + + :param dir: directory to look into to find new readers/writers + """ + return self.__registry.find_plugins(directory) + + def get_wildcards(self): + """ + Return the list of wildcards + """ + return self.__registry.wildcards diff --git a/sas/sascalc/dataloader/loader_exceptions.py b/sas/sascalc/dataloader/loader_exceptions.py new file mode 100755 index 000000000..840e6dd21 --- /dev/null +++ b/sas/sascalc/dataloader/loader_exceptions.py @@ -0,0 +1,41 @@ +""" +Exceptions specific to loading data. +""" + + +class NoKnownLoaderException(Exception): + """ + Exception for files with no associated reader based on the file + extension of the loaded file. This exception should only be thrown by + loader.py. + """ + def __init__(self, e=None): + self.message = e + + +class DefaultReaderException(Exception): + """ + Exception for files with no associated reader. This should be thrown by + default readers only to tell Loader to try the next reader. + """ + def __init__(self, e=None): + self.message = e + + +class FileContentsException(Exception): + """ + Exception for files with an associated reader, but with no loadable data. + This is useful for catching loader or file format issues. + """ + def __init__(self, e=None): + self.message = e + + +class DataReaderException(Exception): + """ + Exception for files that were able to mostly load, but had minor issues + along the way. + Any exceptions of this type should be put into the datainfo.errors + """ + def __init__(self, e=None): + self.message = e diff --git a/sas/sascalc/dataloader/manipulations.py b/sas/sascalc/dataloader/manipulations.py new file mode 100755 index 000000000..7b5f32a8d --- /dev/null +++ b/sas/sascalc/dataloader/manipulations.py @@ -0,0 +1,1172 @@ +from __future__ import division +""" +Data manipulations for 2D data sets. +Using the meta data information, various types of averaging +are performed in Q-space + +To test this module use: +``` +cd test +PYTHONPATH=../src/ python2 -m sasdataloader.test.utest_averaging DataInfoTests.test_sectorphi_quarter +``` +""" +##################################################################### +# This software was developed by the University of Tennessee as part of the +# Distributed Data Analysis of Neutron Scattering Experiments (DANSE) +# project funded by the US National Science Foundation. +# See the license text in license.txt +# copyright 2008, University of Tennessee +###################################################################### + + +# TODO: copy the meta data from the 2D object to the resulting 1D object +import math +import numpy as np +import sys + +#from data_info import plottable_2D +from .data_info import Data1D + + +def get_q(dx, dy, det_dist, wavelength): + """ + :param dx: x-distance from beam center [mm] + :param dy: y-distance from beam center [mm] + :return: q-value at the given position + """ + # Distance from beam center in the plane of detector + plane_dist = math.sqrt(dx * dx + dy * dy) + # Half of the scattering angle + theta = 0.5 * math.atan(plane_dist / det_dist) + return (4.0 * math.pi / wavelength) * math.sin(theta) + + +def get_q_compo(dx, dy, det_dist, wavelength, compo=None): + """ + This reduces tiny error at very large q. + Implementation of this func is not started yet.<--ToDo + """ + if dy == 0: + if dx >= 0: + angle_xy = 0 + else: + angle_xy = math.pi + else: + angle_xy = math.atan(dx / dy) + + if compo == "x": + out = get_q(dx, dy, det_dist, wavelength) * math.cos(angle_xy) + elif compo == "y": + out = get_q(dx, dy, det_dist, wavelength) * math.sin(angle_xy) + else: + out = get_q(dx, dy, det_dist, wavelength) + return out + + +def flip_phi(phi): + """ + Correct phi to within the 0 <= to <= 2pi range + + :return: phi in >=0 and <=2Pi + """ + Pi = math.pi + if phi < 0: + phi_out = phi + (2 * Pi) + elif phi > (2 * Pi): + phi_out = phi - (2 * Pi) + else: + phi_out = phi + return phi_out + +def get_pixel_fraction_square(x, xmin, xmax): + """ + Return the fraction of the length + from xmin to x.:: + + A B + +-----------+---------+ + xmin x xmax + + :param x: x-value + :param xmin: minimum x for the length considered + :param xmax: minimum x for the length considered + :return: (x-xmin)/(xmax-xmin) when xmin < x < xmax + + """ + if x <= xmin: + return 0.0 + if x > xmin and x < xmax: + return (x - xmin) / (xmax - xmin) + else: + return 1.0 + +def get_intercept(q, q_0, q_1): + """ + Returns the fraction of the side at which the + q-value intercept the pixel, None otherwise. + The values returned is the fraction ON THE SIDE + OF THE LOWEST Q. :: + + A B + +-----------+--------+ <--- pixel size + 0 1 + Q_0 -------- Q ----- Q_1 <--- equivalent Q range + if Q_1 > Q_0, A is returned + if Q_1 < Q_0, B is returned + if Q is outside the range of [Q_0, Q_1], None is returned + + """ + if q_1 > q_0: + if q > q_0 and q <= q_1: + return (q - q_0) / (q_1 - q_0) + else: + if q > q_1 and q <= q_0: + return (q - q_1) / (q_0 - q_1) + return None + +def get_pixel_fraction(qmax, q_00, q_01, q_10, q_11): + """ + Returns the fraction of the pixel defined by + the four corners (q_00, q_01, q_10, q_11) that + has q < qmax.:: + + q_01 q_11 + y=1 +--------------+ + | | + | | + | | + y=0 +--------------+ + q_00 q_10 + + x=0 x=1 + + """ + # y side for x = minx + x_0 = get_intercept(qmax, q_00, q_01) + # y side for x = maxx + x_1 = get_intercept(qmax, q_10, q_11) + + # x side for y = miny + y_0 = get_intercept(qmax, q_00, q_10) + # x side for y = maxy + y_1 = get_intercept(qmax, q_01, q_11) + + # surface fraction for a 1x1 pixel + frac_max = 0 + + if x_0 and x_1: + frac_max = (x_0 + x_1) / 2.0 + elif y_0 and y_1: + frac_max = (y_0 + y_1) / 2.0 + elif x_0 and y_0: + if q_00 < q_10: + frac_max = x_0 * y_0 / 2.0 + else: + frac_max = 1.0 - x_0 * y_0 / 2.0 + elif x_0 and y_1: + if q_00 < q_10: + frac_max = x_0 * y_1 / 2.0 + else: + frac_max = 1.0 - x_0 * y_1 / 2.0 + elif x_1 and y_0: + if q_00 > q_10: + frac_max = x_1 * y_0 / 2.0 + else: + frac_max = 1.0 - x_1 * y_0 / 2.0 + elif x_1 and y_1: + if q_00 < q_10: + frac_max = 1.0 - (1.0 - x_1) * (1.0 - y_1) / 2.0 + else: + frac_max = (1.0 - x_1) * (1.0 - y_1) / 2.0 + + # If we make it here, there is no intercept between + # this pixel and the constant-q ring. We only need + # to know if we have to include it or exclude it. + elif (q_00 + q_01 + q_10 + q_11) / 4.0 < qmax: + frac_max = 1.0 + + return frac_max + +def get_dq_data(data2D): + ''' + Get the dq for resolution averaging + The pinholes and det. pix contribution present + in both direction of the 2D which must be subtracted when + converting to 1D: dq_overlap should calculated ideally at + q = 0. Note This method works on only pinhole geometry. + Extrapolate dqx(r) and dqy(phi) at q = 0, and take an average. + ''' + z_max = max(data2D.q_data) + z_min = min(data2D.q_data) + dqx_at_z_max = data2D.dqx_data[np.argmax(data2D.q_data)] + dqx_at_z_min = data2D.dqx_data[np.argmin(data2D.q_data)] + dqy_at_z_max = data2D.dqy_data[np.argmax(data2D.q_data)] + dqy_at_z_min = data2D.dqy_data[np.argmin(data2D.q_data)] + # Find qdx at q = 0 + dq_overlap_x = (dqx_at_z_min * z_max - dqx_at_z_max * z_min) / (z_max - z_min) + # when extrapolation goes wrong + if dq_overlap_x > min(data2D.dqx_data): + dq_overlap_x = min(data2D.dqx_data) + dq_overlap_x *= dq_overlap_x + # Find qdx at q = 0 + dq_overlap_y = (dqy_at_z_min * z_max - dqy_at_z_max * z_min) / (z_max - z_min) + # when extrapolation goes wrong + if dq_overlap_y > min(data2D.dqy_data): + dq_overlap_y = min(data2D.dqy_data) + # get dq at q=0. + dq_overlap_y *= dq_overlap_y + + dq_overlap = np.sqrt((dq_overlap_x + dq_overlap_y) / 2.0) + # Final protection of dq + if dq_overlap < 0: + dq_overlap = dqy_at_z_min + dqx_data = data2D.dqx_data[np.isfinite(data2D.data)] + dqy_data = data2D.dqy_data[np.isfinite( + data2D.data)] - dq_overlap + # def; dqx_data = dq_r dqy_data = dq_phi + # Convert dq 2D to 1D here + dq_data = np.sqrt(dqx_data**2 + dqx_data**2) + return dq_data + +################################################################################ + +def reader2D_converter(data2d=None): + """ + convert old 2d format opened by IhorReader or danse_reader + to new Data2D format + This is mainly used by the Readers + + :param data2d: 2d array of Data2D object + :return: 1d arrays of Data2D object + + """ + if data2d.data is None or data2d.x_bins is None or data2d.y_bins is None: + raise ValueError("Can't convert this data: data=None...") + new_x = np.tile(data2d.x_bins, (len(data2d.y_bins), 1)) + new_y = np.tile(data2d.y_bins, (len(data2d.x_bins), 1)) + new_y = new_y.swapaxes(0, 1) + + new_data = data2d.data.flatten() + qx_data = new_x.flatten() + qy_data = new_y.flatten() + q_data = np.sqrt(qx_data * qx_data + qy_data * qy_data) + if data2d.err_data is None or np.any(data2d.err_data <= 0): + new_err_data = np.sqrt(np.abs(new_data)) + else: + new_err_data = data2d.err_data.flatten() + mask = np.ones(len(new_data), dtype=bool) + + # TODO: make sense of the following two lines... + #from sas.sascalc.dataloader.data_info import Data2D + #output = Data2D() + output = data2d + output.data = new_data + output.err_data = new_err_data + output.qx_data = qx_data + output.qy_data = qy_data + output.q_data = q_data + output.mask = mask + + return output + +################################################################################ + +class Binning(object): + ''' + This class just creates a binning object + either linear or log + ''' + + def __init__(self, min_value, max_value, n_bins, base=None): + ''' + if base is None: Linear binning + ''' + self.min = min_value if min_value > 0 else 0.0001 + self.max = max_value + self.n_bins = n_bins + self.base = base + + def get_bin_index(self, value): + ''' + The general formula logarithm binning is: + bin = floor(N * (log(x) - log(min)) / (log(max) - log(min))) + ''' + if self.base: + temp_x = self.n_bins * (math.log(value, self.base) - math.log(self.min, self.base)) + temp_y = math.log(self.max, self.base) - math.log(self.min, self.base) + else: + temp_x = self.n_bins * (value - self.min) + temp_y = self.max - self.min + # Bin index calulation + return int(math.floor(temp_x / temp_y)) + + +################################################################################ + +class _Slab(object): + """ + Compute average I(Q) for a region of interest + """ + + def __init__(self, x_min=0.0, x_max=0.0, y_min=0.0, + y_max=0.0, bin_width=0.001): + # Minimum Qx value [A-1] + self.x_min = x_min + # Maximum Qx value [A-1] + self.x_max = x_max + # Minimum Qy value [A-1] + self.y_min = y_min + # Maximum Qy value [A-1] + self.y_max = y_max + # Bin width (step size) [A-1] + self.bin_width = bin_width + # If True, I(|Q|) will be return, otherwise, + # negative q-values are allowed + self.fold = False + + def __call__(self, data2D): + return NotImplemented + + def _avg(self, data2D, maj): + """ + Compute average I(Q_maj) for a region of interest. + The major axis is defined as the axis of Q_maj. + The minor axis is the axis that we average over. + + :param data2D: Data2D object + :param maj_min: min value on the major axis + :return: Data1D object + """ + if len(data2D.detector) > 1: + msg = "_Slab._avg: invalid number of " + msg += " detectors: %g" % len(data2D.detector) + raise RuntimeError(msg) + + # Get data + data = data2D.data[np.isfinite(data2D.data)] + err_data = data2D.err_data[np.isfinite(data2D.data)] + qx_data = data2D.qx_data[np.isfinite(data2D.data)] + qy_data = data2D.qy_data[np.isfinite(data2D.data)] + + # Build array of Q intervals + if maj == 'x': + if self.fold: + x_min = 0 + else: + x_min = self.x_min + nbins = int(math.ceil((self.x_max - x_min) / self.bin_width)) + elif maj == 'y': + if self.fold: + y_min = 0 + else: + y_min = self.y_min + nbins = int(math.ceil((self.y_max - y_min) / self.bin_width)) + else: + raise RuntimeError("_Slab._avg: unrecognized axis %s" % str(maj)) + + x = np.zeros(nbins) + y = np.zeros(nbins) + err_y = np.zeros(nbins) + y_counts = np.zeros(nbins) + + # Average pixelsize in q space + for npts in range(len(data)): + # default frac + frac_x = 0 + frac_y = 0 + # get ROI + if self.x_min <= qx_data[npts] and self.x_max > qx_data[npts]: + frac_x = 1 + if self.y_min <= qy_data[npts] and self.y_max > qy_data[npts]: + frac_y = 1 + frac = frac_x * frac_y + + if frac == 0: + continue + # binning: find axis of q + if maj == 'x': + q_value = qx_data[npts] + min_value = x_min + if maj == 'y': + q_value = qy_data[npts] + min_value = y_min + if self.fold and q_value < 0: + q_value = -q_value + # bin + i_q = int(math.ceil((q_value - min_value) / self.bin_width)) - 1 + + # skip outside of max bins + if i_q < 0 or i_q >= nbins: + continue + + # TODO: find better definition of x[i_q] based on q_data + # min_value + (i_q + 1) * self.bin_width / 2.0 + x[i_q] += frac * q_value + y[i_q] += frac * data[npts] + + if err_data is None or err_data[npts] == 0.0: + if data[npts] < 0: + data[npts] = -data[npts] + err_y[i_q] += frac * frac * data[npts] + else: + err_y[i_q] += frac * frac * err_data[npts] * err_data[npts] + y_counts[i_q] += frac + + # Average the sums + for n in range(nbins): + err_y[n] = math.sqrt(err_y[n]) + + err_y = err_y / y_counts + y = y / y_counts + x = x / y_counts + idx = (np.isfinite(y) & np.isfinite(x)) + + if not idx.any(): + msg = "Average Error: No points inside ROI to average..." + raise ValueError(msg) + return Data1D(x=x[idx], y=y[idx], dy=err_y[idx]) + + +class SlabY(_Slab): + """ + Compute average I(Qy) for a region of interest + """ + + def __call__(self, data2D): + """ + Compute average I(Qy) for a region of interest + + :param data2D: Data2D object + :return: Data1D object + """ + return self._avg(data2D, 'y') + + +class SlabX(_Slab): + """ + Compute average I(Qx) for a region of interest + """ + + def __call__(self, data2D): + """ + Compute average I(Qx) for a region of interest + :param data2D: Data2D object + :return: Data1D object + """ + return self._avg(data2D, 'x') + +################################################################################ + +class Boxsum(object): + """ + Perform the sum of counts in a 2D region of interest. + """ + + def __init__(self, x_min=0.0, x_max=0.0, y_min=0.0, y_max=0.0): + # Minimum Qx value [A-1] + self.x_min = x_min + # Maximum Qx value [A-1] + self.x_max = x_max + # Minimum Qy value [A-1] + self.y_min = y_min + # Maximum Qy value [A-1] + self.y_max = y_max + + def __call__(self, data2D): + """ + Perform the sum in the region of interest + + :param data2D: Data2D object + :return: number of counts, error on number of counts, + number of points summed + """ + y, err_y, y_counts = self._sum(data2D) + + # Average the sums + counts = 0 if y_counts == 0 else y + error = 0 if y_counts == 0 else math.sqrt(err_y) + + # Added y_counts to return, SMK & PDB, 04/03/2013 + return counts, error, y_counts + + def _sum(self, data2D): + """ + Perform the sum in the region of interest + + :param data2D: Data2D object + :return: number of counts, + error on number of counts, number of entries summed + """ + if len(data2D.detector) > 1: + msg = "Circular averaging: invalid number " + msg += "of detectors: %g" % len(data2D.detector) + raise RuntimeError(msg) + # Get data + data = data2D.data[np.isfinite(data2D.data)] + err_data = data2D.err_data[np.isfinite(data2D.data)] + qx_data = data2D.qx_data[np.isfinite(data2D.data)] + qy_data = data2D.qy_data[np.isfinite(data2D.data)] + + y = 0.0 + err_y = 0.0 + y_counts = 0.0 + + # Average pixelsize in q space + for npts in range(len(data)): + # default frac + frac_x = 0 + frac_y = 0 + + # get min and max at each points + qx = qx_data[npts] + qy = qy_data[npts] + + # get the ROI + if self.x_min <= qx and self.x_max > qx: + frac_x = 1 + if self.y_min <= qy and self.y_max > qy: + frac_y = 1 + # Find the fraction along each directions + frac = frac_x * frac_y + if frac == 0: + continue + y += frac * data[npts] + if err_data is None or err_data[npts] == 0.0: + if data[npts] < 0: + data[npts] = -data[npts] + err_y += frac * frac * data[npts] + else: + err_y += frac * frac * err_data[npts] * err_data[npts] + y_counts += frac + return y, err_y, y_counts + + +class Boxavg(Boxsum): + """ + Perform the average of counts in a 2D region of interest. + """ + + def __init__(self, x_min=0.0, x_max=0.0, y_min=0.0, y_max=0.0): + super(Boxavg, self).__init__(x_min=x_min, x_max=x_max, + y_min=y_min, y_max=y_max) + + def __call__(self, data2D): + """ + Perform the sum in the region of interest + + :param data2D: Data2D object + :return: average counts, error on average counts + + """ + y, err_y, y_counts = self._sum(data2D) + + # Average the sums + counts = 0 if y_counts == 0 else y / y_counts + error = 0 if y_counts == 0 else math.sqrt(err_y) / y_counts + + return counts, error + +################################################################################ + +class CircularAverage(object): + """ + Perform circular averaging on 2D data + + The data returned is the distribution of counts + as a function of Q + """ + + def __init__(self, r_min=0.0, r_max=0.0, bin_width=0.0005): + # Minimum radius included in the average [A-1] + self.r_min = r_min + # Maximum radius included in the average [A-1] + self.r_max = r_max + # Bin width (step size) [A-1] + self.bin_width = bin_width + + def __call__(self, data2D, ismask=False): + """ + Perform circular averaging on the data + + :param data2D: Data2D object + :return: Data1D object + """ + # Get data W/ finite values + data = data2D.data[np.isfinite(data2D.data)] + q_data = data2D.q_data[np.isfinite(data2D.data)] + err_data = data2D.err_data[np.isfinite(data2D.data)] + mask_data = data2D.mask[np.isfinite(data2D.data)] + + dq_data = None + if data2D.dqx_data is not None and data2D.dqy_data is not None: + dq_data = get_dq_data(data2D) + + if len(q_data) == 0: + msg = "Circular averaging: invalid q_data: %g" % data2D.q_data + raise RuntimeError(msg) + + # Build array of Q intervals + nbins = int(math.ceil((self.r_max - self.r_min) / self.bin_width)) + + x = np.zeros(nbins) + y = np.zeros(nbins) + err_y = np.zeros(nbins) + err_x = np.zeros(nbins) + y_counts = np.zeros(nbins) + + for npt in range(len(data)): + + if ismask and not mask_data[npt]: + continue + + frac = 0 + + # q-value at the pixel (j,i) + q_value = q_data[npt] + data_n = data[npt] + + # No need to calculate the frac when all data are within range + if self.r_min >= self.r_max: + raise ValueError("Limit Error: min > max") + + if self.r_min <= q_value and q_value <= self.r_max: + frac = 1 + if frac == 0: + continue + i_q = int(math.floor((q_value - self.r_min) / self.bin_width)) + + # Take care of the edge case at phi = 2pi. + if i_q == nbins: + i_q = nbins - 1 + y[i_q] += frac * data_n + # Take dqs from data to get the q_average + x[i_q] += frac * q_value + if err_data is None or err_data[npt] == 0.0: + if data_n < 0: + data_n = -data_n + err_y[i_q] += frac * frac * data_n + else: + err_y[i_q] += frac * frac * err_data[npt] * err_data[npt] + if dq_data is not None: + # To be consistent with dq calculation in 1d reduction, + # we need just the averages (not quadratures) because + # it should not depend on the number of the q points + # in the qr bins. + err_x[i_q] += frac * dq_data[npt] + else: + err_x = None + y_counts[i_q] += frac + + # Average the sums + for n in range(nbins): + if err_y[n] < 0: + err_y[n] = -err_y[n] + err_y[n] = math.sqrt(err_y[n]) + # if err_x is not None: + # err_x[n] = math.sqrt(err_x[n]) + + err_y = err_y / y_counts + err_y[err_y == 0] = np.average(err_y) + y = y / y_counts + x = x / y_counts + idx = (np.isfinite(y)) & (np.isfinite(x)) + + if err_x is not None: + d_x = err_x[idx] / y_counts[idx] + else: + d_x = None + + if not idx.any(): + msg = "Average Error: No points inside ROI to average..." + raise ValueError(msg) + + return Data1D(x=x[idx], y=y[idx], dy=err_y[idx], dx=d_x) + +################################################################################ + +class Ring(object): + """ + Defines a ring on a 2D data set. + The ring is defined by r_min, r_max, and + the position of the center of the ring. + + The data returned is the distribution of counts + around the ring as a function of phi. + + Phi_min and phi_max should be defined between 0 and 2*pi + in anti-clockwise starting from the x- axis on the left-hand side + """ + # Todo: remove center. + + def __init__(self, r_min=0, r_max=0, center_x=0, center_y=0, nbins=36): + # Minimum radius + self.r_min = r_min + # Maximum radius + self.r_max = r_max + # Center of the ring in x + self.center_x = center_x + # Center of the ring in y + self.center_y = center_y + # Number of angular bins + self.nbins_phi = nbins + + def __call__(self, data2D): + """ + Apply the ring to the data set. + Returns the angular distribution for a given q range + + :param data2D: Data2D object + + :return: Data1D object + """ + if data2D.__class__.__name__ not in ["Data2D", "plottable_2D"]: + raise RuntimeError("Ring averaging only take plottable_2D objects") + + Pi = math.pi + + # Get data + data = data2D.data[np.isfinite(data2D.data)] + q_data = data2D.q_data[np.isfinite(data2D.data)] + err_data = data2D.err_data[np.isfinite(data2D.data)] + qx_data = data2D.qx_data[np.isfinite(data2D.data)] + qy_data = data2D.qy_data[np.isfinite(data2D.data)] + + # Set space for 1d outputs + phi_bins = np.zeros(self.nbins_phi) + phi_counts = np.zeros(self.nbins_phi) + phi_values = np.zeros(self.nbins_phi) + phi_err = np.zeros(self.nbins_phi) + + # Shift to apply to calculated phi values in order + # to center first bin at zero + phi_shift = Pi / self.nbins_phi + + for npt in range(len(data)): + frac = 0 + # q-value at the point (npt) + q_value = q_data[npt] + data_n = data[npt] + + # phi-value at the point (npt) + phi_value = math.atan2(qy_data[npt], qx_data[npt]) + Pi + + if self.r_min <= q_value and q_value <= self.r_max: + frac = 1 + if frac == 0: + continue + # binning + i_phi = int(math.floor((self.nbins_phi) * + (phi_value + phi_shift) / (2 * Pi))) + + # Take care of the edge case at phi = 2pi. + if i_phi >= self.nbins_phi: + i_phi = 0 + phi_bins[i_phi] += frac * data[npt] + + if err_data is None or err_data[npt] == 0.0: + if data_n < 0: + data_n = -data_n + phi_err[i_phi] += frac * frac * math.fabs(data_n) + else: + phi_err[i_phi] += frac * frac * err_data[npt] * err_data[npt] + phi_counts[i_phi] += frac + + for i in range(self.nbins_phi): + phi_bins[i] = phi_bins[i] / phi_counts[i] + phi_err[i] = math.sqrt(phi_err[i]) / phi_counts[i] + phi_values[i] = 2.0 * math.pi / self.nbins_phi * (1.0 * i) + + idx = (np.isfinite(phi_bins)) + + if not idx.any(): + msg = "Average Error: No points inside ROI to average..." + raise ValueError(msg) + # elif len(phi_bins[idx])!= self.nbins_phi: + # print "resulted",self.nbins_phi- len(phi_bins[idx]) + #,"empty bin(s) due to tight binning..." + return Data1D(x=phi_values[idx], y=phi_bins[idx], dy=phi_err[idx]) + + +class _Sector(object): + """ + Defines a sector region on a 2D data set. + The sector is defined by r_min, r_max, phi_min, phi_max, + and the position of the center of the ring + where phi_min and phi_max are defined by the right + and left lines wrt central line + and phi_max could be less than phi_min. + + Phi is defined between 0 and 2*pi in anti-clockwise + starting from the x- axis on the left-hand side + """ + + def __init__(self, r_min, r_max, phi_min=0, phi_max=2 * math.pi, nbins=20, + base=None): + ''' + :param base: must be a valid base for an algorithm, i.e., + a positive number + ''' + self.r_min = r_min + self.r_max = r_max + self.phi_min = phi_min + self.phi_max = phi_max + self.nbins = nbins + self.base = base + + def _agv(self, data2D, run='phi'): + """ + Perform sector averaging. + + :param data2D: Data2D object + :param run: define the varying parameter ('phi' , 'q' , or 'q2') + + :return: Data1D object + """ + if data2D.__class__.__name__ not in ["Data2D", "plottable_2D"]: + raise RuntimeError("Ring averaging only take plottable_2D objects") + + # Get the all data & info + data = data2D.data[np.isfinite(data2D.data)] + q_data = data2D.q_data[np.isfinite(data2D.data)] + err_data = data2D.err_data[np.isfinite(data2D.data)] + qx_data = data2D.qx_data[np.isfinite(data2D.data)] + qy_data = data2D.qy_data[np.isfinite(data2D.data)] + + dq_data = None + if data2D.dqx_data is not None and data2D.dqy_data is not None: + dq_data = get_dq_data(data2D) + + # set space for 1d outputs + x = np.zeros(self.nbins) + y = np.zeros(self.nbins) + y_err = np.zeros(self.nbins) + x_err = np.zeros(self.nbins) + y_counts = np.zeros(self.nbins) # Cycle counts (for the mean) + + # Get the min and max into the region: 0 <= phi < 2Pi + phi_min = flip_phi(self.phi_min) + phi_max = flip_phi(self.phi_max) + + # binning object + if run.lower() == 'phi': + binning = Binning(self.phi_min, self.phi_max, self.nbins, self.base) + else: + binning = Binning(self.r_min, self.r_max, self.nbins, self.base) + + for n in range(len(data)): + + # q-value at the pixel (j,i) + q_value = q_data[n] + data_n = data[n] + + # Is pixel within range? + is_in = False + + # phi-value of the pixel (j,i) + phi_value = math.atan2(qy_data[n], qx_data[n]) + math.pi + + # No need to calculate: data outside of the radius + if self.r_min > q_value or q_value > self.r_max: + continue + + # In case of two ROIs (symmetric major and minor regions)(for 'q2') + if run.lower() == 'q2': + # For minor sector wing + # Calculate the minor wing phis + phi_min_minor = flip_phi(phi_min - math.pi) + phi_max_minor = flip_phi(phi_max - math.pi) + # Check if phis of the minor ring is within 0 to 2pi + if phi_min_minor > phi_max_minor: + is_in = (phi_value > phi_min_minor or + phi_value < phi_max_minor) + else: + is_in = (phi_value > phi_min_minor and + phi_value < phi_max_minor) + + # For all cases(i.e.,for 'q', 'q2', and 'phi') + # Find pixels within ROI + if phi_min > phi_max: + is_in = is_in or (phi_value > phi_min or + phi_value < phi_max) + else: + is_in = is_in or (phi_value >= phi_min and + phi_value < phi_max) + + # data oustide of the phi range + if not is_in: + continue + + # Get the binning index + if run.lower() == 'phi': + i_bin = binning.get_bin_index(phi_value) + else: + i_bin = binning.get_bin_index(q_value) + + # Take care of the edge case at phi = 2pi. + if i_bin == self.nbins: + i_bin = self.nbins - 1 + + # Get the total y + y[i_bin] += data_n + x[i_bin] += q_value + if err_data[n] is None or err_data[n] == 0.0: + if data_n < 0: + data_n = -data_n + y_err[i_bin] += data_n + else: + y_err[i_bin] += err_data[n]**2 + + if dq_data is not None: + # To be consistent with dq calculation in 1d reduction, + # we need just the averages (not quadratures) because + # it should not depend on the number of the q points + # in the qr bins. + x_err[i_bin] += dq_data[n] + else: + x_err = None + y_counts[i_bin] += 1 + + # Organize the results + for i in range(self.nbins): + y[i] = y[i] / y_counts[i] + y_err[i] = math.sqrt(y_err[i]) / y_counts[i] + + # The type of averaging: phi,q2, or q + # Calculate x[i]should be at the center of the bin + if run.lower() == 'phi': + x[i] = (self.phi_max - self.phi_min) / self.nbins * \ + (1.0 * i + 0.5) + self.phi_min + else: + # We take the center of ring area, not radius. + # This is more accurate than taking the radial center of ring. + # delta_r = (self.r_max - self.r_min) / self.nbins + # r_inner = self.r_min + delta_r * i + # r_outer = r_inner + delta_r + # x[i] = math.sqrt((r_inner * r_inner + r_outer * r_outer) / 2) + x[i] = x[i] / y_counts[i] + y_err[y_err == 0] = np.average(y_err) + idx = (np.isfinite(y) & np.isfinite(y_err)) + if x_err is not None: + d_x = x_err[idx] / y_counts[idx] + else: + d_x = None + if not idx.any(): + msg = "Average Error: No points inside sector of ROI to average..." + raise ValueError(msg) + # elif len(y[idx])!= self.nbins: + # print "resulted",self.nbins- len(y[idx]), + # "empty bin(s) due to tight binning..." + return Data1D(x=x[idx], y=y[idx], dy=y_err[idx], dx=d_x) + + +class SectorPhi(_Sector): + """ + Sector average as a function of phi. + I(phi) is return and the data is averaged over Q. + + A sector is defined by r_min, r_max, phi_min, phi_max. + The number of bin in phi also has to be defined. + """ + + def __call__(self, data2D): + """ + Perform sector average and return I(phi). + + :param data2D: Data2D object + :return: Data1D object + """ + return self._agv(data2D, 'phi') + + +class SectorQ(_Sector): + """ + Sector average as a function of Q for both symatric wings. + I(Q) is return and the data is averaged over phi. + + A sector is defined by r_min, r_max, phi_min, phi_max. + r_min, r_max, phi_min, phi_max >0. + The number of bin in Q also has to be defined. + """ + + def __call__(self, data2D): + """ + Perform sector average and return I(Q). + + :param data2D: Data2D object + + :return: Data1D object + """ + return self._agv(data2D, 'q2') + +################################################################################ + +class Ringcut(object): + """ + Defines a ring on a 2D data set. + The ring is defined by r_min, r_max, and + the position of the center of the ring. + + The data returned is the region inside the ring + + Phi_min and phi_max should be defined between 0 and 2*pi + in anti-clockwise starting from the x- axis on the left-hand side + """ + + def __init__(self, r_min=0, r_max=0, center_x=0, center_y=0): + # Minimum radius + self.r_min = r_min + # Maximum radius + self.r_max = r_max + # Center of the ring in x + self.center_x = center_x + # Center of the ring in y + self.center_y = center_y + + def __call__(self, data2D): + """ + Apply the ring to the data set. + Returns the angular distribution for a given q range + + :param data2D: Data2D object + + :return: index array in the range + """ + if data2D.__class__.__name__ not in ["Data2D", "plottable_2D"]: + raise RuntimeError("Ring cut only take plottable_2D objects") + + # Get data + qx_data = data2D.qx_data + qy_data = data2D.qy_data + q_data = np.sqrt(qx_data * qx_data + qy_data * qy_data) + + # check whether or not the data point is inside ROI + out = (self.r_min <= q_data) & (self.r_max >= q_data) + return out + +################################################################################ + +class Boxcut(object): + """ + Find a rectangular 2D region of interest. + """ + + def __init__(self, x_min=0.0, x_max=0.0, y_min=0.0, y_max=0.0): + # Minimum Qx value [A-1] + self.x_min = x_min + # Maximum Qx value [A-1] + self.x_max = x_max + # Minimum Qy value [A-1] + self.y_min = y_min + # Maximum Qy value [A-1] + self.y_max = y_max + + def __call__(self, data2D): + """ + Find a rectangular 2D region of interest. + + :param data2D: Data2D object + :return: mask, 1d array (len = len(data)) + with Trues where the data points are inside ROI, otherwise False + """ + mask = self._find(data2D) + + return mask + + def _find(self, data2D): + """ + Find a rectangular 2D region of interest. + + :param data2D: Data2D object + + :return: out, 1d array (length = len(data)) + with Trues where the data points are inside ROI, otherwise Falses + """ + if data2D.__class__.__name__ not in ["Data2D", "plottable_2D"]: + raise RuntimeError("Boxcut take only plottable_2D objects") + # Get qx_ and qy_data + qx_data = data2D.qx_data + qy_data = data2D.qy_data + + # check whether or not the data point is inside ROI + outx = (self.x_min <= qx_data) & (self.x_max > qx_data) + outy = (self.y_min <= qy_data) & (self.y_max > qy_data) + + return outx & outy + +################################################################################ + +class Sectorcut(object): + """ + Defines a sector (major + minor) region on a 2D data set. + The sector is defined by phi_min, phi_max, + where phi_min and phi_max are defined by the right + and left lines wrt central line. + + Phi_min and phi_max are given in units of radian + and (phi_max-phi_min) should not be larger than pi + """ + + def __init__(self, phi_min=0, phi_max=math.pi): + self.phi_min = phi_min + self.phi_max = phi_max + + def __call__(self, data2D): + """ + Find a rectangular 2D region of interest. + + :param data2D: Data2D object + + :return: mask, 1d array (len = len(data)) + + with Trues where the data points are inside ROI, otherwise False + """ + mask = self._find(data2D) + + return mask + + def _find(self, data2D): + """ + Find a rectangular 2D region of interest. + + :param data2D: Data2D object + + :return: out, 1d array (length = len(data)) + + with Trues where the data points are inside ROI, otherwise Falses + """ + if data2D.__class__.__name__ not in ["Data2D", "plottable_2D"]: + raise RuntimeError("Sectorcut take only plottable_2D objects") + Pi = math.pi + # Get data + qx_data = data2D.qx_data + qy_data = data2D.qy_data + + # get phi from data + phi_data = np.arctan2(qy_data, qx_data) + + # Get the min and max into the region: -pi <= phi < Pi + phi_min_major = flip_phi(self.phi_min + Pi) - Pi + phi_max_major = flip_phi(self.phi_max + Pi) - Pi + # check for major sector + if phi_min_major > phi_max_major: + out_major = (phi_min_major <= phi_data) + \ + (phi_max_major > phi_data) + else: + out_major = (phi_min_major <= phi_data) & ( + phi_max_major > phi_data) + + # minor sector + # Get the min and max into the region: -pi <= phi < Pi + phi_min_minor = flip_phi(self.phi_min) - Pi + phi_max_minor = flip_phi(self.phi_max) - Pi + + # check for minor sector + if phi_min_minor > phi_max_minor: + out_minor = (phi_min_minor <= phi_data) + \ + (phi_max_minor >= phi_data) + else: + out_minor = (phi_min_minor <= phi_data) & \ + (phi_max_minor >= phi_data) + out = out_major + out_minor + + return out diff --git a/sas/sascalc/dataloader/readers/__init__.py b/sas/sascalc/dataloader/readers/__init__.py new file mode 100755 index 000000000..a297c80e1 --- /dev/null +++ b/sas/sascalc/dataloader/readers/__init__.py @@ -0,0 +1,2 @@ +# Method to associate extensions to default readers +from .associations import read_associations diff --git a/sas/sascalc/dataloader/readers/abs_reader.py b/sas/sascalc/dataloader/readers/abs_reader.py new file mode 100755 index 000000000..1ff0f3ef5 --- /dev/null +++ b/sas/sascalc/dataloader/readers/abs_reader.py @@ -0,0 +1,238 @@ +""" + IGOR 1D data reader +""" +##################################################################### +# This software was developed by the University of Tennessee as part of the +# Distributed Data Analysis of Neutron Scattering Experiments (DANSE) +# project funded by the US National Science Foundation. +# See the license text in license.txt +# copyright 2008, University of Tennessee +###################################################################### + +import logging + +import numpy as np + +from sas.sascalc.data_util.nxsunit import Converter +from ..file_reader_base_class import FileReader +from ..data_info import DataInfo, plottable_1D, Data1D, Detector +from ..loader_exceptions import FileContentsException, DefaultReaderException + +logger = logging.getLogger(__name__) + + +class Reader(FileReader): + """ + Class to load IGOR reduced .ABS files + """ + # File type + type_name = "IGOR 1D" + # Wildcards + type = ["IGOR 1D files (*.abs)|*.abs", "IGOR 1D USANS files (*.cor)|*.cor"] + # List of allowed extensions + ext = ['.abs', '.cor'] + + def get_file_contents(self): + """ + Get the contents of the file + + :raise RuntimeError: when the file can't be opened + :raise ValueError: when the length of the data vectors are inconsistent + """ + buff = self.readall() + filepath = self.f_open.name + lines = buff.splitlines() + self.output = [] + self.current_datainfo = DataInfo() + self.current_datainfo.filename = filepath + detector = Detector() + data_line = 0 + x_index = 4 + self.reset_data_list(len(lines)) + self.current_datainfo.detector.append(detector) + self.current_datainfo.filename = filepath + + is_info = False + is_center = False + is_data_started = False + + base_q_unit = '1/A' + base_i_unit = '1/cm' + data_conv_q = Converter(base_q_unit) + data_conv_i = Converter(base_i_unit) + + for line in lines: + # Information line 1 + if line.find(".bt5") > 0: + x_index = 0 + if is_info: + is_info = False + line_toks = line.split() + + # Wavelength in Angstrom + try: + value = float(line_toks[1]) + if self.current_datainfo.source.wavelength_unit != 'A': + conv = Converter('A') + self.current_datainfo.source.wavelength = conv(value, + units=self.current_datainfo.source.wavelength_unit) + else: + self.current_datainfo.source.wavelength = value + except KeyError: + msg = "ABSReader cannot read wavelength from %s" % filepath + self.current_datainfo.errors.append(msg) + + # Detector distance in meters + try: + value = float(line_toks[3]) + if detector.distance_unit != 'm': + conv = Converter('m') + detector.distance = conv(value, + units=detector.distance_unit) + else: + detector.distance = value + except Exception: + msg = "ABSReader cannot read SDD from %s" % filepath + self.current_datainfo.errors.append(msg) + + # Transmission + try: + self.current_datainfo.sample.transmission = \ + float(line_toks[4]) + except ValueError: + # Transmission isn't always in the header + pass + + # Sample thickness in mm + try: + # ABS writer adds 'C' with no space to the end of the + # thickness column. Remove it if it is there before + # converting the thickness. + if line_toks[5][-1] not in '012345679.': + value = float(line_toks[5][:-1]) + else: + value = float(line_toks[5]) + if self.current_datainfo.sample.thickness_unit != 'cm': + conv = Converter('cm') + self.current_datainfo.sample.thickness = conv(value, + units=self.current_datainfo.sample.thickness_unit) + else: + self.current_datainfo.sample.thickness = value + except ValueError: + # Thickness is not a mandatory entry + pass + + # MON CNT LAMBDA DET ANG DET DIST TRANS THICK AVE STEP + if line.count("LAMBDA") > 0: + is_info = True + + # Find center info line + if is_center: + is_center = False + line_toks = line.split() + # Center in bin number + center_x = float(line_toks[0]) + center_y = float(line_toks[1]) + + # Bin size + if detector.pixel_size_unit != 'mm': + conv = Converter('mm') + detector.pixel_size.x = conv(5.08, + units=detector.pixel_size_unit) + detector.pixel_size.y = conv(5.08, + units=detector.pixel_size_unit) + else: + detector.pixel_size.x = 5.08 + detector.pixel_size.y = 5.08 + + # Store beam center in distance units + # Det 640 x 640 mm + if detector.beam_center_unit != 'mm': + conv = Converter('mm') + detector.beam_center.x = conv(center_x * 5.08, + units=detector.beam_center_unit) + detector.beam_center.y = conv(center_y * 5.08, + units=detector.beam_center_unit) + else: + detector.beam_center.x = center_x * 5.08 + detector.beam_center.y = center_y * 5.08 + + # Detector type + try: + detector.name = line_toks[7] + except: + # Detector name is not a mandatory entry + pass + + # BCENT(X,Y) A1(mm) A2(mm) A1A2DIST(m) DL/L BSTOP(mm) DET_TYP + if line.count("BCENT") > 0: + is_center = True + + # Parse the data + if is_data_started: + toks = line.split() + + try: + _x = float(toks[x_index]) + _y = float(toks[1]) + _dy = float(toks[2]) + _dx = float(toks[3]) + + if data_conv_q is not None: + _x = data_conv_q(_x, units=base_q_unit) + _dx = data_conv_q(_dx, units=base_q_unit) + + if data_conv_i is not None: + _y = data_conv_i(_y, units=base_i_unit) + _dy = data_conv_i(_dy, units=base_i_unit) + + self.current_dataset.x[data_line] = _x + self.current_dataset.y[data_line] = _y + self.current_dataset.dy[data_line] = _dy + if _dx > 0: + self.current_dataset.dx[data_line] = _dx + else: + if data_line == 0: + self.current_dataset.dx = None + self.current_dataset.dxl = np.zeros(len(lines)) + self.current_dataset.dxw = np.zeros(len(lines)) + self.current_dataset.dxl[data_line] = abs(_dx) + self.current_dataset.dxw[data_line] = 0 + data_line += 1 + + except ValueError: + # Could not read this data line. If we are here + # it is because we are in the data section. Just + # skip it. + pass + + # SANS Data: + # The 6 columns are | Q (1/A) | I(Q) (1/cm) | std. dev. + # I(Q) (1/cm) | sigmaQ | meanQ | ShadowFactor| + # USANS Data: + # EMP LEVEL: ; BKG LEVEL: + if line.startswith("The 6 columns") or line.startswith("EMP LEVEL"): + is_data_started = True + + self.remove_empty_q_values() + + # Sanity check + if not len(self.current_dataset.y) == len(self.current_dataset.dy): + self.set_all_to_none() + msg = "abs_reader: y and dy have different length" + raise ValueError(msg) + # If the data length is zero, consider this as + # though we were not able to read the file. + if len(self.current_dataset.x) == 0: + self.set_all_to_none() + raise ValueError("ascii_reader: could not load file") + + self.current_dataset = self.set_default_1d_units(self.current_dataset) + if data_conv_q is not None: + self.current_dataset.xaxis("\\rm{Q}", base_q_unit) + if data_conv_i is not None: + self.current_dataset.yaxis("\\rm{Intensity}", base_i_unit) + + # Store loading process information + self.current_datainfo.meta_data['loader'] = self.type_name + self.send_to_output() diff --git a/sas/sascalc/dataloader/readers/anton_paar_saxs_reader.py b/sas/sascalc/dataloader/readers/anton_paar_saxs_reader.py new file mode 100755 index 000000000..eabb3970c --- /dev/null +++ b/sas/sascalc/dataloader/readers/anton_paar_saxs_reader.py @@ -0,0 +1,175 @@ +""" + CanSAS 2D data reader for reading HDF5 formatted CanSAS files. +""" + +import numpy as np +import re +import os +import sys + +from sas.sascalc.dataloader.readers.xml_reader import XMLreader +from sas.sascalc.dataloader.data_info import plottable_1D, Data1D, DataInfo, Sample, Source +from sas.sascalc.dataloader.data_info import Process, Aperture, Collimation, TransmissionSpectrum, Detector +from sas.sascalc.dataloader.loader_exceptions import FileContentsException, DataReaderException + +class Reader(XMLreader): + """ + A class for reading in Anton Paar .pdh files + """ + + ## Logged warnings or messages + logging = None + ## List of errors for the current data set + errors = None + ## Raw file contents to be processed + raw_data = None + ## For recursion and saving purposes, remember parent objects + parent_list = None + ## Data type name + type_name = "Anton Paar SAXSess" + ## Wildcards + type = ["Anton Paar SAXSess Files (*.pdh)|*.pdh"] + ## List of allowed extensions + ext = ['.pdh', '.PDH'] + ## Flag to bypass extension check + allow_all = False + + def reset_state(self): + self.current_dataset = plottable_1D(np.empty(0), np.empty(0), np.empty(0), np.empty(0)) + self.current_datainfo = DataInfo() + self.datasets = [] + self.raw_data = None + self.errors = set() + self.logging = [] + self.output = [] + self.detector = Detector() + self.collimation = Collimation() + self.aperture = Aperture() + self.process = Process() + self.source = Source() + self.sample = Sample() + self.trans_spectrum = TransmissionSpectrum() + self.upper = 5 + self.lower = 5 + + def get_file_contents(self): + """ + This is the general read method that all SasView data_loaders must have. + + :param filename: A path for an XML formatted Anton Paar SAXSess data file. + :return: List of Data1D objects or a list of errors. + """ + + ## Reinitialize the class when loading a new data file to reset all class variables + self.reset_state() + buff = self.readall() + self.raw_data = buff.splitlines() + self.read_data() + + def read_data(self): + correctly_loaded = True + error_message = "" + + q_unit = "1/nm" + i_unit = "1/um^2" + try: + self.current_datainfo.title = self.raw_data[0] + self.current_datainfo.meta_data["Keywords"] = self.raw_data[1] + line3 = self.raw_data[2].split() + line4 = self.raw_data[3].split() + line5 = self.raw_data[4].split() + self.data_points = int(line3[0]) + self.lower = 5 + self.upper = self.lower + self.data_points + self.source.radiation = 'x-ray' + normal = float(line4[3]) + self.current_datainfo.source.radiation = "x-ray" + self.current_datainfo.source.name = "Anton Paar SAXSess Instrument" + self.current_datainfo.source.wavelength = float(line4[4]) + xvals = [] + yvals = [] + dyvals = [] + for i in range(self.lower, self.upper): + index = i - self.lower + data = self.raw_data[i].split() + xvals.insert(index, normal * float(data[0])) + yvals.insert(index, normal * float(data[1])) + dyvals.insert(index, normal * float(data[2])) + except Exception as e: + error_message = "Couldn't load {}.\n".format(self.f_open.name) + error_message += e.message + raise FileContentsException(error_message) + self.current_dataset.x = np.append(self.current_dataset.x, xvals) + self.current_dataset.y = np.append(self.current_dataset.y, yvals) + self.current_dataset.dy = np.append(self.current_dataset.dy, dyvals) + if self.data_points != self.current_dataset.x.size: + error_message += "Not all data points could be loaded.\n" + correctly_loaded = False + if self.current_dataset.x.size != self.current_dataset.y.size: + error_message += "The x and y data sets are not the same size.\n" + correctly_loaded = False + if self.current_dataset.y.size != self.current_dataset.dy.size: + error_message += "The y and dy datasets are not the same size.\n" + correctly_loaded = False + + self.current_dataset.xaxis("Q", q_unit) + self.current_dataset.yaxis("Intensity", i_unit) + xml_intermediate = self.raw_data[self.upper:] + xml = ''.join(xml_intermediate) + try: + self.set_xml_string(xml) + dom = self.xmlroot.xpath('/fileinfo') + self._parse_child(dom) + except Exception as e: + # Data loaded but XML metadata has an error + error_message += "Data points have been loaded but there was an " + error_message += "error reading XML metadata: " + e.message + correctly_loaded = False + self.send_to_output() + if not correctly_loaded: + raise DataReaderException(error_message) + + def _parse_child(self, dom, parent=''): + """ + Recursive method for stepping through the embedded XML + :param dom: XML node with or without children + """ + for node in dom: + tagname = node.tag + value = node.text + attr = node.attrib + key = attr.get("key", '') + if len(node.getchildren()) > 1: + self._parse_child(node, key) + if key == "SampleDetector": + self.current_datainfo.detector.append(self.detector) + self.detector = Detector() + else: + if key == "value": + if parent == "Wavelength": + self.current_datainfo.source.wavelength = value + elif parent == "SampleDetector": + self.detector.distance = value + elif parent == "Temperature": + self.current_datainfo.sample.temperature = value + elif parent == "CounterSlitLength": + self.detector.slit_length = value + elif key == "unit": + value = value.replace("_", "") + if parent == "Wavelength": + self.current_datainfo.source.wavelength_unit = value + elif parent == "SampleDetector": + self.detector.distance_unit = value + elif parent == "X": + self.current_dataset.xaxis(self.current_dataset._xaxis, value) + elif parent == "Y": + self.current_dataset.yaxis(self.current_dataset._yaxis, value) + elif parent == "Temperature": + self.current_datainfo.sample.temperature_unit = value + elif parent == "CounterSlitLength": + self.detector.slit_length_unit = value + elif key == "quantity": + if parent == "X": + self.current_dataset.xaxis(value, self.current_dataset._xunit) + elif parent == "Y": + self.current_dataset.yaxis(value, self.current_dataset._yunit) diff --git a/sas/sascalc/dataloader/readers/ascii_reader.py b/sas/sascalc/dataloader/readers/ascii_reader.py new file mode 100755 index 000000000..c4fd62fb4 --- /dev/null +++ b/sas/sascalc/dataloader/readers/ascii_reader.py @@ -0,0 +1,163 @@ +""" + Generic multi-column ASCII data reader +""" +############################################################################ +# This software was developed by the University of Tennessee as part of the +# Distributed Data Analysis of Neutron Scattering Experiments (DANSE) +# project funded by the US National Science Foundation. +# If you use DANSE applications to do scientific research that leads to +# publication, we ask that you acknowledge the use of the software with the +# following sentence: +# This work benefited from DANSE software developed under NSF award DMR-0520547. +# copyright 2008, University of Tennessee +############################################################################# + +import logging +from sas.sascalc.dataloader.file_reader_base_class import FileReader +from sas.sascalc.dataloader.data_info import DataInfo, plottable_1D +from sas.sascalc.dataloader.loader_exceptions import FileContentsException,\ + DefaultReaderException + +logger = logging.getLogger(__name__) + + +class Reader(FileReader): + """ + Class to load ascii files (2, 3 or 4 columns). + """ + # File type + type_name = "ASCII" + # Wildcards + type = ["ASCII files (*.txt)|*.txt", + "ASCII files (*.dat)|*.dat", + "ASCII files (*.abs)|*.abs", + "CSV files (*.csv)|*.csv"] + # List of allowed extensions + ext = ['.txt', '.dat', '.abs', '.csv'] + # Flag to bypass extension check + allow_all = True + # data unless that is the only data + min_data_pts = 5 + + def get_file_contents(self): + """ + Get the contents of the file + """ + + buff = self.readall() + filepath = self.f_open.name + lines = buff.splitlines() + self.output = [] + self.current_datainfo = DataInfo() + self.current_datainfo.filename = filepath + self.reset_data_list(len(lines)) + + # The first good line of data will define whether + # we have 2-column or 3-column ascii + has_error_dx = None + has_error_dy = None + + # Initialize counters for data lines and header lines. + is_data = False + # More than "5" lines of data is considered as actual + # To count # of current data candidate lines + candidate_lines = 0 + # To count total # of previous data candidate lines + candidate_lines_previous = 0 + # Current line number + line_no = 0 + # minimum required number of columns of data + lentoks = 2 + for line in lines: + toks = self.splitline(line.strip()) + # To remember the number of columns in the current line of data + new_lentoks = len(toks) + try: + if new_lentoks == 0: + # If the line is blank, skip and continue on + # In case of breaks within data sets. + continue + elif new_lentoks != lentoks and is_data: + # If a footer is found, break the loop and save the data + break + elif new_lentoks != lentoks and not is_data: + # If header lines are numerical + candidate_lines = 0 + self.reset_data_list(len(lines) - line_no) + + self.current_dataset.x[candidate_lines] = float(toks[0]) + + if new_lentoks > 1: + self.current_dataset.y[candidate_lines] = float(toks[1]) + + # If a 3rd row is present, consider it dy + if new_lentoks > 2: + self.current_dataset.dy[candidate_lines] = \ + float(toks[2]) + has_error_dy = True + + # If a 4th row is present, consider it dx + if new_lentoks > 3: + self.current_dataset.dx[candidate_lines] = \ + float(toks[3]) + has_error_dx = True + + candidate_lines += 1 + # If 5 or more lines, this is considering the set data + if candidate_lines >= self.min_data_pts: + is_data = True + + if is_data and new_lentoks >= 8: + msg = "This data looks like 2D ASCII data. Use the file " + msg += "converter tool to convert it to NXcanSAS." + raise FileContentsException(msg) + + # To remember the # of columns on the current line + # for the next line of data + lentoks = new_lentoks + line_no += 1 + except ValueError: + # ValueError is raised when non numeric strings conv. to float + # It is data and meet non - number, then stop reading + if is_data: + break + # Delete the previously stored lines of data candidates if + # the list is not data + self.reset_data_list(len(lines) - line_no) + lentoks = 2 + has_error_dx = None + has_error_dy = None + # Reset # of lines of data candidates + candidate_lines = 0 + + if not is_data: + self.set_all_to_none() + if self.extension in self.ext: + msg = "ASCII Reader error: Fewer than five Q data points found " + msg += "in {}.".format(filepath) + raise FileContentsException(msg) + else: + msg = "ASCII Reader could not load the file {}".format(filepath) + raise DefaultReaderException(msg) + # Sanity check + if has_error_dy and not len(self.current_dataset.y) == \ + len(self.current_dataset.dy): + msg = "ASCII Reader error: Number of I and dI data points are" + msg += " different in {}.".format(filepath) + # TODO: Add error to self.current_datainfo.errors instead? + self.set_all_to_none() + raise FileContentsException(msg) + if has_error_dx and not len(self.current_dataset.x) == \ + len(self.current_dataset.dx): + msg = "ASCII Reader error: Number of Q and dQ data points are" + msg += " different in {}.".format(filepath) + # TODO: Add error to self.current_datainfo.errors instead? + self.set_all_to_none() + raise FileContentsException(msg) + + self.remove_empty_q_values() + self.current_dataset = self.set_default_1d_units(self.current_dataset) + + # Store loading process information + self.current_datainfo.meta_data['loader'] = self.type_name + self.send_to_output() diff --git a/sas/sascalc/dataloader/readers/associations.py b/sas/sascalc/dataloader/readers/associations.py new file mode 100755 index 000000000..ccaf45681 --- /dev/null +++ b/sas/sascalc/dataloader/readers/associations.py @@ -0,0 +1,59 @@ +""" +Module to associate default readers to file extensions. +The module reads an xml file to get the readers for each file extension. +The readers are tried in order they appear when reading a file. +""" +############################################################################ +#This software was developed by the University of Tennessee as part of the +#Distributed Data Analysis of Neutron Scattering Experiments (DANSE) +#project funded by the US National Science Foundation. +#If you use DANSE applications to do scientific research that leads to +#publication, we ask that you acknowledge the use of the software with the +#following sentence: +#This work benefited from DANSE software developed under NSF award DMR-0520547. +#copyright 2009, University of Tennessee +############################################################################# +import sys +import logging + +logger = logging.getLogger(__name__) + +FILE_ASSOCIATIONS = { + ".xml": "cansas_reader", + ".ses": "sesans_reader", + ".h5": "cansas_reader_HDF5", + ".nxs": "cansas_reader_HDF5", + ".txt": "ascii_reader", + ".dat": "red2d_reader", + ".abs": "abs_reader", + ".cor": "abs_reader", + ".sans": "danse_reader", + ".pdh": "anton_paar_saxs_reader" +} + + +def read_associations(loader, settings=FILE_ASSOCIATIONS): + """ + Read the specified settings file to associate + default readers to file extension. + + :param loader: Loader object + :param settings: path to the json settings file [string] + """ + # For each FileType entry, get the associated reader and extension + for ext, reader in settings.items(): + if reader is not None and ext is not None: + # Associate the extension with a particular reader + # TODO: Modify the Register code to be case-insensitive + # FIXME: Remove exec statements + # and remove the extra line below. + try: + exec("from . import %s" % reader) + exec("loader.associate_file_type('%s', %s)" + % (ext.lower(), reader)) + exec("loader.associate_file_type('%s', %s)" + % (ext.upper(), reader)) + except: + msg = "read_associations: skipping association" + msg += " for %s\n %s" % (ext.lower(), sys.exc_value) + logger.error(msg) diff --git a/sas/sascalc/dataloader/readers/cansas_constants.py b/sas/sascalc/dataloader/readers/cansas_constants.py new file mode 100755 index 000000000..58b6dacf5 --- /dev/null +++ b/sas/sascalc/dataloader/readers/cansas_constants.py @@ -0,0 +1,448 @@ +""" +Information relating to the CanSAS data format. These constants are used in +the cansas_reader.py file to read in any version of the cansas format. +""" +class CansasConstants(object): + """ + The base class to define where all of the data is to be saved by + cansas_reader.py. + """ + names = '' + format = '' + + def __init__(self): + self.names = self.CANSAS_NS + self.format = self.CANSAS_FORMAT + + def iterate_namespace(self, namespace): + """ + Method to iterate through a cansas constants tree based on a list of + names + + :param namespace: A list of names that match the tree structure of + cansas_constants + """ + # The current level to look through in cansas_constants. + return_me = CurrentLevel() + return_me.current_level = self.CANSAS_FORMAT.get("SASentry") + # Defaults for variable and datatype + return_me.ns_datatype = "content" + return_me.ns_optional = True + for name in namespace: + try: + if name != "SASentry": + return_me.current_level = \ + return_me.current_level.get("children").get(name, "") + if return_me.current_level == "": + return_me.current_level = \ + return_me.current_level.get("", "") + cl_datatype = return_me.current_level.get("storeas", "") + cl_units_optional = \ + return_me.current_level.get("units_optional", "") + # Where are how to store the variable for the given + # namespace CANSAS_CONSTANTS tree is hierarchical, so + # is no value, inherit + return_me.ns_datatype = cl_datatype if cl_datatype != "" \ + else return_me.ns_datatype + return_me.ns_optional = cl_units_optional if \ + cl_units_optional != return_me.ns_optional \ + else return_me.ns_optional + except AttributeError: + return_me.ns_datatype = "content" + return_me.ns_optional = True + return return_me + + def get_namespace_map(self): + """ + Helper method to get the names namespace list + """ + return self.names + + # CANSAS_NS holds the base namespace and default schema file information + CANSAS_NS = {"1.0" : {"ns" : "cansas1d/1.0", + "schema" : "cansas1d_v1_0.xsd" + }, + "1.1" : {"ns" : "urn:cansas1d:1.1", + "schema" : "cansas1d_v1_1.xsd" + } + } + + # The constants below hold information on where to store the CanSAS data + # when loaded in using sasview + ANY = {"storeas" : "content"} + TITLE = {} + SASNOTE = {} + SASPROCESS_TERM = {"attributes" : {"unit" : {}, "name" : {}}} + SASPROCESS_SASPROCESSNOTE = {"children" : {"" : ANY}} + SASPROCESS = {"children" : {"name" : {}, + "date" : {}, + "description" : {}, + "term" : SASPROCESS_TERM, + "SASprocessnote" : SASPROCESS_SASPROCESSNOTE, + "" : ANY + }, + } + RUN = {"attributes" : {"name" :{}}} + SASDATA_IDATA_Q = {"units_optional" : False, + "storeas" : "float", + "unit" : "x_unit", + "attributes" : {"unit" : {"storeas" : "content"}}, + } + SASDATA_IDATA_I = {"units_optional" : False, + "storeas" : "float", + "unit" : "y_unit", + "attributes" : {"unit" : {"storeas" : "content"}}, + } + SASDATA_IDATA_IDEV = {"units_optional" : False, + "storeas" : "float", + "unit" : "y_unit", + "attributes" : {"unit" : {"storeas" : "content"}}, + } + SASDATA_IDATA_QDEV = {"units_optional" : False, + "storeas" : "float", + "unit" : "x_unit", + "attributes" : {"unit" : {"storeas" : "content"}}, + } + SASDATA_IDATA_DQL = {"units_optional" : False, + "storeas" : "float", + "unit" : "x_unit", + "attributes" : {"unit" : {"storeas" : "content"}}, + } + SASDATA_IDATA_DQW = {"units_optional" : False, + "storeas" : "float", + "unit" : "x_unit", + "attributes" : {"unit" : {"storeas" : "content"}}, + } + SASDATA_IDATA_QMEAN = {"unit" : "x_unit", + "attributes" : {"unit" : {}}, + } + SASDATA_IDATA_SHADOWFACTOR = {} + SASDATA_IDATA = {"attributes" : {"name" : {},"timestamp" : {"storeas" : "timestamp"}}, + "children" : {"Q" : SASDATA_IDATA_Q, + "I" : SASDATA_IDATA_I, + "Idev" : SASDATA_IDATA_IDEV, + "Qdev" : SASDATA_IDATA_QDEV, + "dQw" : SASDATA_IDATA_DQW, + "dQl" : SASDATA_IDATA_DQL, + "Qmean" : SASDATA_IDATA_QMEAN, + "Shadowfactor" : SASDATA_IDATA_SHADOWFACTOR, + "" : ANY + } + } + SASDATA = {"attributes" : {"name" : {}}, + "variable" : None, + "children" : {"Idata" : SASDATA_IDATA, + "Sesans": {"storeas": "content"}, + "zacceptance": {"storeas": "float"}, + "yacceptance": {"storeas": "float"}, + "" : ANY + } + } + SASTRANSSPEC_TDATA_LAMDBA = {"storeas" : "float", + "unit" : "wavelength_unit", + "attributes" : {"unit" : {"storeas" : "content"}} + } + SASTRANSSPEC_TDATA_T = {"storeas" : "float", + "unit" : "transmission_unit", + "attributes" : {"unit" : {"storeas" : "content"}} + } + SASTRANSSPEC_TDATA_TDEV = {"storeas" : "float", + "unit" : "transmission_deviation_unit", + "attributes" : {"unit" :{"storeas" : "content"}} + } + SASTRANSSPEC_TDATA = {"children" : {"Lambda" : SASTRANSSPEC_TDATA_LAMDBA, + "T" : SASTRANSSPEC_TDATA_T, + "Tdev" : SASTRANSSPEC_TDATA_TDEV, + "" : ANY, + } + } + SASTRANSSPEC = {"children" : {"Tdata" : SASTRANSSPEC_TDATA, + "" : ANY, + }, + "attributes" : {"name" :{}, "timestamp" : {},} + } + SASSAMPLE_THICK = {"unit" : "thickness_unit", + "storeas" : "float", + "attributes" : {"unit" :{}}, + } + SASSAMPLE_TRANS = {"storeas" : "float",} + SASSAMPLE_TEMP = {"unit" : "temperature_unit", + "storeas" : "float", + "attributes" :{"unit" :{}}, + } + SASSAMPLE_POS_ATTR = {"unit" : {}} + SASSAMPLE_POS_X = {"unit" : "position_unit", + "storeas" : "float", + "attributes" : SASSAMPLE_POS_ATTR + } + SASSAMPLE_POS_Y = {"unit" : "position_unit", + "storeas" : "float", + "attributes" : SASSAMPLE_POS_ATTR + } + SASSAMPLE_POS_Z = {"unit" : "position_unit", + "storeas" : "float", + "attributes" : SASSAMPLE_POS_ATTR + } + SASSAMPLE_POS = {"children" : {"x" : SASSAMPLE_POS_X, + "y" : SASSAMPLE_POS_Y, + "z" : SASSAMPLE_POS_Z, + }, + } + SASSAMPLE_ORIENT_ATTR = {"unit" :{}} + SASSAMPLE_ORIENT_ROLL = {"unit" : "orientation_unit", + "storeas" : "float", + "attributes" : SASSAMPLE_ORIENT_ATTR + } + SASSAMPLE_ORIENT_PITCH = {"unit" : "orientation_unit", + "storeas" : "float", + "attributes" : SASSAMPLE_ORIENT_ATTR + } + SASSAMPLE_ORIENT_YAW = {"unit" : "orientation_unit", + "storeas" : "float", + "attributes" : SASSAMPLE_ORIENT_ATTR + } + SASSAMPLE_ORIENT = {"children" : {"roll" : SASSAMPLE_ORIENT_ROLL, + "pitch" : SASSAMPLE_ORIENT_PITCH, + "yaw" : SASSAMPLE_ORIENT_YAW, + }, + } + SASSAMPLE = {"attributes" : + {"name" : {},}, + "children" : {"ID" : {}, + "thickness" : SASSAMPLE_THICK, + "transmission" : SASSAMPLE_TRANS, + "temperature" : SASSAMPLE_TEMP, + "position" : SASSAMPLE_POS, + "orientation" : SASSAMPLE_ORIENT, + "details" : {}, + "" : ANY + }, + } + SASINSTR_SRC_BEAMSIZE_ATTR = {"unit" : ""} + SASINSTR_SRC_BEAMSIZE_X = {"unit" : "beam_size_unit", + "storeas" : "float", + "attributes" : SASINSTR_SRC_BEAMSIZE_ATTR + } + SASINSTR_SRC_BEAMSIZE_Y = {"unit" : "beam_size_unit", + "storeas" : "float", + "attributes" : SASINSTR_SRC_BEAMSIZE_ATTR + } + SASINSTR_SRC_BEAMSIZE_Z = {"unit" : "beam_size_unit", + "storeas" : "float", + "attributes" : SASINSTR_SRC_BEAMSIZE_ATTR + } + SASINSTR_SRC_BEAMSIZE = {"attributes" : {"name" : {}}, + "children" : {"x" : SASINSTR_SRC_BEAMSIZE_X, + "y" : SASINSTR_SRC_BEAMSIZE_Y, + "z" : SASINSTR_SRC_BEAMSIZE_Z, + } + } + SASINSTR_SRC_WL = {"unit" : "wavelength_unit", + "storeas" : "float", + "attributes" : {"unit" :{}, + } + } + SASINSTR_SRC_WL_MIN = {"unit" : "wavelength_min_unit", + "storeas" : "float", + "attributes" : {"unit" :{"storeas" : "content"},} + } + SASINSTR_SRC_WL_MAX = {"unit" : "wavelength_max_unit", + "storeas" : "float", + "attributes" : {"unit" :{"storeas" : "content"},} + } + SASINSTR_SRC_WL_SPR = {"unit" : "wavelength_spread_unit", + "storeas" : "float", + "attributes" : {"unit" : {"storeas" : "content"},} + } + SASINSTR_SRC = {"attributes" : {"name" : {}}, + "children" : {"radiation" : {}, + "beam_size" : SASINSTR_SRC_BEAMSIZE, + "beam_shape" : {}, + "wavelength" : SASINSTR_SRC_WL, + "wavelength_min" : SASINSTR_SRC_WL_MIN, + "wavelength_max" : SASINSTR_SRC_WL_MAX, + "wavelength_spread" : SASINSTR_SRC_WL_SPR, + }, + } + SASINSTR_COLL_APER_ATTR = {"unit" : {}} + SASINSTR_COLL_APER_X = {"unit" : "size_unit", + "storeas" : "float", + "attributes" : SASINSTR_COLL_APER_ATTR + } + SASINSTR_COLL_APER_Y = {"unit" : "size_unit", + "storeas" : "float", + "attributes" : SASINSTR_COLL_APER_ATTR + } + SASINSTR_COLL_APER_Z = {"unit" : "size_unit", + "storeas" : "float", + "attributes" : SASINSTR_COLL_APER_ATTR + } + SASINSTR_COLL_APER_SIZE = {"attributes" : {"unit" : {}}, + "children" : {"storeas" : "float", + "x" : SASINSTR_COLL_APER_X, + "y" : SASINSTR_COLL_APER_Y, + "z" : SASINSTR_COLL_APER_Z, + } + } + SASINSTR_COLL_APER_DIST = {"storeas" : "float", + "attributes" : {"unit" : {}}, + "unit" : "distance_unit", + } + SASINSTR_COLL_APER = {"attributes" : {"name" : {}, "type" : {}, }, + "children" : {"size" : SASINSTR_COLL_APER_SIZE, + "distance" : SASINSTR_COLL_APER_DIST + } + } + SASINSTR_COLL = {"attributes" : {"name" : {}}, + "children" : + {"length" : + {"unit" : "length_unit", + "storeas" : "float", + "attributes" : {"storeas" : "content", "unit" : {}}, + }, + "aperture" : SASINSTR_COLL_APER, + }, + } + SASINSTR_DET_SDD = {"storeas" : "float", + "unit" : "distance_unit", + "attributes" : {"unit" :{}}, + } + SASINSTR_DET_OFF_ATTR = {"unit" : {"storeas" : "content" }} + SASINSTR_DET_OFF_X = {"storeas" : "float", + "unit" : "offset_unit", + "attributes" : SASINSTR_DET_OFF_ATTR + } + SASINSTR_DET_OFF_Y = {"storeas" : "float", + "unit" : "offset_unit", + "attributes" : SASINSTR_DET_OFF_ATTR + } + SASINSTR_DET_OFF_Z = {"storeas" : "float", + "unit" : "offset_unit", + "attributes" : SASINSTR_DET_OFF_ATTR + } + SASINSTR_DET_OFF = {"children" : {"x" : SASINSTR_DET_OFF_X, + "y" : SASINSTR_DET_OFF_Y, + "z" : SASINSTR_DET_OFF_Z, + } + } + SASINSTR_DET_OR_ATTR = {} + SASINSTR_DET_OR_ROLL = {"storeas" : "float", + "unit" : "orientation_unit", + "attributes" : SASINSTR_DET_OR_ATTR + } + SASINSTR_DET_OR_PITCH = {"storeas" : "float", + "unit" : "orientation_unit", + "attributes" : SASINSTR_DET_OR_ATTR + } + SASINSTR_DET_OR_YAW = {"storeas" : "float", + "unit" : "orientation_unit", + "attributes" : SASINSTR_DET_OR_ATTR + } + SASINSTR_DET_OR = {"children" : {"roll" : SASINSTR_DET_OR_ROLL, + "pitch" : SASINSTR_DET_OR_PITCH, + "yaw" : SASINSTR_DET_OR_YAW, + } + } + SASINSTR_DET_BC_X = {"storeas" : "float", + "unit" : "beam_center_unit", + "attributes" : {"storeas" : "content"} + } + SASINSTR_DET_BC_Y = {"storeas" : "float", + "unit" : "beam_center_unit", + "attributes" : {"storeas" : "content"} + } + SASINSTR_DET_BC_Z = {"storeas" : "float", + "unit" : "beam_center_unit", + "attributes" : {"storeas" : "content"} + } + SASINSTR_DET_BC = {"children" : {"x" : SASINSTR_DET_BC_X, + "y" : SASINSTR_DET_BC_Y, + "z" : SASINSTR_DET_BC_Z,} + } + SASINSTR_DET_PIXEL_X = {"storeas" : "float", + "unit" : "pixel_size_unit", + "attributes" : {"storeas" : "content" } + } + SASINSTR_DET_PIXEL_Y = {"storeas" : "float", + "unit" : "pixel_size_unit", + "attributes" : {"storeas" : "content"} + } + SASINSTR_DET_PIXEL_Z = {"storeas" : "float", + "unit" : "pixel_size_unit", + "attributes" : {"storeas" : "content"} + } + SASINSTR_DET_PIXEL = {"children" : {"x" : SASINSTR_DET_PIXEL_X, + "y" : SASINSTR_DET_PIXEL_Y, + "z" : SASINSTR_DET_PIXEL_Z, + } + } + SASINSTR_DET_SLIT = {"storeas" : "float", + "unit" : "slit_length_unit", + "attributes" : {"unit" : {}} + } + SASINSTR_DET = {"attributes" : {"name" : {"storeas" : "content"}}, + "children" : {"name" : {"storeas" : "content"}, + "SDD" : SASINSTR_DET_SDD, + "offset" : SASINSTR_DET_OFF, + "orientation" : SASINSTR_DET_OR, + "beam_center" : SASINSTR_DET_BC, + "pixel_size" : SASINSTR_DET_PIXEL, + "slit_length" : SASINSTR_DET_SLIT, + } + } + SASINSTR = {"children" : + {"name" : {}, + "SASsource" : SASINSTR_SRC, + "SAScollimation" : SASINSTR_COLL, + "SASdetector" : SASINSTR_DET, + }, + } + CANSAS_FORMAT = {"SASentry" : + {"units_optional" : True, + "storeas" : "content", + "attributes" : {"name" : {}}, + "children" : {"Title" : TITLE, + "Run" : RUN, + "SASdata" : SASDATA, + "SAStransmission_spectrum" : SASTRANSSPEC, + "SASsample" : SASSAMPLE, + "SASinstrument" : SASINSTR, + "SASprocess" : SASPROCESS, + "SASnote" : SASNOTE, + "" : ANY, + } + } + } + + +class CurrentLevel(object): + """ + A helper class to hold information on where you are in the constants tree + """ + + current_level = '' + ns_datatype = '' + ns_optional = True + + def __init__(self): + self.current_level = {} + self.ns_datatype = "content" + self.ns_optional = True + + def get_current_level(self): + """ + Helper method to get the current_level map + """ + return self.current_level + + def get_data_type(self): + """ + Helper method to get the ns_datatype label + """ + return self.ns_datatype + + def get_variable(self): + """ + Helper method to get the ns_variable label + """ + return self.ns_variable diff --git a/sas/sascalc/dataloader/readers/cansas_reader.py b/sas/sascalc/dataloader/readers/cansas_reader.py new file mode 100755 index 000000000..755c27a33 --- /dev/null +++ b/sas/sascalc/dataloader/readers/cansas_reader.py @@ -0,0 +1,1335 @@ +import logging +import os +import sys +import datetime +import inspect + +import numpy as np + +# The following 2 imports *ARE* used. Do not remove either. +import xml.dom.minidom +from xml.dom.minidom import parseString + +from lxml import etree + +from sas.sascalc.data_util.nxsunit import Converter + +# For saving individual sections of data +from ..data_info import Data1D, Data2D, DataInfo, plottable_1D, plottable_2D, \ + Collimation, TransmissionSpectrum, Detector, Process, Aperture +from ..loader_exceptions import FileContentsException, DefaultReaderException, \ + DataReaderException +from . import xml_reader +from .xml_reader import XMLreader +from .cansas_constants import CansasConstants + +logger = logging.getLogger(__name__) + +PREPROCESS = "xmlpreprocess" +ENCODING = "encoding" +RUN_NAME_DEFAULT = "None" +INVALID_SCHEMA_PATH_1_1 = "{0}/sas/sascalc/dataloader/readers/schema/cansas1d_invalid_v1_1.xsd" +INVALID_SCHEMA_PATH_1_0 = "{0}/sas/sascalc/dataloader/readers/schema/cansas1d_invalid_v1_0.xsd" +INVALID_XML = "\n\nThe loaded xml file, {0} does not fully meet the CanSAS v1.x specification. SasView loaded " + \ + "as much of the data as possible.\n\n" + +CONSTANTS = CansasConstants() +CANSAS_FORMAT = CONSTANTS.format +CANSAS_NS = CONSTANTS.names +ALLOW_ALL = True + +class Reader(XMLreader): + cansas_version = "1.0" + base_ns = "{cansas1d/1.0}" + cansas_defaults = None + type_name = "canSAS" + invalid = True + frm = "" + # Log messages and errors + logging = None + errors = set() + # Namespace hierarchy for current xml_file object + names = None + ns_list = None + # Temporary storage location for loading multiple data sets in a single file + current_data1d = None + data = None + # Wildcards + type = ["XML files (*.xml)|*.xml", "SasView Save Files (*.svs)|*.svs"] + # List of allowed extensions + ext = ['.xml', '.svs'] + # Flag to bypass extension check + allow_all = True + + def reset_state(self): + """ + Resets the class state to a base case when loading a new data file so previous + data files do not appear a second time + """ + super(Reader, self).reset_state() + self.data = [] + self.process = Process() + self.transspectrum = TransmissionSpectrum() + self.aperture = Aperture() + self.collimation = Collimation() + self.detector = Detector() + self.names = [] + self.cansas_defaults = {} + self.ns_list = None + self.logging = [] + self.encoding = None + + def _read(self, xml_file, schema_path="", invalid=True): + if schema_path != "" or not invalid: + # read has been called from self.get_file_contents because xml file doens't conform to schema + _, self.extension = os.path.splitext(os.path.basename(xml_file)) + return self.get_file_contents(xml_file=xml_file, schema_path=schema_path, invalid=invalid) + + # Otherwise, read has been called by the data loader - file_reader_base_class handles this + return super(XMLreader, self).read(xml_file) + + def get_file_contents(self): + return self._get_file_contents(xml_file=None, schema_path="", invalid=True) + + def _get_file_contents(self, xml_file=None, schema_path="", invalid=True): + # Reset everything since we're loading a new file + self.reset_state() + self.invalid = invalid + if xml_file is None: + xml_file = self.f_open.name + # We don't sure f_open since lxml handles opnening/closing files + try: + # Raises FileContentsException + self.load_file_and_schema(xml_file, schema_path) + # Parse each SASentry + entry_list = self.xmlroot.xpath('/ns:SASroot/ns:SASentry', + namespaces={ + 'ns': self.cansas_defaults.get( + "ns") + }) + self.is_cansas(self.extension) + self.set_processing_instructions() + for entry in entry_list: + self._parse_entry(entry) + self.data_cleanup() + except FileContentsException as fc_exc: + # File doesn't meet schema - try loading with a less strict schema + base_name = xml_reader.__file__ + base_name = base_name.replace("\\", "/") + base = base_name.split("/sas/")[0] + if self.cansas_version == "1.1": + invalid_schema = INVALID_SCHEMA_PATH_1_1.format(base, self.cansas_defaults.get("schema")) + else: + invalid_schema = INVALID_SCHEMA_PATH_1_0.format(base, self.cansas_defaults.get("schema")) + self.set_schema(invalid_schema) + if self.invalid: + try: + # Load data with less strict schema + self._get_file_contents(xml_file, invalid_schema, False) + + # File can still be read but doesn't match schema, so raise exception + self.load_file_and_schema(xml_file) # Reload strict schema so we can find where error are in file + invalid_xml = self.find_invalid_xml() + if invalid_xml != "": + basename, _ = os.path.splitext( + os.path.basename(self.f_open.name)) + invalid_xml = INVALID_XML.format(basename + self.extension) + invalid_xml + raise DataReaderException(invalid_xml) # Handled by base class + except FileContentsException as fc_exc: + msg = "CanSAS Reader could not load the file {}".format(xml_file) + if fc_exc.message is not None: # Propagate error messages from earlier + msg = fc_exc.message + if not self.extension in self.ext: # If the file has no associated loader + raise DefaultReaderException(msg) + raise FileContentsException(msg) + pass + else: + raise fc_exc + except Exception as e: # Convert all other exceptions to FileContentsExceptions + raise FileContentsException(str(e)) + finally: + if not self.f_open.closed: + self.f_open.close() + + def load_file_and_schema(self, xml_file, schema_path=""): + base_name = xml_reader.__file__ + base_name = base_name.replace("\\", "/") + base = base_name.split("/sas/")[0] + + # Try and parse the XML file + try: + self.set_xml_file(xml_file) + except etree.XMLSyntaxError: # File isn't valid XML so can't be loaded + msg = "SasView cannot load {}.\nInvalid XML syntax".format(xml_file) + raise FileContentsException(msg) + + self.cansas_version = self.xmlroot.get("version", "1.0") + self.cansas_defaults = CANSAS_NS.get(self.cansas_version, "1.0") + + if schema_path == "": + schema_path = "{}/sas/sascalc/dataloader/readers/schema/{}".format( + base, self.cansas_defaults.get("schema").replace("\\", "/") + ) + self.set_schema(schema_path) + + def is_cansas(self, ext="xml"): + """ + Checks to see if the XML file is a CanSAS file + + :param ext: The file extension of the data file + :raises FileContentsException: Raised if XML file isn't valid CanSAS + """ + if self.validate_xml(): # Check file is valid XML + name = "{http://www.w3.org/2001/XMLSchema-instance}schemaLocation" + value = self.xmlroot.get(name) + # Check schema CanSAS version matches file CanSAS version + if CANSAS_NS.get(self.cansas_version).get("ns") == value.rsplit(" ")[0]: + return True + if ext == "svs": + return True # Why is this required? + # If we get to this point then file isn't valid CanSAS + logger.warning("File doesn't meet CanSAS schema. Trying to load anyway.") + raise FileContentsException("The file is not valid CanSAS") + + def _parse_entry(self, dom, recurse=False): + if not self._is_call_local() and not recurse: + self.reset_state() + if not recurse: + self.current_datainfo = DataInfo() + # Raises FileContentsException if file doesn't meet CanSAS schema + self.invalid = False + # Look for a SASentry + self.data = [] + self.parent_class = "SASentry" + self.names.append("SASentry") + self.current_datainfo.meta_data["loader"] = "CanSAS XML 1D" + self.current_datainfo.meta_data[ + PREPROCESS] = self.processing_instructions + if self._is_call_local() and not recurse: + basename, _ = os.path.splitext(os.path.basename(self.f_open.name)) + self.current_datainfo.filename = basename + self.extension + # Create an empty dataset if no data has been passed to the reader + if self.current_dataset is None: + self._initialize_new_data_set(dom) + self.base_ns = "{" + CANSAS_NS.get(self.cansas_version).get("ns") + "}" + + # Loop through each child in the parent element + for node in dom: + attr = node.attrib + name = attr.get("name", "") + type = attr.get("type", "") + # Get the element name and set the current names level + tagname = node.tag.replace(self.base_ns, "") + tagname_original = tagname + # Skip this iteration when loading in save state information + if tagname in ["fitting_plug_in", "pr_inversion", "invariant", "corfunc"]: + continue + # Get where to store content + self.names.append(tagname_original) + self.ns_list = CONSTANTS.iterate_namespace(self.names) + # If the element is a child element, recurse + if len(node.getchildren()) > 0: + self.parent_class = tagname_original + if tagname == 'SASdata': + self._initialize_new_data_set(node) + if isinstance(self.current_dataset, plottable_2D): + x_bins = attr.get("x_bins", "") + y_bins = attr.get("y_bins", "") + if x_bins is not "" and y_bins is not "": + self.current_dataset.shape = (x_bins, y_bins) + else: + self.current_dataset.shape = () + # Recurse to access data within the group + self._parse_entry(node, recurse=True) + if tagname == "SASsample": + self.current_datainfo.sample.name = name + elif tagname == "beam_size": + self.current_datainfo.source.beam_size_name = name + elif tagname == "SAScollimation": + self.collimation.name = name + elif tagname == "aperture": + self.aperture.name = name + self.aperture.type = type + self._add_intermediate() + else: + # TODO: Clean this up to make it faster (fewer if/elifs) + if isinstance(self.current_dataset, plottable_2D): + data_point = node.text + unit = attr.get('unit', '') + else: + data_point, unit = self._get_node_value(node, tagname) + + # If this is a dataset, store the data appropriately + if tagname == 'Run': + self.current_datainfo.run_name[data_point] = name + self.current_datainfo.run.append(data_point) + elif tagname == 'Title': + self.current_datainfo.title = data_point + elif tagname == 'SASnote': + self.current_datainfo.notes.append(data_point) + + # I and Q points + elif tagname == 'I' and isinstance(self.current_dataset, plottable_1D): + self.current_dataset.yaxis("Intensity", unit) + self.current_dataset.y = np.append(self.current_dataset.y, data_point) + elif tagname == 'Idev' and isinstance(self.current_dataset, plottable_1D): + self.current_dataset.dy = np.append(self.current_dataset.dy, data_point) + elif tagname == 'Q': + self.current_dataset.xaxis("Q", unit) + self.current_dataset.x = np.append(self.current_dataset.x, data_point) + elif tagname == 'Qdev': + self.current_dataset.dx = np.append(self.current_dataset.dx, data_point) + elif tagname == 'dQw': + self.current_dataset.dxw = np.append(self.current_dataset.dxw, data_point) + elif tagname == 'dQl': + self.current_dataset.dxl = np.append(self.current_dataset.dxl, data_point) + elif tagname == 'Qmean': + pass + elif tagname == 'Shadowfactor': + pass + elif tagname == 'Sesans': + self.current_datainfo.isSesans = bool(data_point) + self.current_dataset.xaxis(attr.get('x_axis'), + attr.get('x_unit')) + self.current_dataset.yaxis(attr.get('y_axis'), + attr.get('y_unit')) + elif tagname == 'yacceptance': + self.current_datainfo.sample.yacceptance = (data_point, unit) + elif tagname == 'zacceptance': + self.current_datainfo.sample.zacceptance = (data_point, unit) + + # I and Qx, Qy - 2D data + elif tagname == 'I' and isinstance(self.current_dataset, plottable_2D): + self.current_dataset.yaxis("Intensity", unit) + self.current_dataset.data = np.fromstring(data_point, dtype=float, sep=",") + elif tagname == 'Idev' and isinstance(self.current_dataset, plottable_2D): + self.current_dataset.err_data = np.fromstring(data_point, dtype=float, sep=",") + elif tagname == 'Qx': + self.current_dataset.xaxis("Qx", unit) + self.current_dataset.qx_data = np.fromstring(data_point, dtype=float, sep=",") + elif tagname == 'Qy': + self.current_dataset.yaxis("Qy", unit) + self.current_dataset.qy_data = np.fromstring(data_point, dtype=float, sep=",") + elif tagname == 'Qxdev': + self.current_dataset.xaxis("Qxdev", unit) + self.current_dataset.dqx_data = np.fromstring(data_point, dtype=float, sep=",") + elif tagname == 'Qydev': + self.current_dataset.yaxis("Qydev", unit) + self.current_dataset.dqy_data = np.fromstring(data_point, dtype=float, sep=",") + elif tagname == 'Mask': + inter = [item == "1" for item in data_point.split(",")] + self.current_dataset.mask = np.asarray(inter, dtype=bool) + + # Sample Information + elif tagname == 'ID' and self.parent_class == 'SASsample': + self.current_datainfo.sample.ID = data_point + elif tagname == 'Title' and self.parent_class == 'SASsample': + self.current_datainfo.sample.name = data_point + elif tagname == 'thickness' and self.parent_class == 'SASsample': + self.current_datainfo.sample.thickness = data_point + self.current_datainfo.sample.thickness_unit = unit + elif tagname == 'transmission' and self.parent_class == 'SASsample': + self.current_datainfo.sample.transmission = data_point + elif tagname == 'temperature' and self.parent_class == 'SASsample': + self.current_datainfo.sample.temperature = data_point + self.current_datainfo.sample.temperature_unit = unit + elif tagname == 'details' and self.parent_class == 'SASsample': + self.current_datainfo.sample.details.append(data_point) + elif tagname == 'x' and self.parent_class == 'position': + self.current_datainfo.sample.position.x = data_point + self.current_datainfo.sample.position_unit = unit + elif tagname == 'y' and self.parent_class == 'position': + self.current_datainfo.sample.position.y = data_point + self.current_datainfo.sample.position_unit = unit + elif tagname == 'z' and self.parent_class == 'position': + self.current_datainfo.sample.position.z = data_point + self.current_datainfo.sample.position_unit = unit + elif tagname == 'roll' and self.parent_class == 'orientation' and 'SASsample' in self.names: + self.current_datainfo.sample.orientation.x = data_point + self.current_datainfo.sample.orientation_unit = unit + elif tagname == 'pitch' and self.parent_class == 'orientation' and 'SASsample' in self.names: + self.current_datainfo.sample.orientation.y = data_point + self.current_datainfo.sample.orientation_unit = unit + elif tagname == 'yaw' and self.parent_class == 'orientation' and 'SASsample' in self.names: + self.current_datainfo.sample.orientation.z = data_point + self.current_datainfo.sample.orientation_unit = unit + + # Instrumental Information + elif tagname == 'name' and self.parent_class == 'SASinstrument': + self.current_datainfo.instrument = data_point + + # Detector Information + elif tagname == 'name' and self.parent_class == 'SASdetector': + self.detector.name = data_point + elif tagname == 'SDD' and self.parent_class == 'SASdetector': + self.detector.distance = data_point + self.detector.distance_unit = unit + elif tagname == 'slit_length' and self.parent_class == 'SASdetector': + self.detector.slit_length = data_point + self.detector.slit_length_unit = unit + elif tagname == 'x' and self.parent_class == 'offset': + self.detector.offset.x = data_point + self.detector.offset_unit = unit + elif tagname == 'y' and self.parent_class == 'offset': + self.detector.offset.y = data_point + self.detector.offset_unit = unit + elif tagname == 'z' and self.parent_class == 'offset': + self.detector.offset.z = data_point + self.detector.offset_unit = unit + elif tagname == 'x' and self.parent_class == 'beam_center': + self.detector.beam_center.x = data_point + self.detector.beam_center_unit = unit + elif tagname == 'y' and self.parent_class == 'beam_center': + self.detector.beam_center.y = data_point + self.detector.beam_center_unit = unit + elif tagname == 'z' and self.parent_class == 'beam_center': + self.detector.beam_center.z = data_point + self.detector.beam_center_unit = unit + elif tagname == 'x' and self.parent_class == 'pixel_size': + self.detector.pixel_size.x = data_point + self.detector.pixel_size_unit = unit + elif tagname == 'y' and self.parent_class == 'pixel_size': + self.detector.pixel_size.y = data_point + self.detector.pixel_size_unit = unit + elif tagname == 'z' and self.parent_class == 'pixel_size': + self.detector.pixel_size.z = data_point + self.detector.pixel_size_unit = unit + elif tagname == 'roll' and self.parent_class == 'orientation' and 'SASdetector' in self.names: + self.detector.orientation.x = data_point + self.detector.orientation_unit = unit + elif tagname == 'pitch' and self.parent_class == 'orientation' and 'SASdetector' in self.names: + self.detector.orientation.y = data_point + self.detector.orientation_unit = unit + elif tagname == 'yaw' and self.parent_class == 'orientation' and 'SASdetector' in self.names: + self.detector.orientation.z = data_point + self.detector.orientation_unit = unit + + # Collimation and Aperture + elif tagname == 'length' and self.parent_class == 'SAScollimation': + self.collimation.length = data_point + self.collimation.length_unit = unit + elif tagname == 'name' and self.parent_class == 'SAScollimation': + self.collimation.name = data_point + elif tagname == 'distance' and self.parent_class == 'aperture': + self.aperture.distance = data_point + self.aperture.distance_unit = unit + elif tagname == 'x' and self.parent_class == 'size': + self.aperture.size.x = data_point + self.collimation.size_unit = unit + elif tagname == 'y' and self.parent_class == 'size': + self.aperture.size.y = data_point + self.collimation.size_unit = unit + elif tagname == 'z' and self.parent_class == 'size': + self.aperture.size.z = data_point + self.collimation.size_unit = unit + + # Process Information + elif tagname == 'name' and self.parent_class == 'SASprocess': + self.process.name = data_point + elif tagname == 'description' and self.parent_class == 'SASprocess': + self.process.description = data_point + elif tagname == 'date' and self.parent_class == 'SASprocess': + try: + self.process.date = datetime.datetime.fromtimestamp(data_point) + except: + self.process.date = data_point + elif tagname == 'SASprocessnote': + self.process.notes.append(data_point) + elif tagname == 'term' and self.parent_class == 'SASprocess': + unit = attr.get("unit", "") + dic = { "name": name, "value": data_point, "unit": unit } + self.process.term.append(dic) + + # Transmission Spectrum + elif tagname == 'T' and self.parent_class == 'Tdata': + self.transspectrum.transmission = np.append(self.transspectrum.transmission, data_point) + self.transspectrum.transmission_unit = unit + elif tagname == 'Tdev' and self.parent_class == 'Tdata': + self.transspectrum.transmission_deviation = np.append(self.transspectrum.transmission_deviation, data_point) + self.transspectrum.transmission_deviation_unit = unit + elif tagname == 'Lambda' and self.parent_class == 'Tdata': + self.transspectrum.wavelength = np.append(self.transspectrum.wavelength, data_point) + self.transspectrum.wavelength_unit = unit + + # Source Information + elif tagname == 'wavelength' and (self.parent_class == 'SASsource' or self.parent_class == 'SASData'): + self.current_datainfo.source.wavelength = data_point + self.current_datainfo.source.wavelength_unit = unit + elif tagname == 'wavelength_min' and self.parent_class == 'SASsource': + self.current_datainfo.source.wavelength_min = data_point + self.current_datainfo.source.wavelength_min_unit = unit + elif tagname == 'wavelength_max' and self.parent_class == 'SASsource': + self.current_datainfo.source.wavelength_max = data_point + self.current_datainfo.source.wavelength_max_unit = unit + elif tagname == 'wavelength_spread' and self.parent_class == 'SASsource': + self.current_datainfo.source.wavelength_spread = data_point + self.current_datainfo.source.wavelength_spread_unit = unit + elif tagname == 'x' and self.parent_class == 'beam_size': + self.current_datainfo.source.beam_size.x = data_point + self.current_datainfo.source.beam_size_unit = unit + elif tagname == 'y' and self.parent_class == 'beam_size': + self.current_datainfo.source.beam_size.y = data_point + self.current_datainfo.source.beam_size_unit = unit + elif tagname == 'z' and self.parent_class == 'pixel_size': + self.current_datainfo.source.data_point.z = data_point + self.current_datainfo.source.beam_size_unit = unit + elif tagname == 'radiation' and self.parent_class == 'SASsource': + self.current_datainfo.source.radiation = data_point + elif tagname == 'beam_shape' and self.parent_class == 'SASsource': + self.current_datainfo.source.beam_shape = data_point + + # Everything else goes in meta_data + else: + new_key = self._create_unique_key(self.current_datainfo.meta_data, tagname) + self.current_datainfo.meta_data[new_key] = data_point + + self.names.remove(tagname_original) + length = 0 + if len(self.names) > 1: + length = len(self.names) - 1 + self.parent_class = self.names[length] + if not self._is_call_local() and not recurse: + self.frm = "" + self.current_datainfo.errors = set() + for error in self.errors: + self.current_datainfo.errors.add(error) + self.data_cleanup() + self.sort_data() + self.reset_data_list() + return self.output[0], None + + def _is_call_local(self): + if self.frm == "": + inter = inspect.stack() + self.frm = inter[2] + mod_name = self.frm[1].replace("\\", "/").replace(".pyc", "") + mod_name = mod_name.replace(".py", "") + mod = mod_name.split("sas/") + mod_name = mod[1] + if mod_name != "sascalc/dataloader/readers/cansas_reader": + return False + return True + + def _add_intermediate(self): + """ + This method stores any intermediate objects within the final data set after fully reading the set. + """ + if self.parent_class == 'SASprocess': + self.current_datainfo.process.append(self.process) + self.process = Process() + elif self.parent_class == 'SASdetector': + self.current_datainfo.detector.append(self.detector) + self.detector = Detector() + elif self.parent_class == 'SAStransmission_spectrum': + self.current_datainfo.trans_spectrum.append(self.transspectrum) + self.transspectrum = TransmissionSpectrum() + elif self.parent_class == 'SAScollimation': + self.current_datainfo.collimation.append(self.collimation) + self.collimation = Collimation() + elif self.parent_class == 'aperture': + self.collimation.aperture.append(self.aperture) + self.aperture = Aperture() + elif self.parent_class == 'SASdata': + self.data.append(self.current_dataset) + + def _get_node_value(self, node, tagname): + """ + Get the value of a node and any applicable units + + :param node: The XML node to get the value of + :param tagname: The tagname of the node + """ + #Get the text from the node and convert all whitespace to spaces + units = '' + node_value = node.text + if node_value is not None: + node_value = ' '.join(node_value.split()) + else: + node_value = "" + + # If the value is a float, compile with units. + if self.ns_list.ns_datatype == "float": + # If an empty value is given, set as zero. + if node_value is None or node_value.isspace() \ + or node_value.lower() == "nan": + node_value = "0.0" + #Convert the value to the base units + node_value, units = self._unit_conversion(node, tagname, node_value) + + # If the value is a timestamp, convert to a datetime object + elif self.ns_list.ns_datatype == "timestamp": + if node_value is None or node_value.isspace(): + pass + else: + try: + node_value = \ + datetime.datetime.fromtimestamp(node_value) + except ValueError: + node_value = None + return node_value, units + + def _unit_conversion(self, node, tagname, node_value): + """ + A unit converter method used to convert the data included in the file + to the default units listed in data_info + + :param node: XML node + :param tagname: name of the node + :param node_value: The value of the current dom node + """ + attr = node.attrib + value_unit = '' + err_msg = None + default_unit = None + if not isinstance(node_value, float): + node_value = float(node_value) + if 'unit' in attr and attr.get('unit') is not None: + try: + unit = attr['unit'] + # Split the units to retain backwards compatibility with + # projects, analyses, and saved data from v4.1.0 + unit_list = unit.split("|") + if len(unit_list) > 1: + local_unit = unit_list[1] + else: + local_unit = unit + unitname = self.ns_list.current_level.get("unit", "") + if "SASdetector" in self.names: + save_in = "detector" + elif "aperture" in self.names: + save_in = "aperture" + elif "SAScollimation" in self.names: + save_in = "collimation" + elif "SAStransmission_spectrum" in self.names: + save_in = "transspectrum" + elif "SASdata" in self.names: + x = np.zeros(1) + y = np.zeros(1) + self.current_data1d = Data1D(x, y) + save_in = "current_data1d" + elif "SASsource" in self.names: + save_in = "current_datainfo.source" + elif "SASsample" in self.names: + save_in = "current_datainfo.sample" + elif "SASprocess" in self.names: + save_in = "process" + else: + save_in = "current_datainfo" + default_unit = getattrchain(self, '.'.join((save_in, unitname))) + if (local_unit and default_unit + and local_unit.lower() != default_unit.lower() + and local_unit.lower() != "none"): + # Check local units - bad units raise KeyError + #print("loading", tagname, node_value, local_unit, default_unit) + data_conv_q = Converter(local_unit) + value_unit = default_unit + node_value = data_conv_q(node_value, units=default_unit) + else: + value_unit = local_unit + except KeyError: + # Do not throw an error for loading Sesans data in cansas xml + # This is a temporary fix. + if local_unit != "A" and local_unit != 'pol': + err_msg = "CanSAS reader: unexpected " + err_msg += "\"{0}\" unit [{1}]; " + err_msg = err_msg.format(tagname, local_unit) + err_msg += "expecting [{0}]".format(default_unit) + value_unit = local_unit + except Exception: + err_msg = "CanSAS reader: unknown error converting " + err_msg += "\"{0}\" unit [{1}]" + err_msg = err_msg.format(tagname, local_unit) + value_unit = local_unit + elif 'unit' in attr: + value_unit = attr['unit'] + if err_msg: + self.errors.add(err_msg) + return node_value, value_unit + + def _initialize_new_data_set(self, node=None): + if node is not None: + for child in node: + if child.tag.replace(self.base_ns, "") == "Idata": + for i_child in child: + if i_child.tag.replace(self.base_ns, "") == "Qx": + self.current_dataset = plottable_2D() + return + self.current_dataset = plottable_1D(np.array(0), np.array(0)) + + ## Writing Methods + def write(self, filename, datainfo): + """ + Write the content of a Data1D as a CanSAS XML file + + :param filename: name of the file to write + :param datainfo: Data1D object + """ + # Create XML document + doc, _ = self._to_xml_doc(datainfo) + # Write the file + file_ref = open(filename, 'wb') + if self.encoding is None: + self.encoding = "UTF-8" + doc.write(file_ref, encoding=self.encoding, + pretty_print=True, xml_declaration=True) + file_ref.close() + + def _to_xml_doc(self, datainfo): + """ + Create an XML document to contain the content of a Data1D + + :param datainfo: Data1D object + """ + is_2d = False + if issubclass(datainfo.__class__, Data2D): + is_2d = True + + # Get PIs and create root element + pi_string = self._get_pi_string() + # Define namespaces and create SASroot object + main_node = self._create_main_node() + # Create ElementTree, append SASroot and apply processing instructions + base_string = pi_string + self.to_string(main_node) + base_element = self.create_element_from_string(base_string) + doc = self.create_tree(base_element) + # Create SASentry Element + entry_node = self.create_element("SASentry") + root = doc.getroot() + root.append(entry_node) + + # Add Title to SASentry + self.write_node(entry_node, "Title", datainfo.title) + # Add Run to SASentry + self._write_run_names(datainfo, entry_node) + # Add Data info to SASEntry + if is_2d: + self._write_data_2d(datainfo, entry_node) + else: + self._write_data(datainfo, entry_node) + # Transmission Spectrum Info + # TODO: fix the writer to linearize all data, including T_spectrum + # self._write_trans_spectrum(datainfo, entry_node) + # Sample info + self._write_sample_info(datainfo, entry_node) + # Instrument info + instr = self._write_instrument(datainfo, entry_node) + # Source + self._write_source(datainfo, instr) + # Collimation + self._write_collimation(datainfo, instr) + # Detectors + self._write_detectors(datainfo, instr) + # Processes info + self._write_process_notes(datainfo, entry_node) + # Note info + self._write_notes(datainfo, entry_node) + # Return the document, and the SASentry node associated with + # the data we just wrote + # If the calling function was not the cansas reader, return a minidom + # object rather than an lxml object. + self.frm = inspect.stack()[1] + doc, entry_node = self._check_origin(entry_node, doc) + return doc, entry_node + + def write_node(self, parent, name, value, attr=None): + """ + :param doc: document DOM + :param parent: parent node + :param name: tag of the element + :param value: value of the child text node + :param attr: attribute dictionary + + :return: True if something was appended, otherwise False + """ + if value is not None: + parent = self.ebuilder(parent, name, value, attr) + return True + return False + + def _get_pi_string(self): + """ + Creates the processing instructions header for writing to file + """ + pis = self.return_processing_instructions() + if len(pis) > 0: + pi_tree = self.create_tree(pis[0]) + i = 1 + for i in range(1, len(pis) - 1): + pi_tree = self.append(pis[i], pi_tree) + pi_string = self.to_string(pi_tree) + else: + pi_string = "" + return pi_string + + def _create_main_node(self): + """ + Creates the primary xml header used when writing to file + """ + xsi = "http://www.w3.org/2001/XMLSchema-instance" + version = self.cansas_version + n_s = CANSAS_NS.get(version).get("ns") + if version == "1.1": + url = "http://www.cansas.org/formats/1.1/" + else: + url = "http://svn.smallangles.net/svn/canSAS/1dwg/trunk/" + schema_location = "{0} {1}cansas1d.xsd".format(n_s, url) + attrib = {"{" + xsi + "}schemaLocation" : schema_location, + "version" : version} + nsmap = {'xsi' : xsi, None: n_s} + + main_node = self.create_element("{" + n_s + "}SASroot", + attrib=attrib, nsmap=nsmap) + return main_node + + def _write_run_names(self, datainfo, entry_node): + """ + Writes the run names to the XML file + + :param datainfo: The Data1D object the information is coming from + :param entry_node: lxml node ElementTree object to be appended to + """ + if datainfo.run is None or datainfo.run == []: + datainfo.run.append(RUN_NAME_DEFAULT) + datainfo.run_name[RUN_NAME_DEFAULT] = RUN_NAME_DEFAULT + for item in datainfo.run: + runname = {} + if item in datainfo.run_name and \ + len(str(datainfo.run_name[item])) > 1: + runname = {'name': datainfo.run_name[item]} + self.write_node(entry_node, "Run", item, runname) + + def _write_data(self, datainfo, entry_node): + """ + Writes 1D I and Q data to the XML file + + :param datainfo: The Data1D object the information is coming from + :param entry_node: lxml node ElementTree object to be appended to + """ + node = self.create_element("SASdata") + self.append(node, entry_node) + + for i in range(len(datainfo.x)): + point = self.create_element("Idata") + node.append(point) + self.write_node(point, "Q", datainfo.x[i], + {'unit': datainfo._xunit}) + if len(datainfo.y) >= i: + self.write_node(point, "I", datainfo.y[i], + {'unit': datainfo._yunit}) + if datainfo.dy is not None and len(datainfo.dy) > i: + self.write_node(point, "Idev", datainfo.dy[i], + {'unit': datainfo._yunit}) + if datainfo.dx is not None and len(datainfo.dx) > i: + self.write_node(point, "Qdev", datainfo.dx[i], + {'unit': datainfo._xunit}) + if datainfo.dxw is not None and len(datainfo.dxw) > i: + self.write_node(point, "dQw", datainfo.dxw[i], + {'unit': datainfo._xunit}) + if datainfo.dxl is not None and len(datainfo.dxl) > i: + self.write_node(point, "dQl", datainfo.dxl[i], + {'unit': datainfo._xunit}) + if datainfo.isSesans: + sesans_attrib = {'x_axis': datainfo._xaxis, + 'y_axis': datainfo._yaxis, + 'x_unit': datainfo.x_unit, + 'y_unit': datainfo.y_unit} + sesans = self.create_element("Sesans", attrib=sesans_attrib) + sesans.text = str(datainfo.isSesans) + entry_node.append(sesans) + self.write_node(entry_node, "yacceptance", datainfo.sample.yacceptance[0], + {'unit': datainfo.sample.yacceptance[1]}) + self.write_node(entry_node, "zacceptance", datainfo.sample.zacceptance[0], + {'unit': datainfo.sample.zacceptance[1]}) + + + def _write_data_2d(self, datainfo, entry_node): + """ + Writes 2D data to the XML file + + :param datainfo: The Data2D object the information is coming from + :param entry_node: lxml node ElementTree object to be appended to + """ + attr = {} + if datainfo.data.shape: + attr["x_bins"] = str(len(datainfo.x_bins)) + attr["y_bins"] = str(len(datainfo.y_bins)) + node = self.create_element("SASdata", attr) + self.append(node, entry_node) + + point = self.create_element("Idata") + node.append(point) + qx = ','.join(str(v) for v in datainfo.qx_data) + qy = ','.join(str(v) for v in datainfo.qy_data) + intensity = ','.join(str(v) for v in datainfo.data) + + self.write_node(point, "Qx", qx, + {'unit': datainfo._xunit}) + self.write_node(point, "Qy", qy, + {'unit': datainfo._yunit}) + self.write_node(point, "I", intensity, + {'unit': datainfo._zunit}) + if datainfo.err_data is not None: + err = ','.join(str(v) for v in datainfo.err_data) + self.write_node(point, "Idev", err, + {'unit': datainfo._zunit}) + if datainfo.dqy_data is not None: + dqy = ','.join(str(v) for v in datainfo.dqy_data) + self.write_node(point, "Qydev", dqy, + {'unit': datainfo._yunit}) + if datainfo.dqx_data is not None: + dqx = ','.join(str(v) for v in datainfo.dqx_data) + self.write_node(point, "Qxdev", dqx, + {'unit': datainfo._xunit}) + if datainfo.mask is not None: + mask = ','.join("1" if v else "0" for v in datainfo.mask) + self.write_node(point, "Mask", mask) + + def _write_trans_spectrum(self, datainfo, entry_node): + """ + Writes the transmission spectrum data to the XML file + + :param datainfo: The Data1D object the information is coming from + :param entry_node: lxml node ElementTree object to be appended to + """ + for i in range(len(datainfo.trans_spectrum)): + spectrum = datainfo.trans_spectrum[i] + node = self.create_element("SAStransmission_spectrum", + {"name" : spectrum.name}) + self.append(node, entry_node) + if isinstance(spectrum.timestamp, datetime.datetime): + node.setAttribute("timestamp", spectrum.timestamp) + for i in range(len(spectrum.wavelength)): + point = self.create_element("Tdata") + node.append(point) + self.write_node(point, "Lambda", spectrum.wavelength[i], + {'unit': spectrum.wavelength_unit}) + self.write_node(point, "T", spectrum.transmission[i], + {'unit': spectrum.transmission_unit}) + if spectrum.transmission_deviation is not None \ + and len(spectrum.transmission_deviation) >= i: + self.write_node(point, "Tdev", + spectrum.transmission_deviation[i], + {'unit': + spectrum.transmission_deviation_unit}) + + def _write_sample_info(self, datainfo, entry_node): + """ + Writes the sample information to the XML file + + :param datainfo: The Data1D object the information is coming from + :param entry_node: lxml node ElementTree object to be appended to + """ + sample = self.create_element("SASsample") + if datainfo.sample.name is not None: + self.write_attribute(sample, "name", + str(datainfo.sample.name)) + self.append(sample, entry_node) + self.write_node(sample, "ID", str(datainfo.sample.ID)) + self.write_node(sample, "thickness", datainfo.sample.thickness, + {"unit": datainfo.sample.thickness_unit}) + self.write_node(sample, "transmission", datainfo.sample.transmission) + self.write_node(sample, "temperature", datainfo.sample.temperature, + {"unit": datainfo.sample.temperature_unit}) + + pos = self.create_element("position") + written = self.write_node(pos, + "x", + datainfo.sample.position.x, + {"unit": datainfo.sample.position_unit}) + written = written | self.write_node( \ + pos, "y", datainfo.sample.position.y, + {"unit": datainfo.sample.position_unit}) + written = written | self.write_node( \ + pos, "z", datainfo.sample.position.z, + {"unit": datainfo.sample.position_unit}) + if written: + self.append(pos, sample) + + ori = self.create_element("orientation") + written = self.write_node(ori, "roll", + datainfo.sample.orientation.x, + {"unit": datainfo.sample.orientation_unit}) + written = written | self.write_node( \ + ori, "pitch", datainfo.sample.orientation.y, + {"unit": datainfo.sample.orientation_unit}) + written = written | self.write_node( \ + ori, "yaw", datainfo.sample.orientation.z, + {"unit": datainfo.sample.orientation_unit}) + if written: + self.append(ori, sample) + + for item in datainfo.sample.details: + self.write_node(sample, "details", item) + + def _write_instrument(self, datainfo, entry_node): + """ + Writes the instrumental information to the XML file + + :param datainfo: The Data1D object the information is coming from + :param entry_node: lxml node ElementTree object to be appended to + """ + instr = self.create_element("SASinstrument") + self.append(instr, entry_node) + self.write_node(instr, "name", datainfo.instrument) + return instr + + def _write_source(self, datainfo, instr): + """ + Writes the source information to the XML file + + :param datainfo: The Data1D object the information is coming from + :param instr: instrument node to be appended to + """ + source = self.create_element("SASsource") + if datainfo.source.name is not None: + self.write_attribute(source, "name", + str(datainfo.source.name)) + self.append(source, instr) + if datainfo.source.radiation is None or datainfo.source.radiation == '': + datainfo.source.radiation = "neutron" + self.write_node(source, "radiation", datainfo.source.radiation) + + size = self.create_element("beam_size") + if datainfo.source.beam_size_name is not None: + self.write_attribute(size, "name", + str(datainfo.source.beam_size_name)) + written = self.write_node( \ + size, "x", datainfo.source.beam_size.x, + {"unit": datainfo.source.beam_size_unit}) + written = written | self.write_node( \ + size, "y", datainfo.source.beam_size.y, + {"unit": datainfo.source.beam_size_unit}) + written = written | self.write_node( \ + size, "z", datainfo.source.beam_size.z, + {"unit": datainfo.source.beam_size_unit}) + if written: + self.append(size, source) + + self.write_node(source, "beam_shape", datainfo.source.beam_shape) + self.write_node(source, "wavelength", + datainfo.source.wavelength, + {"unit": datainfo.source.wavelength_unit}) + self.write_node(source, "wavelength_min", + datainfo.source.wavelength_min, + {"unit": datainfo.source.wavelength_min_unit}) + self.write_node(source, "wavelength_max", + datainfo.source.wavelength_max, + {"unit": datainfo.source.wavelength_max_unit}) + self.write_node(source, "wavelength_spread", + datainfo.source.wavelength_spread, + {"unit": datainfo.source.wavelength_spread_unit}) + + def _write_collimation(self, datainfo, instr): + """ + Writes the collimation information to the XML file + + :param datainfo: The Data1D object the information is coming from + :param instr: lxml node ElementTree object to be appended to + """ + if datainfo.collimation == [] or datainfo.collimation is None: + coll = Collimation() + datainfo.collimation.append(coll) + for item in datainfo.collimation: + coll = self.create_element("SAScollimation") + if item.name is not None: + self.write_attribute(coll, "name", str(item.name)) + self.append(coll, instr) + + self.write_node(coll, "length", item.length, + {"unit": item.length_unit}) + + for aperture in item.aperture: + apert = self.create_element("aperture") + if aperture.name is not None: + self.write_attribute(apert, "name", str(aperture.name)) + if aperture.type is not None: + self.write_attribute(apert, "type", str(aperture.type)) + self.append(apert, coll) + + size = self.create_element("size") + if aperture.size_name is not None: + self.write_attribute(size, "name", + str(aperture.size_name)) + written = self.write_node(size, "x", aperture.size.x, + {"unit": aperture.size_unit}) + written = written | self.write_node( \ + size, "y", aperture.size.y, + {"unit": aperture.size_unit}) + written = written | self.write_node( \ + size, "z", aperture.size.z, + {"unit": aperture.size_unit}) + if written: + self.append(size, apert) + + self.write_node(apert, "distance", aperture.distance, + {"unit": aperture.distance_unit}) + + def _write_detectors(self, datainfo, instr): + """ + Writes the detector information to the XML file + + :param datainfo: The Data1D object the information is coming from + :param inst: lxml instrument node to be appended to + """ + if datainfo.detector is None or datainfo.detector == []: + det = Detector() + det.name = "" + datainfo.detector.append(det) + + for item in datainfo.detector: + det = self.create_element("SASdetector") + written = self.write_node(det, "name", item.name) + written = written | self.write_node(det, "SDD", item.distance, + {"unit": item.distance_unit}) + if written: + self.append(det, instr) + + off = self.create_element("offset") + written = self.write_node(off, "x", item.offset.x, + {"unit": item.offset_unit}) + written = written | self.write_node(off, "y", item.offset.y, + {"unit": item.offset_unit}) + written = written | self.write_node(off, "z", item.offset.z, + {"unit": item.offset_unit}) + if written: + self.append(off, det) + + ori = self.create_element("orientation") + written = self.write_node(ori, "roll", item.orientation.x, + {"unit": item.orientation_unit}) + written = written | self.write_node(ori, "pitch", + item.orientation.y, + {"unit": item.orientation_unit}) + written = written | self.write_node(ori, "yaw", + item.orientation.z, + {"unit": item.orientation_unit}) + if written: + self.append(ori, det) + + center = self.create_element("beam_center") + written = self.write_node(center, "x", item.beam_center.x, + {"unit": item.beam_center_unit}) + written = written | self.write_node(center, "y", + item.beam_center.y, + {"unit": item.beam_center_unit}) + written = written | self.write_node(center, "z", + item.beam_center.z, + {"unit": item.beam_center_unit}) + if written: + self.append(center, det) + + pix = self.create_element("pixel_size") + written = self.write_node(pix, "x", item.pixel_size.x, + {"unit": item.pixel_size_unit}) + written = written | self.write_node(pix, "y", item.pixel_size.y, + {"unit": item.pixel_size_unit}) + written = written | self.write_node(pix, "z", item.pixel_size.z, + {"unit": item.pixel_size_unit}) + if written: + self.append(pix, det) + self.write_node(det, "slit_length", item.slit_length, + {"unit": item.slit_length_unit}) + + + def _write_process_notes(self, datainfo, entry_node): + """ + Writes the process notes to the XML file + + :param datainfo: The Data1D object the information is coming from + :param entry_node: lxml node ElementTree object to be appended to + + """ + for item in datainfo.process: + node = self.create_element("SASprocess") + self.append(node, entry_node) + self.write_node(node, "name", item.name) + self.write_node(node, "date", item.date) + self.write_node(node, "description", item.description) + for term in item.term: + if isinstance(term, list): + value = term['value'] + del term['value'] + elif isinstance(term, dict): + value = term.get("value") + del term['value'] + else: + value = term + self.write_node(node, "term", value, term) + for note in item.notes: + self.write_node(node, "SASprocessnote", note) + if len(item.notes) == 0: + self.write_node(node, "SASprocessnote", "") + + def _write_notes(self, datainfo, entry_node): + """ + Writes the notes to the XML file and creates an empty note if none + exist + + :param datainfo: The Data1D object the information is coming from + :param entry_node: lxml node ElementTree object to be appended to + + """ + if len(datainfo.notes) == 0: + node = self.create_element("SASnote") + self.append(node, entry_node) + else: + for item in datainfo.notes: + node = self.create_element("SASnote") + self.write_text(node, item) + self.append(node, entry_node) + + def _check_origin(self, entry_node, doc): + """ + Return the document, and the SASentry node associated with + the data we just wrote. + If the calling function was not the cansas reader, return a minidom + object rather than an lxml object. + + :param entry_node: lxml node ElementTree object to be appended to + :param doc: entire xml tree + """ + if not self.frm: + self.frm = inspect.stack()[1] + mod_name = self.frm[1].replace("\\", "/").replace(".pyc", "") + mod_name = mod_name.replace(".py", "") + mod = mod_name.split("sas/") + mod_name = mod[1] + if mod_name != "sascalc/dataloader/readers/cansas_reader": + string = self.to_string(doc, pretty_print=False) + doc = parseString(string) + node_name = entry_node.tag + node_list = doc.getElementsByTagName(node_name) + entry_node = node_list.item(0) + return doc, entry_node + + # DO NOT REMOVE - used in saving and loading panel states. + def _store_float(self, location, node, variable, storage, optional=True): + """ + Get the content of a xpath location and store + the result. Check that the units are compatible + with the destination. The value is expected to + be a float. + + The xpath location might or might not exist. + If it does not exist, nothing is done + + :param location: xpath location to fetch + :param node: node to read the data from + :param variable: name of the data member to store it in [string] + :param storage: data object that has the 'variable' data member + :param optional: if True, no exception will be raised + if unit conversion can't be done + + :raise ValueError: raised when the units are not recognized + """ + entry = get_content(location, node) + try: + value = float(entry.text) + except ValueError: + value = None + + if value is not None: + # If the entry has units, check to see that they are + # compatible with what we currently have in the data object + units = entry.get('unit') + if units is not None: + toks = variable.split('.') + local_unit = getattr(storage, toks[0]+"_unit") + if local_unit is not None and units.lower() != local_unit.lower(): + try: + conv = Converter(units) + setattrchain(storage, variable, conv(value, units=local_unit)) + except Exception: + _, exc_value, _ = sys.exc_info() + err_mess = "CanSAS reader: could not convert" + err_mess += " %s unit [%s]; expecting [%s]\n %s" \ + % (variable, units, local_unit, exc_value) + self.errors.add(err_mess) + if optional: + logger.info(err_mess) + else: + raise ValueError(err_mess) + else: + setattrchain(storage, variable, value) + else: + setattrchain(storage, variable, value) + + # DO NOT REMOVE - used in saving and loading panel states. + def _store_content(self, location, node, variable, storage): + """ + Get the content of a xpath location and store + the result. The value is treated as a string. + + The xpath location might or might not exist. + If it does not exist, nothing is done + + :param location: xpath location to fetch + :param node: node to read the data from + :param variable: name of the data member to store it in [string] + :param storage: data object that has the 'variable' data member + + :return: return a list of errors + """ + entry = get_content(location, node) + if entry is not None and entry.text is not None: + setattrchain(storage, variable, entry.text.strip()) + +# DO NOT REMOVE Called by outside packages: +# sas.sasgui.perspectives.invariant.invariant_state +# sas.sasgui.perspectives.fitting.pagestate +def get_content(location, node): + """ + Get the first instance of the content of a xpath location. + + :param location: xpath location + :param node: node to start at + + :return: Element, or None + """ + nodes = node.xpath(location, + namespaces={'ns': CANSAS_NS.get("1.0").get("ns")}) + if len(nodes) > 0: + return nodes[0] + else: + return None + +# DO NOT REMOVE Called by outside packages: +# sas.sasgui.perspectives.fitting.pagestate +def write_node(doc, parent, name, value, attr=None): + """ + :param doc: document DOM + :param parent: parent node + :param name: tag of the element + :param value: value of the child text node + :param attr: attribute dictionary + + :return: True if something was appended, otherwise False + """ + if attr is None: + attr = {} + if value is not None: + node = doc.createElement(name) + node.appendChild(doc.createTextNode(str(value))) + for item in attr: + node.setAttribute(item, attr[item]) + parent.appendChild(node) + return True + return False + +def getattrchain(obj, chain, default=None): + """Like getattr, but the attr may contain multiple parts separated by '.'""" + for part in chain.split('.'): + if hasattr(obj, part): + obj = getattr(obj, part, None) + else: + return default + return obj + +def setattrchain(obj, chain, value): + """Like setattr, but the attr may contain multiple parts separated by '.'""" + parts = list(chain.split('.')) + for part in parts[-1]: + obj = getattr(obj, part, None) + if obj is None: + raise ValueError("missing parent object "+part) + setattr(obj, value) diff --git a/sas/sascalc/dataloader/readers/cansas_reader_HDF5.py b/sas/sascalc/dataloader/readers/cansas_reader_HDF5.py new file mode 100755 index 000000000..ddf5ca05d --- /dev/null +++ b/sas/sascalc/dataloader/readers/cansas_reader_HDF5.py @@ -0,0 +1,733 @@ +""" + NXcanSAS data reader for reading HDF5 formatted CanSAS files. +""" + +import h5py +import numpy as np +import re +import os +import sys + +from ..data_info import plottable_1D, plottable_2D,\ + Data1D, Data2D, DataInfo, Process, Aperture, Collimation, \ + TransmissionSpectrum, Detector +from ..loader_exceptions import FileContentsException, DefaultReaderException +from ..file_reader_base_class import FileReader, decode + +try: + basestring +except NameError: # CRUFT: python 2 support + basestring = str + + +def h5attr(node, key, default=None): + return decode(node.attrs.get(key, default)) + + +class Reader(FileReader): + """ + A class for reading in NXcanSAS data files. The current implementation has + been tested to load data generated by multiple facilities, all of which are + known to produce NXcanSAS standards compliant data. Any number of data sets + may be present within the file and any dimensionality of data may be used. + Currently 1D and 2D SAS data sets are supported, but should be immediately + extensible to SESANS data. + + Any number of SASdata groups may be present in a SASentry and the data + within each SASdata group can be a single 1D I(Q), multi-framed 1D I(Q), + 2D I(Qx, Qy) or multi-framed 2D I(Qx, Qy). + + :Dependencies: + The NXcanSAS HDF5 reader requires h5py => v2.5.0 or later. + """ + + # CanSAS version + cansas_version = 2.0 + # Data type name + type_name = "NXcanSAS" + # Wildcards + type = ["NXcanSAS HDF5 Files (*.h5)|*.h5|"] + # List of allowed extensions + ext = ['.h5', '.H5'] + # Flag to bypass extension check + allow_all = True + + def get_file_contents(self): + """ + This is the general read method that all SasView data_loaders must have. + + :param filename: A path for an HDF5 formatted CanSAS 2D data file. + :return: List of Data1D/2D objects and/or a list of errors. + """ + # Reinitialize when loading a new data file to reset all class variables + self.reset_state() + + filename = self.f_open.name + self.f_open.close() # IO handled by h5py + + # Check that the file exists + if os.path.isfile(filename): + basename = os.path.basename(filename) + _, extension = os.path.splitext(basename) + # If the file type is not allowed, return empty list + if extension in self.ext or self.allow_all: + # Load the data file + try: + self.raw_data = h5py.File(filename, 'r') + except Exception as e: + if extension not in self.ext: + msg = "NXcanSAS Reader could not load file {}".format( + basename + extension) + raise DefaultReaderException(msg) + raise FileContentsException(e.message) + try: + # Read in all child elements of top level SASroot + self.read_children(self.raw_data, []) + # Add the last data set to the list of outputs + self.add_data_set() + except Exception as exc: + raise FileContentsException(exc.message) + finally: + # Close the data file + self.raw_data.close() + + for data_set in self.output: + if isinstance(data_set, Data1D): + if data_set.x.size < 5: + exception = FileContentsException( + "Fewer than 5 data points found.") + data_set.errors.append(exception) + + def reset_state(self): + """ + Create the reader object and define initial states for class variables + """ + super(Reader, self).reset_state() + self.data1d = [] + self.data2d = [] + self.raw_data = None + self.multi_frame = False + self.data_frames = [] + self.data_uncertainty_frames = [] + self.errors = [] + self.logging = [] + self.q_names = [] + self.mask_name = u'' + self.i_name = u'' + self.i_node = u'' + self.i_uncertainties_name = u'' + self.q_uncertainty_names = [] + self.q_resolution_names = [] + self.parent_class = u'' + self.detector = Detector() + self.collimation = Collimation() + self.aperture = Aperture() + self.process = Process() + self.trans_spectrum = TransmissionSpectrum() + + def read_children(self, data, parent_list): + """ + A recursive method for stepping through the hierarchical data file. + + :param data: h5py Group object of any kind + :param parent: h5py Group parent name + """ + + # Loop through each element of the parent and process accordingly + for key in data.keys(): + # Get all information for the current key + value = data.get(key) + class_name = h5attr(value, u'canSAS_class') + if isinstance(class_name, (list, tuple, np.ndarray)): + class_name = class_name[0] + if class_name is None: + class_name = h5attr(value, u'NX_class') + if class_name is not None: + class_prog = re.compile(class_name) + else: + class_prog = re.compile(value.name) + + if isinstance(value, h5py.Group): + # Set parent class before recursion + last_parent_class = self.parent_class + self.parent_class = class_name + parent_list.append(key) + # If a new sasentry, store the current data sets and create + # a fresh Data1D/2D object + if class_prog.match(u'SASentry'): + self.add_data_set(key) + elif class_prog.match(u'SASdata'): + self._find_data_attributes(value) + self._initialize_new_data_set(value) + # Recursion step to access data within the group + self.read_children(value, parent_list) + self.add_intermediate() + # Reset parent class when returning from recursive method + self.parent_class = last_parent_class + parent_list.remove(key) + + elif isinstance(value, h5py.Dataset): + # If this is a dataset, store the data appropriately + data_set = value.value + unit = self._get_unit(value) + + for data_point in data_set: + if isinstance(data_point, np.ndarray): + if data_point.dtype.char == 'S': + data_point = decode(bytes(data_point)) + else: + data_point = decode(data_point) + # Top Level Meta Data + if key == u'definition': + if isinstance(data_set, basestring): + self.current_datainfo.meta_data['reader'] = data_set + break + else: + self.current_datainfo.meta_data[ + 'reader'] = data_point + # Run + elif key == u'run': + try: + run_name = h5attr(value, 'name') + run_dict = {data_set: run_name} + self.current_datainfo.run_name = run_dict + except Exception: + pass + if isinstance(data_set, basestring): + self.current_datainfo.run.append(data_set) + break + else: + self.current_datainfo.run.append(data_point) + # Title + elif key == u'title': + if isinstance(data_set, basestring): + self.current_datainfo.title = data_set + break + else: + self.current_datainfo.title = data_point + # Note + elif key == u'SASnote': + self.current_datainfo.notes.append(data_set) + break + # Sample Information + elif self.parent_class == u'SASsample': + self.process_sample(data_point, key) + # Instrumental Information + elif (key == u'name' + and self.parent_class == u'SASinstrument'): + self.current_datainfo.instrument = data_point + # Detector + elif self.parent_class == u'SASdetector': + self.process_detector(data_point, key, unit) + # Collimation + elif self.parent_class == u'SAScollimation': + self.process_collimation(data_point, key, unit) + # Aperture + elif self.parent_class == u'SASaperture': + self.process_aperture(data_point, key) + # Process Information + elif self.parent_class == u'SASprocess': # CanSAS 2.0 + self.process_process(data_point, key) + # Source + elif self.parent_class == u'SASsource': + self.process_source(data_point, key, unit) + # Everything else goes in meta_data + elif self.parent_class == u'SASdata': + if isinstance(self.current_dataset, plottable_2D): + self.process_2d_data_object(data_set, key, unit) + else: + self.process_1d_data_object(data_set, key, unit) + + break + elif self.parent_class == u'SAStransmission_spectrum': + self.process_trans_spectrum(data_set, key) + break + else: + new_key = self._create_unique_key( + self.current_datainfo.meta_data, key) + self.current_datainfo.meta_data[new_key] = data_point + + else: + # I don't know if this reachable code + self.errors.append("ShouldNeverHappenException") + + def process_1d_data_object(self, data_set, key, unit): + """ + SASdata processor method for 1d data items + :param data_set: data from HDF5 file + :param key: canSAS_class attribute + :param unit: unit attribute + """ + if key == self.i_name: + if self.multi_frame: + for x in range(0, data_set.shape[0]): + self.data_frames.append(data_set[x].flatten()) + else: + self.current_dataset.y = data_set.flatten() + self.current_dataset.yaxis("Intensity", unit) + elif key == self.i_uncertainties_name: + if self.multi_frame: + for x in range(0, data_set.shape[0]): + self.data_uncertainty_frames.append(data_set[x].flatten()) + self.current_dataset.dy = data_set.flatten() + elif key in self.q_names: + self.current_dataset.xaxis("Q", unit) + self.current_dataset.x = data_set.flatten() + elif key in self.q_resolution_names: + if (len(self.q_resolution_names) > 1 + and np.where(self.q_resolution_names == key)[0] == 0): + self.current_dataset.dxw = data_set.flatten() + elif (len(self.q_resolution_names) > 1 + and np.where(self.q_resolution_names == key)[0] == 1): + self.current_dataset.dxl = data_set.flatten() + else: + self.current_dataset.dx = data_set.flatten() + elif key in self.q_uncertainty_names: + if (len(self.q_uncertainty_names) > 1 + and np.where(self.q_uncertainty_names == key)[0] == 0): + self.current_dataset.dxw = data_set.flatten() + elif (len(self.q_uncertainty_names) > 1 + and np.where(self.q_uncertainty_names == key)[0] == 1): + self.current_dataset.dxl = data_set.flatten() + else: + self.current_dataset.dx = data_set.flatten() + elif key == self.mask_name: + self.current_dataset.mask = data_set.flatten() + elif key == u'wavelength': + self.current_datainfo.source.wavelength = data_set[0] + self.current_datainfo.source.wavelength_unit = unit + + def process_2d_data_object(self, data_set, key, unit): + if key == self.i_name: + self.current_dataset.data = data_set + self.current_dataset.zaxis("Intensity", unit) + elif key == self.i_uncertainties_name: + self.current_dataset.err_data = data_set.flatten() + elif key in self.q_names: + self.current_dataset.xaxis("Q_x", unit) + self.current_dataset.yaxis("Q_y", unit) + if self.q_names[0] == self.q_names[1]: + # All q data in a single array + self.current_dataset.qx_data = data_set[0] + self.current_dataset.qy_data = data_set[1] + elif self.q_names.index(key) == 0: + self.current_dataset.qx_data = data_set + elif self.q_names.index(key) == 1: + self.current_dataset.qy_data = data_set + elif key in self.q_uncertainty_names or key in self.q_resolution_names: + if ((self.q_uncertainty_names[0] == self.q_uncertainty_names[1]) or + (self.q_resolution_names[0] == self.q_resolution_names[1])): + # All q data in a single array + self.current_dataset.dqx_data = data_set[0].flatten() + self.current_dataset.dqy_data = data_set[1].flatten() + elif (self.q_uncertainty_names.index(key) == 0 or + self.q_resolution_names.index(key) == 0): + self.current_dataset.dqx_data = data_set.flatten() + elif (self.q_uncertainty_names.index(key) == 1 or + self.q_resolution_names.index(key) == 1): + self.current_dataset.dqy_data = data_set.flatten() + self.current_dataset.yaxis("Q_y", unit) + elif key == self.mask_name: + self.current_dataset.mask = data_set.flatten() + elif key == u'Qy': + self.current_dataset.yaxis("Q_y", unit) + self.current_dataset.qy_data = data_set.flatten() + elif key == u'Qydev': + self.current_dataset.dqy_data = data_set.flatten() + elif key == u'Qx': + self.current_dataset.xaxis("Q_x", unit) + self.current_dataset.qx_data = data_set.flatten() + elif key == u'Qxdev': + self.current_dataset.dqx_data = data_set.flatten() + + def process_trans_spectrum(self, data_set, key): + """ + SAStransmission_spectrum processor + :param data_set: data from HDF5 file + :param key: canSAS_class attribute + """ + if key == u'T': + self.trans_spectrum.transmission = data_set.flatten() + elif key == u'Tdev': + self.trans_spectrum.transmission_deviation = data_set.flatten() + elif key == u'lambda': + self.trans_spectrum.wavelength = data_set.flatten() + + def process_sample(self, data_point, key): + """ + SASsample processor + :param data_point: Single point from an HDF5 data file + :param key: class name data_point was taken from + """ + if key == u'Title': + self.current_datainfo.sample.name = data_point + elif key == u'name': + self.current_datainfo.sample.name = data_point + elif key == u'ID': + self.current_datainfo.sample.name = data_point + elif key == u'thickness': + self.current_datainfo.sample.thickness = data_point + elif key == u'temperature': + self.current_datainfo.sample.temperature = data_point + elif key == u'transmission': + self.current_datainfo.sample.transmission = data_point + elif key == u'x_position': + self.current_datainfo.sample.position.x = data_point + elif key == u'y_position': + self.current_datainfo.sample.position.y = data_point + elif key == u'pitch': + self.current_datainfo.sample.orientation.x = data_point + elif key == u'yaw': + self.current_datainfo.sample.orientation.y = data_point + elif key == u'roll': + self.current_datainfo.sample.orientation.z = data_point + elif key == u'details': + self.current_datainfo.sample.details.append(data_point) + + def process_detector(self, data_point, key, unit): + """ + SASdetector processor + :param data_point: Single point from an HDF5 data file + :param key: class name data_point was taken from + :param unit: unit attribute from data set + """ + if key == u'name': + self.detector.name = data_point + elif key == u'SDD': + self.detector.distance = float(data_point) + self.detector.distance_unit = unit + elif key == u'slit_length': + self.detector.slit_length = float(data_point) + self.detector.slit_length_unit = unit + elif key == u'x_position': + self.detector.offset.x = float(data_point) + self.detector.offset_unit = unit + elif key == u'y_position': + self.detector.offset.y = float(data_point) + self.detector.offset_unit = unit + elif key == u'pitch': + self.detector.orientation.x = float(data_point) + self.detector.orientation_unit = unit + elif key == u'roll': + self.detector.orientation.z = float(data_point) + self.detector.orientation_unit = unit + elif key == u'yaw': + self.detector.orientation.y = float(data_point) + self.detector.orientation_unit = unit + elif key == u'beam_center_x': + self.detector.beam_center.x = float(data_point) + self.detector.beam_center_unit = unit + elif key == u'beam_center_y': + self.detector.beam_center.y = float(data_point) + self.detector.beam_center_unit = unit + elif key == u'x_pixel_size': + self.detector.pixel_size.x = float(data_point) + self.detector.pixel_size_unit = unit + elif key == u'y_pixel_size': + self.detector.pixel_size.y = float(data_point) + self.detector.pixel_size_unit = unit + + def process_collimation(self, data_point, key, unit): + """ + SAScollimation processor + :param data_point: Single point from an HDF5 data file + :param key: class name data_point was taken from + :param unit: unit attribute from data set + """ + if key == u'distance': + self.collimation.length = data_point + self.collimation.length_unit = unit + elif key == u'name': + self.collimation.name = data_point + + def process_aperture(self, data_point, key): + """ + SASaperture processor + :param data_point: Single point from an HDF5 data file + :param key: class name data_point was taken from + """ + if key == u'shape': + self.aperture.shape = data_point + elif key == u'x_gap': + self.aperture.size.x = data_point + elif key == u'y_gap': + self.aperture.size.y = data_point + + def process_source(self, data_point, key, unit): + """ + SASsource processor + :param data_point: Single point from an HDF5 data file + :param key: class name data_point was taken from + :param unit: unit attribute from data set + """ + if key == u'incident_wavelength': + self.current_datainfo.source.wavelength = data_point + self.current_datainfo.source.wavelength_unit = unit + elif key == u'wavelength_max': + self.current_datainfo.source.wavelength_max = data_point + self.current_datainfo.source.wavelength_max_unit = unit + elif key == u'wavelength_min': + self.current_datainfo.source.wavelength_min = data_point + self.current_datainfo.source.wavelength_min_unit = unit + elif key == u'incident_wavelength_spread': + self.current_datainfo.source.wavelength_spread = data_point + self.current_datainfo.source.wavelength_spread_unit = unit + elif key == u'beam_size_x': + self.current_datainfo.source.beam_size.x = data_point + self.current_datainfo.source.beam_size_unit = unit + elif key == u'beam_size_y': + self.current_datainfo.source.beam_size.y = data_point + self.current_datainfo.source.beam_size_unit = unit + elif key == u'beam_shape': + self.current_datainfo.source.beam_shape = data_point + elif key == u'radiation': + self.current_datainfo.source.radiation = data_point + + def process_process(self, data_point, key): + """ + SASprocess processor + :param data_point: Single point from an HDF5 data file + :param key: class name data_point was taken from + """ + term_match = re.compile(u'^term[0-9]+$') + if key == u'Title': # CanSAS 2.0 + self.process.name = data_point + elif key == u'name': # NXcanSAS + self.process.name = data_point + elif key == u'description': + self.process.description = data_point + elif key == u'date': + self.process.date = data_point + elif term_match.match(key): + self.process.term.append(data_point) + else: + self.process.notes.append(data_point) + + def add_intermediate(self): + """ + This method stores any intermediate objects within the final data set + after fully reading the set. + + :param parent: The NXclass name for the h5py Group object that just + finished being processed + """ + + if self.parent_class == u'SASprocess': + self.current_datainfo.process.append(self.process) + self.process = Process() + elif self.parent_class == u'SASdetector': + self.current_datainfo.detector.append(self.detector) + self.detector = Detector() + elif self.parent_class == u'SAStransmission_spectrum': + self.current_datainfo.trans_spectrum.append(self.trans_spectrum) + self.trans_spectrum = TransmissionSpectrum() + elif self.parent_class == u'SAScollimation': + self.current_datainfo.collimation.append(self.collimation) + self.collimation = Collimation() + elif self.parent_class == u'SASaperture': + self.collimation.aperture.append(self.aperture) + self.aperture = Aperture() + elif self.parent_class == u'SASdata': + if isinstance(self.current_dataset, plottable_2D): + self.data2d.append(self.current_dataset) + elif isinstance(self.current_dataset, plottable_1D): + if self.multi_frame: + for x in range(0, len(self.data_frames)): + self.current_dataset.y = self.data_frames[x] + if len(self.data_uncertainty_frames) > x: + self.current_dataset.dy = \ + self.data_uncertainty_frames[x] + self.data1d.append(self.current_dataset) + else: + self.data1d.append(self.current_dataset) + + def final_data_cleanup(self): + """ + Does some final cleanup and formatting on self.current_datainfo and + all data1D and data2D objects and then combines the data and info into + Data1D and Data2D objects + """ + # Type cast data arrays to float64 + if len(self.current_datainfo.trans_spectrum) > 0: + spectrum_list = [] + for spectrum in self.current_datainfo.trans_spectrum: + spectrum.transmission = spectrum.transmission.astype(np.float64) + spectrum.transmission_deviation = \ + spectrum.transmission_deviation.astype(np.float64) + spectrum.wavelength = spectrum.wavelength.astype(np.float64) + if len(spectrum.transmission) > 0: + spectrum_list.append(spectrum) + self.current_datainfo.trans_spectrum = spectrum_list + + # Append errors to dataset and reset class errors + self.current_datainfo.errors = self.errors + self.errors = [] + + # Combine all plottables with datainfo and append each to output + # Type cast data arrays to float64 and find min/max as appropriate + for dataset in self.data2d: + # Calculate the actual Q matrix + try: + if dataset.q_data.size <= 1: + dataset.q_data = np.sqrt(dataset.qx_data + * dataset.qx_data + + dataset.qy_data + * dataset.qy_data).flatten() + except: + dataset.q_data = None + + if dataset.data.ndim == 2: + dataset.y_bins = np.unique(dataset.qy_data.flatten()) + dataset.x_bins = np.unique(dataset.qx_data.flatten()) + dataset.data = dataset.data.flatten() + dataset.qx_data = dataset.qx_data.flatten() + dataset.qy_data = dataset.qy_data.flatten() + + try: + iter(dataset.mask) + dataset.mask = np.invert(np.asarray(dataset.mask, dtype=bool)) + except TypeError: + dataset.mask = np.ones(dataset.data.shape, dtype=bool) + self.current_dataset = dataset + self.send_to_output() + + for dataset in self.data1d: + self.current_dataset = dataset + self.send_to_output() + + def add_data_set(self, key=""): + """ + Adds the current_dataset to the list of outputs after preforming final + processing on the data and then calls a private method to generate a + new data set. + + :param key: NeXus group name for current tree level + """ + + if self.current_datainfo and self.current_dataset: + self.final_data_cleanup() + self.data_frames = [] + self.data_uncertainty_frames = [] + self.data1d = [] + self.data2d = [] + self.current_datainfo = DataInfo() + + def _initialize_new_data_set(self, value=None): + """ + A private class method to generate a new 1D or 2D data object based on + the type of data within the set. Outside methods should call + add_data_set() to be sure any existing data is stored properly. + + :param parent_list: List of names of parent elements + """ + if self._is_2d_not_multi_frame(value): + self.current_dataset = plottable_2D() + else: + x = np.array(0) + y = np.array(0) + self.current_dataset = plottable_1D(x, y) + self.current_datainfo.filename = self.raw_data.filename + + @staticmethod + def as_list_or_array(iterable): + """ + Return value as a list if not already a list or array. + :param iterable: + :return: + """ + if not (isinstance(iterable, np.ndarray) or isinstance(iterable, list)): + iterable = iterable.split(",") if isinstance(iterable, basestring)\ + else [iterable] + return iterable + + def _find_data_attributes(self, value): + """ + A class to find the indices for Q, the name of the Qdev and Idev, and + the name of the mask. + :param value: SASdata/NXdata HDF5 Group + """ + # Initialize values to base types + self.mask_name = u'' + self.i_name = u'' + self.i_node = u'' + self.i_uncertainties_name = u'' + self.q_names = [] + self.q_uncertainty_names = [] + self.q_resolution_names = [] + # Get attributes + attrs = value.attrs + signal = attrs.get("signal", "I") + i_axes = attrs.get("I_axes", ["Q"]) + q_indices = attrs.get("Q_indices", [0]) + i_axes = self.as_list_or_array(i_axes) + keys = value.keys() + # Assign attributes to appropriate class variables + self.q_names = [i_axes[int(v)] for v in self.as_list_or_array(q_indices)] + self.mask_name = attrs.get("mask") + self.i_name = signal + self.i_node = value.get(self.i_name) + for item in self.q_names: + if item in keys: + q_vals = value.get(item) + if q_vals.attrs.get("uncertainties") is not None: + self.q_uncertainty_names = q_vals.attrs.get("uncertainties") + elif q_vals.attrs.get("uncertainty") is not None: + self.q_uncertainty_names = q_vals.attrs.get("uncertainty") + if isinstance(self.q_uncertainty_names, basestring): + self.q_uncertainty_names = self.q_uncertainty_names.split(",") + if q_vals.attrs.get("resolutions") is not None: + self.q_resolution_names = q_vals.attrs.get("resolutions") + if isinstance(self.q_resolution_names, basestring): + self.q_resolution_names = self.q_resolution_names.split(",") + if self.i_name in keys: + i_vals = value.get(self.i_name) + self.i_uncertainties_name = i_vals.attrs.get("uncertainties") + if self.i_uncertainties_name is None: + self.i_uncertainties_name = i_vals.attrs.get("uncertainty") + + def _is_2d_not_multi_frame(self, value, i_base="", q_base=""): + """ + A private class to determine if the data set is 1d or 2d. + + :param value: Nexus/NXcanSAS data group + :param basename: Approximate name of an entry to search for + :return: True if 2D, otherwise false + """ + i_basename = i_base if i_base != "" else self.i_name + i_vals = value.get(i_basename) + q_basename = q_base if q_base != "" else self.q_names + q_vals = value.get(q_basename[0]) + self.multi_frame = (i_vals is not None and q_vals is not None + and len(i_vals.shape) != 1 + and len(q_vals.shape) == 1) + return (i_vals is not None and len(i_vals.shape) != 1 + and not self.multi_frame) + + def _create_unique_key(self, dictionary, name, numb=0): + """ + Create a unique key value for any dictionary to prevent overwriting + Recurses until a unique key value is found. + + :param dictionary: A dictionary with any number of entries + :param name: The index of the item to be added to dictionary + :param numb: The number to be appended to the name, starts at 0 + :return: The new name for the dictionary entry + """ + if dictionary.get(name) is not None: + numb += 1 + name = name.split("_")[0] + name += "_{0}".format(numb) + name = self._create_unique_key(dictionary, name, numb) + return name + + def _get_unit(self, value): + """ + Find the unit for a particular value within the h5py dictionary + + :param value: attribute dictionary for a particular value set + :return: unit for the value passed to the method + """ + unit = h5attr(value, u'units') + if unit is None: + unit = h5attr(value, u'unit') + return unit diff --git a/sas/sascalc/dataloader/readers/danse_reader.py b/sas/sascalc/dataloader/readers/danse_reader.py new file mode 100755 index 000000000..ac72544c7 --- /dev/null +++ b/sas/sascalc/dataloader/readers/danse_reader.py @@ -0,0 +1,207 @@ +""" + DANSE/SANS file reader +""" +############################################################################ +#This software was developed by the University of Tennessee as part of the +#Distributed Data Analysis of Neutron Scattering Experiments (DANSE) +#project funded by the US National Science Foundation. +#If you use DANSE applications to do scientific research that leads to +#publication, we ask that you acknowledge the use of the software with the +#following sentence: +#This work benefited from DANSE software developed under NSF award DMR-0520547. +#copyright 2008, University of Tennessee +############################################################################# +import math +import os +import logging + +import numpy as np + +from ..data_info import plottable_2D, DataInfo, Detector +from ..manipulations import reader2D_converter +from ..file_reader_base_class import FileReader +from ..loader_exceptions import FileContentsException, DataReaderException + +logger = logging.getLogger(__name__) + +# Look for unit converter +has_converter = True +try: + from sas.sascalc.data_util.nxsunit import Converter +except: + has_converter = False + + +class Reader(FileReader): + """ + Example data manipulation + """ + ## File type + type_name = "DANSE" + ## Wildcards + type = ["DANSE files (*.sans)|*.sans"] + ## Extension + ext = ['.sans', '.SANS'] + + def get_file_contents(self): + self.current_datainfo = DataInfo() + self.current_dataset = plottable_2D() + self.output = [] + + loaded_correctly = True + error_message = "" + + # defaults + # wavelength in Angstrom + wavelength = 10.0 + # Distance in meter + distance = 11.0 + # Pixel number of center in x + center_x = 65 + # Pixel number of center in y + center_y = 65 + # Pixel size [mm] + pixel = 5.0 + # Size in x, in pixels + size_x = 128 + # Size in y, in pixels + size_y = 128 + # Format version + fversion = 1.0 + + self.current_datainfo.filename = os.path.basename(self.f_open.name) + detector = Detector() + self.current_datainfo.detector.append(detector) + + self.current_dataset.data = np.zeros([size_x, size_y]) + self.current_dataset.err_data = np.zeros([size_x, size_y]) + + read_on = True + data_start_line = 1 + while read_on: + line = self.nextline() + data_start_line += 1 + if line.find("DATA:") >= 0: + read_on = False + break + toks = line.split(':') + try: + if toks[0] == "FORMATVERSION": + fversion = float(toks[1]) + elif toks[0] == "WAVELENGTH": + wavelength = float(toks[1]) + elif toks[0] == "DISTANCE": + distance = float(toks[1]) + elif toks[0] == "CENTER_X": + center_x = float(toks[1]) + elif toks[0] == "CENTER_Y": + center_y = float(toks[1]) + elif toks[0] == "PIXELSIZE": + pixel = float(toks[1]) + elif toks[0] == "SIZE_X": + size_x = int(toks[1]) + elif toks[0] == "SIZE_Y": + size_y = int(toks[1]) + except ValueError as e: + error_message += "Unable to parse {}. Default value used.\n".format(toks[0]) + loaded_correctly = False + + # Read the data + data = [] + error = [] + if not fversion >= 1.0: + msg = "danse_reader can't read this file {}".format(self.f_open.name) + raise FileContentsException(msg) + + for line_num, data_str in enumerate(self.nextlines()): + toks = data_str.split() + try: + val = float(toks[0]) + err = float(toks[1]) + data.append(val) + error.append(err) + except ValueError as exc: + msg = "Unable to parse line {}: {}".format(line_num + data_start_line, data_str.strip()) + raise FileContentsException(msg) + + num_pts = size_x * size_y + if len(data) < num_pts: + msg = "Not enough data points provided. Expected {} but got {}".format( + size_x * size_y, len(data)) + raise FileContentsException(msg) + elif len(data) > num_pts: + error_message += ("Too many data points provided. Expected {0} but" + " got {1}. Only the first {0} will be used.\n").format(num_pts, len(data)) + loaded_correctly = False + data = data[:num_pts] + error = error[:num_pts] + + # Qx and Qy vectors + theta = pixel / distance / 100.0 + i_x = np.arange(size_x) + theta = (i_x - center_x + 1) * pixel / distance / 100.0 + x_vals = 4.0 * np.pi / wavelength * np.sin(theta / 2.0) + xmin = x_vals.min() + xmax = x_vals.max() + + i_y = np.arange(size_y) + theta = (i_y - center_y + 1) * pixel / distance / 100.0 + y_vals = 4.0 * np.pi / wavelength * np.sin(theta / 2.0) + ymin = y_vals.min() + ymax = y_vals.max() + + self.current_dataset.data = np.array(data, dtype=np.float64).reshape((size_y, size_x)) + if fversion > 1.0: + self.current_dataset.err_data = np.array(error, dtype=np.float64).reshape((size_y, size_x)) + + # Store all data + # Store wavelength + if has_converter and self.current_datainfo.source.wavelength_unit != 'A': + conv = Converter('A') + wavelength = conv(wavelength, + units=self.current_datainfo.source.wavelength_unit) + self.current_datainfo.source.wavelength = wavelength + + # Store distance + if has_converter and detector.distance_unit != 'm': + conv = Converter('m') + distance = conv(distance, units=detector.distance_unit) + detector.distance = distance + + # Store pixel size + if has_converter and detector.pixel_size_unit != 'mm': + conv = Converter('mm') + pixel = conv(pixel, units=detector.pixel_size_unit) + detector.pixel_size.x = pixel + detector.pixel_size.y = pixel + + # Store beam center in distance units + detector.beam_center.x = center_x * pixel + detector.beam_center.y = center_y * pixel + + self.current_dataset = self.set_default_2d_units(self.current_dataset) + self.current_dataset.x_bins = x_vals + self.current_dataset.y_bins = y_vals + + # Reshape data + x_vals = np.tile(x_vals, (size_y, 1)).flatten() + y_vals = np.tile(y_vals, (size_x, 1)).T.flatten() + if (np.all(self.current_dataset.err_data == None) + or np.any(self.current_dataset.err_data <= 0)): + new_err_data = np.sqrt(np.abs(self.current_dataset.data)) + else: + new_err_data = self.current_dataset.err_data.flatten() + + self.current_dataset.err_data = new_err_data + self.current_dataset.qx_data = x_vals + self.current_dataset.qy_data = y_vals + self.current_dataset.q_data = np.sqrt(x_vals**2 + y_vals**2) + self.current_dataset.mask = np.ones(len(x_vals), dtype=bool) + + # Store loading process information + self.current_datainfo.meta_data['loader'] = self.type_name + + self.send_to_output() + + if not loaded_correctly: + raise DataReaderException(error_message) diff --git a/sas/sascalc/dataloader/readers/red2d_reader.py b/sas/sascalc/dataloader/readers/red2d_reader.py new file mode 100755 index 000000000..39ee51b8f --- /dev/null +++ b/sas/sascalc/dataloader/readers/red2d_reader.py @@ -0,0 +1,324 @@ +""" + TXT/IGOR 2D Q Map file reader +""" +##################################################################### +#This software was developed by the University of Tennessee as part of the +#Distributed Data Analysis of Neutron Scattering Experiments (DANSE) +#project funded by the US National Science Foundation. +#See the license text in license.txt +#copyright 2008, University of Tennessee +###################################################################### +import os +import math +import time + +import numpy as np + +from sas.sascalc.data_util.nxsunit import Converter + +from ..data_info import plottable_2D, DataInfo, Detector +from ..file_reader_base_class import FileReader +from ..loader_exceptions import FileContentsException + + +def check_point(x_point): + """ + check point validity + """ + # set zero for non_floats + try: + return float(x_point) + except Exception: + return 0 + + +class Reader(FileReader): + """ Simple data reader for Igor data files """ + ## File type + type_name = "IGOR/DAT 2D Q_map" + ## Wildcards + type = ["IGOR/DAT 2D file in Q_map (*.dat)|*.DAT"] + ## Extension + ext = ['.DAT', '.dat'] + + def write(self, filename, data): + """ + Write to .dat + + :param filename: file name to write + :param data: data2D + """ + # Write the file + try: + fd = open(filename, 'w') + t = time.localtime() + time_str = time.strftime("%H:%M on %b %d %y", t) + + header_str = "Data columns are Qx - Qy - I(Qx,Qy)\n\nASCII data" + header_str += " created at %s \n\n" % time_str + # simple 2D header + fd.write(header_str) + # write qx qy I values + for i in range(len(data.data)): + fd.write("%g %g %g\n" % (data.qx_data[i], + data.qy_data[i], + data.data[i])) + finally: + fd.close() + + def get_file_contents(self): + # Read file + buf = self.readall() + self.f_open.close() + # Instantiate data object + self.current_dataset = plottable_2D() + self.current_datainfo = DataInfo() + self.current_datainfo.filename = os.path.basename(self.f_open.name) + self.current_datainfo.detector.append(Detector()) + + # Get content + data_started = False + + ## Defaults + lines = buf.split('\n') + x = [] + y = [] + + wavelength = None + distance = None + transmission = None + + pixel_x = None + pixel_y = None + + is_info = False + is_center = False + + # Remove the last lines before the for loop if the lines are empty + # to calculate the exact number of data points + count = 0 + while (len(lines[len(lines) - (count + 1)].lstrip().rstrip()) < 1): + del lines[len(lines) - (count + 1)] + count = count + 1 + + #Read Header and find the dimensions of 2D data + line_num = 0 + # Old version NIST files: 0 + ver = 0 + for line in lines: + line_num += 1 + ## Reading the header applies only to IGOR/NIST 2D q_map data files + # Find setup info line + if is_info: + is_info = False + line_toks = line.split() + # Wavelength in Angstrom + try: + wavelength = float(line_toks[1]) + # Wavelength is stored in angstroms; convert if necessary + if self.current_datainfo.source.wavelength_unit != 'A': + conv = Converter('A') + wavelength = conv(wavelength, + units=self.current_datainfo.source.wavelength_unit) + except Exception: + pass # Not required + try: + distance = float(line_toks[3]) + # Distance is stored in meters; convert if necessary + if self.current_datainfo.detector[0].distance_unit != 'm': + conv = Converter('m') + distance = conv(distance, + units=self.current_datainfo.detector[0].distance_unit) + except Exception: + pass # Not required + + try: + transmission = float(line_toks[4]) + except Exception: + pass # Not required + + if line.count("LAMBDA") > 0: + is_info = True + + # Find center info line + if is_center: + is_center = False + line_toks = line.split() + # Center in bin number + center_x = float(line_toks[0]) + center_y = float(line_toks[1]) + + if line.count("BCENT") > 0: + is_center = True + # Check version + if line.count("Data columns") > 0: + if line.count("err(I)") > 0: + ver = 1 + # Find data start + if line.count("ASCII data") > 0: + data_started = True + continue + + ## Read and get data. + if data_started: + line_toks = line.split() + if len(line_toks) == 0: + #empty line + continue + # the number of columns must be stayed same + col_num = len(line_toks) + break + + # Make numpy array to remove header lines using index + lines_array = np.array(lines) + + # index for lines_array + lines_index = np.arange(len(lines)) + + # get the data lines + data_lines = lines_array[lines_index >= (line_num - 1)] + # Now we get the total number of rows (i.e., # of data points) + row_num = len(data_lines) + # make it as list again to control the separators + data_list = " ".join(data_lines.tolist()) + # split all data to one big list w/" "separator + data_list = data_list.split() + + # Check if the size is consistent with data, otherwise + #try the tab(\t) separator + # (this may be removed once get the confidence + #the former working all cases). + if len(data_list) != (len(data_lines)) * col_num: + data_list = "\t".join(data_lines.tolist()) + data_list = data_list.split() + + # Change it(string) into float + #data_list = map(float,data_list) + data_list1 = list(map(check_point, data_list)) + + # numpy array form + data_array = np.array(data_list1) + # Redimesion based on the row_num and col_num, + #otherwise raise an error. + try: + data_point = data_array.reshape(row_num, col_num).transpose() + except Exception: + msg = "red2d_reader can't read this file: Incorrect number of data points provided." + raise FileContentsException(msg) + ## Get the all data: Let's HARDcoding; Todo find better way + # Defaults + dqx_data = np.zeros(0) + dqy_data = np.zeros(0) + err_data = np.ones(row_num) + qz_data = np.zeros(row_num) + mask = np.ones(row_num, dtype=bool) + # Get from the array + qx_data = data_point[0] + qy_data = data_point[1] + data = data_point[2] + if ver == 1: + if col_num > (2 + ver): + err_data = data_point[(2 + ver)] + if col_num > (3 + ver): + qz_data = data_point[(3 + ver)] + if col_num > (4 + ver): + dqx_data = data_point[(4 + ver)] + if col_num > (5 + ver): + dqy_data = data_point[(5 + ver)] + #if col_num > (6 + ver): mask[data_point[(6 + ver)] < 1] = False + q_data = np.sqrt(qx_data*qx_data+qy_data*qy_data+qz_data*qz_data) + + # Extra protection(it is needed for some data files): + # If all mask elements are False, put all True + if not mask.any(): + mask[mask == False] = True + + # Store limits of the image in q space + xmin = np.min(qx_data) + xmax = np.max(qx_data) + ymin = np.min(qy_data) + ymax = np.max(qy_data) + + ## calculate the range of the qx and qy_data + x_size = math.fabs(xmax - xmin) + y_size = math.fabs(ymax - ymin) + + # calculate the number of pixels in the each axes + npix_y = math.floor(math.sqrt(len(data))) + npix_x = math.floor(len(data) / npix_y) + + # calculate the size of bins + xstep = x_size / (npix_x - 1) + ystep = y_size / (npix_y - 1) + + # store x and y axis bin centers in q space + x_bins = np.arange(xmin, xmax + xstep, xstep) + y_bins = np.arange(ymin, ymax + ystep, ystep) + + # get the limits of q values + xmin = xmin - xstep / 2 + xmax = xmax + xstep / 2 + ymin = ymin - ystep / 2 + ymax = ymax + ystep / 2 + + #Store data in outputs + #TODO: Check the lengths + self.current_dataset.data = data + if (err_data == 1).all(): + self.current_dataset.err_data = np.sqrt(np.abs(data)) + self.current_dataset.err_data[self.current_dataset.err_data == 0.0] = 1.0 + else: + self.current_dataset.err_data = err_data + + self.current_dataset.qx_data = qx_data + self.current_dataset.qy_data = qy_data + self.current_dataset.q_data = q_data + self.current_dataset.mask = mask + + self.current_dataset.x_bins = x_bins + self.current_dataset.y_bins = y_bins + + self.current_dataset.xmin = xmin + self.current_dataset.xmax = xmax + self.current_dataset.ymin = ymin + self.current_dataset.ymax = ymax + + self.current_datainfo.source.wavelength = wavelength + + # Store pixel size in mm + self.current_datainfo.detector[0].pixel_size.x = pixel_x + self.current_datainfo.detector[0].pixel_size.y = pixel_y + + # Store the sample to detector distance + self.current_datainfo.detector[0].distance = distance + + # optional data: if all of dq data == 0, do not pass to output + if len(dqx_data) == len(qx_data) and dqx_data.any() != 0: + # if no dqx_data, do not pass dqy_data. + #(1 axis dq is not supported yet). + if len(dqy_data) == len(qy_data) and dqy_data.any() != 0: + # Currently we do not support dq parr, perp. + # tranfer the comp. to cartesian coord. for newer version. + if ver != 1: + diag = np.sqrt(qx_data * qx_data + qy_data * qy_data) + cos_th = qx_data / diag + sin_th = qy_data / diag + self.current_dataset.dqx_data = np.sqrt((dqx_data * cos_th) * \ + (dqx_data * cos_th) \ + + (dqy_data * sin_th) * \ + (dqy_data * sin_th)) + self.current_dataset.dqy_data = np.sqrt((dqx_data * sin_th) * \ + (dqx_data * sin_th) \ + + (dqy_data * cos_th) * \ + (dqy_data * cos_th)) + else: + self.current_dataset.dqx_data = dqx_data + self.current_dataset.dqy_data = dqy_data + + # Units of axes + self.current_dataset = self.set_default_2d_units(self.current_dataset) + + # Store loading process information + self.current_datainfo.meta_data['loader'] = self.type_name + + self.send_to_output() diff --git a/sas/sascalc/dataloader/readers/schema/cansas1d_invalid_v1_0.xsd b/sas/sascalc/dataloader/readers/schema/cansas1d_invalid_v1_0.xsd new file mode 100755 index 000000000..c98b9843b --- /dev/null +++ b/sas/sascalc/dataloader/readers/schema/cansas1d_invalid_v1_0.xsd @@ -0,0 +1,98 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/sas/sascalc/dataloader/readers/schema/cansas1d_invalid_v1_1.xsd b/sas/sascalc/dataloader/readers/schema/cansas1d_invalid_v1_1.xsd new file mode 100755 index 000000000..914caed00 --- /dev/null +++ b/sas/sascalc/dataloader/readers/schema/cansas1d_invalid_v1_1.xsd @@ -0,0 +1,98 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/sas/sascalc/dataloader/readers/schema/cansas1d_v1_0.xsd b/sas/sascalc/dataloader/readers/schema/cansas1d_v1_0.xsd new file mode 100755 index 000000000..78244a82a --- /dev/null +++ b/sas/sascalc/dataloader/readers/schema/cansas1d_v1_0.xsd @@ -0,0 +1,238 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/sas/sascalc/dataloader/readers/schema/cansas1d_v1_1.xsd b/sas/sascalc/dataloader/readers/schema/cansas1d_v1_1.xsd new file mode 100755 index 000000000..c376e590b --- /dev/null +++ b/sas/sascalc/dataloader/readers/schema/cansas1d_v1_1.xsd @@ -0,0 +1,271 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/sas/sascalc/dataloader/readers/sesans_reader.py b/sas/sascalc/dataloader/readers/sesans_reader.py new file mode 100755 index 000000000..47c0c41c8 --- /dev/null +++ b/sas/sascalc/dataloader/readers/sesans_reader.py @@ -0,0 +1,175 @@ +""" + SESANS reader (based on ASCII reader) + + Reader for .ses or .sesans file format + + Jurrian Bakker +""" +import os + +import numpy as np + +from ..file_reader_base_class import FileReader +from ..data_info import plottable_1D, DataInfo +from ..loader_exceptions import FileContentsException + +# Check whether we have a converter available +has_converter = True +try: + from sas.sascalc.data_util.nxsunit import Converter +except ImportError: + has_converter = False +_ZERO = 1e-16 + +class Reader(FileReader): + """ + Class to load sesans files (6 columns). + """ + # File type + type_name = "SESANS" + + ## Wildcards + type = ["SESANS files (*.ses)|*.ses", + "SESANS files (*..sesans)|*.sesans"] + # List of allowed extensions + ext = ['.ses', '.SES', '.sesans', '.SESANS'] + + # Flag to bypass extension check + allow_all = True + + def get_file_contents(self): + self.current_datainfo = DataInfo() + self.current_dataset = plottable_1D(np.array([]), np.array([])) + self.current_datainfo.isSesans = True + self.output = [] + + line = self.nextline() + params = {} + while line and not line.startswith("BEGIN_DATA"): + terms = line.split() + if len(terms) >= 2: + params[terms[0]] = " ".join(terms[1:]) + line = self.nextline() + self.params = params + + if "FileFormatVersion" not in self.params: + raise FileContentsException("SES file missing FileFormatVersion") + if float(self.params["FileFormatVersion"]) >= 2.0: + raise FileContentsException("SASView only supports SES version 1") + + if "SpinEchoLength_unit" not in self.params: + raise FileContentsException("SpinEchoLength has no units") + if "Wavelength_unit" not in self.params: + raise FileContentsException("Wavelength has no units") + if params["SpinEchoLength_unit"] != params["Wavelength_unit"]: + raise FileContentsException( + "The spin echo data has rudely used " + "different units for the spin echo length " + "and the wavelength. While sasview could " + "handle this instance, it is a violation " + "of the file format and will not be " + "handled by other software.") + + headers = self.nextline().split() + + self._insist_header(headers, "SpinEchoLength") + self._insist_header(headers, "Depolarisation") + self._insist_header(headers, "Depolarisation_error") + self._insist_header(headers, "Wavelength") + + data = np.loadtxt(self.f_open) + + if data.shape[1] != len(headers): + raise FileContentsException( + "File has {} headers, but {} columns".format( + len(headers), + data.shape[1])) + + if not data.size: + raise FileContentsException("{} is empty".format(self.filepath)) + x = data[:, headers.index("SpinEchoLength")] + if "SpinEchoLength_error" in headers: + dx = data[:, headers.index("SpinEchoLength_error")] + else: + dx = x * 0.05 + lam = data[:, headers.index("Wavelength")] + if "Wavelength_error" in headers: + dlam = data[:, headers.index("Wavelength_error")] + else: + dlam = lam * 0.05 + y = data[:, headers.index("Depolarisation")] + dy = data[:, headers.index("Depolarisation_error")] + + lam_unit = self._unit_fetch("Wavelength") + x, x_unit = self._unit_conversion(x, "A", + self._unit_fetch( + "SpinEchoLength")) + dx, dx_unit = self._unit_conversion( + dx, lam_unit, + self._unit_fetch("SpinEchoLength")) + dlam, dlam_unit = self._unit_conversion( + dlam, lam_unit, + self._unit_fetch("Wavelength")) + y_unit = self._unit_fetch("Depolarisation") + + self.current_dataset.x = x + self.current_dataset.y = y + self.current_dataset.lam = lam + self.current_dataset.dy = dy + self.current_dataset.dx = dx + self.current_dataset.dlam = dlam + self.current_datainfo.isSesans = True + + self.current_datainfo._yunit = y_unit + self.current_datainfo._xunit = x_unit + self.current_datainfo.source.wavelength_unit = lam_unit + self.current_datainfo.source.wavelength = lam + self.current_datainfo.filename = os.path.basename(self.f_open.name) + self.current_dataset.xaxis(r"\rm{z}", x_unit) + # Adjust label to ln P/(lam^2 t), remove lam column refs + self.current_dataset.yaxis(r"\rm{ln(P)/(t \lambda^2)}", y_unit) + # Store loading process information + self.current_datainfo.meta_data['loader'] = self.type_name + self.current_datainfo.sample.name = params["Sample"] + self.current_datainfo.sample.ID = params["DataFileTitle"] + self.current_datainfo.sample.thickness = self._unit_conversion( + float(params["Thickness"]), "cm", + self._unit_fetch("Thickness"))[0] + + self.current_datainfo.sample.zacceptance = ( + float(params["Theta_zmax"]), + self._unit_fetch("Theta_zmax")) + + self.current_datainfo.sample.yacceptance = ( + float(params["Theta_ymax"]), + self._unit_fetch("Theta_ymax")) + + self.send_to_output() + + @staticmethod + def _insist_header(headers, name): + if name not in headers: + raise FileContentsException( + "Missing {} column in spin echo data".format(name)) + + @staticmethod + def _unit_conversion(value, value_unit, default_unit): + """ + Performs unit conversion on a measurement. + + :param value: The magnitude of the measurement + :param value_unit: a string containing the final desired unit + :param default_unit: string with the units of the original measurement + :return: The magnitude of the measurement in the new units + """ + # (float, string, string) -> float + if has_converter and value_unit != default_unit: + data_conv_q = Converter(default_unit) + value = data_conv_q(value, units=value_unit) + new_unit = default_unit + else: + new_unit = value_unit + return value, new_unit + + def _unit_fetch(self, unit): + return self.params[unit+"_unit"] diff --git a/sas/sascalc/dataloader/readers/tiff_reader.py b/sas/sascalc/dataloader/readers/tiff_reader.py new file mode 100755 index 000000000..9dd42d348 --- /dev/null +++ b/sas/sascalc/dataloader/readers/tiff_reader.py @@ -0,0 +1,108 @@ +##################################################################### +#This software was developed by the University of Tennessee as part of the +#Distributed Data Analysis of Neutron Scattering Experiments (DANSE) +#project funded by the US National Science Foundation. +#See the license text in license.txt +#copyright 2008, University of Tennessee +###################################################################### +""" + Image reader. Untested. +""" +#TODO: load and check data and orientation of the image (needs rendering) +import math +import logging +import os +import numpy as np +from sas.sascalc.dataloader.data_info import Data2D +from sas.sascalc.dataloader.manipulations import reader2D_converter + +logger = logging.getLogger(__name__) + +class Reader: + """ + Example data manipulation + """ + ## File type + type_name = "TIF" + ## Wildcards + type = ["TIF files (*.tif)|*.tif", + "TIFF files (*.tiff)|*.tiff", + ] + ## Extension + ext = ['.tif', '.tiff'] + + def read(self, filename=None): + """ + Open and read the data in a file + + :param file: path of the file + """ + try: + import Image + import TiffImagePlugin + Image._initialized=2 + except: + msg = "tiff_reader: could not load file. Missing Image module." + raise RuntimeError(msg) + + # Instantiate data object + output = Data2D() + output.filename = os.path.basename(filename) + + # Read in the image + try: + im = Image.open(filename) + except: + raise RuntimeError("cannot open %s"%(filename)) + data = im.getdata() + + # Initiazed the output data object + output.data = np.zeros([im.size[0], im.size[1]]) + output.err_data = np.zeros([im.size[0], im.size[1]]) + output.mask = np.ones([im.size[0], im.size[1]], dtype=bool) + + # Initialize + x_vals = [] + y_vals = [] + + # x and y vectors + for i_x in range(im.size[0]): + x_vals.append(i_x) + + itot = 0 + for i_y in range(im.size[1]): + y_vals.append(i_y) + + for val in data: + try: + value = float(val) + except: + logger.error("tiff_reader: had to skip a non-float point") + continue + + # Get bin number + if math.fmod(itot, im.size[0]) == 0: + i_x = 0 + i_y += 1 + else: + i_x += 1 + + output.data[im.size[1] - 1 - i_y][i_x] = value + + itot += 1 + + output.xbins = im.size[0] + output.ybins = im.size[1] + output.x_bins = x_vals + output.y_bins = y_vals + output.qx_data = np.array(x_vals) + output.qy_data = np.array(y_vals) + output.xmin = 0 + output.xmax = im.size[0] - 1 + output.ymin = 0 + output.ymax = im.size[0] - 1 + + # Store loading process information + output.meta_data['loader'] = self.type_name + output = reader2D_converter(output) + return output diff --git a/sas/sascalc/dataloader/readers/xml_reader.py b/sas/sascalc/dataloader/readers/xml_reader.py new file mode 100755 index 000000000..0f9587d19 --- /dev/null +++ b/sas/sascalc/dataloader/readers/xml_reader.py @@ -0,0 +1,314 @@ +""" + Generic XML read and write utility + + Usage: Either extend xml_reader or add as a class variable. +""" +############################################################################ +#This software was developed by the University of Tennessee as part of the +#Distributed Data Analysis of Neutron Scattering Experiments (DANSE) +#project funded by the US National Science Foundation. +#If you use DANSE applications to do scientific research that leads to +#publication, we ask that you acknowledge the use of the software with the +#following sentence: +#This work benefited from DANSE software developed under NSF award DMR-0520547. +#copyright 2008,2009 University of Tennessee +############################################################################# + +import logging + +from lxml import etree +from lxml.builder import E + +from ..file_reader_base_class import FileReader, decode + +logger = logging.getLogger(__name__) + +PARSER = etree.ETCompatXMLParser(remove_comments=True, remove_pis=False) + +class XMLreader(FileReader): + """ + Generic XML read and write class. Mostly helper functions. + Makes reading/writing XML a bit easier than calling lxml libraries directly. + + :Dependencies: + This class requires lxml 2.3 or higher. + """ + + xml = None + xmldoc = None + xmlroot = None + schema = None + schemadoc = None + encoding = None + processing_instructions = None + + def __init__(self, xml=None, schema=None): + self.xml = xml + self.schema = schema + self.processing_instructions = {} + if xml is not None: + self.set_xml_file(xml) + else: + self.xmldoc = None + self.xmlroot = None + if schema is not None: + self.set_schema(schema) + else: + self.schemadoc = None + + def reader(self): + """ + Read in an XML file into memory and return an lxml dictionary + """ + if self.validate_xml(): + self.xmldoc = etree.parse(self.xml, parser=PARSER) + else: + raise etree.XMLSchemaValidateError(self, self.find_invalid_xml()) + return self.xmldoc + + def set_xml_file(self, xml): + """ + Set the XML file and parse + """ + try: + self.xml = xml + self.xmldoc = etree.parse(self.xml, parser=PARSER) + self.xmlroot = self.xmldoc.getroot() + except etree.XMLSyntaxError as xml_error: + logger.info(xml_error) + raise xml_error + except Exception: + self.xml = None + self.xmldoc = None + self.xmlroot = None + + def set_xml_string(self, tag_soup): + """ + Set an XML string as the working XML. + + :param tag_soup: XML formatted string + """ + try: + self.xml = tag_soup + self.xmldoc = tag_soup + self.xmlroot = etree.fromstring(tag_soup) + except etree.XMLSyntaxError as xml_error: + logger.info(xml_error) + raise xml_error + except Exception as exc: + self.xml = None + self.xmldoc = None + self.xmlroot = None + raise exc + + def set_schema(self, schema): + """ + Set the schema file and parse + """ + try: + self.schema = schema + self.schemadoc = etree.parse(self.schema, parser=PARSER) + except etree.XMLSyntaxError as xml_error: + logger.info(xml_error) + except Exception: + self.schema = None + self.schemadoc = None + + def validate_xml(self): + """ + Checks to see if the XML file meets the schema + """ + valid = True + if self.schema is not None: + self.parse_schema_and_doc() + schema_check = etree.XMLSchema(self.schemadoc) + valid = schema_check.validate(self.xmldoc) + return valid + + def find_invalid_xml(self): + """ + Finds the first offending element that should not be present in XML file + """ + first_error = "" + self.parse_schema_and_doc() + schema = etree.XMLSchema(self.schemadoc) + try: + first_error = schema.assertValid(self.xmldoc) + except etree.DocumentInvalid as err: + # Suppress errors for <'any'> elements + if "##other" in str(err): + return first_error + first_error = str(err) + return first_error + + def parse_schema_and_doc(self): + """ + Creates a dictionary of the parsed schema and xml files. + """ + self.set_xml_file(self.xml) + self.set_schema(self.schema) + + def to_string(self, elem, pretty_print=False, encoding=None): + """ + Converts an etree element into a string + """ + return decode(etree.tostring(elem, pretty_print=pretty_print, + encoding=encoding)) + + def break_processing_instructions(self, string, dic): + """ + Method to break a processing instruction string apart and add to a dict + + :param string: A processing instruction as a string + :param dic: The dictionary to save the PIs to + """ + pi_string = string.replace("", "") + split = pi_string.split(" ", 1) + pi_name = split[0] + attr = split[1] + new_pi_name = self._create_unique_key(dic, pi_name) + dic[new_pi_name] = attr + return dic + + def set_processing_instructions(self): + """ + Take out all processing instructions and create a dictionary from them + If there is a default encoding, the value is also saved + """ + dic = {} + proc_instr = self.xmlroot.getprevious() + while proc_instr is not None: + pi_string = self.to_string(proc_instr) + if "?>\n\n " + " .").format(current_line + 1) + err_msg += " Instead got '{}'.".format(all_lines[current_line]) + raise ValueError(err_msg) + + if width > len(qx) or height > len(qy): + err_msg = "File incorrectly formatted.\n" + err_msg += ("Line {} says to use {}x{} points. " + "Only {}x{} provided.").format(current_line + 1, width, + height, len(qx), len(qy)) + raise ValueError(err_msg) + + # More qx and/or qy points can be provided than are actually used + qx = qx[:width] + qy = qy[:height] + + current_line += 1 + # iflag = 1 => Only intensity data (not dealt with here) + # iflag = 2 => q axis and intensity data + # iflag = 3 => q axis, intensity and error data + try: + iflag = int(all_lines[current_line].strip()[0]) + if iflag <= 0 or iflag > 3: raise ValueError() + except: + err_msg = "File incorrectly formatted.\n" + iflag = all_lines[current_line].strip()[0] + err_msg += ("Expected iflag on line {} to be 1, 2 or 3. " + "Instead got '{}'.").format(current_line+1, iflag) + raise ValueError(err_msg) + + current_line += 1 + + try: + current_line, I = _load_points(all_lines, current_line, + width * height) + dI = np.zeros(width*height) + + # Load error data if it's provided + if iflag == 3: + _, dI = _load_points(all_lines, current_line, width*height) + except Exception as e: + err_msg = "File incorrectly formatted.\n" + if str(e).find("list index") != -1: + err_msg += ("Incorrect number of data points. Expected {}" + " intensity").format(width * height) + if iflag == 3: + err_msg += " and error" + err_msg += " points." + else: + err_msg += str(e) + raise ValueError(err_msg) + + # Format data for use with Data2D + qx = list(qx) * height + qy = np.array([[y] * width for y in qy]).flatten() + + data = Data2D(qx_data=qx, qy_data=qy, data=I, err_data=dI) + + return data diff --git a/sas/sascalc/file_converter/bsl_loader.py b/sas/sascalc/file_converter/bsl_loader.py new file mode 100755 index 000000000..7025ec22d --- /dev/null +++ b/sas/sascalc/file_converter/bsl_loader.py @@ -0,0 +1,131 @@ +from sas.sascalc.file_converter._bsl_loader import CLoader +from sas.sascalc.dataloader.data_info import Data2D +from copy import deepcopy +import os +import numpy as np + +class BSLParsingError(Exception): + pass + +class BSLLoader(CLoader): + """ + Loads 2D SAS data from a BSL file. + CLoader is a C extension (found in c_ext/bsl_loader.c) + + See http://www.diamond.ac.uk/Beamlines/Soft-Condensed-Matter/small-angle/SAXS-Software/CCP13/BSL.html + for more info on the BSL file format. + """ + + def __init__(self, filename): + """ + Parses the BSL header file and sets instance variables apropriately + + :param filename: Path to the BSL header file + """ + header_file = open(filename, 'r') + data_info = {} + is_valid = True + err_msg = "" + + [folder, filename] = os.path.split(filename) + # SAS data will be in file Xnn001.mdd + sasdata_filename = filename.replace('000.', '001.') + if sasdata_filename == filename: + err_msg = ("Invalid header filename {}.\nShould be of the format " + "Xnn000.XXX where X is any alphanumeric character and n is any" + " digit.").format(filename) + raise BSLParsingError(err_msg) + + # First 2 lines are headers + header_file.readline() + header_file.readline() + + while True: + metadata = header_file.readline().strip() + metadata = metadata.split() + data_filename = header_file.readline().strip() + + if len(metadata) != 10: + is_valid = False + err_msg = "Invalid header file: {}".format(filename) + break + + if data_filename != sasdata_filename: + last_file = (metadata[9] == '0') + if last_file: # Reached last file we have metadata for + is_valid = False + err_msg = "No metadata for {} found in header file: {}" + err_msg = err_msg.format(sasdata_filename, filename) + break + continue + try: + data_info = { + 'filename': os.path.join(folder, data_filename), + 'pixels': int(metadata[0]), + 'rasters': int(metadata[1]), + 'frames': int(metadata[2]), + 'swap_bytes': int(metadata[3]) + } + except Exception: + is_valid = False + err_msg = "Invalid metadata in header file for {}" + err_msg = err_msg.format(sasdata_filename) + break + + header_file.close() + if not is_valid: + raise BSLParsingError(err_msg) + + CLoader.__init__(self, data_info['filename'], data_info['frames'], + data_info['pixels'], data_info['rasters'], data_info['swap_bytes']) + + def load_frames(self, frames): + frame_data = [] + # Prepare axis values (arbitrary scale) + x = self.n_rasters * range(1, self.n_pixels+1) + y = [self.n_pixels * [i] for i in range(1, self.n_rasters+1)] + y = np.reshape(y, (1, self.n_pixels*self.n_rasters))[0] + x_bins = x[:self.n_pixels] + y_bins = y[0::self.n_pixels] + + for frame in frames: + self.frame = frame + raw_frame_data = self.load_data() + data2d = Data2D(data=raw_frame_data, qx_data=x, qy_data=y) + data2d.x_bins = x_bins + data2d.y_bins = y_bins + data2d.Q_unit = '' # Using arbitrary units + frame_data.append(data2d) + + return frame_data + + + def __setattr__(self, name, value): + if name == 'filename': + return self.set_filename(value) + elif name == 'n_frames': + return self.set_n_frames(value) + elif name == 'frame': + return self.set_frame(value) + elif name == 'n_pixels': + return self.set_n_pixels(value) + elif name == 'n_rasters': + return self.set_n_rasters(value) + elif name == 'swap_bytes': + return self.set_swap_bytes(value) + return CLoader.__setattr__(self, name, value) + + def __getattr__(self, name): + if name == 'filename': + return self.get_filename() + elif name == 'n_frames': + return self.get_n_frames() + elif name == 'frame': + return self.get_frame() + elif name == 'n_pixels': + return self.get_n_pixels() + elif name == 'n_rasters': + return self.get_n_rasters() + elif name == 'swap_bytes': + return self.get_swap_bytes() + return CLoader.__getattr__(self, name) diff --git a/sas/sascalc/file_converter/c_ext/bsl_loader.c b/sas/sascalc/file_converter/c_ext/bsl_loader.c new file mode 100755 index 000000000..17572ef15 --- /dev/null +++ b/sas/sascalc/file_converter/c_ext/bsl_loader.c @@ -0,0 +1,345 @@ +#include +#include + +//#define Py_LIMITED_API 0x03020000 +#include +#include +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#include + +#include "bsl_loader.h" + +typedef struct { + PyObject_HEAD + CLoader_params params; +} CLoader; + +static PyObject *CLoader_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { + CLoader *self; + + self = (CLoader *)type->tp_alloc(type, 0); + + return (PyObject *)self; +} + +static PyObject *CLoader_init(CLoader *self, PyObject *args, PyObject *kwds) { + const char *filename; + int n_frames; + int n_pixels; + int n_rasters; + int swap_bytes; + + if (self != NULL) { + if (!PyArg_ParseTuple(args, "siiii", &filename, &n_frames, &n_pixels, &n_rasters, &swap_bytes)) + Py_RETURN_NONE; + if (!(self->params.filename = malloc(strlen(filename) + 1))) + Py_RETURN_NONE; + strcpy(self->params.filename, filename); + self->params.n_frames = n_frames; + self->params.n_pixels = n_pixels; + self->params.n_rasters = n_rasters; + self->params.swap_bytes = swap_bytes; + } + + return 0; +} + +static void CLoader_dealloc(CLoader *self) { + free(self->params.filename); + Py_TYPE(self)->tp_free((PyObject *)self); +} + +static PyObject *to_string(CLoader *self, PyObject *params) { + char str[100]; + sprintf(str, + "Filename: %s\nn_frames: %d\nframe: %d\nn_pixels: %d\nn_rasters: %d\nswap_bytes: %d", + self->params.filename, + self->params.n_frames, + self->params.frame, + self->params.n_pixels, + self->params.n_rasters, + self->params.swap_bytes); + return Py_BuildValue("s", str); +} + +/* ----- Setters and Getters ----- */ + +static PyObject *get_filename(CLoader *self, PyObject *args) { + return Py_BuildValue("s", self->params.filename); +} + +static PyObject *set_filename(CLoader *self, PyObject *args) { + const char *new_filename; + if (!PyArg_ParseTuple(args, "s", &new_filename)) + return NULL; + strcpy(self->params.filename, new_filename); + + return Py_BuildValue("s", self->params.filename); +} + +static PyObject *get_n_frames(CLoader *self, PyObject *args) { + return Py_BuildValue("i", self->params.n_frames); +} + +static PyObject *set_n_frames(CLoader *self, PyObject *args) { + int new_frames; + if (!PyArg_ParseTuple(args, "i", &new_frames)) + return NULL; + self->params.n_frames = new_frames; + + return Py_BuildValue("i", self->params.n_frames); +} + +static PyObject *get_frame(CLoader *self, PyObject *args) { + return Py_BuildValue("i", self->params.frame); +} + +static PyObject *set_frame(CLoader *self, PyObject *args) { + int new_frame; + if (!PyArg_ParseTuple(args, "i", &new_frame)) + return NULL; + self->params.frame = new_frame; + + return Py_BuildValue("i", self->params.frame); +} + +static PyObject *get_n_pixels(CLoader *self, PyObject *args) { + return Py_BuildValue("i", self->params.n_pixels); +} + +static PyObject *set_n_pixels(CLoader *self, PyObject *args) { + int new_pixels; + if (!PyArg_ParseTuple(args, "i", &new_pixels)) + return NULL; + self->params.n_pixels = new_pixels; + + return Py_BuildValue("i", self->params.n_pixels); +} + +static PyObject *get_n_rasters(CLoader *self, PyObject *args) { + return Py_BuildValue("i", self->params.n_rasters); +} + +static PyObject *set_n_rasters(CLoader *self, PyObject *args) { + int new_rasters; + if (!PyArg_ParseTuple(args, "i", &new_rasters)) + return NULL; + self->params.n_rasters = new_rasters; + + return Py_BuildValue("i", self->params.n_rasters); +} + +static PyObject *get_swap_bytes(CLoader *self, PyObject *args) { + return Py_BuildValue("i", self->params.swap_bytes); +} + +static PyObject *set_swap_bytes(CLoader *self, PyObject *args) { + int new_swap; + if (!PyArg_ParseTuple(args, "i", &new_swap)) + return NULL; + self->params.swap_bytes = new_swap; + + return Py_BuildValue("i", self->params.swap_bytes); +} + +/* ----- Instance Methods ----- */ + +float reverse_float(const float in_float) { + // Reverse the order of the bytes of a float + float retval; + char *to_convert = (char *)&in_float; + char *return_float = (char *)&retval; + + return_float[0] = to_convert[3]; + return_float[1] = to_convert[2]; + return_float[2] = to_convert[1]; + return_float[3] = to_convert[0]; + + return retval; +} + +static PyObject *load_data(CLoader *self, PyObject *args) { + int raster; + int pixel; + int frame_pos; + npy_intp size[2] = {self->params.n_rasters, self->params.n_pixels}; + float cur_val; + FILE *input_file; + PyArrayObject *data; + + // Create a new numpy array to store the data in + data = (PyArrayObject *)PyArray_SimpleNew(2, size, NPY_FLOAT); + + // Attempt to open the file specified + input_file = fopen(self->params.filename, "rb"); + if (!input_file) { + // BSL filenames are 10 characters long + // Filename validity checked in bsl_loader.py + size_t filename_start = strlen(self->params.filename) - 10; + char *filename = self->params.filename + filename_start; + char *err_msg = (char *)malloc(sizeof(char) * 32); + + sprintf(err_msg, "Unable to open file: %s", filename); + + PyErr_SetString(PyExc_RuntimeError, err_msg); + free(err_msg); + return NULL; + } + + // Move the file cursor the the position where the data we're interested + // in begins + frame_pos = self->params.n_pixels * self->params.n_rasters * self->params.frame; + fseek(input_file, frame_pos*sizeof(float), SEEK_SET); + + for (raster = 0; raster < self->params.n_rasters; raster++) { + for (pixel = 0; pixel < self->params.n_pixels; pixel++) { + // Try reading the file + if (fread(&cur_val, sizeof(float), 1, input_file) == 0) { + PyErr_SetString(PyExc_RuntimeError, "Error reading file or EOF reached."); + return NULL; + } + + // Swap the order of the bytes read, if specified that we should do + // so in the header file + if (self->params.swap_bytes == 0) + cur_val = reverse_float(cur_val); + + // Add the read value to the numpy array + PyArray_SETITEM(data, PyArray_GETPTR2(data, raster, pixel), PyFloat_FromDouble(cur_val)); + } + } + + fclose(input_file); + + return Py_BuildValue("N", data); +} + +/* ----- Class Registration ----- */ + +static PyMethodDef CLoader_methods[] = { + { "to_string", (PyCFunction)to_string, METH_VARARGS, "Print the objects params" }, + { "get_filename", (PyCFunction)get_filename, METH_VARARGS, "Get the filename" }, + { "set_filename", (PyCFunction)set_filename, METH_VARARGS, "Set the filename" }, + { "get_n_frames", (PyCFunction)get_n_frames, METH_VARARGS, "Get n_frames" }, + { "set_n_frames", (PyCFunction)set_n_frames, METH_VARARGS, "Set n_frames" }, + { "get_frame", (PyCFunction)get_frame, METH_VARARGS, "Get the frame that will be loaded" }, + { "set_frame", (PyCFunction)set_frame, METH_VARARGS, "Set the frame that will be loaded" }, + { "get_n_pixels", (PyCFunction)get_n_pixels, METH_VARARGS, "Get n_pixels" }, + { "set_n_pixels", (PyCFunction)set_n_pixels, METH_VARARGS, "Set n_pixels" }, + { "get_n_rasters", (PyCFunction)get_n_rasters, METH_VARARGS, "Get n_rasters" }, + { "set_n_rasters", (PyCFunction)set_n_rasters, METH_VARARGS, "Set n_rasters" }, + { "get_swap_bytes", (PyCFunction)get_swap_bytes, METH_VARARGS, "Get swap_bytes" }, + { "set_swap_bytes", (PyCFunction)set_swap_bytes, METH_VARARGS, "Set swap_bytes" }, + { "load_data", (PyCFunction)load_data, METH_VARARGS, "Load the data into a numpy array" }, + {NULL} +}; + +static PyMemberDef CLoader_members[] = { + {NULL} +}; + +static PyTypeObject CLoaderType = { + //PyObject_HEAD_INIT(NULL) + //0, /*ob_size*/ + PyVarObject_HEAD_INIT(NULL, 0) + "CLoader", /*tp_name*/ + sizeof(CLoader), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)CLoader_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_compare*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash */ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ + "CLoader objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + CLoader_methods, /* tp_methods */ + CLoader_members, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)CLoader_init, /* tp_init */ + 0, /* tp_alloc */ + CLoader_new, /* tp_new */ +}; + +/** + * Function used to add the model class to a module + * @param module: module to add the class to + */ +void addCLoader(PyObject *module) +{ + if (PyType_Ready(&CLoaderType) < 0) + return; + Py_INCREF(&CLoaderType); + PyModule_AddObject(module, "CLoader", (PyObject *)&CLoaderType); +} + +#define MODULE_DOC "C module for loading bsl." +#define MODULE_NAME "_bsl_loader" +#define MODULE_INIT2 init_bsl_loader +#define MODULE_INIT3 PyInit__bsl_loader +#define MODULE_METHODS module_methods + +/* ==== boilerplate python 2/3 interface bootstrap ==== */ + + +#if defined(WIN32) && !defined(__MINGW32__) + #define DLL_EXPORT __declspec(dllexport) +#else + #define DLL_EXPORT +#endif + +#if PY_MAJOR_VERSION >= 3 + + static PyMethodDef module_methods[] = { + {NULL} + }; + + DLL_EXPORT PyMODINIT_FUNC MODULE_INIT3(void) + { + static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + MODULE_NAME, /* m_name */ + MODULE_DOC, /* m_doc */ + -1, /* m_size */ + MODULE_METHODS, /* m_methods */ + NULL, /* m_reload */ + NULL, /* m_traverse */ + NULL, /* m_clear */ + NULL, /* m_free */ + }; + PyObject* m = PyModule_Create(&moduledef); + import_array(); + addCLoader(m); + return m; + } + +#else /* !PY_MAJOR_VERSION >= 3 */ + + DLL_EXPORT PyMODINIT_FUNC MODULE_INIT2(void) + { + PyObject* m = Py_InitModule(MODULE_NAME, NULL); + import_array(); + addCLoader(m); + } + +#endif /* !PY_MAJOR_VERSION >= 3 */ diff --git a/sas/sascalc/file_converter/c_ext/bsl_loader.h b/sas/sascalc/file_converter/c_ext/bsl_loader.h new file mode 100755 index 000000000..ca8602523 --- /dev/null +++ b/sas/sascalc/file_converter/c_ext/bsl_loader.h @@ -0,0 +1,19 @@ +#ifndef bsl_loader_h +#define bsl_loader_h + +typedef struct { + // File to load + char *filename; + // Number of frames in the file + int n_frames; + // Frame to load + int frame; + // Number of pixels in the file + int n_pixels; + // Number of rasters in the file + int n_rasters; + // Whether or not the bytes are in reverse order + int swap_bytes; +} CLoader_params; + +#endif diff --git a/sas/sascalc/file_converter/cansas_writer.py b/sas/sascalc/file_converter/cansas_writer.py new file mode 100755 index 000000000..b519d4de9 --- /dev/null +++ b/sas/sascalc/file_converter/cansas_writer.py @@ -0,0 +1,109 @@ +from sas.sascalc.dataloader.readers.cansas_reader import Reader as CansasReader +from sas.sascalc.dataloader.data_info import Data1D + +import inspect + +class CansasWriter(CansasReader): + + def write(self, filename, frame_data, sasentry_attrs=None): + """ + Write the content of a Data1D as a CanSAS XML file + + :param filename: name of the file to write + :param datainfo: Data1D object + """ + # Create XML document + doc, _ = self._to_xml_doc(frame_data, sasentry_attrs) + # Write the file + file_ref = open(filename, 'w') + if self.encoding is None: + self.encoding = "UTF-8" + doc.write(file_ref, encoding=self.encoding, + pretty_print=True, xml_declaration=True) + file_ref.close() + + + def _to_xml_doc(self, frame_data, sasentry_attrs=None): + """ + Create an XML document to contain the content of an array of Data1Ds + + :param frame_data: An array of Data1D objects + """ + valid_class = all([issubclass(data.__class__, Data1D) for data in frame_data]) + if not valid_class: + raise RuntimeError("The cansas writer expects an array of " + "Data1D instances") + + # Get PIs and create root element + pi_string = self._get_pi_string() + # Define namespaces and create SASroot object + main_node = self._create_main_node() + # Create ElementTree, append SASroot and apply processing instructions + base_string = pi_string + self.to_string(main_node) + base_element = self.create_element_from_string(base_string) + doc = self.create_tree(base_element) + # Create SASentry Element + entry_node = self.create_element("SASentry", sasentry_attrs) + root = doc.getroot() + root.append(entry_node) + + # Use the first element in the array for writing metadata + datainfo = frame_data[0] + # Add Title to SASentry + self.write_node(entry_node, "Title", datainfo.title) + # Add Run to SASentry + self._write_run_names(datainfo, entry_node) + # Add Data info to SASEntry + for data_info in frame_data: + self._write_data(data_info, entry_node) + # Transmission Spectrum Info + self._write_trans_spectrum(datainfo, entry_node) + # Sample info + self._write_sample_info(datainfo, entry_node) + # Instrument info + instr = self._write_instrument(datainfo, entry_node) + # Source + self._write_source(datainfo, instr) + # Collimation + self._write_collimation(datainfo, instr) + # Detectors + self._write_detectors(datainfo, instr) + # Processes info + self._write_process_notes(datainfo, entry_node) + # Note info + self._write_notes(datainfo, entry_node) + # Return the document, and the SASentry node associated with + # the data we just wrote + + return doc, entry_node + + def _write_data(self, datainfo, entry_node): + """ + Writes the I and Q data to the XML file + + :param datainfo: The Data1D object the information is coming from + :param entry_node: lxml node ElementTree object to be appended to + """ + node = self.create_element("SASdata") + self.append(node, entry_node) + + for i in range(len(datainfo.x)): + point = self.create_element("Idata") + node.append(point) + self.write_node(point, "Q", datainfo.x[i], + {'unit': datainfo.x_unit}) + if len(datainfo.y) >= i: + self.write_node(point, "I", datainfo.y[i], + {'unit': datainfo.y_unit}) + if datainfo.dy is not None and len(datainfo.dy) > i: + self.write_node(point, "Idev", datainfo.dy[i], + {'unit': datainfo.y_unit}) + if datainfo.dx is not None and len(datainfo.dx) > i: + self.write_node(point, "Qdev", datainfo.dx[i], + {'unit': datainfo.x_unit}) + if datainfo.dxw is not None and len(datainfo.dxw) > i: + self.write_node(point, "dQw", datainfo.dxw[i], + {'unit': datainfo.x_unit}) + if datainfo.dxl is not None and len(datainfo.dxl) > i: + self.write_node(point, "dQl", datainfo.dxl[i], + {'unit': datainfo.x_unit}) diff --git a/sas/sascalc/file_converter/nxcansas_writer.py b/sas/sascalc/file_converter/nxcansas_writer.py new file mode 100755 index 000000000..28d331954 --- /dev/null +++ b/sas/sascalc/file_converter/nxcansas_writer.py @@ -0,0 +1,368 @@ +""" + NXcanSAS 1/2D data reader for writing HDF5 formatted NXcanSAS files. +""" + +import h5py +import numpy as np +import re +import os + +from sas.sascalc.dataloader.readers.cansas_reader_HDF5 import Reader +from sas.sascalc.dataloader.data_info import Data1D, Data2D + +class NXcanSASWriter(Reader): + """ + A class for writing in NXcanSAS data files. Any number of data sets may be + written to the file. Currently 1D and 2D SAS data sets are supported + + NXcanSAS spec: http://download.nexusformat.org/sphinx/classes/contributed_definitions/NXcanSAS.html + + :Dependencies: + The NXcanSAS writer requires h5py => v2.5.0 or later. + """ + + def write(self, dataset, filename): + """ + Write an array of Data1d or Data2D objects to an NXcanSAS file, as + one SASEntry with multiple SASData elements. The metadata of the first + elememt in the array will be written as the SASentry metadata + (detector, instrument, sample, etc). + + :param dataset: A list of Data1D or Data2D objects to write + :param filename: Where to write the NXcanSAS file + """ + + def _h5_string(string): + """ + Convert a string to a numpy string in a numpy array. This way it is + written to the HDF5 file as a fixed length ASCII string and is + compatible with the Reader read() method. + """ + if isinstance(string, np.ndarray): + return string + elif not isinstance(string, str): + string = str(string) + + return np.array([np.string_(string)]) + + def _write_h5_string(entry, value, key): + entry[key] = _h5_string(value) + + def _h5_float(x): + if not (isinstance(x, list)): + x = [x] + return np.array(x, dtype=np.float32) + + def _write_h5_float(entry, value, key): + entry.create_dataset(key, data=_h5_float(value)) + + def _write_h5_vector(entry, vector, names=['x_position', 'y_position'], + units=None, write_fn=_write_h5_string): + """ + Write a vector to an h5 entry + + :param entry: The H5Py entry to write to + :param vector: The Vector to write + :param names: What to call the x,y and z components of the vector + when writing to the H5Py entry + :param units: The units of the vector (optional) + :param write_fn: A function to convert the value to the required + format and write it to the H5Py entry, of the form + f(entry, value, name) (optional) + """ + if len(names) < 2: + raise ValueError("Length of names must be >= 2.") + + if vector.x is not None: + write_fn(entry, vector.x, names[0]) + if units is not None: + entry[names[0]].attrs['units'] = units + if vector.y is not None: + write_fn(entry, vector.y, names[1]) + if units is not None: + entry[names[1]].attrs['units'] = units + if len(names) == 3 and vector.z is not None: + write_fn(entry, vector.z, names[2]) + if units is not None: + entry[names[2]].attrs['units'] = units + + valid_data = all([isinstance(d, (Data1D, Data2D)) for d in dataset]) + if not valid_data: + raise ValueError("All entries of dataset must be Data1D or Data2D" + "objects") + + # Get run name and number from first Data object + data_info = dataset[0] + run_number = '' + run_name = '' + if len(data_info.run) > 0: + run_number = data_info.run[0] + if len(data_info.run_name) > 0: + run_name = data_info.run_name[run_number] + + f = h5py.File(filename, 'w') + sasentry = f.create_group('sasentry01') + sasentry['definition'] = _h5_string('NXcanSAS') + sasentry['run'] = _h5_string(run_number) + sasentry['run'].attrs['name'] = run_name + sasentry['title'] = _h5_string(data_info.title) + sasentry.attrs['canSAS_class'] = 'SASentry' + sasentry.attrs['version'] = '1.0' + + for i, data_obj in enumerate(dataset): + data_entry = sasentry.create_group("sasdata{0:0=2d}".format(i+1)) + data_entry.attrs['canSAS_class'] = 'SASdata' + if isinstance(data_obj, Data1D): + self._write_1d_data(data_obj, data_entry) + elif isinstance(data_obj, Data2D): + self._write_2d_data(data_obj, data_entry) + + data_info = dataset[0] + # Sample metadata + sample_entry = sasentry.create_group('sassample') + sample_entry.attrs['canSAS_class'] = 'SASsample' + sample_entry['ID'] = _h5_string(data_info.sample.name) + sample_attrs = ['thickness', 'temperature', 'transmission'] + for key in sample_attrs: + if getattr(data_info.sample, key) is not None: + sample_entry.create_dataset(key, + data=_h5_float(getattr(data_info.sample, key))) + _write_h5_vector(sample_entry, data_info.sample.position) + # NXcanSAS doesn't save information about pitch, only roll + # and yaw. The _write_h5_vector method writes vector.y, but we + # need to write vector.z for yaw + data_info.sample.orientation.y = data_info.sample.orientation.z + _write_h5_vector(sample_entry, data_info.sample.orientation, + names=['polar_angle', 'azimuthal_angle']) + if data_info.sample.details is not None\ + and data_info.sample.details != []: + details = None + if len(data_info.sample.details) > 1: + details = [np.string_(d) for d in data_info.sample.details] + details = np.array(details) + elif data_info.sample.details != []: + details = _h5_string(data_info.sample.details[0]) + if details is not None: + sample_entry.create_dataset('details', data=details) + + # Instrument metadata + instrument_entry = sasentry.create_group('sasinstrument') + instrument_entry.attrs['canSAS_class'] = 'SASinstrument' + instrument_entry['name'] = _h5_string(data_info.instrument) + + # Source metadata + source_entry = instrument_entry.create_group('sassource') + source_entry.attrs['canSAS_class'] = 'SASsource' + if data_info.source.radiation is None: + source_entry['radiation'] = _h5_string('neutron') + else: + source_entry['radiation'] = _h5_string(data_info.source.radiation) + if data_info.source.beam_shape is not None: + source_entry['beam_shape'] = _h5_string(data_info.source.beam_shape) + wavelength_keys = { 'wavelength': 'incident_wavelength', + 'wavelength_min':'wavelength_min', + 'wavelength_max': 'wavelength_max', + 'wavelength_spread': 'incident_wavelength_spread' } + for sasname, nxname in wavelength_keys.items(): + value = getattr(data_info.source, sasname) + units = getattr(data_info.source, sasname + '_unit') + if value is not None: + source_entry[nxname] = _h5_float(value) + source_entry[nxname].attrs['units'] = units + _write_h5_vector(source_entry, data_info.source.beam_size, + names=['beam_size_x', 'beam_size_y'], + units=data_info.source.beam_size_unit, write_fn=_write_h5_float) + + # Collimation metadata + if len(data_info.collimation) > 0: + for i, coll_info in enumerate(data_info.collimation): + collimation_entry = instrument_entry.create_group( + 'sascollimation{0:0=2d}'.format(i + 1)) + collimation_entry.attrs['canSAS_class'] = 'SAScollimation' + if coll_info.length is not None: + _write_h5_float(collimation_entry, coll_info.length, 'SDD') + collimation_entry['SDD'].attrs['units'] =\ + coll_info.length_unit + if coll_info.name is not None: + collimation_entry['name'] = _h5_string(coll_info.name) + else: + # Create a blank one - at least 1 collimation required by format + instrument_entry.create_group('sascollimation01') + + # Detector metadata + if len(data_info.detector) > 0: + i = 1 + for i, det_info in enumerate(data_info.detector): + detector_entry = instrument_entry.create_group( + 'sasdetector{0:0=2d}'.format(i + 1)) + detector_entry.attrs['canSAS_class'] = 'SASdetector' + if det_info.distance is not None: + _write_h5_float(detector_entry, det_info.distance, 'SDD') + detector_entry['SDD'].attrs['units'] =\ + det_info.distance_unit + if det_info.name is not None: + detector_entry['name'] = _h5_string(det_info.name) + else: + detector_entry['name'] = _h5_string('') + if det_info.slit_length is not None: + _write_h5_float(detector_entry, det_info.slit_length, + 'slit_length') + detector_entry['slit_length'].attrs['units'] =\ + det_info.slit_length_unit + _write_h5_vector(detector_entry, det_info.offset) + # NXcanSAS doesn't save information about pitch, only roll + # and yaw. The _write_h5_vector method writes vector.y, but we + # need to write vector.z for yaw + det_info.orientation.y = det_info.orientation.z + _write_h5_vector(detector_entry, det_info.orientation, + names=['polar_angle', 'azimuthal_angle']) + _write_h5_vector(detector_entry, det_info.beam_center, + names=['beam_center_x', 'beam_center_y'], + write_fn=_write_h5_float, units=det_info.beam_center_unit) + _write_h5_vector(detector_entry, det_info.pixel_size, + names=['x_pixel_size', 'y_pixel_size'], + write_fn=_write_h5_float, units=det_info.pixel_size_unit) + else: + # Create a blank one - at least 1 detector required by format + detector_entry = instrument_entry.create_group('sasdetector01') + detector_entry.attrs['canSAS_class'] = 'SASdetector' + detector_entry.attrs['name'] = '' + + # Process meta data + for i, process in enumerate(data_info.process): + process_entry = sasentry.create_group('sasprocess{0:0=2d}'.format( + i + 1)) + process_entry.attrs['canSAS_class'] = 'SASprocess' + if process.name: + name = _h5_string(process.name) + process_entry.create_dataset('name', data=name) + if process.date: + date = _h5_string(process.date) + process_entry.create_dataset('date', data=date) + if process.description: + desc = _h5_string(process.description) + process_entry.create_dataset('description', data=desc) + for j, term in enumerate(process.term): + # Don't save empty terms + if term: + h5_term = _h5_string(term) + process_entry.create_dataset('term{0:0=2d}'.format( + j + 1), data=h5_term) + for j, note in enumerate(process.notes): + # Don't save empty notes + if note: + h5_note = _h5_string(note) + process_entry.create_dataset('note{0:0=2d}'.format( + j + 1), data=h5_note) + + # Transmission Spectrum + for i, trans in enumerate(data_info.trans_spectrum): + trans_entry = sasentry.create_group( + 'sastransmission_spectrum{0:0=2d}'.format(i + 1)) + trans_entry.attrs['canSAS_class'] = 'SAStransmission_spectrum' + trans_entry.attrs['signal'] = 'T' + trans_entry.attrs['T_axes'] = 'T' + trans_entry.attrs['name'] = trans.name + if trans.timestamp is not '': + trans_entry.attrs['timestamp'] = trans.timestamp + transmission = trans_entry.create_dataset('T', + data=trans.transmission) + transmission.attrs['unertainties'] = 'Tdev' + trans_entry.create_dataset('Tdev', + data=trans.transmission_deviation) + trans_entry.create_dataset('lambda', data=trans.wavelength) + + note_entry = sasentry.create_group('sasnote'.format(i)) + note_entry.attrs['canSAS_class'] = 'SASnote' + notes = None + if len(data_info.notes) > 1: + notes = [np.string_(n) for n in data_info.notes] + notes = np.array(notes) + elif data_info.notes != []: + notes = _h5_string(data_info.notes[0]) + if notes is not None: + note_entry.create_dataset('SASnote', data=notes) + + f.close() + + def _write_1d_data(self, data_obj, data_entry): + """ + Writes the contents of a Data1D object to a SASdata h5py Group + + :param data_obj: A Data1D object to write to the file + :param data_entry: A h5py Group object representing the SASdata + """ + data_entry.attrs['signal'] = 'I' + data_entry.attrs['I_axes'] = 'Q' + data_entry.attrs['Q_indices'] = [0] + q_entry = data_entry.create_dataset('Q', data=data_obj.x) + q_entry.attrs['units'] = data_obj.x_unit + i_entry = data_entry.create_dataset('I', data=data_obj.y) + i_entry.attrs['units'] = data_obj.y_unit + if data_obj.dy is not None: + i_entry.attrs['uncertainties'] = 'Idev' + i_dev_entry = data_entry.create_dataset('Idev', data=data_obj.dy) + i_dev_entry.attrs['units'] = data_obj.y_unit + if data_obj.dx is not None: + q_entry.attrs['resolutions'] = 'dQ' + dq_entry = data_entry.create_dataset('dQ', data=data_obj.dx) + dq_entry.attrs['units'] = data_obj.x_unit + elif data_obj.dxl is not None: + q_entry.attrs['resolutions'] = ['dQl','dQw'] + dql_entry = data_entry.create_dataset('dQl', data=data_obj.dxl) + dql_entry.attrs['units'] = data_obj.x_unit + dqw_entry = data_entry.create_dataset('dQw', data=data_obj.dxw) + dqw_entry.attrs['units'] = data_obj.x_unit + + def _write_2d_data(self, data, data_entry): + """ + Writes the contents of a Data2D object to a SASdata h5py Group + + :param data: A Data2D object to write to the file + :param data_entry: A h5py Group object representing the SASdata + """ + data_entry.attrs['signal'] = 'I' + data_entry.attrs['I_axes'] = 'Qx,Qy' + data_entry.attrs['Q_indices'] = [0,1] + + (n_rows, n_cols) = (len(data.y_bins), len(data.x_bins)) + + if (n_rows == 0 and n_cols == 0) or (n_cols*n_rows != data.data.size): + # Calculate rows and columns, assuming detector is square + # Same logic as used in PlotPanel.py _get_bins + n_cols = int(np.floor(np.sqrt(len(data.qy_data)))) + n_rows = int(np.floor(len(data.qy_data) / n_cols)) + + if n_rows * n_cols != len(data.qy_data): + raise ValueError("Unable to calculate dimensions of 2D data") + + intensity = np.reshape(data.data, (n_rows, n_cols)) + qx = np.reshape(data.qx_data, (n_rows, n_cols)) + qy = np.reshape(data.qy_data, (n_rows, n_cols)) + + i_entry = data_entry.create_dataset('I', data=intensity) + i_entry.attrs['units'] = data.I_unit + qx_entry = data_entry.create_dataset('Qx', data=qx) + qx_entry.attrs['units'] = data.Q_unit + qy_entry = data_entry.create_dataset('Qy', data=qy) + qy_entry.attrs['units'] = data.Q_unit + if (data.err_data is not None + and not all(v is None for v in data.err_data)): + d_i = np.reshape(data.err_data, (n_rows, n_cols)) + i_entry.attrs['uncertainties'] = 'Idev' + i_dev_entry = data_entry.create_dataset('Idev', data=d_i) + i_dev_entry.attrs['units'] = data.I_unit + if (data.dqx_data is not None + and not all(v is None for v in data.dqx_data)): + qx_entry.attrs['resolutions'] = 'dQx' + dqx_entry = data_entry.create_dataset('dQx', data=data.dqx_data) + dqx_entry.attrs['units'] = data.Q_unit + if (data.dqy_data is not None + and not all(v is None for v in data.dqy_data)): + qy_entry.attrs['resolutions'] = 'dQy' + dqy_entry = data_entry.create_dataset('dQy', data=data.dqy_data) + dqy_entry.attrs['units'] = data.Q_unit + if data.mask is not None and not all(v is None for v in data.mask): + data_entry.attrs['mask'] = "mask" + mask = np.invert(np.asarray(data.mask, dtype=bool)) + data_entry.create_dataset('mask', data=mask) diff --git a/sas/sascalc/file_converter/otoko_loader.py b/sas/sascalc/file_converter/otoko_loader.py new file mode 100755 index 000000000..3c747c4e8 --- /dev/null +++ b/sas/sascalc/file_converter/otoko_loader.py @@ -0,0 +1,150 @@ +""" +Here we handle loading of "OTOKO" data (for more info about this format see +the comment in load_otoko_data). Given the paths of header and data files, we +aim to load the data into numpy arrays for use later. +""" + +import itertools +import os +import struct +import numpy as np + +class CStyleStruct: + """A nice and easy way to get "C-style struct" functionality.""" + def __init__(self, **kwds): + self.__dict__.update(kwds) + +class OTOKOParsingError(Exception): + pass + +class OTOKOData: + def __init__(self, q_axis, data_axis): + self.q_axis = q_axis + self.data_axis = data_axis + +class OTOKOLoader(object): + + def __init__(self, qaxis_path, data_path): + self.qaxis_path = qaxis_path + self.data_path = data_path + + def load_otoko_data(self): + """ + Loads "OTOKO" data, which is a format that stores each axis separately. + An axis is represented by a "header" file, which in turn will give details + of one or more binary files where the actual data is stored. + + Given the paths of two header files, this function will load each axis in + turn. If loading is successful then an instance of the OTOKOData class + will be returned, else an exception will be raised. + + For more information on the OTOKO file format, please see: + http://www.diamond.ac.uk/Home/Beamlines/small-angle/SAXS-Software/CCP13/ + XOTOKO.html + """ + q_axis = self._load_otoko_axis(self.qaxis_path) + data_axis = self._load_otoko_axis(self.data_path) + + return OTOKOData(q_axis, data_axis) + + def _load_otoko_axis(self, header_path): + """ + Loads an "OTOKO" axis, given the header file path. Essentially, the + header file contains information about the data in the form of integer + "indicators", as well as the names of each of the binary files which are + assumed to be in the same directory as the header. + """ + if not os.path.exists(header_path): + raise OTOKOParsingError("The header file %s does not exist." % header_path) + + binary_file_info_list = [] + total_frames = 0 + header_dir = os.path.dirname(os.path.abspath(header_path)) + + with open(header_path, "r") as header_file: + lines = header_file.readlines() + if len(lines) < 4: + raise OTOKOParsingError("Expected more lines in %s." % header_path) + + info = lines[0] + lines[1] + + def pairwise(iterable): + """ + s -> (s0,s1), (s2,s3), (s4, s5), ... + From http://stackoverflow.com/a/5389547/778572 + """ + a = iter(iterable) + return itertools.izip(a, a) + + for indicators, filename in pairwise(lines[2:]): + indicators = indicators.split() + + if len(indicators) != 10: + raise OTOKOParsingError( + "Expected 10 integer indicators on line 3 of %s." \ + % header_path) + if not all([i.isdigit() for i in indicators]): + raise OTOKOParsingError( + "Expected all indicators on line 3 of %s to be integers." \ + % header_path) + + binary_file_info = CStyleStruct( + # The indicators at indices 4 to 8 are always zero since they + # have been reserved for future use by the format. Also, the + # "last_file" indicator seems to be there for legacy reasons, + # as it doesn't appear to be something we have to bother + # enforcing correct use of; we just define the last file as + # being the last file in the list. + file_path = os.path.join(header_dir, filename.strip()), + n_channels = int(indicators[0]), + n_frames = int(indicators[1]), + dimensions = int(indicators[2]), + swap_bytes = int(indicators[3]) == 0, + last_file = int(indicators[9]) == 0 # We don't use this. + ) + if binary_file_info.dimensions != 1: + msg = "File {} has {} dimensions, expected 1. Is it a BSL file?" + raise OTOKOParsingError(msg.format(filename.strip(), + binary_file_info.dimensions)) + + binary_file_info_list.append(binary_file_info) + + total_frames += binary_file_info.n_frames + + # Check that all binary files are listed in the header as having the same + # number of channels, since I don't think CorFunc can handle ragged data. + all_n_channels = [info.n_channels for info in binary_file_info_list] + if not all(all_n_channels[0] == c for c in all_n_channels): + raise OTOKOParsingError( + "Expected all binary files listed in %s to have the same number of channels." % header_path) + + data = np.zeros(shape=(total_frames, all_n_channels[0])) + frames_so_far = 0 + + for info in binary_file_info_list: + if not os.path.exists(info.file_path): + raise OTOKOParsingError( + "The data file %s does not exist." % info.file_path) + + with open(info.file_path, "rb") as binary_file: + # Ideally we'd like to use numpy's fromfile() to read in binary + # data, but we are forced to roll our own float-by-float file + # reader because of the rules imposed on us by the file format; + # namely, if the swap indicator flag has been raised then the bytes + # of each float occur in reverse order. + for frame in range(info.n_frames): + for channel in range(info.n_channels): + b = bytes(binary_file.read(4)) + if info.swap_bytes: + b = b[::-1] # "Extended slice" syntax, used to reverse. + value = struct.unpack('f', b)[0] + data[frames_so_far + frame][channel] = value + + frames_so_far += info.n_frames + + return CStyleStruct( + header_path = header_path, + data = data, + binary_file_info_list = binary_file_info_list, + header_info = info + ) diff --git a/sas/sascalc/file_converter/red2d_writer.py b/sas/sascalc/file_converter/red2d_writer.py new file mode 100755 index 000000000..8de558b16 --- /dev/null +++ b/sas/sascalc/file_converter/red2d_writer.py @@ -0,0 +1,36 @@ +import os +import time +from sas.sascalc.dataloader.readers.red2d_reader import Reader as Red2DReader + +class Red2DWriter(Red2DReader): + + def write(self, filename, data, thread): + """ + Write to .dat + + :param filename: file name to write + :param data: data2D + """ + # Write the file + fd = open(filename, 'w') + t = time.localtime() + time_str = time.strftime("%H:%M on %b %d %y", t) + + header_str = "Data columns are Qx - Qy - I(Qx,Qy)\n\nASCII data" + header_str += " created at %s \n\n" % time_str + # simple 2D header + fd.write(header_str) + # write qx qy I values + for i in range(len(data.data)): + if thread.isquit(): + fd.close() + os.remove(filename) + return False + + fd.write("%g %g %g\n" % (data.qx_data[i], + data.qy_data[i], + data.data[i])) + + fd.close() + + return True diff --git a/sas/sascalc/fit/AbstractFitEngine.py b/sas/sascalc/fit/AbstractFitEngine.py new file mode 100755 index 000000000..e82a583fc --- /dev/null +++ b/sas/sascalc/fit/AbstractFitEngine.py @@ -0,0 +1,630 @@ +from __future__ import print_function + +import copy +#import logging +import sys +import math +import numpy as np + +from sas.sascalc.dataloader.data_info import Data1D +from sas.sascalc.dataloader.data_info import Data2D +_SMALLVALUE = 1.0e-10 + +class FitHandler(object): + """ + Abstract interface for fit thread handler. + + The methods in this class are called by the optimizer as the fit + progresses. + + Note that it is up to the optimizer to call the fit handler correctly, + reporting all status changes and maintaining the 'done' flag. + """ + done = False + """True when the fit job is complete""" + result = None + """The current best result of the fit""" + + def improvement(self): + """ + Called when a result is observed which is better than previous + results from the fit. + + result is a FitResult object, with parameters, #calls and fitness. + """ + def error(self, msg): + """ + Model had an error; print traceback + """ + def progress(self, current, expected): + """ + Called each cycle of the fit, reporting the current and the + expected amount of work. The meaning of these values is + optimizer dependent, but they can be converted into a percent + complete using (100*current)//expected. + + Progress is updated each iteration of the fit, whatever that + means for the particular optimization algorithm. It is called + after any calls to improvement for the iteration so that the + update handler can control I/O bandwidth by suppressing + intermediate improvements until the fit is complete. + """ + def finalize(self): + """ + Fit is complete; best results are reported + """ + def abort(self): + """ + Fit was aborted. + """ + + # TODO: not sure how these are used, but they are needed for running the fit + def update_fit(self, last=False): pass + def set_result(self, result=None): self.result = result + +class Model: + """ + Fit wrapper for SAS models. + """ + def __init__(self, sas_model, sas_data=None, **kw): + """ + :param sas_model: the sas model to wrap for fitting + + """ + self.model = sas_model + self.name = sas_model.name + self.data = sas_data + + def get_params(self, fitparams): + """ + return a list of value of parameter to fit + + :param fitparams: list of parameters name to fit + + """ + return [self.model.getParam(k) for k in fitparams] + + def set_params(self, paramlist, params): + """ + Set value for parameters to fit + + :param params: list of value for parameters to fit + + """ + for k,v in zip(paramlist, params): + self.model.setParam(k,v) + + def set(self, **kw): + self.set_params(*zip(*kw.items())) + + def eval(self, x): + """ + Override eval method of model. + + :param x: the x value used to compute a function + """ + try: + return self.model.evalDistribution(x) + except: + raise + + def eval_derivs(self, x, pars=[]): + """ + Evaluate the model and derivatives wrt pars at x. + + pars is a list of the names of the parameters for which derivatives + are desired. + + This method needs to be specialized in the model to evaluate the + model function. Alternatively, the model can implement is own + version of residuals which calculates the residuals directly + instead of calling eval. + """ + raise NotImplementedError('no derivatives available') + + def __call__(self, x): + return self.eval(x) + +class FitData1D(Data1D): + """ + Wrapper class for SAS data + FitData1D inherits from DataLoader.data_info.Data1D. Implements + a way to get residuals from data. + """ + def __init__(self, x, y, dx=None, dy=None, smearer=None, data=None, lam=None, dlam=None): + """ + :param smearer: is an object of class QSmearer or SlitSmearer + that will smear the theory data (slit smearing or resolution + smearing) when set. + + The proper way to set the smearing object would be to + do the following: :: + + from sas.sascalc.fit.qsmearing import smear_selection + smearer = smear_selection(some_data) + fitdata1d = FitData1D( x= [1,3,..,], + y= [3,4,..,8], + dx=None, + dy=[1,2...], smearer= smearer) + + :Note: that some_data _HAS_ to be of + class DataLoader.data_info.Data1D + Setting it back to None will turn smearing off. + + """ + Data1D.__init__(self, x=x, y=y, dx=dx, dy=dy, lam=lam, dlam=dlam) + self.num_points = len(x) + self.sas_data = data + self.smearer = smearer + self._first_unsmeared_bin = None + self._last_unsmeared_bin = None + # Check error bar; if no error bar found, set it constant(=1) + # TODO: Should provide an option for users to set it like percent, + # constant, or dy data + if dy is None or dy == [] or dy.all() == 0: + self.dy = np.ones(len(y)) + else: + self.dy = np.asarray(dy).copy() + + ## Min Q-value + #Skip the Q=0 point, especially when y(q=0)=None at x[0]. + if min(self.x) == 0.0 and self.x[0] == 0 and\ + not np.isfinite(self.y[0]): + self.qmin = min(self.x[self.x != 0]) + else: + self.qmin = min(self.x) + ## Max Q-value + self.qmax = max(self.x) + + # Range used for input to smearing + self._qmin_unsmeared = self.qmin + self._qmax_unsmeared = self.qmax + # Identify the bin range for the unsmeared and smeared spaces + self.idx = (self.x >= self.qmin) & (self.x <= self.qmax) + self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \ + & (self.x <= self._qmax_unsmeared) + + def set_fit_range(self, qmin=None, qmax=None): + """ to set the fit range""" + # Skip Q=0 point, (especially for y(q=0)=None at x[0]). + # ToDo: Find better way to do it. + if qmin == 0.0 and not np.isfinite(self.y[qmin]): + self.qmin = min(self.x[self.x != 0]) + elif qmin is not None: + self.qmin = qmin + if qmax is not None: + self.qmax = qmax + # Determine the range needed in unsmeared-Q to cover + # the smeared Q range + self._qmin_unsmeared = self.qmin + self._qmax_unsmeared = self.qmax + + self._first_unsmeared_bin = 0 + self._last_unsmeared_bin = len(self.x) - 1 + + if self.smearer is not None: + self._first_unsmeared_bin, self._last_unsmeared_bin = \ + self.smearer.get_bin_range(self.qmin, self.qmax) + self._qmin_unsmeared = self.x[self._first_unsmeared_bin] + self._qmax_unsmeared = self.x[self._last_unsmeared_bin] + + # Identify the bin range for the unsmeared and smeared spaces + self.idx = (self.x >= self.qmin) & (self.x <= self.qmax) + ## zero error can not participate for fitting + self.idx = self.idx & (self.dy != 0) + self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \ + & (self.x <= self._qmax_unsmeared) + + def get_fit_range(self): + """ + Return the range of data.x to fit + """ + return self.qmin, self.qmax + + def size(self): + """ + Number of measurement points in data set after masking, etc. + """ + return len(self.x) + + def residuals(self, fn): + """ + Compute residuals. + + If self.smearer has been set, use if to smear + the data before computing chi squared. + + :param fn: function that return model value + + :return: residuals + """ + # Compute theory data f(x) + fx = np.zeros(len(self.x)) + fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared]) + + ## Smear theory data + if self.smearer is not None: + fx = self.smearer(fx, self._first_unsmeared_bin, + self._last_unsmeared_bin) + ## Sanity check + if np.size(self.dy) != np.size(fx): + msg = "FitData1D: invalid error array " + msg += "%d <> %d" % (np.shape(self.dy), np.size(fx)) + raise RuntimeError(msg) + return (self.y[self.idx] - fx[self.idx]) / self.dy[self.idx], fx[self.idx] + + def residuals_deriv(self, model, pars=[]): + """ + :return: residuals derivatives . + + :note: in this case just return empty array + """ + return [] + + +class FitData2D(Data2D): + """ + Wrapper class for SAS data + """ + def __init__(self, sas_data2d, data=None, err_data=None): + Data2D.__init__(self, data=data, err_data=err_data) + # Data can be initialized with a sas plottable or with vectors. + self.res_err_image = [] + self.num_points = 0 # will be set by set_data + self.idx = [] + self.qmin = None + self.qmax = None + self.smearer = None + self.radius = 0 + self.res_err_data = [] + self.sas_data = sas_data2d + self.set_data(sas_data2d) + + def set_data(self, sas_data2d, qmin=None, qmax=None): + """ + Determine the correct qx_data and qy_data within range to fit + """ + self.data = sas_data2d.data + self.err_data = sas_data2d.err_data + self.qx_data = sas_data2d.qx_data + self.qy_data = sas_data2d.qy_data + self.mask = sas_data2d.mask + + x_max = max(math.fabs(sas_data2d.xmin), math.fabs(sas_data2d.xmax)) + y_max = max(math.fabs(sas_data2d.ymin), math.fabs(sas_data2d.ymax)) + + ## fitting range + if qmin is None: + self.qmin = 1e-16 + if qmax is None: + self.qmax = math.sqrt(x_max * x_max + y_max * y_max) + ## new error image for fitting purpose + if self.err_data is None or self.err_data == []: + self.res_err_data = np.ones(len(self.data)) + else: + self.res_err_data = copy.deepcopy(self.err_data) + #self.res_err_data[self.res_err_data==0]=1 + + self.radius = np.sqrt(self.qx_data**2 + self.qy_data**2) + + # Note: mask = True: for MASK while mask = False for NOT to mask + self.idx = ((self.qmin <= self.radius) &\ + (self.radius <= self.qmax)) + self.idx = (self.idx) & (self.mask) + self.idx = (self.idx) & (np.isfinite(self.data)) + self.num_points = np.sum(self.idx) + + def set_smearer(self, smearer): + """ + Set smearer + """ + if smearer is None: + return + self.smearer = smearer + self.smearer.set_index(self.idx) + self.smearer.get_data() + + def set_fit_range(self, qmin=None, qmax=None): + """ + To set the fit range + """ + if qmin == 0.0: + self.qmin = 1e-16 + elif qmin is not None: + self.qmin = qmin + if qmax is not None: + self.qmax = qmax + self.radius = np.sqrt(self.qx_data**2 + self.qy_data**2) + self.idx = ((self.qmin <= self.radius) &\ + (self.radius <= self.qmax)) + self.idx = (self.idx) & (self.mask) + self.idx = (self.idx) & (np.isfinite(self.data)) + self.idx = (self.idx) & (self.res_err_data != 0) + + def get_fit_range(self): + """ + return the range of data.x to fit + """ + return self.qmin, self.qmax + + def size(self): + """ + Number of measurement points in data set after masking, etc. + """ + return np.sum(self.idx) + + def residuals(self, fn): + """ + return the residuals + """ + if self.smearer is not None: + fn.set_index(self.idx) + gn = fn.get_value() + else: + gn = fn([self.qx_data[self.idx], + self.qy_data[self.idx]]) + # use only the data point within ROI range + res = (self.data[self.idx] - gn) / self.res_err_data[self.idx] + + return res, gn + + def residuals_deriv(self, model, pars=[]): + """ + :return: residuals derivatives . + + :note: in this case just return empty array + + """ + return [] + + +class FitAbort(Exception): + """ + Exception raise to stop the fit + """ + #pass + #print"Creating fit abort Exception" + + + +class FitEngine: + def __init__(self): + """ + Base class for the fit engine + """ + #Dictionnary of fitArrange element (fit problems) + self.fit_arrange_dict = {} + self.fitter_id = None + + def set_model(self, model, id, pars=[], constraints=[], data=None): + """ + set a model on a given in the fit engine. + + :param model: sas.models type + :param id: is the key of the fitArrange dictionary where model is saved as a value + :param pars: the list of parameters to fit + :param constraints: list of + tuple (name of parameter, value of parameters) + the value of parameter must be a string to constraint 2 different + parameters. + Example: + we want to fit 2 model M1 and M2 both have parameters A and B. + constraints can be ``constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]`` + + + :note: pars must contains only name of existing model's parameters + + """ + if not pars: + raise ValueError("no fitting parameters") + + if model is None: + raise ValueError("no model to fit") + + if not issubclass(model.__class__, Model): + model = Model(model, data) + + sasmodel = model.model + available_parameters = sasmodel.getParamList() + for p in pars: + if p not in available_parameters: + raise ValueError("parameter %s not available in model %s; use one of [%s] instead" + %(p, sasmodel.name, ", ".join(available_parameters))) + + if id not in self.fit_arrange_dict: + self.fit_arrange_dict[id] = FitArrange() + + self.fit_arrange_dict[id].set_model(model) + self.fit_arrange_dict[id].pars = pars + self.fit_arrange_dict[id].vals = [sasmodel.getParam(name) for name in pars] + self.fit_arrange_dict[id].constraints = constraints + + def set_data(self, data, id, smearer=None, qmin=None, qmax=None): + """ + Receives plottable, creates a list of data to fit,set data + in a FitArrange object and adds that object in a dictionary + with key id. + + :param data: data added + :param id: unique key corresponding to a fitArrange object with data + """ + if data.__class__.__name__ == 'Data2D': + fitdata = FitData2D(sas_data2d=data, data=data.data, + err_data=data.err_data) + else: + fitdata = FitData1D(x=data.x, y=data.y, + dx=data.dx, dy=data.dy, smearer=smearer) + fitdata.sas_data = data + + fitdata.set_fit_range(qmin=qmin, qmax=qmax) + #A fitArrange is already created but contains model only at id + if id in self.fit_arrange_dict: + self.fit_arrange_dict[id].add_data(fitdata) + else: + #no fitArrange object has been create with this id + fitproblem = FitArrange() + fitproblem.add_data(fitdata) + self.fit_arrange_dict[id] = fitproblem + + def get_model(self, id): + """ + :param id: id is key in the dictionary containing the model to return + + :return: a model at this id or None if no FitArrange element was + created with this id + """ + if id in self.fit_arrange_dict: + return self.fit_arrange_dict[id].get_model() + else: + return None + + def remove_fit_problem(self, id): + """remove fitarrange in id""" + if id in self.fit_arrange_dict: + del self.fit_arrange_dict[id] + + def select_problem_for_fit(self, id, value): + """ + select a couple of model and data at the id position in dictionary + and set in self.selected value to value + + :param value: the value to allow fitting. + can only have the value one or zero + """ + if id in self.fit_arrange_dict: + self.fit_arrange_dict[id].set_to_fit(value) + + def get_problem_to_fit(self, id): + """ + return the self.selected value of the fit problem of id + + :param id: the id of the problem + """ + if id in self.fit_arrange_dict: + self.fit_arrange_dict[id].get_to_fit() + + +class FitArrange: + def __init__(self): + """ + Class FitArrange contains a set of data for a given model + to perform the Fit.FitArrange must contain exactly one model + and at least one data for the fit to be performed. + + model: the model selected by the user + Ldata: a list of data what the user wants to fit + + """ + self.model = None + self.data_list = [] + self.pars = [] + self.vals = [] + self.selected = 0 + + def set_model(self, model): + """ + set_model save a copy of the model + + :param model: the model being set + """ + self.model = model + + def add_data(self, data): + """ + add_data fill a self.data_list with data to fit + + :param data: Data to add in the list + """ + if not data in self.data_list: + self.data_list.append(data) + + def get_model(self): + """ + :return: saved model + """ + return self.model + + def get_data(self): + """ + :return: list of data data_list + """ + return self.data_list[0] + + def remove_data(self, data): + """ + Remove one element from the list + + :param data: Data to remove from data_list + """ + if data in self.data_list: + self.data_list.remove(data) + + def set_to_fit(self, value=0): + """ + set self.selected to 0 or 1 for other values raise an exception + + :param value: integer between 0 or 1 + """ + self.selected = value + + def get_to_fit(self): + """ + return self.selected value + """ + return self.selected + +class FResult(object): + """ + Storing fit result + """ + def __init__(self, model=None, param_list=None, data=None): + self.calls = None + self.fitness = None + self.chisqr = None + self.pvec = [] + self.cov = [] + self.info = None + self.mesg = None + self.success = None + self.stderr = None + self.residuals = [] + self.index = [] + self.model = model + self.data = data + self.theory = [] + self.param_list = param_list + self.iterations = 0 + self.inputs = [] + self.fitter_id = None + if self.model is not None and self.data is not None: + self.inputs = [(self.model, self.data)] + + def set_model(self, model): + """ + """ + self.model = model + + def set_fitness(self, fitness): + """ + """ + self.fitness = fitness + + def __str__(self): + """ + """ + if self.pvec is None and self.model is None and self.param_list is None: + return "No results" + + sasmodel = self.model.model + pars = enumerate(sasmodel.getParamList()) + msg1 = "[Iteration #: %s ]" % self.iterations + msg3 = "=== goodness of fit: %s ===" % (str(self.fitness)) + msg2 = ["P%-3d %s......|.....%s" % (i, v, sasmodel.getParam(v)) + for i,v in pars if v in self.param_list] + msg = [msg1, msg3] + msg2 + return "\n".join(msg) + + def print_summary(self): + """ + """ + print(str(self)) diff --git a/sas/sascalc/fit/BumpsFitting.py b/sas/sascalc/fit/BumpsFitting.py new file mode 100755 index 000000000..e61734089 --- /dev/null +++ b/sas/sascalc/fit/BumpsFitting.py @@ -0,0 +1,381 @@ +""" +BumpsFitting module runs the bumps optimizer. +""" +import os +from datetime import timedelta, datetime +import traceback + +import numpy as np + +from bumps import fitters +try: + from bumps.options import FIT_CONFIG + # Default bumps to use the Levenberg-Marquardt optimizer + FIT_CONFIG.selected_id = fitters.LevenbergMarquardtFit.id + def get_fitter(): + return FIT_CONFIG.selected_fitter, FIT_CONFIG.selected_values +except ImportError: + # CRUFT: Bumps changed its handling of fit options around 0.7.5.6 + # Default bumps to use the Levenberg-Marquardt optimizer + fitters.FIT_DEFAULT = 'lm' + def get_fitter(): + fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT] + return fitopts.fitclass, fitopts.options.copy() + + +from bumps.mapper import SerialMapper, MPMapper +from bumps import parameter +from bumps.fitproblem import FitProblem + + +from sas.sascalc.fit.AbstractFitEngine import FitEngine +from sas.sascalc.fit.AbstractFitEngine import FResult +from sas.sascalc.fit.expression import compile_constraints + +class Progress(object): + def __init__(self, history, max_step, pars, dof): + remaining_time = int(history.time[0]*(float(max_step)/history.step[0]-1)) + # Depending on the time remaining, either display the expected + # time of completion, or the amount of time remaining. Use precision + # appropriate for the duration. + if remaining_time >= 1800: + completion_time = datetime.now() + timedelta(seconds=remaining_time) + if remaining_time >= 36000: + time = completion_time.strftime('%Y-%m-%d %H:%M') + else: + time = completion_time.strftime('%H:%M') + else: + if remaining_time >= 3600: + time = '%dh %dm'%(remaining_time//3600, (remaining_time%3600)//60) + elif remaining_time >= 60: + time = '%dm %ds'%(remaining_time//60, remaining_time%60) + else: + time = '%ds'%remaining_time + chisq = "%.3g"%(2*history.value[0]/dof) + step = "%d of %d"%(history.step[0], max_step) + header = "=== Steps: %s chisq: %s ETA: %s\n"%(step, chisq, time) + parameters = ["%15s: %-10.3g%s"%(k,v,("\n" if i%3==2 else " | ")) + for i, (k, v) in enumerate(zip(pars, history.point[0]))] + self.msg = "".join([header]+parameters) + + def __str__(self): + return self.msg + + +class BumpsMonitor(object): + def __init__(self, handler, max_step, pars, dof): + self.handler = handler + self.max_step = max_step + self.pars = pars + self.dof = dof + + def config_history(self, history): + history.requires(time=1, value=2, point=1, step=1) + + def __call__(self, history): + if self.handler is None: return + self.handler.set_result(Progress(history, self.max_step, self.pars, self.dof)) + self.handler.progress(history.step[0], self.max_step) + if len(history.step) > 1 and history.step[1] > history.step[0]: + self.handler.improvement() + self.handler.update_fit() + +class ConvergenceMonitor(object): + """ + ConvergenceMonitor contains population summary statistics to show progress + of the fit. This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or + just a list [ (best, ) ] if population size is 1. + """ + def __init__(self): + self.convergence = [] + + def config_history(self, history): + history.requires(value=1, population_values=1) + + def __call__(self, history): + best = history.value[0] + try: + p = history.population_values[0] + n, p = len(p), np.sort(p) + QI, Qmid = int(0.2*n), int(0.5*n) + self.convergence.append((best, p[0], p[QI], p[Qmid], p[-1-QI], p[-1])) + except Exception: + self.convergence.append((best, best, best, best, best, best)) + + +# Note: currently using bumps parameters for each parameter object so that +# a SasFitness can be used directly in bumps with the usual semantics. +# The disadvantage of this technique is that we need to copy every parameter +# back into the model each time the function is evaluated. We could instead +# define reference parameters for each sas parameter, but then we would not +# be able to express constraints using python expressions in the usual way +# from bumps, and would instead need to use string expressions. +class SasFitness(object): + """ + Wrap SAS model as a bumps fitness object + """ + def __init__(self, model, data, fitted=[], constraints={}, + initial_values=None, **kw): + self.name = model.name + self.model = model.model + self.data = data + if self.data.smearer is not None: + self.data.smearer.model = self.model + self._define_pars() + self._init_pars(kw) + if initial_values is not None: + self._reset_pars(fitted, initial_values) + self.constraints = dict(constraints) + self.set_fitted(fitted) + self.update() + + def _reset_pars(self, names, values): + for k, v in zip(names, values): + self._pars[k].value = v + + def _define_pars(self): + self._pars = {} + for k in self.model.getParamList(): + name = ".".join((self.name, k)) + value = self.model.getParam(k) + bounds = self.model.details.get(k, ["", None, None])[1:3] + self._pars[k] = parameter.Parameter(value=value, bounds=bounds, + fixed=True, name=name) + #print parameter.summarize(self._pars.values()) + + def _init_pars(self, kw): + for k, v in kw.items(): + # dispersion parameters initialized with _field instead of .field + if k.endswith('_width'): + k = k[:-6]+'.width' + elif k.endswith('_npts'): + k = k[:-5]+'.npts' + elif k.endswith('_nsigmas'): + k = k[:-7]+'.nsigmas' + elif k.endswith('_type'): + k = k[:-5]+'.type' + if k not in self._pars: + formatted_pars = ", ".join(sorted(self._pars.keys())) + raise KeyError("invalid parameter %r for %s--use one of: %s" + %(k, self.model, formatted_pars)) + if '.' in k and not k.endswith('.width'): + self.model.setParam(k, v) + elif isinstance(v, parameter.BaseParameter): + self._pars[k] = v + elif isinstance(v, (tuple, list)): + low, high = v + self._pars[k].value = (low+high)/2 + self._pars[k].range(low, high) + else: + self._pars[k].value = v + + def set_fitted(self, param_list): + """ + Flag a set of parameters as fitted parameters. + """ + for k, p in self._pars.items(): + p.fixed = (k not in param_list or k in self.constraints) + self.fitted_par_names = [k for k in param_list if k not in self.constraints] + self.computed_par_names = [k for k in param_list if k in self.constraints] + self.fitted_pars = [self._pars[k] for k in self.fitted_par_names] + self.computed_pars = [self._pars[k] for k in self.computed_par_names] + + # ===== Fitness interface ==== + def parameters(self): + return self._pars + + def update(self): + for k, v in self._pars.items(): + #print "updating",k,v,v.value + self.model.setParam(k, v.value) + self._dirty = True + + def _recalculate(self): + if self._dirty: + self._residuals, self._theory \ + = self.data.residuals(self.model.evalDistribution) + self._dirty = False + + def numpoints(self): + return np.sum(self.data.idx) # number of fitted points + + def nllf(self): + return 0.5*np.sum(self.residuals()**2) + + def theory(self): + self._recalculate() + return self._theory + + def residuals(self): + self._recalculate() + return self._residuals + + # Not implementing the data methods for now: + # + # resynth_data/restore_data/save/plot + +class ParameterExpressions(object): + def __init__(self, models): + self.models = models + self._setup() + + def _setup(self): + exprs = {} + for M in self.models: + exprs.update((".".join((M.name, k)), v) for k, v in M.constraints.items()) + if exprs: + symtab = dict((".".join((M.name, k)), p) + for M in self.models + for k, p in M.parameters().items()) + self.update = compile_constraints(symtab, exprs) + else: + self.update = lambda: 0 + + def __call__(self): + self.update() + + def __getstate__(self): + return self.models + + def __setstate__(self, state): + self.models = state + self._setup() + +class BumpsFit(FitEngine): + """ + Fit a model using bumps. + """ + def __init__(self): + """ + Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements + with Uid as keys + """ + FitEngine.__init__(self) + self.curr_thread = None + + def fit(self, msg_q=None, + q=None, handler=None, curr_thread=None, + ftol=1.49012e-8, reset_flag=False): + # Build collection of bumps fitness calculators + models = [SasFitness(model=M.get_model(), + data=M.get_data(), + constraints=M.constraints, + fitted=M.pars, + initial_values=M.vals if reset_flag else None) + for M in self.fit_arrange_dict.values() + if M.get_to_fit()] + if len(models) == 0: + raise RuntimeError("Nothing to fit") + problem = FitProblem(models) + + # TODO: need better handling of parameter expressions and bounds constraints + # so that they are applied during polydispersity calculations. This + # will remove the immediate need for the setp_hook in bumps, though + # bumps may still need something similar, such as a sane class structure + # which allows a subclass to override setp. + problem.setp_hook = ParameterExpressions(models) + + # Run the fit + result = run_bumps(problem, handler, curr_thread) + if handler is not None: + handler.update_fit(last=True) + + # TODO: shouldn't reference internal parameters of fit problem + varying = problem._parameters + # collect the results + all_results = [] + for M in problem.models: + fitness = M.fitness + fitted_index = [varying.index(p) for p in fitness.fitted_pars] + param_list = fitness.fitted_par_names + fitness.computed_par_names + R = FResult(model=fitness.model, data=fitness.data, + param_list=param_list) + R.theory = fitness.theory() + R.residuals = fitness.residuals() + R.index = fitness.data.idx + R.fitter_id = self.fitter_id + # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown + R.success = result['success'] + if R.success: + if result['stderr'] is None: + R.stderr = np.NaN*np.ones(len(param_list)) + else: + R.stderr = np.hstack((result['stderr'][fitted_index], + np.NaN*np.ones(len(fitness.computed_pars)))) + R.pvec = np.hstack((result['value'][fitted_index], + [p.value for p in fitness.computed_pars])) + R.fitness = np.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index)) + else: + R.stderr = np.NaN*np.ones(len(param_list)) + R.pvec = np.asarray([p.value for p in fitness.fitted_pars+fitness.computed_pars]) + R.fitness = np.NaN + R.convergence = result['convergence'] + if result['uncertainty'] is not None: + R.uncertainty_state = result['uncertainty'] + all_results.append(R) + all_results[0].mesg = result['errors'] + + if q is not None: + q.put(all_results) + return q + else: + return all_results + +def run_bumps(problem, handler, curr_thread): + def abort_test(): + if curr_thread is None: return False + try: curr_thread.isquit() + except KeyboardInterrupt: + if handler is not None: + handler.stop("Fitting: Terminated!!!") + return True + return False + + fitclass, options = get_fitter() + steps = options.get('steps', 0) + if steps == 0: + pop = options.get('pop', 0)*len(problem._parameters) + samples = options.get('samples', 0) + steps = (samples+pop-1)/pop if pop != 0 else samples + max_step = steps + options.get('burn', 0) + pars = [p.name for p in problem._parameters] + #x0 = np.asarray([p.value for p in problem._parameters]) + options['monitors'] = [ + BumpsMonitor(handler, max_step, pars, problem.dof), + ConvergenceMonitor(), + ] + fitdriver = fitters.FitDriver(fitclass, problem=problem, + abort_test=abort_test, **options) + omp_threads = int(os.environ.get('OMP_NUM_THREADS', '0')) + mapper = MPMapper if omp_threads == 1 else SerialMapper + fitdriver.mapper = mapper.start_mapper(problem, None) + #import time; T0 = time.time() + try: + best, fbest = fitdriver.fit() + errors = [] + except Exception as exc: + best, fbest = None, np.NaN + errors = [str(exc), traceback.format_exc()] + finally: + mapper.stop_mapper(fitdriver.mapper) + + + convergence_list = options['monitors'][-1].convergence + convergence = (2*np.asarray(convergence_list)/problem.dof + if convergence_list else np.empty((0, 1), 'd')) + + success = best is not None + try: + stderr = fitdriver.stderr() if success else None + except Exception as exc: + errors.append(str(exc)) + errors.append(traceback.format_exc()) + stderr = None + return { + 'value': best if success else None, + 'stderr': stderr, + 'success': success, + 'convergence': convergence, + 'uncertainty': getattr(fitdriver.fitter, 'state', None), + 'errors': '\n'.join(errors), + } diff --git a/sas/sascalc/fit/Loader.py b/sas/sascalc/fit/Loader.py new file mode 100755 index 000000000..f2a99877e --- /dev/null +++ b/sas/sascalc/fit/Loader.py @@ -0,0 +1,87 @@ +from __future__ import print_function + +# class Loader to load any king of file +#import wx +#import string +import numpy as np + +class Load: + """ + This class is loading values from given file or value giving by the user + """ + def __init__(self, x=None, y=None, dx=None, dy=None): + raise NotImplementedError("a code search shows that this code is not active, and you are not seeing this message") + # variable to store loaded values + self.x = x + self.y = y + self.dx = dx + self.dy = dy + self.filename = None + + def set_filename(self, path=None): + """ + Store path into a variable.If the user doesn't give + a path as a parameter a pop-up + window appears to select the file. + + :param path: the path given by the user + + """ + self.filename = path + + def get_filename(self): + """ return the file's path""" + return self.filename + + def set_values(self): + """ Store the values loaded from file in local variables""" + if self.filename is not None: + input_f = open(self.filename, 'r') + buff = input_f.read() + lines = buff.split('\n') + self.x = [] + self.y = [] + self.dx = [] + self.dy = [] + for line in lines: + try: + toks = line.split() + x = float(toks[0]) + y = float(toks[1]) + dy = float(toks[2]) + + self.x.append(x) + self.y.append(y) + self.dy.append(dy) + self.dx = np.zeros(len(self.x)) + except: + print("READ ERROR", line) + # Sanity check + if not len(self.x) == len(self.dx): + raise ValueError("x and dx have different length") + if not len(self.y) == len(self.dy): + raise ValueError("y and dy have different length") + + + def get_values(self): + """ Return x, y, dx, dy""" + return self.x, self.y, self.dx, self.dy + + def load_data(self, data): + """ Return plottable""" + #load data + data.x = self.x + data.y = self.y + data.dx = self.dx + data.dy = self.dy + #Load its View class + #plottable.reset_view() + + +if __name__ == "__main__": + load = Load() + load.set_filename("testdata_line.txt") + print(load.get_filename()) + load.set_values() + print(load.get_values()) + diff --git a/sas/sascalc/fit/MultiplicationModel.py b/sas/sascalc/fit/MultiplicationModel.py new file mode 100755 index 000000000..375549526 --- /dev/null +++ b/sas/sascalc/fit/MultiplicationModel.py @@ -0,0 +1,335 @@ +import copy + +import numpy as np + +from sas.sascalc.calculator.BaseComponent import BaseComponent + +class MultiplicationModel(BaseComponent): + r""" + Use for P(Q)\*S(Q); function call must be in the order of P(Q) and then S(Q): + The model parameters are combined from both models, P(Q) and S(Q), except 1) 'radius_effective' of S(Q) + which will be calculated from P(Q) via calculate_ER(), + and 2) 'scale' in P model which is synchronized w/ volfraction in S + then P*S is multiplied by a new parameter, 'scale_factor'. + The polydispersion is applicable only to P(Q), not to S(Q). + + .. note:: P(Q) refers to 'form factor' model while S(Q) does to 'structure factor'. + """ + def __init__(self, p_model, s_model ): + BaseComponent.__init__(self) + """ + :param p_model: form factor, P(Q) + :param s_model: structure factor, S(Q) + """ + + ## Setting model name model description + self.description = "" + self.name = p_model.name +" * "+ s_model.name + self.description= self.name + "\n" + self.fill_description(p_model, s_model) + + ## Define parameters + self.params = {} + + ## Parameter details [units, min, max] + self.details = {} + + ## Define parameters to exclude from multiplication model + self.excluded_params={'radius_effective','scale','background'} + + ##models + self.p_model = p_model + self.s_model = s_model + self.magnetic_params = [] + ## dispersion + self._set_dispersion() + ## Define parameters + self._set_params() + ## New parameter:Scaling factor + self.params['scale_factor'] = 1 + self.params['background'] = 0 + + ## Parameter details [units, min, max] + self._set_details() + self.details['scale_factor'] = ['', 0.0, np.inf] + self.details['background'] = ['',-np.inf,np.inf] + + #list of parameter that can be fitted + self._set_fixed_params() + ## parameters with orientation + for item in self.p_model.orientation_params: + self.orientation_params.append(item) + for item in self.p_model.magnetic_params: + self.magnetic_params.append(item) + for item in self.s_model.orientation_params: + if not item in self.orientation_params: + self.orientation_params.append(item) + # get multiplicity if model provide it, else 1. + try: + multiplicity = p_model.multiplicity + except AttributeError: + multiplicity = 1 + ## functional multiplicity of the model + self.multiplicity = multiplicity + + # non-fittable parameters + self.non_fittable = p_model.non_fittable + self.multiplicity_info = [] + self.fun_list = [] + if self.non_fittable > 1: + try: + self.multiplicity_info = p_model.multiplicity_info + self.fun_list = p_model.fun_list + self.is_multiplicity_model = True + except AttributeError: + pass + else: + self.is_multiplicity_model = False + self.multiplicity_info = [0] + + def _clone(self, obj): + """ + Internal utility function to copy the internal data members to a + fresh copy. + """ + obj.params = copy.deepcopy(self.params) + obj.description = copy.deepcopy(self.description) + obj.details = copy.deepcopy(self.details) + obj.dispersion = copy.deepcopy(self.dispersion) + obj.p_model = self.p_model.clone() + obj.s_model = self.s_model.clone() + #obj = copy.deepcopy(self) + return obj + + + def _set_dispersion(self): + """ + combine the two models' dispersions. Polydispersity should not be + applied to s_model + """ + ##set dispersion only from p_model + for name , value in self.p_model.dispersion.items(): + self.dispersion[name] = value + + def getProfile(self): + """ + Get SLD profile of p_model if exists + + :return: (r, beta) where r is a list of radius of the transition points\ + beta is a list of the corresponding SLD values + + .. note:: This works only for func_shell num = 2 (exp function). + """ + try: + x, y = self.p_model.getProfile() + except: + x = None + y = None + + return x, y + + def _set_params(self): + """ + Concatenate the parameters of the two models to create + these model parameters + """ + + for name , value in self.p_model.params.items(): + if not name in self.params.keys() and name not in self.excluded_params: + self.params[name] = value + + for name , value in self.s_model.params.items(): + #Remove the radius_effective from the (P*S) model parameters. + if not name in self.params.keys() and name not in self.excluded_params: + self.params[name] = value + + # Set "scale and effec_radius to P and S model as initializing + # since run P*S comes from P and S separately. + self._set_backgrounds() + self._set_scale_factor() + self._set_radius_effective() + + def _set_details(self): + """ + Concatenate details of the two models to create + this model's details + """ + for name, detail in self.p_model.details.items(): + if name not in self.excluded_params: + self.details[name] = detail + + for name , detail in self.s_model.details.items(): + if not name in self.details.keys() or name not in self.exluded_params: + self.details[name] = detail + + def _set_backgrounds(self): + """ + Set component backgrounds to zero + """ + if 'background' in self.p_model.params: + self.p_model.setParam('background',0) + if 'background' in self.s_model.params: + self.s_model.setParam('background',0) + + + def _set_scale_factor(self): + """ + Set scale=volfraction for P model + """ + value = self.params['volfraction'] + if value is not None: + factor = self.p_model.calculate_VR() + if factor is None or factor == NotImplemented or factor == 0.0: + val = value + else: + val = value / factor + self.p_model.setParam('scale', value) + self.s_model.setParam('volfraction', val) + + def _set_radius_effective(self): + """ + Set effective radius to S(Q) model + """ + if not 'radius_effective' in self.s_model.params.keys(): + return + effective_radius = self.p_model.calculate_ER() + #Reset the effective_radius of s_model just before the run + if effective_radius is not None and effective_radius != NotImplemented: + self.s_model.setParam('radius_effective', effective_radius) + + def setParam(self, name, value): + """ + Set the value of a model parameter + + :param name: name of the parameter + :param value: value of the parameter + """ + # set param to P*S model + self._setParamHelper( name, value) + + ## setParam to p model + # set 'scale' in P(Q) equal to volfraction + if name == 'volfraction': + self._set_scale_factor() + elif name in self.p_model.getParamList() and name not in self.excluded_params: + self.p_model.setParam( name, value) + + ## setParam to s model + # This is a little bit abundant: Todo: find better way + self._set_radius_effective() + if name in self.s_model.getParamList() and name not in self.excluded_params: + if name != 'volfraction': + self.s_model.setParam( name, value) + + + #self._setParamHelper( name, value) + + def _setParamHelper(self, name, value): + """ + Helper function to setparam + """ + # Look for dispersion parameters + toks = name.split('.') + if len(toks)==2: + for item in self.dispersion.keys(): + if item.lower()==toks[0].lower(): + for par in self.dispersion[item]: + if par.lower() == toks[1].lower(): + self.dispersion[item][par] = value + return + else: + # Look for standard parameter + for item in self.params.keys(): + if item.lower() == name.lower(): + self.params[item] = value + return + + raise ValueError("Model does not contain parameter %s" % name) + + + def _set_fixed_params(self): + """ + Fill the self.fixed list with the p_model fixed list + """ + for item in self.p_model.fixed: + self.fixed.append(item) + + self.fixed.sort() + + + def run(self, x = 0.0): + """ + Evaluate the model + + :param x: input q-value (float or [float, float] as [r, theta]) + :return: (scattering function value) + """ + # set effective radius and scaling factor before run + self._set_radius_effective() + self._set_scale_factor() + return self.params['scale_factor'] * self.p_model.run(x) * \ + self.s_model.run(x) + self.params['background'] + + def runXY(self, x = 0.0): + """ + Evaluate the model + + :param x: input q-value (float or [float, float] as [qx, qy]) + :return: scattering function value + """ + # set effective radius and scaling factor before run + self._set_radius_effective() + self._set_scale_factor() + out = self.params['scale_factor'] * self.p_model.runXY(x) * \ + self.s_model.runXY(x) + self.params['background'] + return out + + ## Now (May27,10) directly uses the model eval function + ## instead of the for-loop in Base Component. + def evalDistribution(self, x = []): + """ + Evaluate the model in cartesian coordinates + + :param x: input q[], or [qx[], qy[]] + :return: scattering function P(q[]) + """ + # set effective radius and scaling factor before run + self._set_radius_effective() + self._set_scale_factor() + out = self.params['scale_factor'] * self.p_model.evalDistribution(x) * \ + self.s_model.evalDistribution(x) + self.params['background'] + return out + + def set_dispersion(self, parameter, dispersion): + """ + Set the dispersion object for a model parameter + + :param parameter: name of the parameter [string] + :dispersion: dispersion object of type DispersionModel + """ + value = None + try: + if parameter in self.p_model.dispersion.keys(): + value = self.p_model.set_dispersion(parameter, dispersion) + self._set_dispersion() + return value + except: + raise + + def fill_description(self, p_model, s_model): + """ + Fill the description for P(Q)*S(Q) + """ + description = "" + description += "Note:1) The radius_effective (effective radius) of %s \n"%\ + (s_model.name) + description += " is automatically calculated " + description += "from size parameters (radius...).\n" + description += " 2) For non-spherical shape, " + description += "this approximation is valid \n" + description += " only for limited systems. " + description += "Thus, use it at your own risk.\n" + description += "See %s description and %s description \n"% \ + ( p_model.name, s_model.name ) + description += " for details of individual models." + self.description += description diff --git a/sas/sascalc/fit/__init__.py b/sas/sascalc/fit/__init__.py new file mode 100755 index 000000000..03d374576 --- /dev/null +++ b/sas/sascalc/fit/__init__.py @@ -0,0 +1 @@ +from .AbstractFitEngine import FitHandler \ No newline at end of file diff --git a/sas/sascalc/fit/expression.py b/sas/sascalc/fit/expression.py new file mode 100755 index 000000000..f913319b3 --- /dev/null +++ b/sas/sascalc/fit/expression.py @@ -0,0 +1,405 @@ +# This program is public domain +""" +Parameter expression evaluator. + +For systems in which constraints are expressed as string expressions rather +than python code, :func:`compile_constraints` can construct an expression +evaluator that substitutes the computed values of the expressions into the +parameters. + +The compiler requires a symbol table, an expression set and a context. +The symbol table maps strings containing fully qualified names such as +'M1.c[3].full_width' to parameter objects with a 'value' property that +can be queried and set. The expression set maps symbol names from the +symbol table to string expressions. The context provides additional symbols +for the expressions in addition to the usual mathematical functions and +constants. + +The expressions are compiled and interpreted by python, with only minimal +effort to make sure that they don't contain bad code. The resulting +constraints function returns 0 so it can be used directly in a fit problem +definition. + +Extracting the symbol table from the model depends on the structure of the +model. If fitness.parameters() is set correctly, then this should simply +be a matter of walking the parameter data, remembering the path to each +parameter in the symbol table. For compactness, dictionary elements should +be referenced by .name rather than ["name"]. Model name can be used as the +top level. + +Getting the parameter expressions applied correctly is challenging. +The following monkey patch works by overriding model_update in FitProblem +so that after setp(p) is called and, the constraints expression can be +applied before telling the underlying fitness function that the model +is out of date:: + + # Override model update so that parameter constraints are applied + problem._model_update = problem.model_update + def model_update(): + constraints() + problem._model_update() + problem.model_update = model_update + +Ideally, this interface will change +""" +from __future__ import print_function + +import math +import re + +# simple pattern which matches symbols. Note that it will also match +# invalid substrings such as a3...9, but given syntactically correct +# input it will only match symbols. +_symbol_pattern = re.compile('([a-zA-Z_][a-zA-Z_0-9.]*)') + +def _symbols(expr,symtab): + """ + Given an expression string and a symbol table, return the set of symbols + used in the expression. Symbols are only returned once even if they + occur multiple times. The return value is a set with the elements in + no particular order. + + This is the first step in computing a dependency graph. + """ + matches = [m.group(0) for m in _symbol_pattern.finditer(expr)] + return set([symtab[m] for m in matches if m in symtab]) + +def _substitute(expr,mapping): + """ + Replace all occurrences of symbol s with mapping[s] for s in mapping. + """ + # Find the symbols and the mapping + matches = [(m.start(),m.end(),mapping[m.group(1)]) + for m in _symbol_pattern.finditer(expr) + if m.group(1) in mapping] + + # Split the expression in to pieces, with new symbols replacing old + pieces = [] + offset = 0 + for start,end,text in matches: + pieces += [expr[offset:start],text] + offset = end + pieces.append(expr[offset:]) + + # Join the pieces and return them + return "".join(pieces) + +def _find_dependencies(symtab, exprs): + """ + Returns a list of pair-wise dependencies from the parameter expressions. + + For example, if p3 = p1+p2, then find_dependencies([p1,p2,p3]) will + return [(p3,p1),(p3,p2)]. For base expressions without dependencies, + such as p4 = 2*pi, this should return [(p4, None)] + """ + deps = [(target,source) + for target,expr in exprs.items() + for source in _symbols_or_none(expr,symtab)] + return deps + +# Hack to deal with expressions without dependencies --- return a fake +# dependency of None. +# The better solution is fix order_dependencies so that it takes a +# dictionary of {symbol: dependency_list}, for which no dependencies +# is simply []; fix in parameter_mapping as well +def _symbols_or_none(expr,symtab): + syms = _symbols(expr,symtab) + return syms if len(syms) else [None] + +def _parameter_mapping(pairs): + """ + Find the parameter substitution we need so that expressions can + be evaluated without having to traverse a chain of + model.layer.parameter.value + """ + left,right = zip(*pairs) + pars = list(sorted(p for p in set(left+right) if p is not None)) + definition = dict( ('P%d'%i,p) for i,p in enumerate(pars) ) + # p is None when there is an expression with no dependencies + substitution = dict( (p,'P%d.value'%i) + for i,p in enumerate(sorted(pars)) + if p is not None) + return definition, substitution + +def no_constraints(): + """ + This parameter set has no constraints between the parameters. + """ + pass + +def compile_constraints(symtab, exprs, context={}): + """ + Build and return a function to evaluate all parameter expressions in + the proper order. + + Input: + + *symtab* is the symbol table for the model: { 'name': parameter } + + *exprs* is the set of computed symbols: { 'name': 'expression' } + + *context* is any additional context needed to evaluate the expression + + Return: + + updater function which sets parameter.value for each expression + + Raises: + + AssertionError - model, parameter or function is missing + + SyntaxError - improper expression syntax + + ValueError - expressions have circular dependencies + + This function is not terribly sophisticated, and it would be easy to + trick. However it handles the common cases cleanly and generates + reasonable messages for the common errors. + + This code has not been fully audited for security. While we have + removed the builtins and the ability to import modules, there may + be other vectors for users to perform more than simple function + evaluations. Unauthenticated users should not be running this code. + + Parameter names are assumed to contain only _.a-zA-Z0-9#[] + + Both names are provided for inverse functions, e.g., acos and arccos. + + Should try running the function to identify syntax errors before + running it in a fit. + + Use help(fn) to see the code generated for the returned function fn. + dis.dis(fn) will show the corresponding python vm instructions. + """ + + # Sort the parameters in the order they need to be evaluated + deps = _find_dependencies(symtab, exprs) + if deps == []: return no_constraints + order = order_dependencies(deps) + + + # Rather than using the full path to the parameters in the parameter + # expressions, instead use Pn, and substitute Pn.value for each occurrence + # of the parameter in the expression. + names = list(sorted(symtab.keys())) + parameters = dict(('P%d'%i, symtab[k]) for i,k in enumerate(names)) + mapping = dict((k, 'P%d.value'%i) for i,k in enumerate(names)) + + + # Initialize dictionary with available functions + globals = {} + globals.update(math.__dict__) + globals.update(dict(arcsin=math.asin,arccos=math.acos, + arctan=math.atan,arctan2=math.atan2)) + globals.update(context) + globals.update(parameters) + globals['id'] = id + locals = {} + + # Define the constraints function + assignments = ["=".join((p,exprs[p])) for p in order] + code = [_substitute(s, mapping) for s in assignments] + functiondef = """ +def eval_expressions(): + ''' + %s + ''' + %s + return 0 +"""%("\n ".join(assignments),"\n ".join(code)) + + #print("Function: "+functiondef) + exec functiondef in globals,locals + retfn = locals['eval_expressions'] + + # Remove garbage added to globals by exec + globals.pop('__doc__',None) + globals.pop('__name__',None) + globals.pop('__file__',None) + globals.pop('__builtins__') + #print globals.keys() + + return retfn + +def order_dependencies(pairs): + """ + Order elements from pairs so that b comes before a in the + ordered list for all pairs (a,b). + """ + #print "order_dependencies",pairs + emptyset = set() + order = [] + + # Break pairs into left set and right set + left,right = [set(s) for s in zip(*pairs)] if pairs != [] else ([],[]) + while pairs != []: + #print "within",pairs + # Find which items only occur on the right + independent = right - left + if independent == emptyset: + cycleset = ", ".join(str(s) for s in left) + raise ValueError("Cyclic dependencies amongst %s"%cycleset) + + # The possibly resolvable items are those that depend on the independents + dependent = set([a for a,b in pairs if b in independent]) + pairs = [(a,b) for a,b in pairs if b not in independent] + if pairs == []: + resolved = dependent + else: + left,right = [set(s) for s in zip(*pairs)] + resolved = dependent - left + #print "independent",independent,"dependent",dependent,"resolvable",resolved + order += resolved + #print "new order",order + order.reverse() + return order + +# ========= Test code ======== +def _check(msg,pairs): + """ + Verify that the list n contains the given items, and that the list + satisfies the partial ordering given by the pairs in partial order. + """ + left,right = zip(*pairs) if pairs != [] else ([],[]) + items = set(left) + n = order_dependencies(pairs) + if set(n) != items or len(n) != len(items): + n.sort() + items = list(items); items.sort() + raise ValueError("%s expect %s to contain %s for %s"%(msg,n,items,pairs)) + for lo,hi in pairs: + if lo in n and hi in n and n.index(lo) >= n.index(hi): + raise ValueError("%s expect %s before %s in %s for %s"%(msg,lo,hi,n,pairs)) + +def test_deps(): + import numpy as np + + # Null case + _check("test empty",[]) + + # Some dependencies + _check("test1",[(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)]) + _check("test1 renumbered",[(6,1),(7,3),(7,4),(6,7),(5,7),(3,2)]) + _check("test1 numpy",np.array([(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)])) + + # No dependencies + _check("test2",[(4,1),(3,2),(8,4)]) + + # Cycle test + pairs = [(1,4),(4,3),(4,5),(5,1)] + try: + n = order_dependencies(pairs) + except ValueError: + pass + else: + raise ValueError("test3 expect ValueError exception for %s"%(pairs,)) + + # large test for gross speed check + A = np.random.randint(4000,size=(1000,2)) + A[:,1] += 4000 # Avoid cycles + _check("test-large",A) + + # depth tests + k = 200 + A = np.array([range(0,k),range(1,k+1)]).T + _check("depth-1",A) + + A = np.array([range(1,k+1),range(0,k)]).T + _check("depth-2",A) + +def test_expr(): + import inspect, dis + import math + + symtab = {'a.b.x':1, 'a.c':2, 'a.b':3, 'b.x':4} + expr = 'a.b.x + sin(4*pi*a.c) + a.b.x/a.b' + + # Check symbol lookup + assert _symbols(expr, symtab) == set([1,2,3]) + + # Check symbol rename + assert _substitute(expr,{'a.b.x':'Q'}) == 'Q + sin(4*pi*a.c) + Q/a.b' + assert _substitute(expr,{'a.b':'Q'}) == 'a.b.x + sin(4*pi*a.c) + a.b.x/Q' + + + # Check dependency builder + # Fake parameter class + class Parameter: + def __init__(self, name, value=0, expression=''): + self.path = name + self.value = value + self.expression = expression + def iscomputed(self): return (self.expression != '') + def __repr__(self): return self.path + def world(*pars): + symtab = dict((p.path,p) for p in pars) + exprs = dict((p.path,p.expression) for p in pars if p.iscomputed()) + return symtab, exprs + p1 = Parameter('G0.sigma',5) + p2 = Parameter('other',expression='2*pi*sin(G0.sigma/.1875) + M1.G1') + p3 = Parameter('M1.G1',6) + p4 = Parameter('constant',expression='2*pi*35') + # Simple chain + assert set(_find_dependencies(*world(p1,p2,p3))) == set([(p2.path,p1),(p2.path,p3)]) + # Constant expression + assert set(_find_dependencies(*world(p1,p4))) == set([(p4.path,None)]) + # No dependencies + assert set(_find_dependencies(*world(p1,p3))) == set([]) + + # Check function builder + fn = compile_constraints(*world(p1,p2,p3)) + + # Inspect the resulting function + if 0: + print(inspect.getdoc(fn)) + print(dis.dis(fn)) + + # Evaluate the function and see if it updates the + # target value as expected + fn() + expected = 2*math.pi*math.sin(5/.1875) + 6 + assert p2.value == expected,"Value was %s, not %s"%(p2.value,expected) + + # Check empty dependency set doesn't crash + fn = compile_constraints(*world(p1,p3)) + fn() + + # Check that constants are evaluated properly + fn = compile_constraints(*world(p4)) + fn() + assert p4.value == 2*math.pi*35 + + # Check additional context example; this also tests multiple + # expressions + class Table: + Si = 2.09 + values = {'Si': 2.07} + tbl = Table() + p5 = Parameter('lookup',expression="tbl.Si") + fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl)) + fn() + assert p5.value == 2.09,"Value for %s was %s"%(p5.expression,p5.value) + p5.expression = "tbl.values['Si']" + fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl)) + fn() + assert p5.value == 2.07,"Value for %s was %s"%(p5.expression,p5.value) + + + # Verify that we capture invalid expressions + for expr in ['G4.cage', 'M0.cage', 'M1.G1 + *2', + 'piddle', + '5; import sys; print "p0wned"', + '__import__("sys").argv']: + try: + p6 = Parameter('broken',expression=expr) + fn = compile_constraints(*world(p6)) + fn() + except Exception as msg: + #print(msg) + pass + else: + raise "Failed to raise error for %s"%expr + +if __name__ == "__main__": + test_expr() + test_deps() diff --git a/sas/sascalc/fit/models.py b/sas/sascalc/fit/models.py new file mode 100755 index 000000000..3f9b4aa99 --- /dev/null +++ b/sas/sascalc/fit/models.py @@ -0,0 +1,336 @@ +""" + Utilities to manage models +""" +from __future__ import print_function + +import os +import sys +import time +import datetime +import logging +import traceback +import py_compile +import shutil + +from sasmodels.sasview_model import load_custom_model, load_standard_models + +from sas import get_user_dir + +# Explicitly import from the pluginmodel module so that py2exe +# places it in the distribution. The Model1DPlugin class is used +# as the base class of plug-in models. +from .pluginmodel import Model1DPlugin + +logger = logging.getLogger(__name__) + + +PLUGIN_DIR = 'plugin_models' +PLUGIN_LOG = os.path.join(get_user_dir(), PLUGIN_DIR, "plugins.log") +PLUGIN_NAME_BASE = '[plug-in] ' + + +def plugin_log(message): + """ + Log a message in a file located in the user's home directory + """ + out = open(PLUGIN_LOG, 'a') + now = time.time() + stamp = datetime.datetime.fromtimestamp(now).strftime('%Y-%m-%d %H:%M:%S') + out.write("%s: %s\n" % (stamp, message)) + out.close() + + +def _check_plugin(model, name): + """ + Do some checking before model adding plugins in the list + + :param model: class model to add into the plugin list + :param name:name of the module plugin + + :return model: model if valid model or None if not valid + + """ + #Check if the plugin is of type Model1DPlugin + if not issubclass(model, Model1DPlugin): + msg = "Plugin %s must be of type Model1DPlugin \n" % str(name) + plugin_log(msg) + return None + if model.__name__ != "Model": + msg = "Plugin %s class name must be Model \n" % str(name) + plugin_log(msg) + return None + try: + new_instance = model() + except Exception: + msg = "Plugin %s error in __init__ \n\t: %s %s\n" % (str(name), + str(sys.exc_type), + sys.exc_info()[1]) + plugin_log(msg) + return None + + if hasattr(new_instance, "function"): + try: + value = new_instance.function() + except Exception: + msg = "Plugin %s: error writing function \n\t :%s %s\n " % \ + (str(name), str(sys.exc_type), sys.exc_info()[1]) + plugin_log(msg) + return None + else: + msg = "Plugin %s needs a method called function \n" % str(name) + plugin_log(msg) + return None + return model + + +def find_plugins_dir(): + """ + Find path of the plugins directory. + The plugin directory is located in the user's home directory. + """ + path = os.path.join(os.path.expanduser("~"), '.sasview', PLUGIN_DIR) + + # TODO: trigger initialization of plugins dir from installer or startup + # If the plugin directory doesn't exist, create it + if not os.path.isdir(path): + os.makedirs(path) + # TODO: should we be checking for new default models every time? + # TODO: restore support for default plugins + #initialize_plugins_dir(path) + return path + + +def initialize_plugins_dir(path): + # TODO: There are no default plugins + # TODO: Default plugins directory is in sasgui, but models.py is in sascalc + # TODO: Move default plugins beside sample data files + # TODO: Should not look for defaults above the root of the sasview install + + # Walk up the tree looking for default plugin_models directory + base = os.path.abspath(os.path.dirname(__file__)) + for _ in range(12): + default_plugins_path = os.path.join(base, PLUGIN_DIR) + if os.path.isdir(default_plugins_path): + break + base, _ = os.path.split(base) + else: + logger.error("default plugins directory not found") + return + + # Copy files from default plugins to the .sasview directory + # This may include c files, depending on the example. + # Note: files are never replaced, even if the default plugins are updated + for filename in os.listdir(default_plugins_path): + # skip __init__.py and all pyc files + if filename == "__init__.py" or filename.endswith('.pyc'): + continue + source = os.path.join(default_plugins_path, filename) + target = os.path.join(path, filename) + if os.path.isfile(source) and not os.path.isfile(target): + shutil.copy(source, target) + + +class ReportProblem(object): + """ + Class to check for problems with specific values + """ + def __nonzero__(self): + type, value, tb = sys.exc_info() + if type is not None and issubclass(type, py_compile.PyCompileError): + print("Problem with", repr(value)) + raise type, value, tb + return 1 + +report_problem = ReportProblem() + + +def compile_file(dir): + """ + Compile a py file + """ + try: + import compileall + compileall.compile_dir(dir=dir, ddir=dir, force=0, + quiet=report_problem) + except Exception: + return sys.exc_info()[1] + return None + + +def find_plugin_models(): + """ + Find custom models + """ + # List of plugin objects + plugins_dir = find_plugins_dir() + # Go through files in plug-in directory + if not os.path.isdir(plugins_dir): + msg = "SasView couldn't locate Model plugin folder %r." % plugins_dir + logger.warning(msg) + return {} + + plugin_log("looking for models in: %s" % plugins_dir) + # compile_file(plugins_dir) #always recompile the folder plugin + logger.info("plugin model dir: %s", plugins_dir) + + plugins = {} + for filename in os.listdir(plugins_dir): + name, ext = os.path.splitext(filename) + if ext == '.py' and not name == '__init__': + path = os.path.abspath(os.path.join(plugins_dir, filename)) + try: + model = load_custom_model(path) + # TODO: add [plug-in] tag to model name in sasview_model + if not model.name.startswith(PLUGIN_NAME_BASE): + model.name = PLUGIN_NAME_BASE + model.name + plugins[model.name] = model + except Exception: + msg = traceback.format_exc() + msg += "\nwhile accessing model in %r" % path + plugin_log(msg) + logger.warning("Failed to load plugin %r. See %s for details", + path, PLUGIN_LOG) + + return plugins + + +class ModelManagerBase(object): + """ + Base class for the model manager + """ + #: mutable dictionary of models, continually updated to reflect the + #: current set of plugins + model_dictionary = None # type: Dict[str, Model] + #: constant list of standard models + standard_models = None # type: Dict[str, Model] + #: list of plugin models reset each time the plugin directory is queried + plugin_models = None # type: Dict[str, Model] + #: timestamp on the plugin directory at the last plugin update + last_time_dir_modified = 0 # type: int + + def __init__(self): + # the model dictionary is allocated at the start and updated to + # reflect the current list of models. Be sure to clear it rather + # than reassign to it. + self.model_dictionary = {} + + #Build list automagically from sasmodels package + self.standard_models = {model.name: model + for model in load_standard_models()} + # Look for plugins + self.plugins_reset() + + def _is_plugin_dir_changed(self): + """ + check the last time the plugin dir has changed and return true + is the directory was modified else return false + """ + is_modified = False + plugin_dir = find_plugins_dir() + if os.path.isdir(plugin_dir): + mod_time = os.path.getmtime(plugin_dir) + if self.last_time_dir_modified != mod_time: + is_modified = True + self.last_time_dir_modified = mod_time + + return is_modified + + def composable_models(self): + """ + return list of standard models that can be used in sum/multiply + """ + # TODO: should scan plugin models in addition to standard models + # and update model_editor so that it doesn't add plugins to the list + return [model.name for model in self.standard_models.values() + if not model.is_multiplicity_model] + + def plugins_update(self): + """ + return a dictionary of model if + new models were added else return empty dictionary + """ + return self.plugins_reset() + #if self._is_plugin_dir_changed(): + # return self.plugins_reset() + #else: + # return {} + + def plugins_reset(self): + """ + return a dictionary of model + """ + self.plugin_models = find_plugin_models() + self.model_dictionary.clear() + self.model_dictionary.update(self.standard_models) + self.model_dictionary.update(self.plugin_models) + return self.get_model_list() + + def get_model_list(self): + """ + return dictionary of classified models + + *Structure Factors* are the structure factor models + *Multi-Functions* are the multiplicity models + *Plugin Models* are the plugin models + + Note that a model can be both a plugin and a structure factor or + multiplicity model. + """ + ## Model_list now only contains attribute lists not category list. + ## Eventually this should be in one master list -- read in category + ## list then pull those models that exist and get attributes then add + ## to list ..and if model does not exist remove from list as now + ## and update json file. + ## + ## -PDB April 26, 2014 + + + # Classify models + structure_factors = [] + form_factors = [] + multiplicity_models = [] + for model in self.model_dictionary.values(): + # Old style models don't have is_structure_factor attribute + if getattr(model, 'is_structure_factor', False): + structure_factors.append(model) + if getattr(model, 'is_form_factor', False): + form_factors.append(model) + if model.is_multiplicity_model: + multiplicity_models.append(model) + plugin_models = list(self.plugin_models.values()) + + return { + "Structure Factors": structure_factors, + "Form Factors": form_factors, + "Plugin Models": plugin_models, + "Multi-Functions": multiplicity_models, + } + + +class ModelManager(object): + """ + manage the list of available models + """ + base = None # type: ModelManagerBase() + + def __init__(self): + if ModelManager.base is None: + ModelManager.base = ModelManagerBase() + + def cat_model_list(self): + return list(self.base.standard_models.values()) + + def update(self): + return self.base.plugins_update() + + def plugins_reset(self): + return self.base.plugins_reset() + + def get_model_list(self): + return self.base.get_model_list() + + def composable_models(self): + return self.base.composable_models() + + def get_model_dictionary(self): + return self.base.model_dictionary diff --git a/sas/sascalc/fit/pagestate.py b/sas/sascalc/fit/pagestate.py new file mode 100755 index 000000000..58ba76983 --- /dev/null +++ b/sas/sascalc/fit/pagestate.py @@ -0,0 +1,1392 @@ +""" +Class that holds a fit page state +""" +# TODO: Refactor code so we don't need to use getattr/setattr +################################################################################ +# This software was developed by the University of Tennessee as part of the +# Distributed Data Analysis of Neutron Scattering Experiments (DANSE) +# project funded by the US National Science Foundation. +# +# See the license text in license.txt +# +# copyright 2009, University of Tennessee +################################################################################ +import time +import re +import os +import sys +import copy +import logging +import numpy as np +import traceback + +import xml.dom.minidom +from xml.dom.minidom import parseString +from xml.dom.minidom import getDOMImplementation +from lxml import etree + +from sasmodels import convert +import sasmodels.weights + +from sas.sasview import __version__ as SASVIEW_VERSION + +import sas.sascalc.dataloader +from sas.sascalc.dataloader.readers.cansas_reader import Reader as CansasReader +from sas.sascalc.dataloader.readers.cansas_reader import get_content, write_node +from sas.sascalc.dataloader.data_info import Data2D, Collimation, Detector +from sas.sascalc.dataloader.data_info import Process, Aperture + +logger = logging.getLogger(__name__) + +# Information to read/write state as xml +FITTING_NODE_NAME = 'fitting_plug_in' +CANSAS_NS = {"ns": "cansas1d/1.0"} + +CUSTOM_MODEL = 'Plugin Models' +CUSTOM_MODEL_OLD = 'Customized Models' + +LIST_OF_DATA_ATTRIBUTES = [["is_data", "is_data", "bool"], + ["group_id", "data_group_id", "string"], + ["data_name", "data_name", "string"], + ["data_id", "data_id", "string"], + ["name", "name", "string"], + ["data_name", "data_name", "string"]] +LIST_OF_STATE_ATTRIBUTES = [["qmin", "qmin", "float"], + ["qmax", "qmax", "float"], + ["npts", "npts", "float"], + ["categorycombobox", "categorycombobox", "string"], + ["formfactorcombobox", "formfactorcombobox", + "string"], + ["structurecombobox", "structurecombobox", + "string"], + ["multi_factor", "multi_factor", "float"], + ["magnetic_on", "magnetic_on", "bool"], + ["enable_smearer", "enable_smearer", "bool"], + ["disable_smearer", "disable_smearer", "bool"], + ["pinhole_smearer", "pinhole_smearer", "bool"], + ["slit_smearer", "slit_smearer", "bool"], + ["enable_disp", "enable_disp", "bool"], + ["disable_disp", "disable_disp", "bool"], + ["dI_noweight", "dI_noweight", "bool"], + ["dI_didata", "dI_didata", "bool"], + ["dI_sqrdata", "dI_sqrdata", "bool"], + ["dI_idata", "dI_idata", "bool"], + ["enable2D", "enable2D", "bool"], + ["cb1", "cb1", "bool"], + ["tcChi", "tcChi", "float"], + ["dq_l", "dq_l", "float"], + ["dq_r", "dq_r", "float"], + ["dx_percent", "dx_percent", "float"], + ["dxl", "dxl", "float"], + ["dxw", "dxw", "float"]] + +LIST_OF_MODEL_ATTRIBUTES = [["values", "values"], + ["weights", "weights"]] + +DISPERSION_LIST = [["disp_obj_dict", "disp_obj_dict", "string"]] + +LIST_OF_STATE_PARAMETERS = [["parameters", "parameters"], + ["str_parameters", "str_parameters"], + ["orientation_parameters", "orientation_params"], + ["dispersity_parameters", + "orientation_params_disp"], + ["fixed_param", "fixed_param"], + ["fittable_param", "fittable_param"]] +LIST_OF_DATA_2D_ATTR = [["xmin", "xmin", "float"], + ["xmax", "xmax", "float"], + ["ymin", "ymin", "float"], + ["ymax", "ymax", "float"], + ["_xaxis", "_xaxis", "string"], + ["_xunit", "_xunit", "string"], + ["_yaxis", "_yaxis", "string"], + ["_yunit", "_yunit", "string"], + ["_zaxis", "_zaxis", "string"], + ["_zunit", "_zunit", "string"]] +LIST_OF_DATA_2D_VALUES = [["qx_data", "qx_data", "float"], + ["qy_data", "qy_data", "float"], + ["dqx_data", "dqx_data", "float"], + ["dqy_data", "dqy_data", "float"], + ["data", "data", "float"], + ["q_data", "q_data", "float"], + ["err_data", "err_data", "float"], + ["mask", "mask", "bool"]] + + +def parse_entry_helper(node, item): + """ + Create a numpy list from value extrated from the node + + :param node: node from each the value is stored + :param item: list name of three strings.the two first are name of data + attribute and the third one is the type of the value of that + attribute. type can be string, float, bool, etc. + + : return: numpy array + """ + if node is not None: + if item[2] == "string": + return str(node.get(item[0]).strip()) + elif item[2] == "bool": + try: + return node.get(item[0]).strip() == "True" + except Exception: + return None + else: + try: + return float(node.get(item[0])) + except Exception: + return None + + +class PageState(object): + """ + Contains information to reconstruct a page of the fitpanel. + """ + def __init__(self, model=None, data=None): + """ + Initialize the current state + + :param model: a selected model within a page + :param data: + + """ + self.file = None + # Time of state creation + self.timestamp = time.time() + # Data member to store the dispersion object created + self.disp_obj_dict = {} + # ------------------------ + # Data used for fitting + self.data = data + # model data + self.theory_data = None + # Is 2D + self.is_2D = False + self.images = None + + # save additional information on data that dataloader.reader + # does not read + self.is_data = None + self.data_name = "" + + if self.data is not None: + self.data_name = self.data.name + self.data_id = None + if self.data is not None and hasattr(self.data, "id"): + self.data_id = self.data.id + self.data_group_id = None + if self.data is not None and hasattr(self.data, "group_id"): + self.data_group_id = self.data.group_id + + # reset True change the state of existing button + self.reset = False + + # flag to allow data2D plot + self.enable2D = False + # model on which the fit would be performed + self.model = model + self.m_name = None + # list of process done to model + self.process = [] + # fit page manager + self.manager = None + # Event_owner is the owner of model event + self.event_owner = None + # page name + self.page_name = "" + # Contains link between model, its parameters, and panel organization + self.parameters = [] + # String parameter list that can not be fitted + self.str_parameters = [] + # Contains list of parameters that cannot be fitted and reference to + # panel objects + self.fixed_param = [] + # Contains list of parameters with dispersity and reference to + # panel objects + self.fittable_param = [] + # orientation parameters + self.orientation_params = [] + # orientation parameters for gaussian dispersity + self.orientation_params_disp = [] + self.dq_l = None + self.dq_r = None + self.dx_percent = None + self.dx_old = False + self.dxl = None + self.dxw = None + # list of dispersion parameters + self.disp_list = [] + if self.model is not None: + self.disp_list = self.model.getDispParamList() + + self.disp_cb_dict = {} + self.values = {} + self.weights = {} + + # contains link between a model and selected parameters to fit + self.param_toFit = [] + # save the state of the context menu + self.saved_states = {} + # save selection of combobox + self.formfactorcombobox = None + self.categorycombobox = None + self.structurecombobox = None + + # radio box to select type of model + # self.shape_rbutton = False + # self.shape_indep_rbutton = False + # self.struct_rbutton = False + # self.plugin_rbutton = False + # the indice of the current selection + self.disp_box = 0 + # Qrange + # Q range + self.qmin = 0.001 + self.qmax = 0.1 + # reset data range + self.qmax_x = None + self.qmin_x = None + + self.npts = None + self.name = "" + self.multi_factor = None + self.magnetic_on = False + # enable smearering state + self.enable_smearer = False + self.disable_smearer = True + self.pinhole_smearer = False + self.slit_smearer = False + # weighting options + self.dI_noweight = False + self.dI_didata = True + self.dI_sqrdata = False + self.dI_idata = False + # disperity selection + self.enable_disp = False + self.disable_disp = True + + # state of selected all check button + self.cb1 = False + # store value of chisqr + self.tcChi = None + self.version = (1, 0, 0) + + def clone(self): + """ + Create a new copy of the current object + """ + model = None + if self.model is not None: + model = self.model.clone() + model.name = self.model.name + obj = PageState(model=model) + obj.file = copy.deepcopy(self.file) + obj.data = copy.deepcopy(self.data) + if self.data is not None: + self.data_name = self.data.name + obj.data_name = self.data_name + obj.is_data = self.is_data + + obj.categorycombobox = self.categorycombobox + obj.formfactorcombobox = self.formfactorcombobox + obj.structurecombobox = self.structurecombobox + + # obj.shape_rbutton = self.shape_rbutton + # obj.shape_indep_rbutton = self.shape_indep_rbutton + # obj.struct_rbutton = self.struct_rbutton + # obj.plugin_rbutton = self.plugin_rbutton + + obj.manager = self.manager + obj.event_owner = self.event_owner + obj.disp_list = copy.deepcopy(self.disp_list) + + obj.enable2D = copy.deepcopy(self.enable2D) + obj.parameters = copy.deepcopy(self.parameters) + obj.str_parameters = copy.deepcopy(self.str_parameters) + obj.fixed_param = copy.deepcopy(self.fixed_param) + obj.fittable_param = copy.deepcopy(self.fittable_param) + obj.orientation_params = copy.deepcopy(self.orientation_params) + obj.orientation_params_disp = \ + copy.deepcopy(self.orientation_params_disp) + obj.enable_disp = copy.deepcopy(self.enable_disp) + obj.disable_disp = copy.deepcopy(self.disable_disp) + obj.tcChi = self.tcChi + + if len(self.disp_obj_dict) > 0: + for k, v in self.disp_obj_dict.items(): + obj.disp_obj_dict[k] = v + if len(self.disp_cb_dict) > 0: + for k, v in self.disp_cb_dict.items(): + obj.disp_cb_dict[k] = v + if len(self.values) > 0: + for k, v in self.values.items(): + obj.values[k] = v + if len(self.weights) > 0: + for k, v in self.weights.items(): + obj.weights[k] = v + obj.enable_smearer = copy.deepcopy(self.enable_smearer) + obj.disable_smearer = copy.deepcopy(self.disable_smearer) + obj.pinhole_smearer = copy.deepcopy(self.pinhole_smearer) + obj.slit_smearer = copy.deepcopy(self.slit_smearer) + obj.dI_noweight = copy.deepcopy(self.dI_noweight) + obj.dI_didata = copy.deepcopy(self.dI_didata) + obj.dI_sqrdata = copy.deepcopy(self.dI_sqrdata) + obj.dI_idata = copy.deepcopy(self.dI_idata) + obj.dq_l = copy.deepcopy(self.dq_l) + obj.dq_r = copy.deepcopy(self.dq_r) + obj.dx_percent = copy.deepcopy(self.dx_percent) + obj.dx_old = copy.deepcopy(self.dx_old) + obj.dxl = copy.deepcopy(self.dxl) + obj.dxw = copy.deepcopy(self.dxw) + obj.disp_box = copy.deepcopy(self.disp_box) + obj.qmin = copy.deepcopy(self.qmin) + obj.qmax = copy.deepcopy(self.qmax) + obj.multi_factor = self.multi_factor + obj.magnetic_on = self.magnetic_on + obj.npts = copy.deepcopy(self.npts) + obj.cb1 = copy.deepcopy(self.cb1) + obj.version = copy.deepcopy(self.version) + + for name, state in self.saved_states.items(): + copy_name = copy.deepcopy(name) + copy_state = state.clone() + obj.saved_states[copy_name] = copy_state + return obj + + def _old_first_model(self): + """ + Handle save states from 4.0.1 and before where the first item in the + selection boxes of category, formfactor and structurefactor were not + saved. + :return: None + """ + if self.categorycombobox == CUSTOM_MODEL_OLD: + self.categorycombobox = CUSTOM_MODEL + if self.formfactorcombobox == '': + FIRST_FORM = { + 'Shapes' : 'BCCrystalModel', + 'Uncategorized' : 'LineModel', + 'StructureFactor' : 'HardsphereStructure', + 'Ellipsoid' : 'core_shell_ellipsoid', + 'Lamellae' : 'lamellar', + 'Paracrystal' : 'bcc_paracrystal', + 'Parallelepiped' : 'core_shell_parallelepiped', + 'Shape Independent' : 'be_polyelectrolyte', + 'Sphere' : 'adsorbed_layer', + 'Structure Factor' : 'hardsphere', + CUSTOM_MODEL : '' + } + if self.categorycombobox == '': + if len(self.parameters) == 3: + self.categorycombobox = "Shape-Independent" + self.formfactorcombobox = 'PowerLawAbsModel' + elif len(self.parameters) == 9: + self.categorycombobox = 'Cylinder' + self.formfactorcombobox = 'barbell' + else: + msg = "Save state does not have enough information to load" + msg += " the all of the data." + logger.warning(msg=msg) + else: + self.formfactorcombobox = FIRST_FORM[self.categorycombobox] + + @staticmethod + def param_remap_to_sasmodels_convert(params, is_string=False): + """ + Remaps the parameters for sasmodels conversion + + :param params: list of parameters (likely self.parameters) + :return: remapped dictionary of parameters + """ + p = dict() + for fittable, name, value, _, uncert, lower, upper, units in params: + if not value: + value = np.nan + if not uncert or uncert[1] == '' or uncert[1] == 'None': + uncert[0] = False + uncert[1] = np.nan + if not upper or upper[1] == '' or upper[1] == 'None': + upper[0] = False + upper[1] = np.nan + if not lower or lower[1] == '' or lower[1] == 'None': + lower[0] = False + lower[1] = np.nan + if is_string: + p[name] = str(value) + else: + p[name] = float(value) + p[name + ".fittable"] = bool(fittable) + p[name + ".std"] = float(uncert[1]) + p[name + ".upper"] = float(upper[1]) + p[name + ".lower"] = float(lower[1]) + p[name + ".units"] = units + return p + + @staticmethod + def param_remap_from_sasmodels_convert(params): + """ + Converts {name : value} map back to [] param list + :param params: parameter map returned from sasmodels + :return: None + """ + p_map = [] + for name, info in params.items(): + if ".fittable" in name or ".std" in name or ".upper" in name or \ + ".lower" in name or ".units" in name: + pass + else: + fittable = params.get(name + ".fittable", True) + std = params.get(name + ".std", '0.0') + upper = params.get(name + ".upper", 'inf') + lower = params.get(name + ".lower", '-inf') + units = params.get(name + ".units") + if std is not None and std is not np.nan: + std = [True, str(std)] + else: + std = [False, ''] + if lower is not None and lower is not np.nan: + lower = [True, str(lower)] + else: + lower = [True, '-inf'] + if upper is not None and upper is not np.nan: + upper = [True, str(upper)] + else: + upper = [True, 'inf'] + param_list = [bool(fittable), str(name), str(info), + "+/-", std, lower, upper, str(units)] + p_map.append(param_list) + return p_map + + def _convert_to_sasmodels(self): + """ + Convert parameters to a form usable by sasmodels converter + + :return: None + """ + # Create conversion dictionary to send to sasmodels + self._old_first_model() + p = self.param_remap_to_sasmodels_convert(self.parameters) + structurefactor, params = convert.convert_model(self.structurecombobox, + p, False, self.version) + formfactor, params = convert.convert_model(self.formfactorcombobox, + params, False, self.version) + if len(self.str_parameters) > 0: + str_pars = self.param_remap_to_sasmodels_convert( + self.str_parameters, True) + formfactor, str_params = convert.convert_model( + self.formfactorcombobox, str_pars, False, self.version) + for key, value in str_params.items(): + params[key] = value + + if self.formfactorcombobox == 'SphericalSLDModel': + self.multi_factor += 1 + self.formfactorcombobox = formfactor + self.structurecombobox = structurefactor + self.parameters = [] + self.parameters = self.param_remap_from_sasmodels_convert(params) + + def _repr_helper(self, list, rep): + """ + Helper method to print a state + """ + for item in list: + rep += "parameter name: %s \n" % str(item[1]) + rep += "value: %s\n" % str(item[2]) + rep += "selected: %s\n" % str(item[0]) + rep += "error displayed : %s \n" % str(item[4][0]) + rep += "error value:%s \n" % str(item[4][1]) + rep += "minimum displayed : %s \n" % str(item[5][0]) + rep += "minimum value : %s \n" % str(item[5][1]) + rep += "maximum displayed : %s \n" % str(item[6][0]) + rep += "maximum value : %s \n" % str(item[6][1]) + rep += "parameter unit: %s\n\n" % str(item[7]) + return rep + + def __repr__(self): + """ + output string for printing + """ + rep = "\nState name: %s\n" % self.file + t = time.localtime(self.timestamp) + time_str = time.strftime("%b %d %Y %H:%M:%S ", t) + + rep += "State created: %s\n" % time_str + rep += "State form factor combobox selection: %s\n" % \ + self.formfactorcombobox + rep += "State structure factor combobox selection: %s\n" % \ + self.structurecombobox + rep += "is data : %s\n" % self.is_data + rep += "data's name : %s\n" % self.data_name + rep += "data's id : %s\n" % self.data_id + if self.model is not None: + m_name = self.model.__class__.__name__ + if m_name == 'Model': + m_name = self.m_name + rep += "model name : %s\n" % m_name + else: + rep += "model name : None\n" + rep += "multi_factor : %s\n" % str(self.multi_factor) + rep += "magnetic_on : %s\n" % str(self.magnetic_on) + rep += "model type (Category) selected: %s\n" % self.categorycombobox + rep += "data : %s\n" % str(self.data) + rep += "Plotting Range: min: %s, max: %s, steps: %s\n" % \ + (str(self.qmin), str(self.qmax), str(self.npts)) + rep += "Dispersion selection : %s\n" % str(self.disp_box) + rep += "Smearing enable : %s\n" % str(self.enable_smearer) + rep += "Smearing disable : %s\n" % str(self.disable_smearer) + rep += "Pinhole smearer enable : %s\n" % str(self.pinhole_smearer) + rep += "Slit smearer enable : %s\n" % str(self.slit_smearer) + rep += "Dispersity enable : %s\n" % str(self.enable_disp) + rep += "Dispersity disable : %s\n" % str(self.disable_disp) + rep += "Slit smearer enable: %s\n" % str(self.slit_smearer) + + rep += "dI_noweight : %s\n" % str(self.dI_noweight) + rep += "dI_didata : %s\n" % str(self.dI_didata) + rep += "dI_sqrdata : %s\n" % str(self.dI_sqrdata) + rep += "dI_idata : %s\n" % str(self.dI_idata) + + rep += "2D enable : %s\n" % str(self.enable2D) + rep += "All parameters checkbox selected: %s\n" % self.cb1 + rep += "Value of Chisqr : %s\n" % str(self.tcChi) + rep += "dq_l : %s\n" % self.dq_l + rep += "dq_r : %s\n" % self.dq_r + rep += "dx_percent : %s\n" % str(self.dx_percent) + rep += "dxl : %s\n" % str(self.dxl) + rep += "dxw : %s\n" % str(self.dxw) + rep += "model : %s\n\n" % str(self.model) + temp_parameters = [] + temp_fittable_param = [] + if self.data.__class__.__name__ == "Data2D": + self.is_2D = True + else: + self.is_2D = False + if self.data is not None: + if not self.is_2D: + for item in self.parameters: + if item not in self.orientation_params: + temp_parameters.append(item) + for item in self.fittable_param: + if item not in self.orientation_params_disp: + temp_fittable_param.append(item) + else: + temp_parameters = self.parameters + temp_fittable_param = self.fittable_param + + rep += "number parameters(self.parameters): %s\n" % \ + len(temp_parameters) + rep = self._repr_helper(list=temp_parameters, rep=rep) + rep += "number str_parameters(self.str_parameters): %s\n" % \ + len(self.str_parameters) + rep = self._repr_helper(list=self.str_parameters, rep=rep) + rep += "number fittable_param(self.fittable_param): %s\n" % \ + len(temp_fittable_param) + rep = self._repr_helper(list=temp_fittable_param, rep=rep) + return rep + + def _get_report_string(self): + """ + Get the values (strings) from __str__ for report + """ + # Dictionary of the report strings + repo_time = "" + model_name = "" + title = "" + title_name = "" + file_name = "" + param_string = "" + paramval_string = "" + chi2_string = "" + q_range = "" + strings = self.__repr__() + fixed_parameter = False + lines = strings.split('\n') + # get all string values from __str__() + for line in lines: + # Skip lines which are not key: value pairs, which includes + # blank lines and freeform notes in SASNotes fields. + if not ':' in line: + #msg = "Report string expected 'name: value' but got %r" % line + #logger.error(msg) + continue + + name, value = [s.strip() for s in line.split(":", 1)] + if name == "State created": + repo_time = value + elif name == "parameter name": + val_name = value.split(".") + if len(val_name) > 1: + if val_name[1].count("width"): + param_string += value + ',' + else: + continue + else: + param_string += value + ',' + elif name == "value": + param_string += value + ',' + elif name == "selected": + # remember if it is fixed when reporting error value + fixed_parameter = (value == u'False') + elif name == "error value": + if fixed_parameter: + param_string += '(fixed),' + else: + param_string += value + ',' + elif name == "parameter unit": + param_string += value + ':' + elif name == "Value of Chisqr": + chi2 = ("Chi2/Npts = " + value) + chi2_string = CENTRE % chi2 + elif name == "Title": + if len(value.strip()) == 0: + continue + title = (value + " [" + repo_time + "] [SasView v" + + SASVIEW_VERSION + "]") + title_name = HEADER % title + elif name == "data": + try: + # parsing "data : File: filename [mmm dd hh:mm]" + name = value.split(':', 1)[1].strip() + file_value = "File name:" + name + #Truncating string so print doesn't complain of being outside margins + if sys.platform != "win32": + MAX_STRING_LENGHT = 50 + if len(file_value) > MAX_STRING_LENGHT: + file_value = "File name:.."+file_value[-MAX_STRING_LENGHT+10:] + file_name = CENTRE % file_value + if len(title) == 0: + title = name + " [" + repo_time + "]" + title_name = HEADER % title + except Exception: + msg = "While parsing 'data: ...'\n" + logger.error(msg + traceback.format_exc()) + elif name == "model name": + try: + modelname = "Model name:" + value + except Exception: + modelname = "Model name:" + " NAN" + model_name = CENTRE % modelname + + elif name == "Plotting Range": + try: + parts = value.split(':') + q_range = parts[0] + " = " + parts[1] \ + + " = " + parts[2].split(",")[0] + q_name = ("Q Range: " + q_range) + q_range = CENTRE % q_name + except Exception: + msg = "While parsing 'Plotting Range: ...'\n" + logger.error(msg + traceback.format_exc()) + + paramval = "" + for lines in param_string.split(":"): + line = lines.split(",") + if len(lines) > 0: + param = line[0] + param += " = " + line[1] + if len(line[2].split()) > 0 and not line[2].count("None"): + param += " +- " + line[2] + if len(line[3].split()) > 0 and not line[3].count("None"): + param += " " + line[3] + if not paramval.count(param): + paramval += param + "\n" + paramval_string += CENTRE % param + "\n" + + text_string = "\n\n\n%s\n\n%s\n%s\n%s\n\n%s" % \ + (title, file, q_name, chi2, paramval) + + title_name = self._check_html_format(title_name) + file_name = self._check_html_format(file_name) + title = self._check_html_format(title) + + html_string = title_name + "\n" + file_name + \ + "\n" + model_name + \ + "\n" + q_range + \ + "\n" + chi2_string + \ + "\n" + ELINE + \ + "\n" + paramval_string + \ + "\n" + ELINE + \ + "\n" + FEET_1 % title + + return html_string, text_string, title + + def _check_html_format(self, name): + """ + Check string '%' for html format + """ + if name.count('%'): + name = name.replace('%', '%') + + return name + + def report(self, fig_urls): + """ + Invoke report dialog panel + + : param figs: list of pylab figures [list] + """ + # get the strings for report + html_str, text_str, title = self._get_report_string() + # Allow 2 figures to append + #Constraining image width for OSX and linux, so print doesn't complain of being outside margins + if sys.platform == "win32": + image_links = [FEET_2%fig for fig in fig_urls] + else: + image_links = [FEET_2_unix%fig for fig in fig_urls] + # final report html strings + report_str = html_str + ELINE.join(image_links) + report_str += FEET_3 + return report_str, text_str + + def _to_xml_helper(self, thelist, element, newdoc): + """ + Helper method to create xml file for saving state + """ + for item in thelist: + sub_element = newdoc.createElement('parameter') + sub_element.setAttribute('name', str(item[1])) + sub_element.setAttribute('value', str(item[2])) + sub_element.setAttribute('selected_to_fit', str(item[0])) + sub_element.setAttribute('error_displayed', str(item[4][0])) + sub_element.setAttribute('error_value', str(item[4][1])) + sub_element.setAttribute('minimum_displayed', str(item[5][0])) + sub_element.setAttribute('minimum_value', str(item[5][1])) + sub_element.setAttribute('maximum_displayed', str(item[6][0])) + sub_element.setAttribute('maximum_value', str(item[6][1])) + sub_element.setAttribute('unit', str(item[7])) + element.appendChild(sub_element) + + def to_xml(self, file="fitting_state.fitv", doc=None, + entry_node=None, batch_fit_state=None): + """ + Writes the state of the fit panel to file, as XML. + + Compatible with standalone writing, or appending to an + already existing XML document. In that case, the XML document is + required. An optional entry node in the XML document may also be given. + + :param file: file to write to + :param doc: XML document object [optional] + :param entry_node: XML node within the XML document at which we + will append the data [optional] + :param batch_fit_state: simultaneous fit state + """ + # Check whether we have to write a standalone XML file + if doc is None: + impl = getDOMImplementation() + doc_type = impl.createDocumentType(FITTING_NODE_NAME, "1.0", "1.0") + newdoc = impl.createDocument(None, FITTING_NODE_NAME, doc_type) + top_element = newdoc.documentElement + else: + # We are appending to an existing document + newdoc = doc + try: + top_element = newdoc.createElement(FITTING_NODE_NAME) + except Exception: + string = etree.tostring(doc, pretty_print=True) + newdoc = parseString(string) + top_element = newdoc.createElement(FITTING_NODE_NAME) + if entry_node is None: + newdoc.documentElement.appendChild(top_element) + else: + try: + entry_node.appendChild(top_element) + except Exception: + node_name = entry_node.tag + node_list = newdoc.getElementsByTagName(node_name) + entry_node = node_list.item(0) + entry_node.appendChild(top_element) + + attr = newdoc.createAttribute("version") + attr.nodeValue = SASVIEW_VERSION + # attr.nodeValue = '1.0' + top_element.setAttributeNode(attr) + + # File name + element = newdoc.createElement("filename") + if self.file is not None: + element.appendChild(newdoc.createTextNode(str(self.file))) + else: + element.appendChild(newdoc.createTextNode(str(file))) + top_element.appendChild(element) + + element = newdoc.createElement("timestamp") + element.appendChild(newdoc.createTextNode(time.ctime(self.timestamp))) + attr = newdoc.createAttribute("epoch") + attr.nodeValue = str(self.timestamp) + element.setAttributeNode(attr) + top_element.appendChild(element) + + # Inputs + inputs = newdoc.createElement("Attributes") + top_element.appendChild(inputs) + + if self.data is not None and hasattr(self.data, "group_id"): + self.data_group_id = self.data.group_id + if self.data is not None and hasattr(self.data, "is_data"): + self.is_data = self.data.is_data + if self.data is not None: + self.data_name = self.data.name + if self.data is not None and hasattr(self.data, "id"): + self.data_id = self.data.id + + for item in LIST_OF_DATA_ATTRIBUTES: + element = newdoc.createElement(item[0]) + element.setAttribute(item[0], str(getattr(self, item[1]))) + inputs.appendChild(element) + + for item in LIST_OF_STATE_ATTRIBUTES: + element = newdoc.createElement(item[0]) + element.setAttribute(item[0], str(getattr(self, item[1]))) + inputs.appendChild(element) + + # For self.values ={ disp_param_name: [vals,...],...} + # and for self.weights ={ disp_param_name: [weights,...],...} + for item in LIST_OF_MODEL_ATTRIBUTES: + element = newdoc.createElement(item[0]) + value_list = getattr(self, item[1]) + for key, value in value_list.items(): + sub_element = newdoc.createElement(key) + sub_element.setAttribute('name', str(key)) + for val in value: + sub_element.appendChild(newdoc.createTextNode(str(val))) + + element.appendChild(sub_element) + inputs.appendChild(element) + + # Create doc for the dictionary of self.disp_obj_dic + for tagname, varname, tagtype in DISPERSION_LIST: + element = newdoc.createElement(tagname) + value_list = getattr(self, varname) + for key, value in value_list.items(): + sub_element = newdoc.createElement(key) + sub_element.setAttribute('name', str(key)) + sub_element.setAttribute('value', str(value)) + element.appendChild(sub_element) + inputs.appendChild(element) + + for item in LIST_OF_STATE_PARAMETERS: + element = newdoc.createElement(item[0]) + self._to_xml_helper(thelist=getattr(self, item[1]), + element=element, newdoc=newdoc) + inputs.appendChild(element) + + # Combined and Simultaneous Fit Parameters + if batch_fit_state is not None: + batch_combo = newdoc.createElement('simultaneous_fit') + top_element.appendChild(batch_combo) + + # Simultaneous Fit Number For Linking Later + element = newdoc.createElement('sim_fit_number') + element.setAttribute('fit_number', str(batch_fit_state.fit_page_no)) + batch_combo.appendChild(element) + + # Save constraints + constraints = newdoc.createElement('constraints') + batch_combo.appendChild(constraints) + for constraint in batch_fit_state.constraints_list: + if constraint.model_cbox.GetValue() != "": + # model_cbox, param_cbox, egal_txt, constraint, + # btRemove, sizer + doc_cons = newdoc.createElement('constraint') + doc_cons.setAttribute('model_cbox', + str(constraint.model_cbox.GetValue())) + doc_cons.setAttribute('param_cbox', + str(constraint.param_cbox.GetValue())) + doc_cons.setAttribute('egal_txt', + str(constraint.egal_txt.GetLabel())) + doc_cons.setAttribute('constraint', + str(constraint.constraint.GetValue())) + constraints.appendChild(doc_cons) + + # Save all models + models = newdoc.createElement('model_list') + batch_combo.appendChild(models) + for model in batch_fit_state.model_list: + doc_model = newdoc.createElement('model_list_item') + doc_model.setAttribute('checked', str(model[0].GetValue())) + keys = model[1].keys() + doc_model.setAttribute('name', str(keys[0])) + values = model[1].get(keys[0]) + doc_model.setAttribute('fit_number', str(model[2])) + doc_model.setAttribute('fit_page_source', str(model[3])) + doc_model.setAttribute('model_name', str(values.model.id)) + models.appendChild(doc_model) + + # Select All Checkbox + element = newdoc.createElement('select_all') + if batch_fit_state.select_all: + element.setAttribute('checked', 'True') + else: + element.setAttribute('checked', 'False') + batch_combo.appendChild(element) + + # Save the file + if doc is None: + fd = open(file, 'w') + fd.write(newdoc.toprettyxml()) + fd.close() + return None + else: + return newdoc + + def _from_xml_helper(self, node, list): + """ + Helper function to write state to xml + """ + for item in node: + name = item.get('name') + value = item.get('value') + selected_to_fit = (item.get('selected_to_fit') == "True") + error_displayed = (item.get('error_displayed') == "True") + error_value = item.get('error_value') + minimum_displayed = (item.get('minimum_displayed') == "True") + minimum_value = item.get('minimum_value') + maximum_displayed = (item.get('maximum_displayed') == "True") + maximum_value = item.get('maximum_value') + unit = item.get('unit') + list.append([selected_to_fit, name, value, "+/-", + [error_displayed, error_value], + [minimum_displayed, minimum_value], + [maximum_displayed, maximum_value], unit]) + + def from_xml(self, file=None, node=None): + """ + Load fitting state from a file + + :param file: .fitv file + :param node: node of a XML document to read from + """ + if file is not None: + msg = "PageState no longer supports non-CanSAS" + msg += " format for fitting files" + raise RuntimeError(msg) + + if node.get('version'): + # Get the version for model conversion purposes + x = re.sub('[^\d.]', '', node.get('version')) + self.version = tuple(int(e) for e in str.split(x, ".")) + # The tuple must be at least 3 items long + while len(self.version) < 3: + ver_list = list(self.version) + ver_list.append(0) + self.version = tuple(ver_list) + + # Get file name + entry = get_content('ns:filename', node) + if entry is not None and entry.text: + self.file = entry.text.strip() + else: + self.file = '' + + # Get time stamp + entry = get_content('ns:timestamp', node) + if entry is not None and entry.get('epoch'): + try: + self.timestamp = float(entry.get('epoch')) + except Exception: + msg = "PageState.fromXML: Could not" + msg += " read timestamp\n %s" % sys.exc_value + logger.error(msg) + + if entry is not None: + # Parse fitting attributes + entry = get_content('ns:Attributes', node) + for item in LIST_OF_DATA_ATTRIBUTES: + node = get_content('ns:%s' % item[0], entry) + setattr(self, item[0], parse_entry_helper(node, item)) + + dx_old_node = get_content('ns:%s' % 'dx_min', entry) + for item in LIST_OF_STATE_ATTRIBUTES: + if item[0] == "dx_percent" and dx_old_node is not None: + dxmin = ["dx_min", "dx_min", "float"] + setattr(self, item[0], parse_entry_helper(dx_old_node, + dxmin)) + self.dx_old = True + else: + node = get_content('ns:%s' % item[0], entry) + setattr(self, item[0], parse_entry_helper(node, item)) + + for item in LIST_OF_STATE_PARAMETERS: + node = get_content("ns:%s" % item[0], entry) + self._from_xml_helper(node=node, + list=getattr(self, item[1])) + + # Recover disp_obj_dict from xml file + self.disp_obj_dict = {} + for tagname, varname, tagtype in DISPERSION_LIST: + node = get_content("ns:%s" % tagname, entry) + for attr in node: + parameter = str(attr.get('name')) + value = attr.get('value') + if value.startswith("<"): + try: + # + cls_name = value[1:].split()[0].split('.')[-1] + cls = getattr(sasmodels.weights, cls_name) + value = cls.type + except Exception: + base = "unable to load distribution %r for %s" + logger.error(base, value, parameter) + continue + disp_obj_dict = getattr(self, varname) + disp_obj_dict[parameter] = value + + # get self.values and self.weights dic. if exists + for tagname, varname in LIST_OF_MODEL_ATTRIBUTES: + node = get_content("ns:%s" % tagname, entry) + dic = {} + value_list = [] + for par in node: + name = par.get('name') + values = par.text.split() + # Get lines only with numbers + for line in values: + try: + val = float(line) + value_list.append(val) + except Exception: + # pass if line is empty (it happens) + msg = ("Error reading %r from %s %s\n" + % (line, tagname, name)) + logger.error(msg + traceback.format_exc()) + dic[name] = np.array(value_list) + setattr(self, varname, dic) + +class SimFitPageState(object): + """ + State of the simultaneous fit page for saving purposes + """ + + def __init__(self): + # Sim Fit Page Number + self.fit_page_no = None + # Select all data + self.select_all = False + # Data sets sent to fit page + self.model_list = [] + # Data sets to be fit + self.model_to_fit = [] + # Number of constraints + self.no_constraint = 0 + # Dictionary of constraints + self.constraint_dict = {} + # List of constraints + self.constraints_list = [] + + def __repr__(self): + # TODO: should use __str__, not __repr__ (similarly for PageState) + # TODO: could use a nicer representation + repr = """\ +fit page number : %(fit_page_no)s +select all : %(select_all)s +model_list : %(model_list)s +model to fit : %(model_to_fit)s +number of construsts : %(no_constraint)s +constraint dict : %(constraint_dict)s +constraints list : %(constraints_list)s +"""%self.__dict__ + return repr + +class Reader(CansasReader): + """ + Class to load a .fitv fitting file + """ + # File type + type_name = "Fitting" + + # Wildcards + type = ["Fitting files (*.fitv)|*.fitv" + "SASView file (*.svs)|*.svs"] + # List of allowed extensions + ext = ['.fitv', '.FITV', '.svs', 'SVS'] + + def __init__(self, call_back=None, cansas=True): + CansasReader.__init__(self) + """ + Initialize the call-back method to be called + after we load a file + + :param call_back: call-back method + :param cansas: True = files will be written/read in CanSAS format + False = write CanSAS format + + """ + # Call back method to be executed after a file is read + self.call_back = call_back + # CanSAS format flag + self.cansas = cansas + self.state = None + # batch fitting params for saving + self.batchfit_params = [] + + def get_state(self): + return self.state + + def read(self, path): + """ + Load a new P(r) inversion state from file + + :param path: file path + + """ + if self.cansas: + return self._read_cansas(path) + + def _parse_state(self, entry): + """ + Read a fit result from an XML node + + :param entry: XML node to read from + :return: PageState object + """ + # Create an empty state + state = None + # Locate the P(r) node + try: + nodes = entry.xpath('ns:%s' % FITTING_NODE_NAME, + namespaces=CANSAS_NS) + if nodes: + # Create an empty state + state = PageState() + state.from_xml(node=nodes[0]) + + except Exception: + logger.info("XML document does not contain fitting information.\n" + + traceback.format_exc()) + + return state + + def _parse_simfit_state(self, entry): + """ + Parses the saved data for a simultaneous fit + :param entry: XML object to read from + :return: XML object for a simultaneous fit or None + """ + nodes = entry.xpath('ns:%s' % FITTING_NODE_NAME, + namespaces=CANSAS_NS) + if nodes: + simfitstate = nodes[0].xpath('ns:simultaneous_fit', + namespaces=CANSAS_NS) + if simfitstate: + sim_fit_state = SimFitPageState() + simfitstate_0 = simfitstate[0] + all = simfitstate_0.xpath('ns:select_all', + namespaces=CANSAS_NS) + atts = all[0].attrib + checked = atts.get('checked') + sim_fit_state.select_all = bool(checked) + model_list = simfitstate_0.xpath('ns:model_list', + namespaces=CANSAS_NS) + model_list_items = model_list[0].xpath('ns:model_list_item', + namespaces=CANSAS_NS) + for model in model_list_items: + attrs = model.attrib + sim_fit_state.model_list.append(attrs) + + constraints = simfitstate_0.xpath('ns:constraints', + namespaces=CANSAS_NS) + constraint_list = constraints[0].xpath('ns:constraint', + namespaces=CANSAS_NS) + for constraint in constraint_list: + attrs = constraint.attrib + sim_fit_state.constraints_list.append(attrs) + + return sim_fit_state + else: + return None + + def _parse_save_state_entry(self, dom): + """ + Parse a SASentry + + :param node: SASentry node + + :return: Data1D/Data2D object + + """ + node = dom.xpath('ns:data_class', namespaces=CANSAS_NS) + return_value, _ = self._parse_entry(dom) + return return_value, _ + + def _read_cansas(self, path): + """ + Load data and fitting information from a CanSAS XML file. + + :param path: file path + :return: Data1D object if a single SASentry was found, + or a list of Data1D objects if multiple entries were found, + or None of nothing was found + :raise RuntimeError: when the file can't be opened + :raise ValueError: when the length of the data vectors are inconsistent + """ + output = [] + simfitstate = None + basename = os.path.basename(path) + root, extension = os.path.splitext(basename) + ext = extension.lower() + try: + if os.path.isfile(path): + if ext in self.ext or ext == '.xml': + tree = etree.parse(path, parser=etree.ETCompatXMLParser()) + # Check the format version number + # Specifying the namespace will take care of the file + # format version + root = tree.getroot() + entry_list = root.xpath('ns:SASentry', + namespaces=CANSAS_NS) + for entry in entry_list: + fitstate = self._parse_state(entry) + # state could be None when .svs file is loaded + # in this case, skip appending to output + if fitstate is not None: + try: + sas_entry, _ = self._parse_save_state_entry( + entry) + except: + raise + sas_entry.meta_data['fitstate'] = fitstate + sas_entry.filename = fitstate.file + output.append(sas_entry) + + else: + self.call_back(format=ext) + raise RuntimeError("%s is not a file" % path) + + # Return output consistent with the loader's api + if len(output) == 0: + self.call_back(state=None, datainfo=None, format=ext) + return None + else: + for data in output: + # Call back to post the new state + state = data.meta_data['fitstate'] + t = time.localtime(state.timestamp) + time_str = time.strftime("%b %d %H:%M", t) + # Check that no time stamp is already appended + max_char = state.file.find("[") + if max_char < 0: + max_char = len(state.file) + original_fname = state.file[0:max_char] + state.file = original_fname + ' [' + time_str + ']' + + if state is not None and state.is_data is not None: + data.is_data = state.is_data + + data.filename = state.file + state.data = data + state.data.name = data.filename # state.data_name + state.data.id = state.data_id + if state.is_data is not None: + state.data.is_data = state.is_data + if data.run_name is not None and len(data.run_name) != 0: + if isinstance(data.run_name, dict): + # Note: key order in dict is not guaranteed, so sort + name = data.run_name.keys()[0] + else: + name = data.run_name + else: + name = original_fname + state.data.group_id = name + state.version = fitstate.version + # store state in fitting + self.call_back(state=state, datainfo=data, format=ext) + self.state = state + simfitstate = self._parse_simfit_state(entry) + if simfitstate is not None: + self.call_back(state=simfitstate) + + return output + except: + self.call_back(format=ext) + raise + + def write(self, filename, datainfo=None, fitstate=None): + """ + Write the content of a Data1D as a CanSAS XML file only for standalone + + :param filename: name of the file to write + :param datainfo: Data1D object + :param fitstate: PageState object + + """ + # Sanity check + if self.cansas: + # Add fitting information to the XML document + doc = self.write_toXML(datainfo, fitstate) + # Write the XML document + else: + doc = fitstate.to_xml(file=filename) + + # Save the document no matter the type + fd = open(filename, 'w') + fd.write(doc.toprettyxml()) + fd.close() + + def write_toXML(self, datainfo=None, state=None, batchfit=None): + """ + Write toXML, a helper for write(), + could be used by guimanager._on_save() + + : return: xml doc + """ + + self.batchfit_params = batchfit + if state.data is None or not state.data.is_data: + return None + # make sure title and data run are filled. + if state.data.title is None or state.data.title == '': + state.data.title = state.data.name + if state.data.run_name is None or state.data.run_name == {}: + state.data.run = [str(state.data.name)] + state.data.run_name[0] = state.data.name + + data = state.data + doc, sasentry = self._to_xml_doc(data) + + if state is not None: + doc = state.to_xml(doc=doc, file=data.filename, entry_node=sasentry, + batch_fit_state=self.batchfit_params) + + return doc + +# Simple html report template +HEADER = "\n" +HEADER += "\n" +HEADER += " \n" +HEADER += "\n" +HEADER += "\n" +HEADER += "\n" +HEADER += "
\n" +HEADER += "

" +HEADER += "%s

" +HEADER += "

 

" +PARA = "

%s \n" +PARA += "

" +CENTRE = "

%s \n" +CENTRE += "

" +FEET_1 = \ +""" +

 

+
+

Graph +

+

 

+
+
Model Computation +
Data: "%s"
+""" +FEET_2 = \ +""" +""" +FEET_2_unix = \ +""" +""" +FEET_3 = \ +"""
+
+ + +""" +ELINE = """

 

+""" diff --git a/sas/sascalc/fit/pluginmodel.py b/sas/sascalc/fit/pluginmodel.py new file mode 100755 index 000000000..b547d0606 --- /dev/null +++ b/sas/sascalc/fit/pluginmodel.py @@ -0,0 +1,66 @@ +from sas.sascalc.calculator.BaseComponent import BaseComponent +import math + +class Model1DPlugin(BaseComponent): + is_multiplicity_model = False + + ## Name of the model + + def __init__(self , name="Plugin Model" ): + """ Initialization """ + BaseComponent.__init__(self) + self.name = name + self.details = {} + self.params = {} + self.description = 'plugin model' + + def function(self, x): + """ + Function to be implemented by the plug-in writer + """ + return x + + def run(self, x = 0.0): + """ + Evaluate the model + + :param x: input x, or [x, phi] [radian] + + :return: function value + + """ + if x.__class__.__name__ == 'list': + x_val = x[0]*math.cos(x[1]) + y_val = x[0]*math.sin(x[1]) + return self.function(x_val)*self.function(y_val) + elif x.__class__.__name__ == 'tuple': + raise ValueError("Tuples are not allowed as input to BaseComponent models") + else: + return self.function(x) + + + def runXY(self, x = 0.0): + """ + Evaluate the model + + :param x: input x, or [x, y] + + :return: function value + + """ + if x.__class__.__name__ == 'list': + return self.function(x[0])*self.function(x[1]) + elif x.__class__.__name__ == 'tuple': + raise ValueError("Tuples are not allowed as input to BaseComponent models") + else: + return self.function(x) + + def set_details(self): + """ + Set default details + """ + if not self.params: + return {} + + for key in self.params.keys(): + self.details[key] = ['', None, None] \ No newline at end of file diff --git a/sas/sascalc/fit/qsmearing.py b/sas/sascalc/fit/qsmearing.py new file mode 100755 index 000000000..321a4b55c --- /dev/null +++ b/sas/sascalc/fit/qsmearing.py @@ -0,0 +1,254 @@ +""" + Handle Q smearing +""" +##################################################################### +#This software was developed by the University of Tennessee as part of the +#Distributed Data Analysis of Neutron Scattering Experiments (DANSE) +#project funded by the US National Science Foundation. +#See the license text in license.txt +#copyright 2008, University of Tennessee +###################################################################### +import math +import logging +import sys + +import numpy as np # type: ignore +from numpy import pi, exp # type:ignore + +from sasmodels.resolution import Slit1D, Pinhole1D +from sasmodels.sesans import SesansTransform +from sasmodels.resolution2d import Pinhole2D + +from sas.sascalc.data_util.nxsunit import Converter + +def smear_selection(data, model = None): + """ + Creates the right type of smearer according + to the data. + The canSAS format has a rule that either + slit smearing data OR resolution smearing data + is available. + + For the present purpose, we choose the one that + has none-zero data. If both slit and resolution + smearing arrays are filled with good data + (which should not happen), then we choose the + resolution smearing data. + + :param data: Data1D object + :param model: sas.model instance + """ + # Sanity check. If we are not dealing with a SAS Data1D + # object, just return None + # This checks for 2D data (does not throw exception because fail is common) + if data.__class__.__name__ not in ['Data1D', 'Theory1D']: + if data is None: + return None + elif data.dqx_data is None or data.dqy_data is None: + return None + return PySmear2D(data) + # This checks for 1D data with smearing info in the data itself (again, fail is likely; no exceptions) + if not hasattr(data, "dx") and not hasattr(data, "dxl")\ + and not hasattr(data, "dxw"): + return None + + # Look for resolution smearing data + # This is the code that checks for SESANS data; it looks for the file loader + # TODO: change other sanity checks to check for file loader instead of data structure? + _found_sesans = False + #if data.dx is not None and data.meta_data['loader']=='SESANS': + if data.dx is not None and data.isSesans: + #if data.dx[0] > 0.0: + if np.size(data.dx[data.dx <= 0]) == 0: + _found_sesans = True + # if data.dx[0] <= 0.0: + if np.size(data.dx[data.dx <= 0]) > 0: + raise ValueError('one or more of your dx values are negative, please check the data file!') + + if _found_sesans: + # Pre-compute the Hankel matrix (H) + SElength = Converter(data._xunit)(data.x, "A") + + theta_max = Converter("radians")(data.sample.zacceptance)[0] + q_max = 2 * np.pi / np.max(data.source.wavelength) * np.sin(theta_max) + zaccept = Converter("1/A")(q_max, "1/" + data.source.wavelength_unit), + + Rmax = 10000000 + hankel = SesansTransform(data.x, SElength, + data.source.wavelength, + zaccept, Rmax) + # Then return the actual transform, as if it were a smearing function + return PySmear(hankel, model, offset=0) + + _found_resolution = False + if data.dx is not None and len(data.dx) == len(data.x): + + # Check that we have non-zero data + if data.dx[0] > 0.0: + _found_resolution = True + #print "_found_resolution",_found_resolution + #print "data1D.dx[0]",data1D.dx[0],data1D.dxl[0] + # If we found resolution smearing data, return a QSmearer + if _found_resolution: + return pinhole_smear(data, model) + + # Look for slit smearing data + _found_slit = False + if data.dxl is not None and len(data.dxl) == len(data.x) \ + and data.dxw is not None and len(data.dxw) == len(data.x): + + # Check that we have non-zero data + if data.dxl[0] > 0.0 or data.dxw[0] > 0.0: + _found_slit = True + + # Sanity check: all data should be the same as a function of Q + for item in data.dxl: + if data.dxl[0] != item: + _found_resolution = False + break + + for item in data.dxw: + if data.dxw[0] != item: + _found_resolution = False + break + # If we found slit smearing data, return a slit smearer + if _found_slit: + return slit_smear(data, model) + return None + + +class PySmear(object): + """ + Wrapper for pure python sasmodels resolution functions. + """ + def __init__(self, resolution, model, offset=None): + self.model = model + self.resolution = resolution + if offset is None: + offset = np.searchsorted(self.resolution.q_calc, self.resolution.q[0]) + self.offset = offset + + def apply(self, iq_in, first_bin=0, last_bin=None): + """ + Apply the resolution function to the data. + Note that this is called with iq_in matching data.x, but with + iq_in[first_bin:last_bin] set to theory values for these bins, + and the remainder left undefined. The first_bin, last_bin values + should be those returned from get_bin_range. + The returned value is of the same length as iq_in, with the range + first_bin:last_bin set to the resolution smeared values. + """ + if last_bin is None: last_bin = len(iq_in) + start, end = first_bin + self.offset, last_bin + self.offset + q_calc = self.resolution.q_calc + iq_calc = np.empty_like(q_calc) + if start > 0: + iq_calc[:start] = self.model.evalDistribution(q_calc[:start]) + if end+1 < len(q_calc): + iq_calc[end+1:] = self.model.evalDistribution(q_calc[end+1:]) + iq_calc[start:end+1] = iq_in[first_bin:last_bin+1] + smeared = self.resolution.apply(iq_calc) + return smeared + __call__ = apply + + def get_bin_range(self, q_min=None, q_max=None): + """ + For a given q_min, q_max, find the corresponding indices in the data. + Returns first, last. + Note that these are indexes into q from the data, not the q_calc + needed by the resolution function. Note also that these are the + indices, not the range limits. That is, the complete range will be + q[first:last+1]. + """ + q = self.resolution.q + first = np.searchsorted(q, q_min) + last = np.searchsorted(q, q_max) + return first, min(last,len(q)-1) + +def slit_smear(data, model=None): + q = data.x + width = data.dxw if data.dxw is not None else 0 + height = data.dxl if data.dxl is not None else 0 + # TODO: width and height seem to be reversed + return PySmear(Slit1D(q, height, width), model) + +def pinhole_smear(data, model=None): + q = data.x + width = data.dx if data.dx is not None else 0 + return PySmear(Pinhole1D(q, width), model) + + +class PySmear2D(object): + """ + Q smearing class for SAS 2d pinhole data + """ + + def __init__(self, data=None, model=None): + self.data = data + self.model = model + self.accuracy = 'Low' + self.limit = 3.0 + self.index = None + self.coords = 'polar' + self.smearer = True + + def set_accuracy(self, accuracy='Low'): + """ + Set accuracy. + + :param accuracy: string + """ + self.accuracy = accuracy + + def set_smearer(self, smearer=True): + """ + Set whether or not smearer will be used + + :param smearer: smear object + + """ + self.smearer = smearer + + def set_data(self, data=None): + """ + Set data. + + :param data: DataLoader.Data_info type + """ + self.data = data + + def set_model(self, model=None): + """ + Set model. + + :param model: sas.models instance + """ + self.model = model + + def set_index(self, index=None): + """ + Set index. + + :param index: 1d arrays + """ + self.index = index + + def get_value(self): + """ + Over sampling of r_nbins times phi_nbins, calculate Gaussian weights, + then find smeared intensity + """ + if self.smearer: + res = Pinhole2D(data=self.data, index=self.index, + nsigma=3.0, accuracy=self.accuracy, + coords=self.coords) + val = self.model.evalDistribution(res.q_calc) + return res.apply(val) + else: + index = self.index if self.index is not None else slice(None) + qx_data = self.data.qx_data[index] + qy_data = self.data.qy_data[index] + q_calc = [qx_data, qy_data] + val = self.model.evalDistribution(q_calc) + return val + diff --git a/sas/sascalc/invariant/__init__.py b/sas/sascalc/invariant/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/sas/sascalc/invariant/invariant.py b/sas/sascalc/invariant/invariant.py new file mode 100755 index 000000000..0a9a99267 --- /dev/null +++ b/sas/sascalc/invariant/invariant.py @@ -0,0 +1,962 @@ +# pylint: disable=invalid-name +##################################################################### +#This software was developed by the University of Tennessee as part of the +#Distributed Data Analysis of Neutron Scattering Experiments (DANSE) +#project funded by the US National Science Foundation. +#See the license text in license.txt +#copyright 2010, University of Tennessee +###################################################################### + +""" +This module implements invariant and its related computations. + +:author: Gervaise B. Alina/UTK +:author: Mathieu Doucet/UTK +:author: Jae Cho/UTK + +""" +import math +import numpy as np + +from sas.sascalc.dataloader.data_info import Data1D as LoaderData1D + +# The minimum q-value to be used when extrapolating +Q_MINIMUM = 1e-5 + +# The maximum q-value to be used when extrapolating +Q_MAXIMUM = 10 + +# Number of steps in the extrapolation +INTEGRATION_NSTEPS = 1000 + +class Transform(object): + """ + Define interface that need to compute a function or an inverse + function given some x, y + """ + + def linearize_data(self, data): + """ + Linearize data so that a linear fit can be performed. + Filter out the data that can't be transformed. + + :param data: LoadData1D instance + + """ + # Check that the vector lengths are equal + assert len(data.x) == len(data.y) + if data.dy is not None: + assert len(data.x) == len(data.dy) + dy = data.dy + else: + dy = np.ones(len(data.y)) + + # Transform the data + data_points = zip(data.x, data.y, dy) + + output_points = [(self.linearize_q_value(p[0]), + math.log(p[1]), + p[2] / p[1]) for p in data_points if p[0] > 0 and \ + p[1] > 0 and p[2] > 0] + + x_out, y_out, dy_out = zip(*output_points) + + # Create Data1D object + x_out = np.asarray(x_out) + y_out = np.asarray(y_out) + dy_out = np.asarray(dy_out) + linear_data = LoaderData1D(x=x_out, y=y_out, dy=dy_out) + + return linear_data + + def get_allowed_bins(self, data): + """ + Goes through the data points and returns a list of boolean values + to indicate whether each points is allowed by the model or not. + + :param data: Data1D object + """ + return [p[0] > 0 and p[1] > 0 and p[2] > 0 for p in zip(data.x, data.y, + data.dy)] + + def linearize_q_value(self, value): + """ + Transform the input q-value for linearization + """ + return NotImplemented + + def extract_model_parameters(self, constant, slope, dconstant=0, dslope=0): + """ + set private member + """ + return NotImplemented + + def evaluate_model(self, x): + """ + Returns an array f(x) values where f is the Transform function. + """ + return NotImplemented + + def evaluate_model_errors(self, x): + """ + Returns an array of I(q) errors + """ + return NotImplemented + +class Guinier(Transform): + """ + class of type Transform that performs operations related to guinier + function + """ + def __init__(self, scale=1, radius=60): + Transform.__init__(self) + self.scale = scale + self.radius = radius + ## Uncertainty of scale parameter + self.dscale = 0 + ## Unvertainty of radius parameter + self.dradius = 0 + + def linearize_q_value(self, value): + """ + Transform the input q-value for linearization + + :param value: q-value + + :return: q*q + """ + return value * value + + def extract_model_parameters(self, constant, slope, dconstant=0, dslope=0): + """ + assign new value to the scale and the radius + """ + self.scale = math.exp(constant) + if slope > 0: + slope = 0.0 + self.radius = math.sqrt(-3 * slope) + # Errors + self.dscale = math.exp(constant) * dconstant + if slope == 0.0: + n_zero = -1.0e-24 + self.dradius = -3.0 / 2.0 / math.sqrt(-3 * n_zero) * dslope + else: + self.dradius = -3.0 / 2.0 / math.sqrt(-3 * slope) * dslope + + return [self.radius, self.scale], [self.dradius, self.dscale] + + def evaluate_model(self, x): + """ + return F(x)= scale* e-((radius*x)**2/3) + """ + return self._guinier(x) + + def evaluate_model_errors(self, x): + """ + Returns the error on I(q) for the given array of q-values + + :param x: array of q-values + """ + p1 = np.array([self.dscale * math.exp(-((self.radius * q) ** 2 / 3)) \ + for q in x]) + p2 = np.array([self.scale * math.exp(-((self.radius * q) ** 2 / 3))\ + * (-(q ** 2 / 3)) * 2 * self.radius * self.dradius for q in x]) + diq2 = p1 * p1 + p2 * p2 + return np.array([math.sqrt(err) for err in diq2]) + + def _guinier(self, x): + """ + Retrieve the guinier function after apply an inverse guinier function + to x + Compute a F(x) = scale* e-((radius*x)**2/3). + + :param x: a vector of q values + :param scale: the scale value + :param radius: the guinier radius value + + :return: F(x) + """ + # transform the radius of coming from the inverse guinier function to a + # a radius of a guinier function + if self.radius <= 0: + msg = "Rg expected positive value, but got %s" % self.radius + raise ValueError(msg) + value = np.array([math.exp(-((self.radius * i) ** 2 / 3)) for i in x]) + return self.scale * value + +class PowerLaw(Transform): + """ + class of type transform that perform operation related to power_law + function + """ + def __init__(self, scale=1, power=4): + Transform.__init__(self) + self.scale = scale + self.power = power + self.dscale = 0.0 + self.dpower = 0.0 + + def linearize_q_value(self, value): + """ + Transform the input q-value for linearization + + :param value: q-value + + :return: log(q) + """ + return math.log(value) + + def extract_model_parameters(self, constant, slope, dconstant=0, dslope=0): + """ + Assign new value to the scale and the power + """ + self.power = -slope + self.scale = math.exp(constant) + + # Errors + self.dscale = math.exp(constant) * dconstant + self.dpower = -dslope + + return [self.power, self.scale], [self.dpower, self.dscale] + + def evaluate_model(self, x): + """ + given a scale and a radius transform x, y using a power_law + function + """ + return self._power_law(x) + + def evaluate_model_errors(self, x): + """ + Returns the error on I(q) for the given array of q-values + :param x: array of q-values + """ + p1 = np.array([self.dscale * math.pow(q, -self.power) for q in x]) + p2 = np.array([self.scale * self.power * math.pow(q, -self.power - 1)\ + * self.dpower for q in x]) + diq2 = p1 * p1 + p2 * p2 + return np.array([math.sqrt(err) for err in diq2]) + + def _power_law(self, x): + """ + F(x) = scale* (x)^(-power) + when power= 4. the model is porod + else power_law + The model has three parameters: :: + 1. x: a vector of q values + 2. power: power of the function + 3. scale : scale factor value + + :param x: array + :return: F(x) + """ + if self.power <= 0: + msg = "Power_law function expected positive power," + msg += " but got %s" % self.power + raise ValueError(msg) + if self.scale <= 0: + msg = "scale expected positive value, but got %s" % self.scale + raise ValueError(msg) + + value = np.array([math.pow(i, -self.power) for i in x]) + return self.scale * value + +class Extrapolator(object): + """ + Extrapolate I(q) distribution using a given model + """ + def __init__(self, data, model=None): + """ + Determine a and b given a linear equation y = ax + b + + If a model is given, it will be used to linearize the data before + the extrapolation is performed. If None, + a simple linear fit will be done. + + :param data: data containing x and y such as y = ax + b + :param model: optional Transform object + """ + self.data = data + self.model = model + + # Set qmin as the lowest non-zero value + self.qmin = Q_MINIMUM + for q_value in self.data.x: + if q_value > 0: + self.qmin = q_value + break + self.qmax = max(self.data.x) + + def fit(self, power=None, qmin=None, qmax=None): + """ + Fit data for y = ax + b return a and b + + :param power: a fixed, otherwise None + :param qmin: Minimum Q-value + :param qmax: Maximum Q-value + """ + if qmin is None: + qmin = self.qmin + if qmax is None: + qmax = self.qmax + + # Identify the bin range for the fit + idx = (self.data.x >= qmin) & (self.data.x <= qmax) + + fx = np.zeros(len(self.data.x)) + + # Uncertainty + if type(self.data.dy) == np.ndarray and \ + len(self.data.dy) == len(self.data.x) and \ + np.all(self.data.dy > 0): + sigma = self.data.dy + else: + sigma = np.ones(len(self.data.x)) + + # Compute theory data f(x) + fx[idx] = self.data.y[idx] + + # Linearize the data + if self.model is not None: + linearized_data = self.model.linearize_data(\ + LoaderData1D(self.data.x[idx], + fx[idx], + dy=sigma[idx])) + else: + linearized_data = LoaderData1D(self.data.x[idx], + fx[idx], + dy=sigma[idx]) + + ##power is given only for function = power_law + if power is not None: + sigma2 = linearized_data.dy * linearized_data.dy + a = -(power) + b = (np.sum(linearized_data.y / sigma2) \ + - a * np.sum(linearized_data.x / sigma2)) / np.sum(1.0 / sigma2) + + + deltas = linearized_data.x * a + \ + np.ones(len(linearized_data.x)) * b - linearized_data.y + residuals = np.sum(deltas * deltas / sigma2) + + err = math.fabs(residuals) / np.sum(1.0 / sigma2) + return [a, b], [0, math.sqrt(err)] + else: + A = np.vstack([linearized_data.x / linearized_data.dy, 1.0 / linearized_data.dy]).T + (p, residuals, _, _) = np.linalg.lstsq(A, linearized_data.y / linearized_data.dy) + + # Get the covariance matrix, defined as inv_cov = a_transposed * a + err = np.zeros(2) + try: + inv_cov = np.dot(A.transpose(), A) + cov = np.linalg.pinv(inv_cov) + err_matrix = math.fabs(residuals) * cov + err = [math.sqrt(err_matrix[0][0]), math.sqrt(err_matrix[1][1])] + except: + err = [-1.0, -1.0] + + return p, err + + +class InvariantCalculator(object): + """ + Compute invariant if data is given. + Can provide volume fraction and surface area if the user provides + Porod constant and contrast values. + + :precondition: the user must send a data of type DataLoader.Data1D + the user provide background and scale values. + + :note: Some computations depends on each others. + """ + def __init__(self, data, background=0, scale=1): + """ + Initialize variables. + + :param data: data must be of type DataLoader.Data1D + :param background: Background value. The data will be corrected + before processing + :param scale: Scaling factor for I(q). The data will be corrected + before processing + """ + # Background and scale should be private data member if the only way to + # change them are by instantiating a new object. + self._background = background + self._scale = scale + # slit height for smeared data + self._smeared = None + # The data should be private + self._data = self._get_data(data) + # get the dxl if the data is smeared: This is done only once on init. + if self._data.dxl is not None and self._data.dxl.all() > 0: + # assumes constant dxl + self._smeared = self._data.dxl[0] + + # Since there are multiple variants of Q*, you should force the + # user to use the get method and keep Q* a private data member + self._qstar = None + + # You should keep the error on Q* so you can reuse it without + # recomputing the whole thing. + self._qstar_err = 0 + + # Extrapolation parameters + self._low_extrapolation_npts = 4 + self._low_extrapolation_function = Guinier() + self._low_extrapolation_power = None + self._low_extrapolation_power_fitted = None + + self._high_extrapolation_npts = 4 + self._high_extrapolation_function = PowerLaw() + self._high_extrapolation_power = None + self._high_extrapolation_power_fitted = None + + # Extrapolation range + self._low_q_limit = Q_MINIMUM + + def _get_data(self, data): + """ + :note: this function must be call before computing any type + of invariant + + :return: new data = self._scale *data - self._background + """ + if not issubclass(data.__class__, LoaderData1D): + #Process only data that inherited from DataLoader.Data_info.Data1D + raise ValueError("Data must be of type DataLoader.Data1D") + #from copy import deepcopy + new_data = (self._scale * data) - self._background + + # Check that the vector lengths are equal + assert len(new_data.x) == len(new_data.y) + + # Verify that the errors are set correctly + if new_data.dy is None or len(new_data.x) != len(new_data.dy) or \ + (min(new_data.dy) == 0 and max(new_data.dy) == 0): + new_data.dy = np.ones(len(new_data.x)) + return new_data + + def _fit(self, model, qmin=Q_MINIMUM, qmax=Q_MAXIMUM, power=None): + """ + fit data with function using + data = self._get_data() + fx = Functor(data , function) + y = data.y + slope, constant = linalg.lstsq(y,fx) + + :param qmin: data first q value to consider during the fit + :param qmax: data last q value to consider during the fit + :param power : power value to consider for power-law + :param function: the function to use during the fit + + :return a: the scale of the function + :return b: the other parameter of the function for guinier will be radius + for power_law will be the power value + """ + extrapolator = Extrapolator(data=self._data, model=model) + p, dp = extrapolator.fit(power=power, qmin=qmin, qmax=qmax) + + return model.extract_model_parameters(constant=p[1], slope=p[0], + dconstant=dp[1], dslope=dp[0]) + + def _get_qstar(self, data): + """ + Compute invariant for pinhole data. + This invariant is given by: :: + + q_star = x0**2 *y0 *dx0 +x1**2 *y1 *dx1 + + ..+ xn**2 *yn *dxn for non smeared data + + q_star = dxl0 *x0 *y0 *dx0 +dxl1 *x1 *y1 *dx1 + + ..+ dlxn *xn *yn *dxn for smeared data + + where n >= len(data.x)-1 + dxl = slit height dQl + dxi = 1/2*(xi+1 - xi) + (xi - xi-1) + dx0 = (x1 - x0)/2 + dxn = (xn - xn-1)/2 + + :param data: the data to use to compute invariant. + + :return q_star: invariant value for pinhole data. q_star > 0 + """ + if len(data.x) <= 1 or len(data.y) <= 1 or len(data.x) != len(data.y): + msg = "Length x and y must be equal" + msg += " and greater than 1; got x=%s, y=%s" % (len(data.x), len(data.y)) + raise ValueError(msg) + else: + # Take care of smeared data + if self._smeared is None: + gx = data.x * data.x + # assumes that len(x) == len(dxl). + else: + gx = data.dxl * data.x + + n = len(data.x) - 1 + #compute the first delta q + dx0 = (data.x[1] - data.x[0]) / 2 + #compute the last delta q + dxn = (data.x[n] - data.x[n - 1]) / 2 + total = 0 + total += gx[0] * data.y[0] * dx0 + total += gx[n] * data.y[n] * dxn + + if len(data.x) == 2: + return total + else: + #iterate between for element different + #from the first and the last + for i in range(1, n - 1): + dxi = (data.x[i + 1] - data.x[i - 1]) / 2 + total += gx[i] * data.y[i] * dxi + return total + + def _get_qstar_uncertainty(self, data): + """ + Compute invariant uncertainty with with pinhole data. + This uncertainty is given as follow: :: + + dq_star = math.sqrt[(x0**2*(dy0)*dx0)**2 + + (x1**2 *(dy1)*dx1)**2 + ..+ (xn**2 *(dyn)*dxn)**2 ] + where n >= len(data.x)-1 + dxi = 1/2*(xi+1 - xi) + (xi - xi-1) + dx0 = (x1 - x0)/2 + dxn = (xn - xn-1)/2 + dyn: error on dy + + :param data: + :note: if data doesn't contain dy assume dy= math.sqrt(data.y) + """ + if len(data.x) <= 1 or len(data.y) <= 1 or \ + len(data.x) != len(data.y) or \ + (data.dy is not None and (len(data.dy) != len(data.y))): + msg = "Length of data.x and data.y must be equal" + msg += " and greater than 1; got x=%s, y=%s" % (len(data.x), len(data.y)) + raise ValueError(msg) + else: + #Create error for data without dy error + if data.dy is None: + dy = math.sqrt(data.y) + else: + dy = data.dy + # Take care of smeared data + if self._smeared is None: + gx = data.x * data.x + # assumes that len(x) == len(dxl). + else: + gx = data.dxl * data.x + + n = len(data.x) - 1 + #compute the first delta + dx0 = (data.x[1] - data.x[0]) / 2 + #compute the last delta + dxn = (data.x[n] - data.x[n - 1]) / 2 + total = 0 + total += (gx[0] * dy[0] * dx0) ** 2 + total += (gx[n] * dy[n] * dxn) ** 2 + if len(data.x) == 2: + return math.sqrt(total) + else: + #iterate between for element different + #from the first and the last + for i in range(1, n - 1): + dxi = (data.x[i + 1] - data.x[i - 1]) / 2 + total += (gx[i] * dy[i] * dxi) ** 2 + return math.sqrt(total) + + def _get_extrapolated_data(self, model, npts=INTEGRATION_NSTEPS, + q_start=Q_MINIMUM, q_end=Q_MAXIMUM): + """ + :return: extrapolate data create from data + """ + #create new Data1D to compute the invariant + q = np.linspace(start=q_start, + stop=q_end, + num=npts, + endpoint=True) + iq = model.evaluate_model(q) + diq = model.evaluate_model_errors(q) + + result_data = LoaderData1D(x=q, y=iq, dy=diq) + if self._smeared is not None: + result_data.dxl = self._smeared * np.ones(len(q)) + return result_data + + def get_data(self): + """ + :return: self._data + """ + return self._data + + def get_extrapolation_power(self, range='high'): + """ + :return: the fitted power for power law function for a given + extrapolation range + """ + if range == 'low': + return self._low_extrapolation_power_fitted + return self._high_extrapolation_power_fitted + + def get_qstar_low(self): + """ + Compute the invariant for extrapolated data at low q range. + + Implementation: + data = self._get_extra_data_low() + return self._get_qstar() + + :return q_star: the invariant for data extrapolated at low q. + """ + # Data boundaries for fitting + qmin = self._data.x[0] + qmax = self._data.x[int(self._low_extrapolation_npts - 1)] + + # Extrapolate the low-Q data + p, _ = self._fit(model=self._low_extrapolation_function, + qmin=qmin, + qmax=qmax, + power=self._low_extrapolation_power) + self._low_extrapolation_power_fitted = p[0] + + # Distribution starting point + self._low_q_limit = Q_MINIMUM + if Q_MINIMUM >= qmin: + self._low_q_limit = qmin / 10 + + data = self._get_extrapolated_data(\ + model=self._low_extrapolation_function, + npts=INTEGRATION_NSTEPS, + q_start=self._low_q_limit, q_end=qmin) + + # Systematic error + # If we have smearing, the shape of the I(q) distribution at low Q will + # may not be a Guinier or simple power law. The following is + # a conservative estimation for the systematic error. + err = qmin * qmin * math.fabs((qmin - self._low_q_limit) * \ + (data.y[0] - data.y[INTEGRATION_NSTEPS - 1])) + return self._get_qstar(data), self._get_qstar_uncertainty(data) + err + + def get_qstar_high(self): + """ + Compute the invariant for extrapolated data at high q range. + + Implementation: + data = self._get_extra_data_high() + return self._get_qstar() + + :return q_star: the invariant for data extrapolated at high q. + """ + # Data boundaries for fitting + x_len = len(self._data.x) - 1 + qmin = self._data.x[int(x_len - (self._high_extrapolation_npts - 1))] + qmax = self._data.x[int(x_len)] + + # fit the data with a model to get the appropriate parameters + p, _ = self._fit(model=self._high_extrapolation_function, + qmin=qmin, + qmax=qmax, + power=self._high_extrapolation_power) + self._high_extrapolation_power_fitted = p[0] + + #create new Data1D to compute the invariant + data = self._get_extrapolated_data(\ + model=self._high_extrapolation_function, + npts=INTEGRATION_NSTEPS, + q_start=qmax, q_end=Q_MAXIMUM) + + return self._get_qstar(data), self._get_qstar_uncertainty(data) + + def get_extra_data_low(self, npts_in=None, q_start=None, npts=20): + """ + Returns the extrapolated data used for the loew-Q invariant calculation. + By default, the distribution will cover the data points used for the + extrapolation. The number of overlap points is a parameter (npts_in). + By default, the maximum q-value of the distribution will be + the minimum q-value used when extrapolating for the purpose of the + invariant calculation. + + :param npts_in: number of data points for which + the extrapolated data overlap + :param q_start: is the minimum value to uses for extrapolated data + :param npts: the number of points in the extrapolated distribution + + """ + # Get extrapolation range + if q_start is None: + q_start = self._low_q_limit + + if npts_in is None: + npts_in = self._low_extrapolation_npts + q_end = self._data.x[max(0, int(npts_in - 1))] + + if q_start >= q_end: + return np.zeros(0), np.zeros(0) + + return self._get_extrapolated_data(\ + model=self._low_extrapolation_function, + npts=npts, + q_start=q_start, q_end=q_end) + + def get_extra_data_high(self, npts_in=None, q_end=Q_MAXIMUM, npts=20): + """ + Returns the extrapolated data used for the high-Q invariant calculation. + By default, the distribution will cover the data points used for the + extrapolation. The number of overlap points is a parameter (npts_in). + By default, the maximum q-value of the distribution will be Q_MAXIMUM, + the maximum q-value used when extrapolating for the purpose of the + invariant calculation. + + :param npts_in: number of data points for which the + extrapolated data overlap + :param q_end: is the maximum value to uses for extrapolated data + :param npts: the number of points in the extrapolated distribution + """ + # Get extrapolation range + if npts_in is None: + npts_in = int(self._high_extrapolation_npts) + _npts = len(self._data.x) + q_start = self._data.x[min(_npts, int(_npts - npts_in))] + + if q_start >= q_end: + return np.zeros(0), np.zeros(0) + + return self._get_extrapolated_data(\ + model=self._high_extrapolation_function, + npts=npts, + q_start=q_start, q_end=q_end) + + def set_extrapolation(self, range, npts=4, function=None, power=None): + """ + Set the extrapolation parameters for the high or low Q-range. + Note that this does not turn extrapolation on or off. + + :param range: a keyword set the type of extrapolation . type string + :param npts: the numbers of q points of data to consider + for extrapolation + :param function: a keyword to select the function to use + for extrapolation. + of type string. + :param power: an power to apply power_low function + + """ + range = range.lower() + if range not in ['high', 'low']: + raise ValueError("Extrapolation range should be 'high' or 'low'") + function = function.lower() + if function not in ['power_law', 'guinier']: + msg = "Extrapolation function should be 'guinier' or 'power_law'" + raise ValueError(msg) + + if range == 'high': + if function != 'power_law': + msg = "Extrapolation only allows a power law at high Q" + raise ValueError(msg) + self._high_extrapolation_npts = npts + self._high_extrapolation_power = power + self._high_extrapolation_power_fitted = power + else: + if function == 'power_law': + self._low_extrapolation_function = PowerLaw() + else: + self._low_extrapolation_function = Guinier() + self._low_extrapolation_npts = npts + self._low_extrapolation_power = power + self._low_extrapolation_power_fitted = power + + def get_qstar(self, extrapolation=None): + """ + Compute the invariant of the local copy of data. + + :param extrapolation: string to apply optional extrapolation + + :return q_star: invariant of the data within data's q range + + :warning: When using setting data to Data1D , + the user is responsible of + checking that the scale and the background are + properly apply to the data + + """ + self._qstar = self._get_qstar(self._data) + self._qstar_err = self._get_qstar_uncertainty(self._data) + + if extrapolation is None: + return self._qstar + + # Compute invariant plus invariant of extrapolated data + extrapolation = extrapolation.lower() + if extrapolation == "low": + qs_low, dqs_low = self.get_qstar_low() + qs_hi, dqs_hi = 0, 0 + + elif extrapolation == "high": + qs_low, dqs_low = 0, 0 + qs_hi, dqs_hi = self.get_qstar_high() + + elif extrapolation == "both": + qs_low, dqs_low = self.get_qstar_low() + qs_hi, dqs_hi = self.get_qstar_high() + + self._qstar += qs_low + qs_hi + self._qstar_err = math.sqrt(self._qstar_err * self._qstar_err \ + + dqs_low * dqs_low + dqs_hi * dqs_hi) + + return self._qstar + + def get_surface(self, contrast, porod_const, extrapolation=None): + """ + Compute the specific surface from the data. + + Implementation:: + + V = self.get_volume_fraction(contrast, extrapolation) + + Compute the surface given by: + surface = (2*pi *V(1- V)*porod_const)/ q_star + + :param contrast: contrast value to compute the volume + :param porod_const: Porod constant to compute the surface + :param extrapolation: string to apply optional extrapolation + + :return: specific surface + """ + # Compute the volume + volume = self.get_volume_fraction(contrast, extrapolation) + return 2 * math.pi * volume * (1 - volume) * \ + float(porod_const) / self._qstar + + def get_volume_fraction(self, contrast, extrapolation=None): + """ + Compute volume fraction is deduced as follow: :: + + q_star = 2*(pi*contrast)**2* volume( 1- volume) + for k = 10^(-8)*q_star/(2*(pi*|contrast|)**2) + we get 2 values of volume: + with 1 - 4 * k >= 0 + volume1 = (1- sqrt(1- 4*k))/2 + volume2 = (1+ sqrt(1- 4*k))/2 + + q_star: the invariant value included extrapolation is applied + unit 1/A^(3)*1/cm + q_star = self.get_qstar() + + the result returned will be 0 <= volume <= 1 + + :param contrast: contrast value provides by the user of type float. + contrast unit is 1/A^(2)= 10^(16)cm^(2) + :param extrapolation: string to apply optional extrapolation + + :return: volume fraction + + :note: volume fraction must have no unit + """ + if contrast <= 0: + raise ValueError("The contrast parameter must be greater than zero") + + # Make sure Q star is up to date + self.get_qstar(extrapolation) + + if self._qstar <= 0: + msg = "Invalid invariant: Invariant Q* must be greater than zero" + raise RuntimeError(msg) + + # Compute intermediate constant + k = 1.e-8 * self._qstar / (2 * (math.pi * math.fabs(float(contrast))) ** 2) + # Check discriminant value + discrim = 1 - 4 * k + + # Compute volume fraction + if discrim < 0: + msg = "Could not compute the volume fraction: negative discriminant" + raise RuntimeError(msg) + elif discrim == 0: + return 1 / 2 + else: + volume1 = 0.5 * (1 - math.sqrt(discrim)) + volume2 = 0.5 * (1 + math.sqrt(discrim)) + + if 0 <= volume1 and volume1 <= 1: + return volume1 + elif 0 <= volume2 and volume2 <= 1: + return volume2 + msg = "Could not compute the volume fraction: inconsistent results" + raise RuntimeError(msg) + + def get_qstar_with_error(self, extrapolation=None): + """ + Compute the invariant uncertainty. + This uncertainty computation depends on whether or not the data is + smeared. + + :param extrapolation: string to apply optional extrapolation + + :return: invariant, the invariant uncertainty + """ + self.get_qstar(extrapolation) + return self._qstar, self._qstar_err + + def get_volume_fraction_with_error(self, contrast, extrapolation=None): + """ + Compute uncertainty on volume value as well as the volume fraction + This uncertainty is given by the following equation: :: + + dV = 0.5 * (4*k* dq_star) /(2* math.sqrt(1-k* q_star)) + + for k = 10^(-8)*q_star/(2*(pi*|contrast|)**2) + + q_star: the invariant value including extrapolated value if existing + dq_star: the invariant uncertainty + dV: the volume uncertainty + + The uncertainty will be set to -1 if it can't be computed. + + :param contrast: contrast value + :param extrapolation: string to apply optional extrapolation + + :return: V, dV = volume fraction, error on volume fraction + """ + volume = self.get_volume_fraction(contrast, extrapolation) + + # Compute error + k = 1.e-8 * self._qstar / (2 * (math.pi * math.fabs(float(contrast))) ** 2) + # Check value inside the sqrt function + value = 1 - k * self._qstar + if (value) <= 0: + uncertainty = -1 + # Compute uncertainty + uncertainty = math.fabs((0.5 * 4 * k * \ + self._qstar_err) / (2 * math.sqrt(1 - k * self._qstar))) + + return volume, uncertainty + + def get_surface_with_error(self, contrast, porod_const, extrapolation=None): + """ + Compute uncertainty of the surface value as well as the surface value. + The uncertainty is given as follow: :: + + dS = porod_const *2*pi[( dV -2*V*dV)/q_star + + dq_star(v-v**2) + + q_star: the invariant value + dq_star: the invariant uncertainty + V: the volume fraction value + dV: the volume uncertainty + + :param contrast: contrast value + :param porod_const: porod constant value + :param extrapolation: string to apply optional extrapolation + + :return S, dS: the surface, with its uncertainty + """ + # We get the volume fraction, with error + # get_volume_fraction_with_error calls get_volume_fraction + # get_volume_fraction calls get_qstar + # which computes Qstar and dQstar + v, dv = self.get_volume_fraction_with_error(contrast, extrapolation) + + s = self.get_surface(contrast=contrast, porod_const=porod_const, + extrapolation=extrapolation) + ds = porod_const * 2 * math.pi * ((dv - 2 * v * dv) / self._qstar\ + + self._qstar_err * (v - v ** 2)) + + return s, ds diff --git a/sas/sascalc/invariant/invariant_mapper.py b/sas/sascalc/invariant/invariant_mapper.py new file mode 100755 index 000000000..2c7cda833 --- /dev/null +++ b/sas/sascalc/invariant/invariant_mapper.py @@ -0,0 +1,48 @@ +""" +This module is a wrapper to a map function. It allows to loop through +different invariant objects to call the same function +""" + + +def get_qstar(inv, extrapolation=None): + """ + Get invariant value (Q*) + """ + return inv.get_qstar(extrapolation) + +def get_qstar_with_error(inv, extrapolation=None): + """ + Get invariant value with uncertainty + """ + return inv.get_qstar_with_error(extrapolation) + +def get_volume_fraction(inv, contrast, extrapolation=None): + """ + Get volume fraction + """ + return inv.get_volume_fraction(contrast, extrapolation) + +def get_volume_fraction_with_error(inv, contrast, extrapolation=None): + """ + Get volume fraction with uncertainty + """ + return inv.get_volume_fraction_with_error(contrast, + extrapolation) + +def get_surface(inv, contrast, porod_const, extrapolation=None): + """ + Get surface with uncertainty + """ + return inv.get_surface(contrast=contrast, + porod_const=porod_const, + extrapolation=extrapolation) + +def get_surface_with_error(inv, contrast, + porod_const, extrapolation=None): + """ + Get surface with uncertainty + """ + return inv.get_surface_with_error(contrast=contrast, + porod_const=porod_const, + extrapolation=extrapolation) + diff --git a/sas/sascalc/pr/Doxyfile b/sas/sascalc/pr/Doxyfile new file mode 100755 index 000000000..a851e45f1 --- /dev/null +++ b/sas/sascalc/pr/Doxyfile @@ -0,0 +1,263 @@ +# Doxyfile 1.5.1-p1 + +#--------------------------------------------------------------------------- +# Project related configuration options +#--------------------------------------------------------------------------- +PROJECT_NAME = saspr +PROJECT_NUMBER = +OUTPUT_DIRECTORY = doc +CREATE_SUBDIRS = NO +OUTPUT_LANGUAGE = English +USE_WINDOWS_ENCODING = NO +BRIEF_MEMBER_DESC = YES +REPEAT_BRIEF = YES +ABBREVIATE_BRIEF = "The $name class" \ + "The $name widget" \ + "The $name file" \ + is \ + provides \ + specifies \ + contains \ + represents \ + a \ + an \ + the +ALWAYS_DETAILED_SEC = NO +INLINE_INHERITED_MEMB = NO +FULL_PATH_NAMES = YES +STRIP_FROM_PATH = +STRIP_FROM_INC_PATH = +SHORT_NAMES = NO +JAVADOC_AUTOBRIEF = NO +MULTILINE_CPP_IS_BRIEF = NO +DETAILS_AT_TOP = NO +INHERIT_DOCS = YES +SEPARATE_MEMBER_PAGES = NO +TAB_SIZE = 4 +ALIASES = +OPTIMIZE_OUTPUT_FOR_C = NO +OPTIMIZE_OUTPUT_JAVA = YES +BUILTIN_STL_SUPPORT = NO +DISTRIBUTE_GROUP_DOC = NO +SUBGROUPING = YES +#--------------------------------------------------------------------------- +# Build related configuration options +#--------------------------------------------------------------------------- +EXTRACT_ALL = NO +EXTRACT_PRIVATE = NO +EXTRACT_STATIC = NO +EXTRACT_LOCAL_CLASSES = YES +EXTRACT_LOCAL_METHODS = NO +HIDE_UNDOC_MEMBERS = NO +HIDE_UNDOC_CLASSES = NO +HIDE_FRIEND_COMPOUNDS = NO +HIDE_IN_BODY_DOCS = NO +INTERNAL_DOCS = NO +CASE_SENSE_NAMES = YES +HIDE_SCOPE_NAMES = NO +SHOW_INCLUDE_FILES = YES +INLINE_INFO = YES +SORT_MEMBER_DOCS = YES +SORT_BRIEF_DOCS = NO +SORT_BY_SCOPE_NAME = NO +GENERATE_TODOLIST = YES +GENERATE_TESTLIST = YES +GENERATE_BUGLIST = YES +GENERATE_DEPRECATEDLIST= YES +ENABLED_SECTIONS = +MAX_INITIALIZER_LINES = 30 +SHOW_USED_FILES = YES +SHOW_DIRECTORIES = YES +FILE_VERSION_FILTER = +#--------------------------------------------------------------------------- +# configuration options related to warning and progress messages +#--------------------------------------------------------------------------- +QUIET = NO +WARNINGS = YES +WARN_IF_UNDOCUMENTED = YES +WARN_IF_DOC_ERROR = YES +WARN_NO_PARAMDOC = NO +WARN_FORMAT = "$file:$line: $text" +WARN_LOGFILE = +#--------------------------------------------------------------------------- +# configuration options related to the input files +#--------------------------------------------------------------------------- +INPUT = . +FILE_PATTERNS = *.c \ + *.cc \ + *.cxx \ + *.cpp \ + *.c++ \ + *.d \ + *.java \ + *.ii \ + *.ixx \ + *.ipp \ + *.i++ \ + *.inl \ + *.h \ + *.hh \ + *.hxx \ + *.hpp \ + *.h++ \ + *.idl \ + *.odl \ + *.cs \ + *.php \ + *.php3 \ + *.inc \ + *.m \ + *.mm \ + *.dox \ + *.py +RECURSIVE = YES +EXCLUDE = build \ + dist \ + test +EXCLUDE_SYMLINKS = NO +EXCLUDE_PATTERNS = +EXAMPLE_PATH = +EXAMPLE_PATTERNS = * +EXAMPLE_RECURSIVE = NO +IMAGE_PATH = doc +INPUT_FILTER = +FILTER_PATTERNS = +FILTER_SOURCE_FILES = NO +#--------------------------------------------------------------------------- +# configuration options related to source browsing +#--------------------------------------------------------------------------- +SOURCE_BROWSER = NO +INLINE_SOURCES = NO +STRIP_CODE_COMMENTS = YES +REFERENCED_BY_RELATION = NO +REFERENCES_RELATION = NO +REFERENCES_LINK_SOURCE = YES +USE_HTAGS = NO +VERBATIM_HEADERS = YES +#--------------------------------------------------------------------------- +# configuration options related to the alphabetical class index +#--------------------------------------------------------------------------- +ALPHABETICAL_INDEX = NO +COLS_IN_ALPHA_INDEX = 5 +IGNORE_PREFIX = +#--------------------------------------------------------------------------- +# configuration options related to the HTML output +#--------------------------------------------------------------------------- +GENERATE_HTML = YES +HTML_OUTPUT = html +HTML_FILE_EXTENSION = .html +HTML_HEADER = +HTML_FOOTER = +HTML_STYLESHEET = +HTML_ALIGN_MEMBERS = YES +GENERATE_HTMLHELP = NO +CHM_FILE = +HHC_LOCATION = +GENERATE_CHI = NO +BINARY_TOC = NO +TOC_EXPAND = NO +DISABLE_INDEX = NO +ENUM_VALUES_PER_LINE = 4 +GENERATE_TREEVIEW = NO +TREEVIEW_WIDTH = 250 +#--------------------------------------------------------------------------- +# configuration options related to the LaTeX output +#--------------------------------------------------------------------------- +GENERATE_LATEX = NO +LATEX_OUTPUT = latex +LATEX_CMD_NAME = latex +MAKEINDEX_CMD_NAME = makeindex +COMPACT_LATEX = NO +PAPER_TYPE = a4wide +EXTRA_PACKAGES = +LATEX_HEADER = +PDF_HYPERLINKS = NO +USE_PDFLATEX = YES +LATEX_BATCHMODE = NO +LATEX_HIDE_INDICES = NO +#--------------------------------------------------------------------------- +# configuration options related to the RTF output +#--------------------------------------------------------------------------- +GENERATE_RTF = NO +RTF_OUTPUT = rtf +COMPACT_RTF = NO +RTF_HYPERLINKS = NO +RTF_STYLESHEET_FILE = +RTF_EXTENSIONS_FILE = +#--------------------------------------------------------------------------- +# configuration options related to the man page output +#--------------------------------------------------------------------------- +GENERATE_MAN = NO +MAN_OUTPUT = man +MAN_EXTENSION = .3 +MAN_LINKS = NO +#--------------------------------------------------------------------------- +# configuration options related to the XML output +#--------------------------------------------------------------------------- +GENERATE_XML = NO +XML_OUTPUT = xml +XML_SCHEMA = +XML_DTD = +XML_PROGRAMLISTING = YES +#--------------------------------------------------------------------------- +# configuration options for the AutoGen Definitions output +#--------------------------------------------------------------------------- +GENERATE_AUTOGEN_DEF = NO +#--------------------------------------------------------------------------- +# configuration options related to the Perl module output +#--------------------------------------------------------------------------- +GENERATE_PERLMOD = NO +PERLMOD_LATEX = NO +PERLMOD_PRETTY = YES +PERLMOD_MAKEVAR_PREFIX = +#--------------------------------------------------------------------------- +# Configuration options related to the preprocessor +#--------------------------------------------------------------------------- +ENABLE_PREPROCESSING = YES +MACRO_EXPANSION = NO +EXPAND_ONLY_PREDEF = NO +SEARCH_INCLUDES = YES +INCLUDE_PATH = +INCLUDE_FILE_PATTERNS = +PREDEFINED = +EXPAND_AS_DEFINED = +SKIP_FUNCTION_MACROS = YES +#--------------------------------------------------------------------------- +# Configuration::additions related to external references +#--------------------------------------------------------------------------- +TAGFILES = +GENERATE_TAGFILE = +ALLEXTERNALS = NO +EXTERNAL_GROUPS = YES +PERL_PATH = /usr/bin/perl +#--------------------------------------------------------------------------- +# Configuration options related to the dot tool +#--------------------------------------------------------------------------- +CLASS_DIAGRAMS = YES +HIDE_UNDOC_RELATIONS = YES +HAVE_DOT = YES +CLASS_GRAPH = YES +COLLABORATION_GRAPH = NO +GROUP_GRAPHS = YES +UML_LOOK = YES +TEMPLATE_RELATIONS = YES +INCLUDE_GRAPH = YES +INCLUDED_BY_GRAPH = YES +CALL_GRAPH = YES +CALLER_GRAPH = YES +GRAPHICAL_HIERARCHY = YES +DIRECTORY_GRAPH = YES +DOT_IMAGE_FORMAT = png +DOT_PATH = +DOTFILE_DIRS = +MAX_DOT_GRAPH_WIDTH = 1024 +MAX_DOT_GRAPH_HEIGHT = 1024 +MAX_DOT_GRAPH_DEPTH = 1000 +DOT_TRANSPARENT = NO +DOT_MULTI_TARGETS = NO +GENERATE_LEGEND = YES +DOT_CLEANUP = YES +#--------------------------------------------------------------------------- +# Configuration::additions related to the search engine +#--------------------------------------------------------------------------- +SEARCHENGINE = NO diff --git a/sas/sascalc/pr/__init__.py b/sas/sascalc/pr/__init__.py new file mode 100755 index 000000000..c4e5ed507 --- /dev/null +++ b/sas/sascalc/pr/__init__.py @@ -0,0 +1,106 @@ +""" + P(r) inversion for SAS +""" +## \mainpage P(r) inversion for SAS +# +# \section intro_sec Introduction +# This module provides calculations to transform scattering intensity data +# I(q) into distance distribution function P(r). A description of the +# technique can be found elsewhere [1-5]. The module is useable as a +# standalone application but its functionality is meant to be presented +# to end-users through the user interface developed as part of the SAS +# flagship application. +# +# Procedure: We will follow the procedure of Moore [1]. +# +# [1] P.B. Moore, J.Appl. Cryst (1980) 13, 168-175. +# +# [2] O. Glatter, J.Appl. Cryst (1977) 10, 415-421. +# +# [3] D.I. Svergun, J.Appl. Cryst (1991) 24, 485-492. +# +# [4] D.I. Svergun, J.Appl. Cryst (1992) 25, 495-503. +# +# [5] S. Hansen and J. Skov Pedersen, J.Appl. Cryst (1991) 24, 541-548. +# +## \subsection class Class Diagram: +# The following shows a partial class diagram with the main attributes +# and methods of the invertor. +# +# \image html architecture.png +# +# \section install_sec Installation +# +# \subsection obtain Obtaining the Code +# +# The code is available here: +# \verbatim +#$ svn co svn://danse.us/sas/pr_inversion +# \endverbatim +# +# \subsection depends External Dependencies +# scipy, numpy +# +# \subsection build Building the code +# The standard python package can be built with distutils. +# \verbatim +#$ python setup.py build +#$ python setup.py install +# \endverbatim +# +# +# \subsection Tutorial +# To create an inversion object: +# \verbatim +#from sas.sascalc.pr.invertor import Invertor +# invertor = Invertor() +# \endverbatim +# +# To set the maximum distance between any two points: +# \verbatim +# invertor.d_max = 160.0 +# \endverbatim +# +# To set the regularization constant: +# \verbatim +# invertor.alpha = 0.0007 +# \endverbatim +# +# To set the q, I(q) and error on I(q): +# \verbatim +# invertor.x = q_vector +# invertor.y = Iq_vector +# invertor.err = dIq_vector +# \endverbatim +# +# To perform the inversion. In this example, we choose +# a P(r) expension wit 10 base functions. +# \verbatim +# c_out, c_cov = invertor.invert(10) +# \endverbatim +# The c_out and c_cov are the set of coefficients and the covariance +# matrix for those coefficients, respectively. +# +# To get P(r): +# \verbatim +# r = 10.0 +# pr = invertor.pr(c_out, r) +# \endverbatim +# Alternatively, one can get P(r) with the error on P(r): +# \verbatim +# r = 10.0 +# pr, dpr = invertor.pr_err(c_out, c_cov, r) +# \endverbatim +# +# To get the output I(q) from the set of coefficients found: +# \verbatim +# q = 0.001 +# iq = invertor.iq(c_out, q) +# \endverbatim +# +# Examples are available as unit tests under sas.pr_inversion.test. +# +# \section help_sec Contact Info +# Code and Documentation produced as part of the DANSE project. + +__author__ = 'University of Tennessee' diff --git a/sas/sascalc/pr/_pr_inversion.so b/sas/sascalc/pr/_pr_inversion.so new file mode 100755 index 000000000..d63255b32 Binary files /dev/null and b/sas/sascalc/pr/_pr_inversion.so differ diff --git a/sas/sascalc/pr/c_extensions/Cinvertor.c b/sas/sascalc/pr/c_extensions/Cinvertor.c new file mode 100755 index 000000000..dac038633 --- /dev/null +++ b/sas/sascalc/pr/c_extensions/Cinvertor.c @@ -0,0 +1,1177 @@ +/** + * C implementation of the P(r) inversion + * Cinvertor is the base class for the Invertor class + * and provides the underlying computations. + * + */ +#include +#include +#include +#include + +//#define Py_LIMITED_API 0x03050000 +#include +#include + +// Vector binding glue +#if (PY_VERSION_HEX > 0x03000000) && !defined(Py_LIMITED_API) + // Assuming that a view into a writable vector points to a + // non-changing pointer for the duration of the C call, capture + // the view pointer and immediately free the view. + #define VECTOR(VEC_obj, VEC_buf, VEC_len) do { \ + Py_buffer VEC_view; \ + int VEC_err = PyObject_GetBuffer(VEC_obj, &VEC_view, PyBUF_WRITABLE|PyBUF_FORMAT); \ + if (VEC_err < 0 || sizeof(*VEC_buf) != VEC_view.itemsize) return NULL; \ + VEC_buf = (typeof(VEC_buf))VEC_view.buf; \ + VEC_len = VEC_view.len/sizeof(*VEC_buf); \ + PyBuffer_Release(&VEC_view); \ + } while (0) +#else + #define VECTOR(VEC_obj, VEC_buf, VEC_len) do { \ + int VEC_err = PyObject_AsWriteBuffer(VEC_obj, (void **)(&VEC_buf), &VEC_len); \ + if (VEC_err < 0) return NULL; \ + VEC_len /= sizeof(*VEC_buf); \ + } while (0) +#endif + +#include "invertor.h" + +/// Error object for raised exceptions +PyObject * CinvertorError; + +// Class definition +/** + * C implementation of the P(r) inversion + * Cinvertor is the base class for the Invertor class + * and provides the underlying computations. + * + */ +typedef struct { + PyObject_HEAD + /// Internal data structure + Invertor_params params; +} Cinvertor; + + +static void +Cinvertor_dealloc(Cinvertor* self) +{ + invertor_dealloc(&(self->params)); + + Py_TYPE(self)->tp_free((PyObject*)self); + +} + +static PyObject * +Cinvertor_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + Cinvertor *self; + + self = (Cinvertor *)type->tp_alloc(type, 0); + + return (PyObject *)self; +} + +static int +Cinvertor_init(Cinvertor *self, PyObject *args, PyObject *kwds) +{ + if (self != NULL) { + // Create parameters + invertor_init(&(self->params)); + } + return 0; +} + +static PyMemberDef Cinvertor_members[] = { + //{"params", T_OBJECT, offsetof(Cinvertor, params), 0, + // "Parameters"}, + {NULL} /* Sentinel */ +}; + +const char set_x_doc[] = + "Function to set the x data\n" + "Takes an array of doubles as input.\n" + " @return: number of entries found"; + +/** + * Function to set the x data + * Takes an array of doubles as input + * Returns the number of entries found + */ +static PyObject * set_x(Cinvertor *self, PyObject *args) { + PyObject *data_obj; + Py_ssize_t ndata; + double *data; + int i; + + if (!PyArg_ParseTuple(args, "O", &data_obj)) return NULL; + VECTOR(data_obj,data,ndata); + + free(self->params.x); + self->params.x = (double*) malloc(ndata*sizeof(double)); + + if(self->params.x==NULL) { + PyErr_SetString(CinvertorError, + "Cinvertor.set_x: problem allocating memory."); + return NULL; + } + + for (i=0; iparams.x[i] = data[i]; + } + + //self->params.x = data; + self->params.npoints = (int)ndata; + return Py_BuildValue("i", self->params.npoints); +} + +const char get_x_doc[] = + "Function to get the x data\n" + "Takes an array of doubles as input.\n" + " @return: number of entries found"; + +static PyObject * get_x(Cinvertor *self, PyObject *args) { + PyObject *data_obj; + Py_ssize_t ndata; + double *data; + int i; + + if (!PyArg_ParseTuple(args, "O", &data_obj)) return NULL; + VECTOR(data_obj, data, ndata); + + // Check that the input array is large enough + if (ndata < self->params.npoints) { + PyErr_SetString(CinvertorError, + "Cinvertor.get_x: input array too short for data."); + return NULL; + } + + for(i=0; iparams.npoints; i++){ + data[i] = self->params.x[i]; + } + + return Py_BuildValue("i", self->params.npoints); +} + +const char set_y_doc[] = + "Function to set the y data\n" + "Takes an array of doubles as input.\n" + " @return: number of entries found"; + +/** + * Function to set the y data + * Takes an array of doubles as input + * Returns the number of entries found + */ +static PyObject * set_y(Cinvertor *self, PyObject *args) { + PyObject *data_obj; + Py_ssize_t ndata; + double *data; + int i; + + if (!PyArg_ParseTuple(args, "O", &data_obj)) return NULL; + VECTOR(data_obj,data,ndata); + + free(self->params.y); + self->params.y = (double*) malloc(ndata*sizeof(double)); + + if(self->params.y==NULL) { + PyErr_SetString(CinvertorError, + "Cinvertor.set_y: problem allocating memory."); + return NULL; + } + + for (i=0; iparams.y[i] = data[i]; + } + + //self->params.y = data; + self->params.ny = (int)ndata; + return Py_BuildValue("i", self->params.ny); +} + +const char get_y_doc[] = + "Function to get the y data\n" + "Takes an array of doubles as input.\n" + " @return: number of entries found"; + +static PyObject * get_y(Cinvertor *self, PyObject *args) { + PyObject *data_obj; + Py_ssize_t ndata; + double *data; + int i; + + if (!PyArg_ParseTuple(args, "O", &data_obj)) return NULL; + VECTOR(data_obj, data, ndata); + + // Check that the input array is large enough + if (ndata < self->params.ny) { + PyErr_SetString(CinvertorError, + "Cinvertor.get_y: input array too short for data."); + return NULL; + } + + for(i=0; iparams.ny; i++){ + data[i] = self->params.y[i]; + } + + return Py_BuildValue("i", self->params.npoints); +} + +const char set_err_doc[] = + "Function to set the err data\n" + "Takes an array of doubles as input.\n" + " @return: number of entries found"; + +/** + * Function to set the x data + * Takes an array of doubles as input + * Returns the number of entries found + */ +static PyObject * set_err(Cinvertor *self, PyObject *args) { + PyObject *data_obj; + Py_ssize_t ndata; + double *data; + int i; + + if (!PyArg_ParseTuple(args, "O", &data_obj)) return NULL; + VECTOR(data_obj,data,ndata); + + free(self->params.err); + self->params.err = (double*) malloc(ndata*sizeof(double)); + + if(self->params.err==NULL) { + PyErr_SetString(CinvertorError, + "Cinvertor.set_err: problem allocating memory."); + return NULL; + } + + for (i=0; iparams.err[i] = data[i]; + } + + //self->params.err = data; + self->params.nerr = (int)ndata; + return Py_BuildValue("i", self->params.nerr); +} + +const char get_err_doc[] = + "Function to get the err data\n" + "Takes an array of doubles as input.\n" + " @return: number of entries found"; + +static PyObject * get_err(Cinvertor *self, PyObject *args) { + PyObject *data_obj; + Py_ssize_t ndata; + double *data; + int i; + + if (!PyArg_ParseTuple(args, "O", &data_obj)) return NULL; + VECTOR(data_obj, data, ndata); + + // Check that the input array is large enough + if (ndata < self->params.nerr) { + PyErr_SetString(CinvertorError, + "Cinvertor.get_err: input array too short for data."); + return NULL; + } + + for(i=0; iparams.nerr; i++){ + data[i] = self->params.err[i]; + } + + return Py_BuildValue("i", self->params.npoints); +} + +const char is_valid_doc[] = + "Check the validity of the stored data\n" + " @return: Returns the number of points if it's all good, -1 otherwise"; + +/** + * Check the validity of the stored data + * Returns the number of points if it's all good, -1 otherwise + */ +static PyObject * is_valid(Cinvertor *self, PyObject *args) { + if(self->params.npoints==self->params.ny && + self->params.npoints==self->params.nerr) { + return Py_BuildValue("i", self->params.npoints); + } else { + return Py_BuildValue("i", -1); + } +} + +const char set_est_bck_doc[] = + "Sets background flag\n"; + +/** + * Sets the maximum distance + */ +static PyObject * set_est_bck(Cinvertor *self, PyObject *args) { + int est_bck; + + if (!PyArg_ParseTuple(args, "i", &est_bck)) return NULL; + self->params.est_bck = est_bck; + return Py_BuildValue("i", self->params.est_bck); +} + +const char get_est_bck_doc[] = + "Gets background flag\n"; + +/** + * Gets the maximum distance + */ +static PyObject * get_est_bck(Cinvertor *self, PyObject *args) { + return Py_BuildValue("i", self->params.est_bck); +} + +const char set_dmax_doc[] = + "Sets the maximum distance\n"; + +/** + * Sets the maximum distance + */ +static PyObject * set_dmax(Cinvertor *self, PyObject *args) { + double d_max; + + if (!PyArg_ParseTuple(args, "d", &d_max)) return NULL; + self->params.d_max = d_max; + return Py_BuildValue("d", self->params.d_max); +} + +const char get_dmax_doc[] = + "Gets the maximum distance\n"; + +/** + * Gets the maximum distance + */ +static PyObject * get_dmax(Cinvertor *self, PyObject *args) { + return Py_BuildValue("d", self->params.d_max); +} + +const char set_slit_height_doc[] = + "Sets the slit height in units of q [A-1]\n"; + +/** + * Sets the slit height + */ +static PyObject * set_slit_height(Cinvertor *self, PyObject *args) { + double slit_height; + + if (!PyArg_ParseTuple(args, "d", &slit_height)) return NULL; + self->params.slit_height = slit_height; + return Py_BuildValue("d", self->params.slit_height); +} + +const char get_slit_height_doc[] = + "Gets the slit height\n"; + +/** + * Gets the slit height + */ +static PyObject * get_slit_height(Cinvertor *self, PyObject *args) { + return Py_BuildValue("d", self->params.slit_height); +} + +const char set_slit_width_doc[] = + "Sets the slit width in units of q [A-1]\n"; + +/** + * Sets the slit width + */ +static PyObject * set_slit_width(Cinvertor *self, PyObject *args) { + double slit_width; + + if (!PyArg_ParseTuple(args, "d", &slit_width)) return NULL; + self->params.slit_width = slit_width; + return Py_BuildValue("d", self->params.slit_width); +} + +const char get_slit_width_doc[] = + "Gets the slit width\n"; + +/** + * Gets the slit width + */ +static PyObject * get_slit_width(Cinvertor *self, PyObject *args) { + return Py_BuildValue("d", self->params.slit_width); +} + + +const char set_qmin_doc[] = + "Sets the minimum q\n"; + +/** + * Sets the minimum q + */ +static PyObject * set_qmin(Cinvertor *self, PyObject *args) { + double q_min; + + if (!PyArg_ParseTuple(args, "d", &q_min)) return NULL; + self->params.q_min = q_min; + return Py_BuildValue("d", self->params.q_min); +} + +const char get_qmin_doc[] = + "Gets the minimum q\n"; + +/** + * Gets the minimum q + */ +static PyObject * get_qmin(Cinvertor *self, PyObject *args) { + return Py_BuildValue("d", self->params.q_min); +} + +const char set_qmax_doc[] = + "Sets the maximum q\n"; + +/** + * Sets the maximum q + */ +static PyObject * set_qmax(Cinvertor *self, PyObject *args) { + double q_max; + + if (!PyArg_ParseTuple(args, "d", &q_max)) return NULL; + self->params.q_max = q_max; + return Py_BuildValue("d", self->params.q_max); +} + +const char get_qmax_doc[] = + "Gets the maximum q\n"; + +/** + * Gets the maximum q + */ +static PyObject * get_qmax(Cinvertor *self, PyObject *args) { + return Py_BuildValue("d", self->params.q_max); +} + +const char set_alpha_doc[] = + "Sets the alpha parameter\n"; + +static PyObject * set_alpha(Cinvertor *self, PyObject *args) { + double alpha; + + if (!PyArg_ParseTuple(args, "d", &alpha)) return NULL; + self->params.alpha = alpha; + return Py_BuildValue("d", self->params.alpha); +} + +const char get_alpha_doc[] = + "Gets the alpha parameter\n"; + +/** + * Gets the maximum distance + */ +static PyObject * get_alpha(Cinvertor *self, PyObject *args) { + return Py_BuildValue("d", self->params.alpha); +} + +const char get_nx_doc[] = + "Gets the number of x points\n"; + +/** + * Gets the number of x points + */ +static PyObject * get_nx(Cinvertor *self, PyObject *args) { + return Py_BuildValue("i", self->params.npoints); +} + +const char get_ny_doc[] = + "Gets the number of y points\n"; + +/** + * Gets the number of y points + */ +static PyObject * get_ny(Cinvertor *self, PyObject *args) { + return Py_BuildValue("i", self->params.ny); +} + +const char get_nerr_doc[] = + "Gets the number of err points\n"; + +/** + * Gets the number of error points + */ +static PyObject * get_nerr(Cinvertor *self, PyObject *args) { + return Py_BuildValue("i", self->params.nerr); +} + + +const char residuals_doc[] = + "Function to call to evaluate the residuals\n" + "for P(r) inversion\n" + " @param args: input parameters\n" + " @return: list of residuals"; + +/** + * Function to call to evaluate the residuals + * @param args: input parameters + * @return: list of residuals + */ +static PyObject * residuals(Cinvertor *self, PyObject *args) { + double *pars; + PyObject* residuals; + int i; + double residual, diff; + // Regularization factor + double regterm = 0.0; + // Number of slices in regularization term estimate + int nslice = 25; + + PyObject *data_obj; + Py_ssize_t npars; + + if (!PyArg_ParseTuple(args, "O", &data_obj)) return NULL; + + VECTOR(data_obj,pars,npars); + + // PyList of residuals + // Should create this list only once and refill it + residuals = PyList_New(self->params.npoints); + + regterm = reg_term(pars, self->params.d_max, (int)npars, nslice); + + for(i=0; iparams.npoints; i++) { + diff = self->params.y[i] - iq(pars, self->params.d_max, (int)npars, self->params.x[i]); + residual = diff*diff / (self->params.err[i]*self->params.err[i]); + + // regularization term + residual += self->params.alpha * regterm; + + if (PyList_SetItem(residuals, i, Py_BuildValue("d",residual) ) < 0){ + PyErr_SetString(CinvertorError, + "Cinvertor.residuals: error setting residual."); + return NULL; + }; + } + return residuals; +} + +const char pr_residuals_doc[] = + "Function to call to evaluate the residuals\n" + "for P(r) minimization (for testing purposes)\n" + " @param args: input parameters\n" + " @return: list of residuals"; + +/** + * Function to call to evaluate the residuals + * for P(r) minimization (for testing purposes) + * @param args: input parameters + * @return: list of residuals + */ +static PyObject * pr_residuals(Cinvertor *self, PyObject *args) { + double *pars; + PyObject* residuals; + int i; + double residual, diff; + // Regularization factor + double regterm = 0.0; + // Number of slices in regularization term estimate + int nslice = 25; + + PyObject *data_obj; + Py_ssize_t npars; + + if (!PyArg_ParseTuple(args, "O", &data_obj)) return NULL; + + VECTOR(data_obj,pars,npars); + + // Should create this list only once and refill it + residuals = PyList_New(self->params.npoints); + + regterm = reg_term(pars, self->params.d_max, (int)npars, nslice); + + + for(i=0; iparams.npoints; i++) { + diff = self->params.y[i] - pr(pars, self->params.d_max, (int)npars, self->params.x[i]); + residual = diff*diff / (self->params.err[i]*self->params.err[i]); + + // regularization term + residual += self->params.alpha * regterm; + + if (PyList_SetItem(residuals, i, Py_BuildValue("d",residual) ) < 0){ + PyErr_SetString(CinvertorError, + "Cinvertor.residuals: error setting residual."); + return NULL; + }; + } + return residuals; +} + +const char get_iq_doc[] = + "Function to call to evaluate the scattering intensity\n" + " @param args: c-parameters, and q\n" + " @return: I(q)"; + +/** + * Function to call to evaluate the scattering intensity + * @param args: c-parameters, and q + * @return: I(q) + */ +static PyObject * get_iq(Cinvertor *self, PyObject *args) { + double *pars; + double q, iq_value; + PyObject *data_obj; + Py_ssize_t npars; + + if (!PyArg_ParseTuple(args, "Od", &data_obj, &q)) return NULL; + VECTOR(data_obj,pars,npars); + + iq_value = iq(pars, self->params.d_max, (int)npars, q); + return Py_BuildValue("f", iq_value); +} + +const char get_iq_smeared_doc[] = + "Function to call to evaluate the scattering intensity.\n" + "The scattering intensity is slit-smeared." + " @param args: c-parameters, and q\n" + " @return: I(q)"; + +/** + * Function to call to evaluate the scattering intensity + * The scattering intensity is slit-smeared. + * @param args: c-parameters, and q + * @return: I(q) + */ +static PyObject * get_iq_smeared(Cinvertor *self, PyObject *args) { + double *pars; + double q, iq_value; + PyObject *data_obj; + Py_ssize_t npars; + + if (!PyArg_ParseTuple(args, "Od", &data_obj, &q)) return NULL; + VECTOR(data_obj,pars,npars); + + iq_value = iq_smeared(pars, self->params.d_max, (int)npars, + self->params.slit_height, self->params.slit_width, + q, 21); + return Py_BuildValue("f", iq_value); +} + +const char get_pr_doc[] = + "Function to call to evaluate P(r)\n" + " @param args: c-parameters and r\n" + " @return: P(r)"; + +/** + * Function to call to evaluate P(r) + * @param args: c-parameters and r + * @return: P(r) + */ +static PyObject * get_pr(Cinvertor *self, PyObject *args) { + double *pars; + double r, pr_value; + PyObject *data_obj; + Py_ssize_t npars; + + if (!PyArg_ParseTuple(args, "Od", &data_obj, &r)) return NULL; + VECTOR(data_obj,pars,npars); + + pr_value = pr(pars, self->params.d_max, (int)npars, r); + return Py_BuildValue("f", pr_value); +} + +const char get_pr_err_doc[] = + "Function to call to evaluate P(r) with errors\n" + " @param args: c-parameters and r\n" + " @return: (P(r),dP(r))"; + +/** + * Function to call to evaluate P(r) with errors + * @param args: c-parameters and r + * @return: P(r) + */ +static PyObject * get_pr_err(Cinvertor *self, PyObject *args) { + double *pars; + double *pars_err; + double pr_err_value; + double r, pr_value; + PyObject *data_obj; + Py_ssize_t npars; + PyObject *err_obj; + Py_ssize_t npars2; + + if (!PyArg_ParseTuple(args, "OOd", &data_obj, &err_obj, &r)) return NULL; + VECTOR(data_obj,pars,npars); + + if (err_obj == Py_None) { + pr_value = pr(pars, self->params.d_max, (int)npars, r); + pr_err_value = 0.0; + } else { + VECTOR(err_obj,pars_err,npars2); + pr_err(pars, pars_err, self->params.d_max, (int)npars, r, &pr_value, &pr_err_value); + } + return Py_BuildValue("ff", pr_value, pr_err_value); +} + +const char basefunc_ft_doc[] = + "Returns the value of the nth Fourier transofrmed base function\n" + " @param args: c-parameters, n and q\n" + " @return: nth Fourier transformed base function, evaluated at q"; + +static PyObject * basefunc_ft(Cinvertor *self, PyObject *args) { + double d_max, q; + int n; + + if (!PyArg_ParseTuple(args, "did", &d_max, &n, &q)) return NULL; + return Py_BuildValue("f", ortho_transformed(d_max, n, q)); + +} + +const char oscillations_doc[] = + "Returns the value of the oscillation figure of merit for\n" + "the given set of coefficients. For a sphere, the oscillation\n" + "figure of merit is 1.1.\n" + " @param args: c-parameters\n" + " @return: oscillation figure of merit"; + +static PyObject * oscillations(Cinvertor *self, PyObject *args) { + double *pars; + PyObject *data_obj; + Py_ssize_t npars; + double oscill, norm; + + if (!PyArg_ParseTuple(args, "O", &data_obj)) return NULL; + VECTOR(data_obj,pars,npars); + + oscill = reg_term(pars, self->params.d_max, (int)npars, 100); + norm = int_p2(pars, self->params.d_max, (int)npars, 100); + return Py_BuildValue("f", sqrt(oscill/norm)/acos(-1.0)*self->params.d_max ); + +} + +const char get_peaks_doc[] = + "Returns the number of peaks in the output P(r) distribution\n" + "for the given set of coefficients.\n" + " @param args: c-parameters\n" + " @return: number of P(r) peaks"; + +static PyObject * get_peaks(Cinvertor *self, PyObject *args) { + double *pars; + PyObject *data_obj; + Py_ssize_t npars; + int count; + + if (!PyArg_ParseTuple(args, "O", &data_obj)) return NULL; + VECTOR(data_obj,pars,npars); + + count = npeaks(pars, self->params.d_max, (int)npars, 100); + + return Py_BuildValue("i", count ); + +} + +const char get_positive_doc[] = + "Returns the fraction of P(r) that is positive over\n" + "the full range of r for the given set of coefficients.\n" + " @param args: c-parameters\n" + " @return: fraction of P(r) that is positive"; + +static PyObject * get_positive(Cinvertor *self, PyObject *args) { + double *pars; + PyObject *data_obj; + Py_ssize_t npars; + double fraction; + + if (!PyArg_ParseTuple(args, "O", &data_obj)) return NULL; + VECTOR(data_obj,pars,npars); + + fraction = positive_integral(pars, self->params.d_max, (int)npars, 100); + + return Py_BuildValue("f", fraction ); + +} + +const char get_pos_err_doc[] = + "Returns the fraction of P(r) that is 1 standard deviation\n" + "above zero over the full range of r for the given set of coefficients.\n" + " @param args: c-parameters\n" + " @return: fraction of P(r) that is positive"; + +static PyObject * get_pos_err(Cinvertor *self, PyObject *args) { + double *pars; + double *pars_err; + PyObject *data_obj; + PyObject *err_obj; + Py_ssize_t npars; + Py_ssize_t npars2; + double fraction; + + if (!PyArg_ParseTuple(args, "OO", &data_obj, &err_obj)) return NULL; + VECTOR(data_obj,pars,npars); + VECTOR(err_obj,pars_err,npars2); + + fraction = positive_errors(pars, pars_err, self->params.d_max, (int)npars, 51); + + return Py_BuildValue("f", fraction ); + +} + +const char get_rg_doc[] = + "Returns the value of the radius of gyration Rg.\n" + " @param args: c-parameters\n" + " @return: Rg"; + +static PyObject * get_rg(Cinvertor *self, PyObject *args) { + double *pars; + PyObject *data_obj; + Py_ssize_t npars; + double value; + + if (!PyArg_ParseTuple(args, "O", &data_obj)) return NULL; + VECTOR(data_obj,pars,npars); + + value = rg(pars, self->params.d_max, (int)npars, 101); + + return Py_BuildValue("f", value ); + +} + +const char get_iq0_doc[] = + "Returns the value of I(q=0).\n" + " @param args: c-parameters\n" + " @return: I(q=0)"; + +static PyObject * get_iq0(Cinvertor *self, PyObject *args) { + double *pars; + PyObject *data_obj; + Py_ssize_t npars; + double value; + + if (!PyArg_ParseTuple(args, "O", &data_obj)) return NULL; + VECTOR(data_obj,pars,npars); + + value = 4.0*acos(-1.0)*int_pr(pars, self->params.d_max, (int)npars, 101); + + return Py_BuildValue("f", value ); + +} + +/** + * Check whether a q-value is within acceptabel limits + * Return 1 if accepted, 0 if rejected. + */ +int accept_q(Cinvertor *self, double q) { + if (self->params.q_min>0 && qparams.q_min) return 0; + if (self->params.q_max>0 && q>self->params.q_max) return 0; + return 1; +} + +const char get_matrix_doc[] = + "Returns A matrix and b vector for least square problem.\n" + " @param nfunc: number of base functions\n" + " @param nr: number of r-points used when evaluating reg term.\n" + " @param a: A array to fill\n" + " @param b: b vector to fill\n" + " @return: 0"; + +static PyObject * get_matrix(Cinvertor *self, PyObject *args) { + double *a; + double *b; + PyObject *a_obj; + PyObject *b_obj; + Py_ssize_t n_a; + Py_ssize_t n_b; + // Number of bins for regularization term evaluation + int nr, nfunc; + int i, j, i_r; + double r, sqrt_alpha, pi; + double tmp; + int offset; + + if (!PyArg_ParseTuple(args, "iiOO", &nfunc, &nr, &a_obj, &b_obj)) return NULL; + VECTOR(a_obj,a,n_a); + VECTOR(b_obj,b,n_b); + + assert(n_b>=nfunc); + assert(n_a>=nfunc*(nr+self->params.npoints)); + + sqrt_alpha = sqrt(self->params.alpha); + pi = acos(-1.0); + offset = (self->params.est_bck==1) ? 0 : 1; + + for (j=0; jparams.npoints; i++) { + if (self->params.err[i]==0.0) { + PyErr_SetString(CinvertorError, + "Cinvertor.get_matrix: Some I(Q) points have no error."); + return NULL; + } + if (accept_q(self, self->params.x[i])){ + if (self->params.est_bck==1 && j==0) { + a[i*nfunc+j] = 1.0/self->params.err[i]; + } else { + if (self->params.slit_width>0 || self->params.slit_height>0) { + a[i*nfunc+j] = ortho_transformed_smeared(self->params.d_max, + j+offset, self->params.slit_height, self->params.slit_width, + self->params.x[i], 21)/self->params.err[i]; + } else { + a[i*nfunc+j] = ortho_transformed(self->params.d_max, j+offset, self->params.x[i])/self->params.err[i]; + } + } + } + } + for (i_r=0; i_rparams.est_bck==1 && j==0) { + a[(i_r+self->params.npoints)*nfunc+j] = 0.0; + } else { + r = self->params.d_max/nr*i_r; + tmp = pi*(j+offset)/self->params.d_max; + a[(i_r+self->params.npoints)*nfunc+j] = sqrt_alpha * 1.0/nr*self->params.d_max*2.0* + (2.0*pi*(j+offset)/self->params.d_max*cos(pi*(j+offset)*r/self->params.d_max) + + tmp*tmp*r * sin(pi*(j+offset)*r/self->params.d_max)); + } + } + } + + for (i=0; iparams.npoints; i++) { + if (accept_q(self, self->params.x[i])){ + b[i] = self->params.y[i]/self->params.err[i]; + } + } + + return Py_BuildValue("i", 0); + +} + +const char get_invcov_matrix_doc[] = + " Compute the inverse covariance matrix, defined as inv_cov = a_transposed x a.\n" + " @param nfunc: number of base functions\n" + " @param nr: number of r-points used when evaluating reg term.\n" + " @param a: A array to fill\n" + " @param inv_cov: inverse covariance array to be filled\n" + " @return: 0"; + +static PyObject * get_invcov_matrix(Cinvertor *self, PyObject *args) { + double *a; + PyObject *a_obj; + Py_ssize_t n_a; + double *inv_cov; + PyObject *cov_obj; + Py_ssize_t n_cov; + int nr, nfunc; + int i, j, k; + + if (!PyArg_ParseTuple(args, "iiOO", &nfunc, &nr, &a_obj, &cov_obj)) return NULL; + VECTOR(a_obj,a,n_a); + VECTOR(cov_obj,inv_cov,n_cov); + + assert(n_cov>=nfunc*nfunc); + assert(n_a>=nfunc*(nr+self->params.npoints)); + + for (i=0; iparams.npoints; k++) { + inv_cov[i*nfunc+j] += a[k*nfunc+i]*a[k*nfunc+j]; + } + } + } + return Py_BuildValue("i", 0); +} + +const char get_reg_size_doc[] = + " Compute the covariance matrix, defined as inv_cov = a_transposed x a.\n" + " @param nfunc: number of base functions\n" + " @param nr: number of r-points used when evaluating reg term.\n" + " @param a: A array to fill\n" + " @param inv_cov: inverse covariance array to be filled\n" + " @return: 0"; + +static PyObject * get_reg_size(Cinvertor *self, PyObject *args) { + double *a; + PyObject *a_obj; + Py_ssize_t n_a; + int nr, nfunc; + int i, j; + double sum_sig, sum_reg; + + if (!PyArg_ParseTuple(args, "iiO", &nfunc, &nr, &a_obj)) return NULL; + VECTOR(a_obj,a,n_a); + + assert(n_a>=nfunc*(nr+self->params.npoints)); + + sum_sig = 0.0; + sum_reg = 0.0; + for (j=0; jparams.npoints; i++){ + if (accept_q(self, self->params.x[i])==1) + sum_sig += (a[i*nfunc+j])*(a[i*nfunc+j]); + } + for (i=0; iparams.npoints)*nfunc+j])*(a[(i+self->params.npoints)*nfunc+j]); + } + } + return Py_BuildValue("ff", sum_sig, sum_reg); +} + +const char eeeget_qmin_doc[] = "\ +This is a multiline doc string.\n\ +\n\ +This is the second line."; +const char eeeset_qmin_doc[] = + "This is a multiline doc string.\n" + "\n" + "This is the second line."; + +static PyMethodDef Cinvertor_methods[] = { + {"residuals", (PyCFunction)residuals, METH_VARARGS, residuals_doc}, + {"pr_residuals", (PyCFunction)pr_residuals, METH_VARARGS, pr_residuals_doc}, + {"set_x", (PyCFunction)set_x, METH_VARARGS, set_x_doc}, + {"get_x", (PyCFunction)get_x, METH_VARARGS, get_x_doc}, + {"set_y", (PyCFunction)set_y, METH_VARARGS, set_y_doc}, + {"get_y", (PyCFunction)get_y, METH_VARARGS, get_y_doc}, + {"set_err", (PyCFunction)set_err, METH_VARARGS, set_err_doc}, + {"get_err", (PyCFunction)get_err, METH_VARARGS, get_err_doc}, + {"set_dmax", (PyCFunction)set_dmax, METH_VARARGS, set_dmax_doc}, + {"get_dmax", (PyCFunction)get_dmax, METH_VARARGS, get_dmax_doc}, + {"set_qmin", (PyCFunction)set_qmin, METH_VARARGS, set_qmin_doc}, + {"get_qmin", (PyCFunction)get_qmin, METH_VARARGS, get_qmin_doc}, + {"set_qmax", (PyCFunction)set_qmax, METH_VARARGS, set_qmax_doc}, + {"get_qmax", (PyCFunction)get_qmax, METH_VARARGS, get_qmax_doc}, + {"set_alpha", (PyCFunction)set_alpha, METH_VARARGS, set_alpha_doc}, + {"get_alpha", (PyCFunction)get_alpha, METH_VARARGS, get_alpha_doc}, + {"set_slit_width", (PyCFunction)set_slit_width, METH_VARARGS, set_slit_width_doc}, + {"get_slit_width", (PyCFunction)get_slit_width, METH_VARARGS, get_slit_width_doc}, + {"set_slit_height", (PyCFunction)set_slit_height, METH_VARARGS, set_slit_height_doc}, + {"get_slit_height", (PyCFunction)get_slit_height, METH_VARARGS, get_slit_height_doc}, + {"set_est_bck", (PyCFunction)set_est_bck, METH_VARARGS, set_est_bck_doc}, + {"get_est_bck", (PyCFunction)get_est_bck, METH_VARARGS, get_est_bck_doc}, + {"get_nx", (PyCFunction)get_nx, METH_VARARGS, get_nx_doc}, + {"get_ny", (PyCFunction)get_ny, METH_VARARGS, get_ny_doc}, + {"get_nerr", (PyCFunction)get_nerr, METH_VARARGS, get_nerr_doc}, + {"iq", (PyCFunction)get_iq, METH_VARARGS, get_iq_doc}, + {"iq_smeared", (PyCFunction)get_iq_smeared, METH_VARARGS, get_iq_smeared_doc}, + {"pr", (PyCFunction)get_pr, METH_VARARGS, get_pr_doc}, + {"get_pr_err", (PyCFunction)get_pr_err, METH_VARARGS, get_pr_err_doc}, + {"is_valid", (PyCFunction)is_valid, METH_VARARGS, is_valid_doc}, + {"basefunc_ft", (PyCFunction)basefunc_ft, METH_VARARGS, basefunc_ft_doc}, + {"oscillations", (PyCFunction)oscillations, METH_VARARGS, oscillations_doc}, + {"get_peaks", (PyCFunction)get_peaks, METH_VARARGS, get_peaks_doc}, + {"get_positive", (PyCFunction)get_positive, METH_VARARGS, get_positive_doc}, + {"get_pos_err", (PyCFunction)get_pos_err, METH_VARARGS, get_pos_err_doc}, + {"rg", (PyCFunction)get_rg, METH_VARARGS, get_rg_doc}, + {"iq0", (PyCFunction)get_iq0, METH_VARARGS, get_iq0_doc}, + {"_get_matrix", (PyCFunction)get_matrix, METH_VARARGS, get_matrix_doc}, + {"_get_invcov_matrix", (PyCFunction)get_invcov_matrix, METH_VARARGS, get_invcov_matrix_doc}, + {"_get_reg_size", (PyCFunction)get_reg_size, METH_VARARGS, get_reg_size_doc}, + + {NULL} +}; + +static PyTypeObject CinvertorType = { + //PyObject_HEAD_INIT(NULL) + //0, /*ob_size*/ + PyVarObject_HEAD_INIT(NULL, 0) + "Cinvertor", /*tp_name*/ + sizeof(Cinvertor), /*tp_basicsize*/ + 0, /*tp_itemsize*/ + (destructor)Cinvertor_dealloc, /*tp_dealloc*/ + 0, /*tp_print*/ + 0, /*tp_getattr*/ + 0, /*tp_setattr*/ + 0, /*tp_compare*/ + 0, /*tp_repr*/ + 0, /*tp_as_number*/ + 0, /*tp_as_sequence*/ + 0, /*tp_as_mapping*/ + 0, /*tp_hash */ + 0, /*tp_call*/ + 0, /*tp_str*/ + 0, /*tp_getattro*/ + 0, /*tp_setattro*/ + 0, /*tp_as_buffer*/ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /*tp_flags*/ + "Cinvertor objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + Cinvertor_methods, /* tp_methods */ + Cinvertor_members, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)Cinvertor_init, /* tp_init */ + 0, /* tp_alloc */ + Cinvertor_new, /* tp_new */ +}; + + +static PyMethodDef module_methods[] = { + {NULL} +}; + +/** + * Function used to add the model class to a module + * @param module: module to add the class to + */ +void addCinvertor(PyObject *module) { + PyObject *d; + + if (PyType_Ready(&CinvertorType) < 0) + return; + + Py_INCREF(&CinvertorType); + PyModule_AddObject(module, "Cinvertor", (PyObject *)&CinvertorType); + + d = PyModule_GetDict(module); + CinvertorError = PyErr_NewException("sas.sascalc.pr.invertor.Cinvertor.InvertorError", PyExc_RuntimeError, NULL); + PyDict_SetItemString(d, "CinvertorError", CinvertorError); +} + + +#define MODULE_DOC "C extension module for inversion to P(r)." +#define MODULE_NAME "_pr_inversion" +#define MODULE_INIT2 init_pr_inversion +#define MODULE_INIT3 PyInit__pr_inversion +#define MODULE_METHODS module_methods + +/* ==== boilerplate python 2/3 interface bootstrap ==== */ + + +#if defined(WIN32) && !defined(__MINGW32__) + #define DLL_EXPORT __declspec(dllexport) +#else + #define DLL_EXPORT +#endif + +#if PY_MAJOR_VERSION >= 3 + + DLL_EXPORT PyMODINIT_FUNC MODULE_INIT3(void) + { + static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + MODULE_NAME, /* m_name */ + MODULE_DOC, /* m_doc */ + -1, /* m_size */ + MODULE_METHODS, /* m_methods */ + NULL, /* m_reload */ + NULL, /* m_traverse */ + NULL, /* m_clear */ + NULL, /* m_free */ + }; + PyObject* m = PyModule_Create(&moduledef); + addCinvertor(m); + return m; + } + +#else /* !PY_MAJOR_VERSION >= 3 */ + + DLL_EXPORT PyMODINIT_FUNC MODULE_INIT2(void) + { + PyObject* m = Py_InitModule4(MODULE_NAME, + MODULE_METHODS, + MODULE_DOC, + 0, + PYTHON_API_VERSION + ); + addCinvertor(m); + } + +#endif /* !PY_MAJOR_VERSION >= 3 */ diff --git a/sas/sascalc/pr/c_extensions/invertor.c b/sas/sascalc/pr/c_extensions/invertor.c new file mode 100755 index 000000000..b6c5372c4 --- /dev/null +++ b/sas/sascalc/pr/c_extensions/invertor.c @@ -0,0 +1,314 @@ +#include +#include "invertor.h" +#include +#include +#include + +double pi = 3.1416; + +/** + * Deallocate memory + */ +void invertor_dealloc(Invertor_params *pars) { + free(pars->x); + free(pars->y); + free(pars->err); +} + +void invertor_init(Invertor_params *pars) { + pars->d_max = 180; + pars->q_min = -1.0; + pars->q_max = -1.0; + pars->est_bck = 0; +} + + +/** + * P(r) of a sphere, for test purposes + * + * @param R: radius of the sphere + * @param r: distance, in the same units as the radius + * @return: P(r) + */ +double pr_sphere(double R, double r) { + if (r <= 2.0*R) { + return 12.0* pow(0.5*r/R, 2.0) * pow(1.0-0.5*r/R, 2.0) * ( 2.0 + 0.5*r/R ); + } else { + return 0.0; + } +} + +/** + * Orthogonal functions: + * B(r) = 2r sin(pi*nr/d) + * + */ +double ortho(double d_max, int n, double r) { + return 2.0*r*sin(pi*n*r/d_max); +} + +/** + * Fourier transform of the nth orthogonal function + * + */ +double ortho_transformed(double d_max, int n, double q) { + return 8.0*pow(pi, 2.0)/q * d_max * n * pow(-1.0, n+1) + *sin(q*d_max) / ( pow(pi*n, 2.0) - pow(q*d_max, 2.0) ); +} + +/** + * Slit-smeared Fourier transform of the nth orthogonal function. + * Smearing follows Lake, Acta Cryst. (1967) 23, 191. + */ +double ortho_transformed_smeared(double d_max, int n, double height, double width, double q, int npts) { + double sum, y, z; + int i, j, n_height, n_width; + double count_w; + double fnpts; + sum = 0.0; + fnpts = (float)npts-1.0; + + // Check for zero slit size + n_height = (height>0) ? npts : 1; + n_width = (width>0) ? npts : 1; + + count_w = 0.0; + + for(j=0; j0){ + z = height/fnpts*(float)j; + } else { + z = 0.0; + } + + for(i=0; i0){ + y = -width/2.0+width/fnpts*(float)i; + } else { + y = 0.0; + } + if (((q-y)*(q-y)+z*z)<=0.0) continue; + count_w += 1.0; + sum += ortho_transformed(d_max, n, sqrt((q-y)*(q-y)+z*z)); + } + } + return sum/count_w; +} + +/** + * First derivative in of the orthogonal function dB(r)/dr + * + */ +double ortho_derived(double d_max, int n, double r) { + return 2.0*sin(pi*n*r/d_max) + 2.0*r*cos(pi*n*r/d_max); +} + +/** + * Scattering intensity calculated from the expansion. + */ +double iq(double *pars, double d_max, int n_c, double q) { + double sum = 0.0; + int i; + for (i=0; i0) { + *pr_value_err = sqrt(sum_err); + } else { + *pr_value_err = sum; + } +} + +/** + * dP(r)/dr calculated from the expansion. + */ +double dprdr(double *pars, double d_max, int n_c, double r) { + double sum = 0.0; + int i; + for (i=0; i0) count += 1; + slope = -1; + } + previous = value; + } + return count; +} + +/** + * Get the fraction of the integral of P(r) over the whole range + * of r that is above zero. + * A valid P(r) is define as being positive for all r. + */ +double positive_integral(double *pars, double d_max, int n_c, int nslice) { + double r; + double value; + int i; + double sum_pos = 0.0; + double sum = 0.0; + + for (i=0; i0.0) sum_pos += value; + sum += fabs(value); + } + return sum_pos/sum; +} + +/** + * Get the fraction of the integral of P(r) over the whole range + * of r that is at least one sigma above zero. + */ +double positive_errors(double *pars, double *err, double d_max, int n_c, int nslice) { + double r; + int i; + double sum_pos = 0.0; + double sum = 0.0; + double pr_val; + double pr_val_err; + + for (i=0; ipr_val_err) sum_pos += pr_val; + sum += fabs(pr_val); + + + } + return sum_pos/sum; +} + +/** + * R_g radius of gyration calculation + * + * R_g**2 = integral[r**2 * p(r) dr] / (2.0 * integral[p(r) dr]) + */ +double rg(double *pars, double d_max, int n_c, int nslice) { + double sum_r2 = 0.0; + double sum = 0.0; + double r; + double value; + int i; + for (i=0; i= self.qmin) & (self.x <= self.qmax) + self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \ + & (self.x <= self._qmax_unsmeared) + + def set_fit_range(self, qmin=None, qmax=None): + """ to set the fit range""" + # Skip Q=0 point, (especially for y(q=0)=None at x[0]). + # ToDo: Find better way to do it. + if qmin == 0.0 and not np.isfinite(self.y[qmin]): + self.qmin = min(self.x[self.x != 0]) + elif qmin is not None: + self.qmin = qmin + if qmax is not None: + self.qmax = qmax + # Determine the range needed in unsmeared-Q to cover + # the smeared Q range + self._qmin_unsmeared = self.qmin + self._qmax_unsmeared = self.qmax + + self._first_unsmeared_bin = 0 + self._last_unsmeared_bin = len(self.x) - 1 + + if self.smearer is not None: + self._first_unsmeared_bin, self._last_unsmeared_bin = \ + self.smearer.get_bin_range(self.qmin, self.qmax) + self._qmin_unsmeared = self.x[self._first_unsmeared_bin] + self._qmax_unsmeared = self.x[self._last_unsmeared_bin] + + # Identify the bin range for the unsmeared and smeared spaces + self.idx = (self.x >= self.qmin) & (self.x <= self.qmax) + ## zero error can not participate for fitting + self.idx = self.idx & (self.dy != 0) + self.idx_unsmeared = (self.x >= self._qmin_unsmeared) \ + & (self.x <= self._qmax_unsmeared) + + def get_fit_range(self): + """ + Return the range of data.x to fit + """ + return self.qmin, self.qmax + + def size(self): + """ + Number of measurement points in data set after masking, etc. + """ + return len(self.x) + + def residuals(self, fn): + """ + Compute residuals. + + If self.smearer has been set, use if to smear + the data before computing chi squared. + + :param fn: function that return model value + + :return: residuals + """ + # Compute theory data f(x) + fx = np.zeros(len(self.x)) + fx[self.idx_unsmeared] = fn(self.x[self.idx_unsmeared]) + + ## Smear theory data + if self.smearer is not None: + fx = self.smearer(fx, self._first_unsmeared_bin, + self._last_unsmeared_bin) + ## Sanity check + if np.size(self.dy) != np.size(fx): + msg = "FitData1D: invalid error array " + msg += "%d <> %d" % (np.shape(self.dy), np.size(fx)) + raise RuntimeError(msg) + return (self.y[self.idx] - fx[self.idx]) / self.dy[self.idx], fx[self.idx] + + def residuals_deriv(self, model, pars=[]): + """ + :return: residuals derivatives . + + :note: in this case just return empty array + """ + return [] + + +class FitData2D(Data2D): + """ + Wrapper class for SAS data + """ + def __init__(self, sas_data2d, data=None, err_data=None): + Data2D.__init__(self, data=data, err_data=err_data) + # Data can be initialized with a sas plottable or with vectors. + self.res_err_image = [] + self.num_points = 0 # will be set by set_data + self.idx = [] + self.qmin = None + self.qmax = None + self.smearer = None + self.radius = 0 + self.res_err_data = [] + self.sas_data = sas_data2d + self.set_data(sas_data2d) + + def set_data(self, sas_data2d, qmin=None, qmax=None): + """ + Determine the correct qx_data and qy_data within range to fit + """ + self.data = sas_data2d.data + self.err_data = sas_data2d.err_data + self.qx_data = sas_data2d.qx_data + self.qy_data = sas_data2d.qy_data + self.mask = sas_data2d.mask + + x_max = max(math.fabs(sas_data2d.xmin), math.fabs(sas_data2d.xmax)) + y_max = max(math.fabs(sas_data2d.ymin), math.fabs(sas_data2d.ymax)) + + ## fitting range + if qmin is None: + self.qmin = 1e-16 + if qmax is None: + self.qmax = math.sqrt(x_max * x_max + y_max * y_max) + ## new error image for fitting purpose + if self.err_data is None or self.err_data == []: + self.res_err_data = np.ones(len(self.data)) + else: + self.res_err_data = copy.deepcopy(self.err_data) + #self.res_err_data[self.res_err_data==0]=1 + + self.radius = np.sqrt(self.qx_data**2 + self.qy_data**2) + + # Note: mask = True: for MASK while mask = False for NOT to mask + self.idx = ((self.qmin <= self.radius) &\ + (self.radius <= self.qmax)) + self.idx = (self.idx) & (self.mask) + self.idx = (self.idx) & (np.isfinite(self.data)) + self.num_points = np.sum(self.idx) + + def set_smearer(self, smearer): + """ + Set smearer + """ + if smearer is None: + return + self.smearer = smearer + self.smearer.set_index(self.idx) + self.smearer.get_data() + + def set_fit_range(self, qmin=None, qmax=None): + """ + To set the fit range + """ + if qmin == 0.0: + self.qmin = 1e-16 + elif qmin is not None: + self.qmin = qmin + if qmax is not None: + self.qmax = qmax + self.radius = np.sqrt(self.qx_data**2 + self.qy_data**2) + self.idx = ((self.qmin <= self.radius) &\ + (self.radius <= self.qmax)) + self.idx = (self.idx) & (self.mask) + self.idx = (self.idx) & (np.isfinite(self.data)) + self.idx = (self.idx) & (self.res_err_data != 0) + + def get_fit_range(self): + """ + return the range of data.x to fit + """ + return self.qmin, self.qmax + + def size(self): + """ + Number of measurement points in data set after masking, etc. + """ + return np.sum(self.idx) + + def residuals(self, fn): + """ + return the residuals + """ + if self.smearer is not None: + fn.set_index(self.idx) + # Get necessary data from self.data and set the data for smearing + fn.get_data() + + gn = fn.get_value() + else: + gn = fn([self.qx_data[self.idx], + self.qy_data[self.idx]]) + # use only the data point within ROI range + res = (self.data[self.idx] - gn) / self.res_err_data[self.idx] + + return res, gn + + def residuals_deriv(self, model, pars=[]): + """ + :return: residuals derivatives . + + :note: in this case just return empty array + + """ + return [] + + +class FitAbort(Exception): + """ + Exception raise to stop the fit + """ + #pass + #print"Creating fit abort Exception" + + + +class FitEngine: + def __init__(self): + """ + Base class for the fit engine + """ + #Dictionnary of fitArrange element (fit problems) + self.fit_arrange_dict = {} + self.fitter_id = None + + def set_model(self, model, id, pars=[], constraints=[], data=None): + """ + set a model on a given in the fit engine. + + :param model: sas.models type + :param id: is the key of the fitArrange dictionary where model is saved as a value + :param pars: the list of parameters to fit + :param constraints: list of + tuple (name of parameter, value of parameters) + the value of parameter must be a string to constraint 2 different + parameters. + Example: + we want to fit 2 model M1 and M2 both have parameters A and B. + constraints can be ``constraints = [(M1.A, M2.B+2), (M1.B= M2.A *5),...,]`` + + + :note: pars must contains only name of existing model's parameters + + """ + if not pars: + raise ValueError("no fitting parameters") + + if model is None: + raise ValueError("no model to fit") + + if not issubclass(model.__class__, Model): + model = Model(model, data) + + sasmodel = model.model + available_parameters = sasmodel.getParamList() + for p in pars: + if p not in available_parameters: + raise ValueError("parameter %s not available in model %s; use one of [%s] instead" + %(p, sasmodel.name, ", ".join(available_parameters))) + + if id not in self.fit_arrange_dict: + self.fit_arrange_dict[id] = FitArrange() + + self.fit_arrange_dict[id].set_model(model) + self.fit_arrange_dict[id].pars = pars + self.fit_arrange_dict[id].vals = [sasmodel.getParam(name) for name in pars] + self.fit_arrange_dict[id].constraints = constraints + + def set_data(self, data, id, smearer=None, qmin=None, qmax=None): + """ + Receives plottable, creates a list of data to fit,set data + in a FitArrange object and adds that object in a dictionary + with key id. + + :param data: data added + :param id: unique key corresponding to a fitArrange object with data + """ + if data.__class__.__name__ == 'Data2D': + fitdata = FitData2D(sas_data2d=data, data=data.data, + err_data=data.err_data) + else: + fitdata = FitData1D(x=data.x, y=data.y, + dx=data.dx, dy=data.dy, smearer=smearer) + fitdata.sas_data = data + + fitdata.set_fit_range(qmin=qmin, qmax=qmax) + #A fitArrange is already created but contains model only at id + if id in self.fit_arrange_dict: + self.fit_arrange_dict[id].add_data(fitdata) + else: + #no fitArrange object has been create with this id + fitproblem = FitArrange() + fitproblem.add_data(fitdata) + self.fit_arrange_dict[id] = fitproblem + + def get_model(self, id): + """ + :param id: id is key in the dictionary containing the model to return + + :return: a model at this id or None if no FitArrange element was + created with this id + """ + if id in self.fit_arrange_dict: + return self.fit_arrange_dict[id].get_model() + else: + return None + + def remove_fit_problem(self, id): + """remove fitarrange in id""" + if id in self.fit_arrange_dict: + del self.fit_arrange_dict[id] + + def select_problem_for_fit(self, id, value): + """ + select a couple of model and data at the id position in dictionary + and set in self.selected value to value + + :param value: the value to allow fitting. + can only have the value one or zero + """ + if id in self.fit_arrange_dict: + self.fit_arrange_dict[id].set_to_fit(value) + + def get_problem_to_fit(self, id): + """ + return the self.selected value of the fit problem of id + + :param id: the id of the problem + """ + if id in self.fit_arrange_dict: + self.fit_arrange_dict[id].get_to_fit() + + +class FitArrange: + def __init__(self): + """ + Class FitArrange contains a set of data for a given model + to perform the Fit.FitArrange must contain exactly one model + and at least one data for the fit to be performed. + + model: the model selected by the user + Ldata: a list of data what the user wants to fit + + """ + self.model = None + self.data_list = [] + self.pars = [] + self.vals = [] + self.selected = 0 + + def set_model(self, model): + """ + set_model save a copy of the model + + :param model: the model being set + """ + self.model = model + + def add_data(self, data): + """ + add_data fill a self.data_list with data to fit + + :param data: Data to add in the list + """ + if not data in self.data_list: + self.data_list.append(data) + + def get_model(self): + """ + :return: saved model + """ + return self.model + + def get_data(self): + """ + :return: list of data data_list + """ + return self.data_list[0] + + def remove_data(self, data): + """ + Remove one element from the list + + :param data: Data to remove from data_list + """ + if data in self.data_list: + self.data_list.remove(data) + + def set_to_fit(self, value=0): + """ + set self.selected to 0 or 1 for other values raise an exception + + :param value: integer between 0 or 1 + """ + self.selected = value + + def get_to_fit(self): + """ + return self.selected value + """ + return self.selected + +class FResult(object): + """ + Storing fit result + """ + def __init__(self, model=None, param_list=None, data=None): + self.calls = None + self.fitness = None + self.chisqr = None + self.pvec = [] + self.cov = [] + self.info = None + self.mesg = None + self.success = None + self.stderr = None + self.residuals = [] + self.index = [] + self.model = model + self.data = data + self.theory = [] + self.param_list = param_list + self.iterations = 0 + self.inputs = [] + self.fitter_id = None + if self.model is not None and self.data is not None: + self.inputs = [(self.model, self.data)] + + def set_model(self, model): + """ + """ + self.model = model + + def set_fitness(self, fitness): + """ + """ + self.fitness = fitness + + def __str__(self): + """ + """ + if self.pvec is None and self.model is None and self.param_list is None: + return "No results" + + sasmodel = self.model.model + pars = enumerate(sasmodel.getParamList()) + msg1 = "[Iteration #: %s ]" % self.iterations + msg3 = "=== goodness of fit: %s ===" % (str(self.fitness)) + msg2 = ["P%-3d %s......|.....%s" % (i, v, sasmodel.getParam(v)) + for i,v in pars if v in self.param_list] + msg = [msg1, msg3] + msg2 + return "\n".join(msg) + + def print_summary(self): + """ + """ + print(str(self)) diff --git a/sas/sascalc/pr/fit/BumpsFitting.py b/sas/sascalc/pr/fit/BumpsFitting.py new file mode 100755 index 000000000..5a99c2c24 --- /dev/null +++ b/sas/sascalc/pr/fit/BumpsFitting.py @@ -0,0 +1,367 @@ +""" +BumpsFitting module runs the bumps optimizer. +""" +from __future__ import division + +import os +from datetime import timedelta, datetime + +import numpy as np + +from bumps import fitters +try: + from bumps.options import FIT_CONFIG + # Default bumps to use the Levenberg-Marquardt optimizer + FIT_CONFIG.selected_id = fitters.LevenbergMarquardtFit.id + def get_fitter(): + return FIT_CONFIG.selected_fitter, FIT_CONFIG.selected_values +except: + # CRUFT: Bumps changed its handling of fit options around 0.7.5.6 + # Default bumps to use the Levenberg-Marquardt optimizer + fitters.FIT_DEFAULT = 'lm' + def get_fitter(): + fitopts = fitters.FIT_OPTIONS[fitters.FIT_DEFAULT] + return fitopts.fitclass, fitopts.options.copy() + + +from bumps.mapper import SerialMapper, MPMapper +from bumps import parameter +from bumps.fitproblem import FitProblem + + +from sas.sascalc.fit.AbstractFitEngine import FitEngine +from sas.sascalc.fit.AbstractFitEngine import FResult +from sas.sascalc.fit.expression import compile_constraints + +class Progress(object): + def __init__(self, history, max_step, pars, dof): + remaining_time = int(history.time[0]*(max_step/history.step[0]-1)) + # Depending on the time remaining, either display the expected + # time of completion, or the amount of time remaining. Use precision + # appropriate for the duration. + if remaining_time >= 1800: + completion_time = datetime.now() + timedelta(seconds=remaining_time) + if remaining_time >= 36000: + time = completion_time.strftime('%Y-%m-%d %H:%M') + else: + time = completion_time.strftime('%H:%M') + else: + if remaining_time >= 3600: + time = '%dh %dm'%(remaining_time//3600, (remaining_time%3600)//60) + elif remaining_time >= 60: + time = '%dm %ds'%(remaining_time//60, remaining_time%60) + else: + time = '%ds'%remaining_time + chisq = "%.3g"%(2*history.value[0]/dof) + step = "%d of %d"%(history.step[0], max_step) + header = "=== Steps: %s chisq: %s ETA: %s\n"%(step, chisq, time) + parameters = ["%15s: %-10.3g%s"%(k,v,("\n" if i%3==2 else " | ")) + for i,(k,v) in enumerate(zip(pars,history.point[0]))] + self.msg = "".join([header]+parameters) + + def __str__(self): + return self.msg + + +class BumpsMonitor(object): + def __init__(self, handler, max_step, pars, dof): + self.handler = handler + self.max_step = max_step + self.pars = pars + self.dof = dof + + def config_history(self, history): + history.requires(time=1, value=2, point=1, step=1) + + def __call__(self, history): + if self.handler is None: return + self.handler.set_result(Progress(history, self.max_step, self.pars, self.dof)) + self.handler.progress(history.step[0], self.max_step) + if len(history.step)>1 and history.step[1] > history.step[0]: + self.handler.improvement() + self.handler.update_fit() + +class ConvergenceMonitor(object): + """ + ConvergenceMonitor contains population summary statistics to show progress + of the fit. This is a list [ (best, 0%, 25%, 50%, 75%, 100%) ] or + just a list [ (best, ) ] if population size is 1. + """ + def __init__(self): + self.convergence = [] + + def config_history(self, history): + history.requires(value=1, population_values=1) + + def __call__(self, history): + best = history.value[0] + try: + p = history.population_values[0] + n,p = len(p), np.sort(p) + QI,Qmid, = int(0.2*n),int(0.5*n) + self.convergence.append((best, p[0],p[QI],p[Qmid],p[-1-QI],p[-1])) + except: + self.convergence.append((best, best,best,best,best,best)) + + +# Note: currently using bumps parameters for each parameter object so that +# a SasFitness can be used directly in bumps with the usual semantics. +# The disadvantage of this technique is that we need to copy every parameter +# back into the model each time the function is evaluated. We could instead +# define reference parameters for each sas parameter, but then we would not +# be able to express constraints using python expressions in the usual way +# from bumps, and would instead need to use string expressions. +class SasFitness(object): + """ + Wrap SAS model as a bumps fitness object + """ + def __init__(self, model, data, fitted=[], constraints={}, + initial_values=None, **kw): + self.name = model.name + self.model = model.model + self.data = data + if self.data.smearer is not None: + self.data.smearer.model = self.model + self._define_pars() + self._init_pars(kw) + if initial_values is not None: + self._reset_pars(fitted, initial_values) + self.constraints = dict(constraints) + self.set_fitted(fitted) + self.update() + + def _reset_pars(self, names, values): + for k,v in zip(names, values): + self._pars[k].value = v + + def _define_pars(self): + self._pars = {} + for k in self.model.getParamList(): + name = ".".join((self.name,k)) + value = self.model.getParam(k) + bounds = self.model.details.get(k,["",None,None])[1:3] + self._pars[k] = parameter.Parameter(value=value, bounds=bounds, + fixed=True, name=name) + #print parameter.summarize(self._pars.values()) + + def _init_pars(self, kw): + for k,v in kw.items(): + # dispersion parameters initialized with _field instead of .field + if k.endswith('_width'): k = k[:-6]+'.width' + elif k.endswith('_npts'): k = k[:-5]+'.npts' + elif k.endswith('_nsigmas'): k = k[:-7]+'.nsigmas' + elif k.endswith('_type'): k = k[:-5]+'.type' + if k not in self._pars: + formatted_pars = ", ".join(sorted(self._pars.keys())) + raise KeyError("invalid parameter %r for %s--use one of: %s" + %(k, self.model, formatted_pars)) + if '.' in k and not k.endswith('.width'): + self.model.setParam(k, v) + elif isinstance(v, parameter.BaseParameter): + self._pars[k] = v + elif isinstance(v, (tuple,list)): + low, high = v + self._pars[k].value = (low+high)/2 + self._pars[k].range(low,high) + else: + self._pars[k].value = v + + def set_fitted(self, param_list): + """ + Flag a set of parameters as fitted parameters. + """ + for k,p in self._pars.items(): + p.fixed = (k not in param_list or k in self.constraints) + self.fitted_par_names = [k for k in param_list if k not in self.constraints] + self.computed_par_names = [k for k in param_list if k in self.constraints] + self.fitted_pars = [self._pars[k] for k in self.fitted_par_names] + self.computed_pars = [self._pars[k] for k in self.computed_par_names] + + # ===== Fitness interface ==== + def parameters(self): + return self._pars + + def update(self): + for k,v in self._pars.items(): + #print "updating",k,v,v.value + self.model.setParam(k,v.value) + self._dirty = True + + def _recalculate(self): + if self._dirty: + self._residuals, self._theory \ + = self.data.residuals(self.model.evalDistribution) + self._dirty = False + + def numpoints(self): + return np.sum(self.data.idx) # number of fitted points + + def nllf(self): + return 0.5*np.sum(self.residuals()**2) + + def theory(self): + self._recalculate() + return self._theory + + def residuals(self): + self._recalculate() + return self._residuals + + # Not implementing the data methods for now: + # + # resynth_data/restore_data/save/plot + +class ParameterExpressions(object): + def __init__(self, models): + self.models = models + self._setup() + + def _setup(self): + exprs = {} + for M in self.models: + exprs.update((".".join((M.name, k)), v) for k, v in M.constraints.items()) + if exprs: + symtab = dict((".".join((M.name, k)), p) + for M in self.models + for k,p in M.parameters().items()) + self.update = compile_constraints(symtab, exprs) + else: + self.update = lambda: 0 + + def __call__(self): + self.update() + + def __getstate__(self): + return self.models + + def __setstate__(self, state): + self.models = state + self._setup() + +class BumpsFit(FitEngine): + """ + Fit a model using bumps. + """ + def __init__(self): + """ + Creates a dictionary (self.fit_arrange_dict={})of FitArrange elements + with Uid as keys + """ + FitEngine.__init__(self) + self.curr_thread = None + + def fit(self, msg_q=None, + q=None, handler=None, curr_thread=None, + ftol=1.49012e-8, reset_flag=False): + # Build collection of bumps fitness calculators + models = [SasFitness(model=M.get_model(), + data=M.get_data(), + constraints=M.constraints, + fitted=M.pars, + initial_values=M.vals if reset_flag else None) + for M in self.fit_arrange_dict.values() + if M.get_to_fit()] + if len(models) == 0: + raise RuntimeError("Nothing to fit") + problem = FitProblem(models) + + # TODO: need better handling of parameter expressions and bounds constraints + # so that they are applied during polydispersity calculations. This + # will remove the immediate need for the setp_hook in bumps, though + # bumps may still need something similar, such as a sane class structure + # which allows a subclass to override setp. + problem.setp_hook = ParameterExpressions(models) + + # Run the fit + result = run_bumps(problem, handler, curr_thread) + if handler is not None: + handler.update_fit(last=True) + + # TODO: shouldn't reference internal parameters of fit problem + varying = problem._parameters + # collect the results + all_results = [] + for M in problem.models: + fitness = M.fitness + fitted_index = [varying.index(p) for p in fitness.fitted_pars] + param_list = fitness.fitted_par_names + fitness.computed_par_names + R = FResult(model=fitness.model, data=fitness.data, + param_list=param_list) + R.theory = fitness.theory() + R.residuals = fitness.residuals() + R.index = fitness.data.idx + R.fitter_id = self.fitter_id + # TODO: should scale stderr by sqrt(chisq/DOF) if dy is unknown + R.success = result['success'] + if R.success: + R.stderr = np.hstack((result['stderr'][fitted_index], + np.NaN*np.ones(len(fitness.computed_pars)))) + R.pvec = np.hstack((result['value'][fitted_index], + [p.value for p in fitness.computed_pars])) + R.fitness = np.sum(R.residuals**2)/(fitness.numpoints() - len(fitted_index)) + else: + R.stderr = np.NaN*np.ones(len(param_list)) + R.pvec = np.asarray( [p.value for p in fitness.fitted_pars+fitness.computed_pars]) + R.fitness = np.NaN + R.convergence = result['convergence'] + if result['uncertainty'] is not None: + R.uncertainty_state = result['uncertainty'] + all_results.append(R) + + if q is not None: + q.put(all_results) + return q + else: + return all_results + +def run_bumps(problem, handler, curr_thread): + def abort_test(): + if curr_thread is None: return False + try: curr_thread.isquit() + except KeyboardInterrupt: + if handler is not None: + handler.stop("Fitting: Terminated!!!") + return True + return False + + fitclass, options = get_fitter() + steps = options.get('steps', 0) + if steps == 0: + pop = options.get('pop',0)*len(problem._parameters) + samples = options.get('samples', 0) + steps = (samples+pop-1)/pop if pop != 0 else samples + max_step = steps + options.get('burn', 0) + pars = [p.name for p in problem._parameters] + #x0 = np.asarray([p.value for p in problem._parameters]) + options['monitors'] = [ + BumpsMonitor(handler, max_step, pars, problem.dof), + ConvergenceMonitor(), + ] + fitdriver = fitters.FitDriver(fitclass, problem=problem, + abort_test=abort_test, **options) + omp_threads = int(os.environ.get('OMP_NUM_THREADS','0')) + mapper = MPMapper if omp_threads == 1 else SerialMapper + fitdriver.mapper = mapper.start_mapper(problem, None) + #import time; T0 = time.time() + try: + best, fbest = fitdriver.fit() + except: + import traceback; traceback.print_exc() + raise + finally: + mapper.stop_mapper(fitdriver.mapper) + + + convergence_list = options['monitors'][-1].convergence + convergence = (2*np.asarray(convergence_list)/problem.dof + if convergence_list else np.empty((0,1),'d')) + + success = best is not None + return { + 'value': best if success else None, + 'stderr': fitdriver.stderr() if success else None, + 'success': success, + 'convergence': convergence, + 'uncertainty': getattr(fitdriver.fitter, 'state', None), + } + diff --git a/sas/sascalc/pr/fit/Loader.py b/sas/sascalc/pr/fit/Loader.py new file mode 100755 index 000000000..91d8ed214 --- /dev/null +++ b/sas/sascalc/pr/fit/Loader.py @@ -0,0 +1,88 @@ +""" +class Loader to load any kind of file +""" + +from __future__ import print_function + +import numpy as np + +class Load: + """ + This class is loading values from given file or value giving by the user + """ + def __init__(self, x=None, y=None, dx=None, dy=None): + raise NotImplementedError("a code search shows that this code is not active, and you are not seeing this message") + # variable to store loaded values + self.x = x + self.y = y + self.dx = dx + self.dy = dy + self.filename = None + + def set_filename(self, path=None): + """ + Store path into a variable.If the user doesn't give + a path as a parameter a pop-up + window appears to select the file. + + :param path: the path given by the user + + """ + self.filename = path + + def get_filename(self): + """ return the file's path""" + return self.filename + + def set_values(self): + """ Store the values loaded from file in local variables""" + if self.filename is not None: + input_f = open(self.filename, 'r') + buff = input_f.read() + lines = buff.split('\n') + self.x = [] + self.y = [] + self.dx = [] + self.dy = [] + for line in lines: + try: + toks = line.split() + x = float(toks[0]) + y = float(toks[1]) + dy = float(toks[2]) + + self.x.append(x) + self.y.append(y) + self.dy.append(dy) + self.dx = np.zeros(len(self.x)) + except: + print("READ ERROR", line) + # Sanity check + if not len(self.x) == len(self.dx): + raise ValueError("x and dx have different length") + if not len(self.y) == len(self.dy): + raise ValueError("y and dy have different length") + + + def get_values(self): + """ Return x, y, dx, dy""" + return self.x, self.y, self.dx, self.dy + + def load_data(self, data): + """ Return plottable""" + #load data + data.x = self.x + data.y = self.y + data.dx = self.dx + data.dy = self.dy + #Load its View class + #plottable.reset_view() + + +if __name__ == "__main__": + load = Load() + load.set_filename("testdata_line.txt") + print(load.get_filename()) + load.set_values() + print(load.get_values()) + diff --git a/sas/sascalc/pr/fit/__init__.py b/sas/sascalc/pr/fit/__init__.py new file mode 100755 index 000000000..03d374576 --- /dev/null +++ b/sas/sascalc/pr/fit/__init__.py @@ -0,0 +1 @@ +from .AbstractFitEngine import FitHandler \ No newline at end of file diff --git a/sas/sascalc/pr/fit/expression.py b/sas/sascalc/pr/fit/expression.py new file mode 100755 index 000000000..3c36c1bf2 --- /dev/null +++ b/sas/sascalc/pr/fit/expression.py @@ -0,0 +1,405 @@ +from __future__ import print_function + +# This program is public domain +""" +Parameter expression evaluator. + +For systems in which constraints are expressed as string expressions rather +than python code, :func:`compile_constraints` can construct an expression +evaluator that substitutes the computed values of the expressions into the +parameters. + +The compiler requires a symbol table, an expression set and a context. +The symbol table maps strings containing fully qualified names such as +'M1.c[3].full_width' to parameter objects with a 'value' property that +can be queried and set. The expression set maps symbol names from the +symbol table to string expressions. The context provides additional symbols +for the expressions in addition to the usual mathematical functions and +constants. + +The expressions are compiled and interpreted by python, with only minimal +effort to make sure that they don't contain bad code. The resulting +constraints function returns 0 so it can be used directly in a fit problem +definition. + +Extracting the symbol table from the model depends on the structure of the +model. If fitness.parameters() is set correctly, then this should simply +be a matter of walking the parameter data, remembering the path to each +parameter in the symbol table. For compactness, dictionary elements should +be referenced by .name rather than ["name"]. Model name can be used as the +top level. + +Getting the parameter expressions applied correctly is challenging. +The following monkey patch works by overriding model_update in FitProblem +so that after setp(p) is called and, the constraints expression can be +applied before telling the underlying fitness function that the model +is out of date:: + + # Override model update so that parameter constraints are applied + problem._model_update = problem.model_update + def model_update(): + constraints() + problem._model_update() + problem.model_update = model_update + +Ideally, this interface will change +""" +import math +import re + +# simple pattern which matches symbols. Note that it will also match +# invalid substrings such as a3...9, but given syntactically correct +# input it will only match symbols. +_symbol_pattern = re.compile('([a-zA-Z_][a-zA-Z_0-9.]*)') + +def _symbols(expr,symtab): + """ + Given an expression string and a symbol table, return the set of symbols + used in the expression. Symbols are only returned once even if they + occur multiple times. The return value is a set with the elements in + no particular order. + + This is the first step in computing a dependency graph. + """ + matches = [m.group(0) for m in _symbol_pattern.finditer(expr)] + return set([symtab[m] for m in matches if m in symtab]) + +def _substitute(expr,mapping): + """ + Replace all occurrences of symbol s with mapping[s] for s in mapping. + """ + # Find the symbols and the mapping + matches = [(m.start(),m.end(),mapping[m.group(1)]) + for m in _symbol_pattern.finditer(expr) + if m.group(1) in mapping] + + # Split the expression in to pieces, with new symbols replacing old + pieces = [] + offset = 0 + for start,end,text in matches: + pieces += [expr[offset:start],text] + offset = end + pieces.append(expr[offset:]) + + # Join the pieces and return them + return "".join(pieces) + +def _find_dependencies(symtab, exprs): + """ + Returns a list of pair-wise dependencies from the parameter expressions. + + For example, if p3 = p1+p2, then find_dependencies([p1,p2,p3]) will + return [(p3,p1),(p3,p2)]. For base expressions without dependencies, + such as p4 = 2*pi, this should return [(p4, None)] + """ + deps = [(target,source) + for target,expr in exprs.items() + for source in _symbols_or_none(expr,symtab)] + return deps + +# Hack to deal with expressions without dependencies --- return a fake +# dependency of None. +# The better solution is fix order_dependencies so that it takes a +# dictionary of {symbol: dependency_list}, for which no dependencies +# is simply []; fix in parameter_mapping as well +def _symbols_or_none(expr,symtab): + syms = _symbols(expr,symtab) + return syms if len(syms) else [None] + +def _parameter_mapping(pairs): + """ + Find the parameter substitution we need so that expressions can + be evaluated without having to traverse a chain of + model.layer.parameter.value + """ + left,right = zip(*pairs) + pars = list(sorted(p for p in set(left+right) if p is not None)) + definition = dict( ('P%d'%i,p) for i,p in enumerate(pars) ) + # p is None when there is an expression with no dependencies + substitution = dict( (p,'P%d.value'%i) + for i,p in enumerate(sorted(pars)) + if p is not None) + return definition, substitution + +def no_constraints(): + """ + This parameter set has no constraints between the parameters. + """ + pass + +def compile_constraints(symtab, exprs, context={}): + """ + Build and return a function to evaluate all parameter expressions in + the proper order. + + Input: + + *symtab* is the symbol table for the model: { 'name': parameter } + + *exprs* is the set of computed symbols: { 'name': 'expression' } + + *context* is any additional context needed to evaluate the expression + + Return: + + updater function which sets parameter.value for each expression + + Raises: + + AssertionError - model, parameter or function is missing + + SyntaxError - improper expression syntax + + ValueError - expressions have circular dependencies + + This function is not terribly sophisticated, and it would be easy to + trick. However it handles the common cases cleanly and generates + reasonable messages for the common errors. + + This code has not been fully audited for security. While we have + removed the builtins and the ability to import modules, there may + be other vectors for users to perform more than simple function + evaluations. Unauthenticated users should not be running this code. + + Parameter names are assumed to contain only _.a-zA-Z0-9#[] + + Both names are provided for inverse functions, e.g., acos and arccos. + + Should try running the function to identify syntax errors before + running it in a fit. + + Use help(fn) to see the code generated for the returned function fn. + dis.dis(fn) will show the corresponding python vm instructions. + """ + + # Sort the parameters in the order they need to be evaluated + deps = _find_dependencies(symtab, exprs) + if deps == []: return no_constraints + order = order_dependencies(deps) + + + # Rather than using the full path to the parameters in the parameter + # expressions, instead use Pn, and substitute Pn.value for each occurrence + # of the parameter in the expression. + names = list(sorted(symtab.keys())) + parameters = dict(('P%d'%i, symtab[k]) for i,k in enumerate(names)) + mapping = dict((k, 'P%d.value'%i) for i,k in enumerate(names)) + + + # Initialize dictionary with available functions + globals = {} + globals.update(math.__dict__) + globals.update(dict(arcsin=math.asin,arccos=math.acos, + arctan=math.atan,arctan2=math.atan2)) + globals.update(context) + globals.update(parameters) + globals['id'] = id + locals = {} + + # Define the constraints function + assignments = ["=".join((p,exprs[p])) for p in order] + code = [_substitute(s, mapping) for s in assignments] + functiondef = """ +def eval_expressions(): + ''' + %s + ''' + %s + return 0 +"""%("\n ".join(assignments),"\n ".join(code)) + + #print("Function: "+functiondef) + exec functiondef in globals,locals + retfn = locals['eval_expressions'] + + # Remove garbage added to globals by exec + globals.pop('__doc__',None) + globals.pop('__name__',None) + globals.pop('__file__',None) + globals.pop('__builtins__') + #print globals.keys() + + return retfn + +def order_dependencies(pairs): + """ + Order elements from pairs so that b comes before a in the + ordered list for all pairs (a,b). + """ + #print "order_dependencies",pairs + emptyset = set() + order = [] + + # Break pairs into left set and right set + left,right = [set(s) for s in zip(*pairs)] if pairs != [] else ([],[]) + while pairs != []: + #print "within",pairs + # Find which items only occur on the right + independent = right - left + if independent == emptyset: + cycleset = ", ".join(str(s) for s in left) + raise ValueError("Cyclic dependencies amongst %s"%cycleset) + + # The possibly resolvable items are those that depend on the independents + dependent = set([a for a,b in pairs if b in independent]) + pairs = [(a,b) for a,b in pairs if b not in independent] + if pairs == []: + resolved = dependent + else: + left,right = [set(s) for s in zip(*pairs)] + resolved = dependent - left + #print "independent",independent,"dependent",dependent,"resolvable",resolved + order += resolved + #print "new order",order + order.reverse() + return order + +# ========= Test code ======== +def _check(msg,pairs): + """ + Verify that the list n contains the given items, and that the list + satisfies the partial ordering given by the pairs in partial order. + """ + left,right = zip(*pairs) if pairs != [] else ([],[]) + items = set(left) + n = order_dependencies(pairs) + if set(n) != items or len(n) != len(items): + n.sort() + items = list(items); items.sort() + raise ValueError("%s expect %s to contain %s for %s"%(msg,n,items,pairs)) + for lo,hi in pairs: + if lo in n and hi in n and n.index(lo) >= n.index(hi): + raise ValueError("%s expect %s before %s in %s for %s"%(msg,lo,hi,n,pairs)) + +def test_deps(): + import numpy as np + + # Null case + _check("test empty",[]) + + # Some dependencies + _check("test1",[(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)]) + _check("test1 renumbered",[(6,1),(7,3),(7,4),(6,7),(5,7),(3,2)]) + _check("test1 numpy",np.array([(2,7),(1,5),(1,4),(2,1),(3,1),(5,6)])) + + # No dependencies + _check("test2",[(4,1),(3,2),(8,4)]) + + # Cycle test + pairs = [(1,4),(4,3),(4,5),(5,1)] + try: + n = order_dependencies(pairs) + except ValueError: + pass + else: + raise Exception("test3 expect ValueError exception for %s"%(pairs,)) + + # large test for gross speed check + A = np.random.randint(4000,size=(1000,2)) + A[:,1] += 4000 # Avoid cycles + _check("test-large",A) + + # depth tests + k = 200 + A = np.array([range(0,k),range(1,k+1)]).T + _check("depth-1",A) + + A = np.array([range(1,k+1),range(0,k)]).T + _check("depth-2",A) + +def test_expr(): + import inspect, dis + import math + + symtab = {'a.b.x':1, 'a.c':2, 'a.b':3, 'b.x':4} + expr = 'a.b.x + sin(4*pi*a.c) + a.b.x/a.b' + + # Check symbol lookup + assert _symbols(expr, symtab) == set([1,2,3]) + + # Check symbol rename + assert _substitute(expr,{'a.b.x':'Q'}) == 'Q + sin(4*pi*a.c) + Q/a.b' + assert _substitute(expr,{'a.b':'Q'}) == 'a.b.x + sin(4*pi*a.c) + a.b.x/Q' + + + # Check dependency builder + # Fake parameter class + class Parameter: + def __init__(self, name, value=0, expression=''): + self.path = name + self.value = value + self.expression = expression + def iscomputed(self): return (self.expression != '') + def __repr__(self): return self.path + def world(*pars): + symtab = dict((p.path,p) for p in pars) + exprs = dict((p.path,p.expression) for p in pars if p.iscomputed()) + return symtab, exprs + p1 = Parameter('G0.sigma',5) + p2 = Parameter('other',expression='2*pi*sin(G0.sigma/.1875) + M1.G1') + p3 = Parameter('M1.G1',6) + p4 = Parameter('constant',expression='2*pi*35') + # Simple chain + assert set(_find_dependencies(*world(p1,p2,p3))) == set([(p2.path,p1),(p2.path,p3)]) + # Constant expression + assert set(_find_dependencies(*world(p1,p4))) == set([(p4.path,None)]) + # No dependencies + assert set(_find_dependencies(*world(p1,p3))) == set([]) + + # Check function builder + fn = compile_constraints(*world(p1,p2,p3)) + + # Inspect the resulting function + if 0: + print(inspect.getdoc(fn)) + print(dis.dis(fn)) + + # Evaluate the function and see if it updates the + # target value as expected + fn() + expected = 2*math.pi*math.sin(5/.1875) + 6 + assert p2.value == expected,"Value was %s, not %s"%(p2.value,expected) + + # Check empty dependency set doesn't crash + fn = compile_constraints(*world(p1,p3)) + fn() + + # Check that constants are evaluated properly + fn = compile_constraints(*world(p4)) + fn() + assert p4.value == 2*math.pi*35 + + # Check additional context example; this also tests multiple + # expressions + class Table: + Si = 2.09 + values = {'Si': 2.07} + tbl = Table() + p5 = Parameter('lookup',expression="tbl.Si") + fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl)) + fn() + assert p5.value == 2.09,"Value for %s was %s"%(p5.expression,p5.value) + p5.expression = "tbl.values['Si']" + fn = compile_constraints(*world(p1,p2,p3,p5),context=dict(tbl=tbl)) + fn() + assert p5.value == 2.07,"Value for %s was %s"%(p5.expression,p5.value) + + + # Verify that we capture invalid expressions + for expr in ['G4.cage', 'M0.cage', 'M1.G1 + *2', + 'piddle', + '5; import sys; print "p0wned"', + '__import__("sys").argv']: + try: + p6 = Parameter('broken',expression=expr) + fn = compile_constraints(*world(p6)) + fn() + except Exception as msg: + #print(msg) + pass + else: + raise "Failed to raise error for %s"%expr + +if __name__ == "__main__": + test_expr() + test_deps() diff --git a/sas/sascalc/pr/invertor.py b/sas/sascalc/pr/invertor.py new file mode 100755 index 000000000..07afd4a56 --- /dev/null +++ b/sas/sascalc/pr/invertor.py @@ -0,0 +1,756 @@ +# pylint: disable=invalid-name +""" +Module to perform P(r) inversion. +The module contains the Invertor class. + +FIXME: The way the Invertor interacts with its C component should be cleaned up +""" +from __future__ import division + +import numpy as np +import sys +import math +import time +import copy +import os +import re +import logging +from numpy.linalg import lstsq +from scipy import optimize +from sas.sascalc.pr._pr_inversion import Cinvertor + +logger = logging.getLogger(__name__) + +def help(): + """ + Provide general online help text + Future work: extend this function to allow topic selection + """ + info_txt = "The inversion approach is based on Moore, J. Appl. Cryst. " + info_txt += "(1980) 13, 168-175.\n\n" + info_txt += "P(r) is set to be equal to an expansion of base functions " + info_txt += "of the type " + info_txt += "phi_n(r) = 2*r*sin(pi*n*r/D_max). The coefficient of each " + info_txt += "base functions " + info_txt += "in the expansion is found by performing a least square fit " + info_txt += "with the " + info_txt += "following fit function:\n\n" + info_txt += "chi**2 = sum_i[ I_meas(q_i) - I_th(q_i) ]**2/error**2 +" + info_txt += "Reg_term\n\n" + info_txt += "where I_meas(q) is the measured scattering intensity and " + info_txt += "I_th(q) is " + info_txt += "the prediction from the Fourier transform of the P(r) " + info_txt += "expansion. " + info_txt += "The Reg_term term is a regularization term set to the second" + info_txt += " derivative " + info_txt += "d**2P(r)/dr**2 integrated over r. It is used to produce " + info_txt += "a smooth P(r) output.\n\n" + info_txt += "The following are user inputs:\n\n" + info_txt += " - Number of terms: the number of base functions in the P(r)" + info_txt += " expansion.\n\n" + info_txt += " - Regularization constant: a multiplicative constant " + info_txt += "to set the size of " + info_txt += "the regularization term.\n\n" + info_txt += " - Maximum distance: the maximum distance between any " + info_txt += "two points in the system.\n" + + return info_txt + + +class Invertor(Cinvertor): + """ + Invertor class to perform P(r) inversion + + The problem is solved by posing the problem as Ax = b, + where x is the set of coefficients we are looking for. + + Npts is the number of points. + + In the following i refers to the ith base function coefficient. + The matrix has its entries j in its first Npts rows set to :: + + A[j][i] = (Fourier transformed base function for point j) + + We then choose a number of r-points, n_r, to evaluate the second + derivative of P(r) at. This is used as our regularization term. + For a vector r of length n_r, the following n_r rows are set to :: + + A[j+Npts][i] = (2nd derivative of P(r), d**2(P(r))/d(r)**2, + evaluated at r[j]) + + The vector b has its first Npts entries set to :: + + b[j] = (I(q) observed for point j) + + The following n_r entries are set to zero. + + The result is found by using scipy.linalg.basic.lstsq to invert + the matrix and find the coefficients x. + + Methods inherited from Cinvertor: + + * ``get_peaks(pars)``: returns the number of P(r) peaks + * ``oscillations(pars)``: returns the oscillation parameters for the output P(r) + * ``get_positive(pars)``: returns the fraction of P(r) that is above zero + * ``get_pos_err(pars)``: returns the fraction of P(r) that is 1-sigma above zero + """ + ## Chisqr of the last computation + chi2 = 0 + ## Time elapsed for last computation + elapsed = 0 + ## Alpha to get the reg term the same size as the signal + suggested_alpha = 0 + ## Last number of base functions used + nfunc = 10 + ## Last output values + out = None + ## Last errors on output values + cov = None + ## Background value + background = 0 + ## Information dictionary for application use + info = {} + + def __init__(self): + Cinvertor.__init__(self) + + def __setstate__(self, state): + """ + restore the state of invertor for pickle + """ + (self.__dict__, self.alpha, self.d_max, + self.q_min, self.q_max, + self.x, self.y, + self.err, self.est_bck, + self.slit_height, self.slit_width) = state + + def __reduce_ex__(self, proto): + """ + Overwrite the __reduce_ex__ + """ + + state = (self.__dict__, + self.alpha, self.d_max, + self.q_min, self.q_max, + self.x, self.y, + self.err, self.est_bck, + self.slit_height, self.slit_width, + ) + return (Invertor, tuple(), state, None, None) + + def __setattr__(self, name, value): + """ + Set the value of an attribute. + Access the parent class methods for + x, y, err, d_max, q_min, q_max and alpha + """ + if name == 'x': + if 0.0 in value: + msg = "Invertor: one of your q-values is zero. " + msg += "Delete that entry before proceeding" + raise ValueError(msg) + return self.set_x(value) + elif name == 'y': + return self.set_y(value) + elif name == 'err': + value2 = abs(value) + return self.set_err(value2) + elif name == 'd_max': + if value <= 0.0: + msg = "Invertor: d_max must be greater than zero." + msg += "Correct that entry before proceeding" + raise ValueError(msg) + return self.set_dmax(value) + elif name == 'q_min': + if value is None: + return self.set_qmin(-1.0) + return self.set_qmin(value) + elif name == 'q_max': + if value is None: + return self.set_qmax(-1.0) + return self.set_qmax(value) + elif name == 'alpha': + return self.set_alpha(value) + elif name == 'slit_height': + return self.set_slit_height(value) + elif name == 'slit_width': + return self.set_slit_width(value) + elif name == 'est_bck': + if value == True: + return self.set_est_bck(1) + elif value == False: + return self.set_est_bck(0) + else: + raise ValueError("Invertor: est_bck can only be True or False") + + return Cinvertor.__setattr__(self, name, value) + + def __getattr__(self, name): + """ + Return the value of an attribute + """ + #import numpy + if name == 'x': + out = np.ones(self.get_nx()) + self.get_x(out) + return out + elif name == 'y': + out = np.ones(self.get_ny()) + self.get_y(out) + return out + elif name == 'err': + out = np.ones(self.get_nerr()) + self.get_err(out) + return out + elif name == 'd_max': + return self.get_dmax() + elif name == 'q_min': + qmin = self.get_qmin() + if qmin < 0: + return None + return qmin + elif name == 'q_max': + qmax = self.get_qmax() + if qmax < 0: + return None + return qmax + elif name == 'alpha': + return self.get_alpha() + elif name == 'slit_height': + return self.get_slit_height() + elif name == 'slit_width': + return self.get_slit_width() + elif name == 'est_bck': + value = self.get_est_bck() + return value == 1 + elif name in self.__dict__: + return self.__dict__[name] + return None + + def clone(self): + """ + Return a clone of this instance + """ + #import copy + + invertor = Invertor() + invertor.chi2 = self.chi2 + invertor.elapsed = self.elapsed + invertor.nfunc = self.nfunc + invertor.alpha = self.alpha + invertor.d_max = self.d_max + invertor.q_min = self.q_min + invertor.q_max = self.q_max + + invertor.x = self.x + invertor.y = self.y + invertor.err = self.err + invertor.est_bck = self.est_bck + invertor.background = self.background + invertor.slit_height = self.slit_height + invertor.slit_width = self.slit_width + + invertor.info = copy.deepcopy(self.info) + + return invertor + + def invert(self, nfunc=10, nr=20): + """ + Perform inversion to P(r) + + The problem is solved by posing the problem as Ax = b, + where x is the set of coefficients we are looking for. + + Npts is the number of points. + + In the following i refers to the ith base function coefficient. + The matrix has its entries j in its first Npts rows set to :: + + A[i][j] = (Fourier transformed base function for point j) + + We then choose a number of r-points, n_r, to evaluate the second + derivative of P(r) at. This is used as our regularization term. + For a vector r of length n_r, the following n_r rows are set to :: + + A[i+Npts][j] = (2nd derivative of P(r), d**2(P(r))/d(r)**2, evaluated at r[j]) + + The vector b has its first Npts entries set to :: + + b[j] = (I(q) observed for point j) + + The following n_r entries are set to zero. + + The result is found by using scipy.linalg.basic.lstsq to invert + the matrix and find the coefficients x. + + :param nfunc: number of base functions to use. + :param nr: number of r points to evaluate the 2nd derivative at for the reg. term. + :return: c_out, c_cov - the coefficients with covariance matrix + """ + # Reset the background value before proceeding + # self.background = 0.0 + if not self.est_bck: + self.y -= self.background + out, cov = self.lstsq(nfunc, nr=nr) + if not self.est_bck: + self.y += self.background + return out, cov + + def iq(self, out, q): + """ + Function to call to evaluate the scattering intensity + + :param args: c-parameters, and q + :return: I(q) + + """ + return Cinvertor.iq(self, out, q) + self.background + + def invert_optimize(self, nfunc=10, nr=20): + """ + Slower version of the P(r) inversion that uses scipy.optimize.leastsq. + + This probably produce more reliable results, but is much slower. + The minimization function is set to + sum_i[ (I_obs(q_i) - I_theo(q_i))/err**2 ] + alpha * reg_term, + where the reg_term is given by Svergun: it is the integral of + the square of the first derivative + of P(r), d(P(r))/dr, integrated over the full range of r. + + :param nfunc: number of base functions to use. + :param nr: number of r points to evaluate the 2nd derivative at + for the reg. term. + + :return: c_out, c_cov - the coefficients with covariance matrix + + """ + self.nfunc = nfunc + # First, check that the current data is valid + if self.is_valid() <= 0: + msg = "Invertor.invert: Data array are of different length" + raise RuntimeError(msg) + + p = np.ones(nfunc) + t_0 = time.time() + out, cov_x, _, _, _ = optimize.leastsq(self.residuals, p, full_output=1) + + # Compute chi^2 + res = self.residuals(out) + chisqr = 0 + for i in range(len(res)): + chisqr += res[i] + + self.chi2 = chisqr + + # Store computation time + self.elapsed = time.time() - t_0 + + if cov_x is None: + cov_x = np.ones([nfunc, nfunc]) + cov_x *= math.fabs(chisqr) + return out, cov_x + + def pr_fit(self, nfunc=5): + """ + This is a direct fit to a given P(r). It assumes that the y data + is set to some P(r) distribution that we are trying to reproduce + with a set of base functions. + + This method is provided as a test. + """ + # First, check that the current data is valid + if self.is_valid() <= 0: + msg = "Invertor.invert: Data arrays are of different length" + raise RuntimeError(msg) + + p = np.ones(nfunc) + t_0 = time.time() + out, cov_x, _, _, _ = optimize.leastsq(self.pr_residuals, p, full_output=1) + + # Compute chi^2 + res = self.pr_residuals(out) + chisqr = 0 + for i in range(len(res)): + chisqr += res[i] + + self.chisqr = chisqr + + # Store computation time + self.elapsed = time.time() - t_0 + + return out, cov_x + + def pr_err(self, c, c_cov, r): + """ + Returns the value of P(r) for a given r, and base function + coefficients, with error. + + :param c: base function coefficients + :param c_cov: covariance matrice of the base function coefficients + :param r: r-value to evaluate P(r) at + + :return: P(r) + + """ + return self.get_pr_err(c, c_cov, r) + + def _accept_q(self, q): + """ + Check q-value against user-defined range + """ + if self.q_min is not None and q < self.q_min: + return False + if self.q_max is not None and q > self.q_max: + return False + return True + + def lstsq(self, nfunc=5, nr=20): + """ + The problem is solved by posing the problem as Ax = b, + where x is the set of coefficients we are looking for. + + Npts is the number of points. + + In the following i refers to the ith base function coefficient. + The matrix has its entries j in its first Npts rows set to :: + + A[i][j] = (Fourier transformed base function for point j) + + We then choose a number of r-points, n_r, to evaluate the second + derivative of P(r) at. This is used as our regularization term. + For a vector r of length n_r, the following n_r rows are set to :: + + A[i+Npts][j] = (2nd derivative of P(r), d**2(P(r))/d(r)**2, + evaluated at r[j]) + + The vector b has its first Npts entries set to :: + + b[j] = (I(q) observed for point j) + + The following n_r entries are set to zero. + + The result is found by using scipy.linalg.basic.lstsq to invert + the matrix and find the coefficients x. + + :param nfunc: number of base functions to use. + :param nr: number of r points to evaluate the 2nd derivative at for the reg. term. + + If the result does not allow us to compute the covariance matrix, + a matrix filled with zeros will be returned. + + """ + # Note: To make sure an array is contiguous: + # blah = np.ascontiguousarray(blah_original) + # ... before passing it to C + + if self.is_valid() < 0: + msg = "Invertor: invalid data; incompatible data lengths." + raise RuntimeError(msg) + + self.nfunc = nfunc + # a -- An M x N matrix. + # b -- An M x nrhs matrix or M vector. + npts = len(self.x) + nq = nr + sqrt_alpha = math.sqrt(math.fabs(self.alpha)) + if sqrt_alpha < 0.0: + nq = 0 + + # If we need to fit the background, add a term + if self.est_bck: + nfunc_0 = nfunc + nfunc += 1 + + a = np.zeros([npts + nq, nfunc]) + b = np.zeros(npts + nq) + err = np.zeros([nfunc, nfunc]) + + # Construct the a matrix and b vector that represent the problem + t_0 = time.time() + try: + self._get_matrix(nfunc, nq, a, b) + except Exception as exc: + raise RuntimeError("Invertor: could not invert I(Q)\n %s" % str(exc)) + + # Perform the inversion (least square fit) + c, chi2, _, _ = lstsq(a, b, rcond=-1) + # Sanity check + try: + float(chi2) + except: + chi2 = -1.0 + self.chi2 = chi2 + + inv_cov = np.zeros([nfunc, nfunc]) + # Get the covariance matrix, defined as inv_cov = a_transposed * a + self._get_invcov_matrix(nfunc, nr, a, inv_cov) + + # Compute the reg term size for the output + sum_sig, sum_reg = self._get_reg_size(nfunc, nr, a) + + if math.fabs(self.alpha) > 0: + new_alpha = sum_sig / (sum_reg / self.alpha) + else: + new_alpha = 0.0 + self.suggested_alpha = new_alpha + + try: + cov = np.linalg.pinv(inv_cov) + err = math.fabs(chi2 / (npts - nfunc)) * cov + except Exception as exc: + # We were not able to estimate the errors + # Return an empty error matrix + logger.error(exc) + + # Keep a copy of the last output + if not self.est_bck: + self.out = c + self.cov = err + else: + self.background = c[0] + + err_0 = np.zeros([nfunc, nfunc]) + c_0 = np.zeros(nfunc) + + for i in range(nfunc_0): + c_0[i] = c[i + 1] + for j in range(nfunc_0): + err_0[i][j] = err[i + 1][j + 1] + + self.out = c_0 + self.cov = err_0 + + # Store computation time + self.elapsed = time.time() - t_0 + + return self.out, self.cov + + def estimate_numterms(self, isquit_func=None): + """ + Returns a reasonable guess for the + number of terms + + :param isquit_func: + reference to thread function to call to check whether the computation needs to + be stopped. + + :return: number of terms, alpha, message + + """ + from .num_term import NTermEstimator + estimator = NTermEstimator(self.clone()) + try: + return estimator.num_terms(isquit_func) + except Exception as exc: + # If we fail, estimate alpha and return the default + # number of terms + best_alpha, _, _ = self.estimate_alpha(self.nfunc) + logger.warning("Invertor.estimate_numterms: %s" % exc) + return self.nfunc, best_alpha, "Could not estimate number of terms" + + def estimate_alpha(self, nfunc): + """ + Returns a reasonable guess for the + regularization constant alpha + + :param nfunc: number of terms to use in the expansion. + + :return: alpha, message, elapsed + + where alpha is the estimate for alpha, + message is a message for the user, + elapsed is the computation time + """ + #import time + try: + pr = self.clone() + + # T_0 for computation time + starttime = time.time() + elapsed = 0 + + # If the current alpha is zero, try + # another value + if pr.alpha <= 0: + pr.alpha = 0.0001 + + # Perform inversion to find the largest alpha + out, _ = pr.invert(nfunc) + elapsed = time.time() - starttime + initial_alpha = pr.alpha + initial_peaks = pr.get_peaks(out) + + # Try the inversion with the estimated alpha + pr.alpha = pr.suggested_alpha + out, _ = pr.invert(nfunc) + + npeaks = pr.get_peaks(out) + # if more than one peak to start with + # just return the estimate + if npeaks > 1: + #message = "Your P(r) is not smooth, + #please check your inversion parameters" + message = None + return pr.suggested_alpha, message, elapsed + else: + + # Look at smaller values + # We assume that for the suggested alpha, we have 1 peak + # if not, send a message to change parameters + alpha = pr.suggested_alpha + best_alpha = pr.suggested_alpha + found = False + for i in range(10): + pr.alpha = (0.33) ** (i + 1) * alpha + out, _ = pr.invert(nfunc) + + peaks = pr.get_peaks(out) + if peaks > 1: + found = True + break + best_alpha = pr.alpha + + # If we didn't find a turning point for alpha and + # the initial alpha already had only one peak, + # just return that + if not found and initial_peaks == 1 and \ + initial_alpha < best_alpha: + best_alpha = initial_alpha + + # Check whether the size makes sense + message = '' + + if not found: + message = None + elif best_alpha >= 0.5 * pr.suggested_alpha: + # best alpha is too big, return a + # reasonable value + message = "The estimated alpha for your system is too " + message += "large. " + message += "Try increasing your maximum distance." + + return best_alpha, message, elapsed + + except Exception as exc: + message = "Invertor.estimate_alpha: %s" % exc + return 0, message, elapsed + + def to_file(self, path, npts=100): + """ + Save the state to a file that will be readable + by SliceView. + + :param path: path of the file to write + :param npts: number of P(r) points to be written + + """ + file = open(path, 'w') + file.write("#d_max=%g\n" % self.d_max) + file.write("#nfunc=%g\n" % self.nfunc) + file.write("#alpha=%g\n" % self.alpha) + file.write("#chi2=%g\n" % self.chi2) + file.write("#elapsed=%g\n" % self.elapsed) + file.write("#qmin=%s\n" % str(self.q_min)) + file.write("#qmax=%s\n" % str(self.q_max)) + file.write("#slit_height=%g\n" % self.slit_height) + file.write("#slit_width=%g\n" % self.slit_width) + file.write("#background=%g\n" % self.background) + if self.est_bck: + file.write("#has_bck=1\n") + else: + file.write("#has_bck=0\n") + file.write("#alpha_estimate=%g\n" % self.suggested_alpha) + if self.out is not None: + if len(self.out) == len(self.cov): + for i in range(len(self.out)): + file.write("#C_%i=%s+-%s\n" % (i, str(self.out[i]), + str(self.cov[i][i]))) + file.write(" \n") + r = np.arange(0.0, self.d_max, self.d_max / npts) + + for r_i in r: + (value, err) = self.pr_err(self.out, self.cov, r_i) + file.write("%g %g %g\n" % (r_i, value, err)) + + file.close() + + def from_file(self, path): + """ + Load the state of the Invertor from a file, + to be able to generate P(r) from a set of + parameters. + + :param path: path of the file to load + + """ + #import os + #import re + if os.path.isfile(path): + try: + fd = open(path, 'r') + + buff = fd.read() + lines = buff.split('\n') + for line in lines: + if line.startswith('#d_max='): + toks = line.split('=') + self.d_max = float(toks[1]) + elif line.startswith('#nfunc='): + toks = line.split('=') + self.nfunc = int(toks[1]) + self.out = np.zeros(self.nfunc) + self.cov = np.zeros([self.nfunc, self.nfunc]) + elif line.startswith('#alpha='): + toks = line.split('=') + self.alpha = float(toks[1]) + elif line.startswith('#chi2='): + toks = line.split('=') + self.chi2 = float(toks[1]) + elif line.startswith('#elapsed='): + toks = line.split('=') + self.elapsed = float(toks[1]) + elif line.startswith('#alpha_estimate='): + toks = line.split('=') + self.suggested_alpha = float(toks[1]) + elif line.startswith('#qmin='): + toks = line.split('=') + try: + self.q_min = float(toks[1]) + except: + self.q_min = None + elif line.startswith('#qmax='): + toks = line.split('=') + try: + self.q_max = float(toks[1]) + except: + self.q_max = None + elif line.startswith('#slit_height='): + toks = line.split('=') + self.slit_height = float(toks[1]) + elif line.startswith('#slit_width='): + toks = line.split('=') + self.slit_width = float(toks[1]) + elif line.startswith('#background='): + toks = line.split('=') + self.background = float(toks[1]) + elif line.startswith('#has_bck='): + toks = line.split('=') + self.est_bck = int(toks[1]) == 1 + + # Now read in the parameters + elif line.startswith('#C_'): + toks = line.split('=') + p = re.compile('#C_([0-9]+)') + m = p.search(toks[0]) + toks2 = toks[1].split('+-') + i = int(m.group(1)) + self.out[i] = float(toks2[0]) + + self.cov[i][i] = float(toks2[1]) + + except Exception as exc: + msg = "Invertor.from_file: corrupted file\n%s" % exc + raise RuntimeError(msg) + else: + msg = "Invertor.from_file: '%s' is not a file" % str(path) + raise RuntimeError(msg) diff --git a/sas/sascalc/pr/num_term.py b/sas/sascalc/pr/num_term.py new file mode 100755 index 000000000..06f1e5e4d --- /dev/null +++ b/sas/sascalc/pr/num_term.py @@ -0,0 +1,200 @@ +from __future__ import print_function, division + +import math +import numpy as np +import copy +import sys +import logging +from sas.sascalc.pr.invertor import Invertor + +logger = logging.getLogger(__name__) + +class NTermEstimator(object): + """ + """ + def __init__(self, invertor): + """ + """ + self.invertor = invertor + self.nterm_min = 10 + self.nterm_max = len(self.invertor.x) + if self.nterm_max > 50: + self.nterm_max = 50 + self.isquit_func = None + + self.osc_list = [] + self.err_list = [] + self.alpha_list = [] + self.mess_list = [] + self.dataset = [] + + def is_odd(self, n): + """ + """ + return bool(n % 2) + + def sort_osc(self): + """ + """ + #import copy + osc = copy.deepcopy(self.dataset) + lis = [] + for i in range(len(osc)): + osc.sort() + re = osc.pop(0) + lis.append(re) + return lis + + def median_osc(self): + """ + """ + osc = self.sort_osc() + dv = len(osc) + med = 0.5*dv + odd = self.is_odd(dv) + medi = 0 + for i in range(dv): + if odd: + medi = osc[int(med)] + else: + medi = osc[int(med) - 1] + return medi + + def get0_out(self): + """ + """ + inver = self.invertor + self.osc_list = [] + self.err_list = [] + self.alpha_list = [] + for k in range(self.nterm_min, self.nterm_max, 1): + if self.isquit_func is not None: + self.isquit_func() + best_alpha, message, _ = inver.estimate_alpha(k) + inver.alpha = best_alpha + inver.out, inver.cov = inver.lstsq(k) + osc = inver.oscillations(inver.out) + err = inver.get_pos_err(inver.out, inver.cov) + if osc > 10.0: + break + self.osc_list.append(osc) + self.err_list.append(err) + self.alpha_list.append(inver.alpha) + self.mess_list.append(message) + + new_osc1 = [] + new_osc2 = [] + new_osc3 = [] + flag9 = False + flag8 = False + for i in range(len(self.err_list)): + if self.err_list[i] <= 1.0 and self.err_list[i] >= 0.9: + new_osc1.append(self.osc_list[i]) + flag9 = True + if self.err_list[i] < 0.9 and self.err_list[i] >= 0.8: + new_osc2.append(self.osc_list[i]) + flag8 = True + if self.err_list[i] < 0.8 and self.err_list[i] >= 0.7: + new_osc3.append(self.osc_list[i]) + + if flag9: + self.dataset = new_osc1 + elif flag8: + self.dataset = new_osc2 + else: + self.dataset = new_osc3 + + return self.dataset + + def ls_osc(self): + """ + """ + # Generate data + self.get0_out() + med = self.median_osc() + + #TODO: check 1 + ls_osc = self.dataset + ls = [] + for i in range(len(ls_osc)): + if int(med) == int(ls_osc[i]): + ls.append(ls_osc[i]) + return ls + + def compare_err(self): + """ + """ + ls = self.ls_osc() + nt_ls = [] + for i in range(len(ls)): + r = ls[i] + n = self.osc_list.index(r) + 10 + nt_ls.append(n) + return nt_ls + + def num_terms(self, isquit_func=None): + """ + """ + try: + self.isquit_func = isquit_func + nts = self.compare_err() + div = len(nts) + tem = 0.5*div + if self.is_odd(div): + nt = nts[int(tem)] + else: + nt = nts[int(tem) - 1] + return nt, self.alpha_list[nt - 10], self.mess_list[nt - 10] + except: + #TODO: check the logic above and make sure it doesn't + # rely on the try-except. + return self.nterm_min, self.invertor.alpha, '' + + +#For testing +def load(path): + # Read the data from the data file + data_x = np.zeros(0) + data_y = np.zeros(0) + data_err = np.zeros(0) + scale = None + min_err = 0.0 + if path is not None: + input_f = open(path, 'r') + buff = input_f.read() + lines = buff.split('\n') + for line in lines: + try: + toks = line.split() + test_x = float(toks[0]) + test_y = float(toks[1]) + if len(toks) > 2: + err = float(toks[2]) + else: + if scale is None: + scale = 0.05 * math.sqrt(test_y) + #scale = 0.05/math.sqrt(y) + min_err = 0.01 * y + err = scale * math.sqrt(test_y) + min_err + #err = 0 + + data_x = np.append(data_x, test_x) + data_y = np.append(data_y, test_y) + data_err = np.append(data_err, err) + except: + logger.error(sys.exc_value) + + return data_x, data_y, data_err + + +if __name__ == "__main__": + invert = Invertor() + x, y, erro = load("test/Cyl_A_D102.txt") + invert.d_max = 102.0 + invert.nfunc = 10 + invert.x = x + invert.y = y + invert.err = erro + # Testing estimator + est = NTermEstimator(invert) + print(est.num_terms()) diff --git a/sas/sascalc/pr/release_notes.txt b/sas/sascalc/pr/release_notes.txt new file mode 100755 index 000000000..523c22bf2 --- /dev/null +++ b/sas/sascalc/pr/release_notes.txt @@ -0,0 +1,64 @@ +Release Notes +============= + +Package name: pr_inversion 0.2 + +1- Version 0.2.2 + - Release date: 7/4/2009 + - Minor fixes + + Version 0.2.1 + - Release date: 7/8/2008 + - Minor change to clone method + + Version 0.2.0 + + - Release date: 7/4/2008 + - Number of term estimator + - Faster implementation + + Version 0.1.0 + + - Release date: 5/28/2008 + - Provide python module sas.pr + +2- Downloading and Installing + + 2.1- System Requirements: + - Python version >= 2.4 should be running on the system + + 2.2- Installing: + - Get the code from svn://danse.us/sas/releases/pr_inversion-0.1.0 + - The following modules are required: + * matplotlib + * numpy + * scipy + +3- Known Issues + + 3.1- All systems: + - None + + 3.2- Windows: + - None + + 3.3- Linux: + - None + +4- Troubleshooting + + - None + +5- Frequently Asked Questions + + - None + +6- Other Resources + + - See: http://danse.chem.utk.edu/prview.html + + + + + + diff --git a/sas/sascalc/realspace/VolumeCanvas.py b/sas/sascalc/realspace/VolumeCanvas.py new file mode 100755 index 000000000..61552283f --- /dev/null +++ b/sas/sascalc/realspace/VolumeCanvas.py @@ -0,0 +1,792 @@ +#!/usr/bin/env python +""" Volume Canvas + + Simulation canvas for real-space simulation of SAS scattering intensity. + The user can create an arrangement of basic shapes and estimate I(q) and + I(q_x, q_y). Error estimates on the simulation are also available. + + Example: + + import sas.sascalc.realspace.VolumeCanvas as VolumeCanvas + canvas = VolumeCanvas.VolumeCanvas() + canvas.setParam('lores_density', 0.01) + + sphere = SphereDescriptor() + handle = canvas.addObject(sphere) + + output, error = canvas.getIqError(q=0.1) + output, error = canvas.getIq2DError(0.1, 0.1) + + or alternatively: + iq = canvas.run(0.1) + i2_2D = canvas.run([0.1, 1.57]) + +""" + +from sas.sascalc.calculator.BaseComponent import BaseComponent +from sas.sascalc.simulation.pointsmodelpy import pointsmodelpy +from sas.sascalc.simulation.geoshapespy import geoshapespy + + +import os.path, math + +class ShapeDescriptor(object): + """ + Class to hold the information about a shape + The descriptor holds a dictionary of parameters. + + Note: if shape parameters are accessed directly + from outside VolumeCanvas. The getPr method + should be called before evaluating I(q). + + """ + def __init__(self): + """ + Initialization + """ + ## Real space object + self.shapeObject = None + ## Parameters of the object + self.params = {} + self.params["center"] = [0, 0, 0] + # Orientation are angular offsets in degrees with respect to X, Y, Z + self.params["orientation"] = [0, 0, 0] + # Default to lores shape + self.params['is_lores'] = True + self.params['order'] = 0 + + def create(self): + """ + Create an instance of the shape + """ + # Set center + x0 = self.params["center"][0] + y0 = self.params["center"][1] + z0 = self.params["center"][2] + geoshapespy.set_center(self.shapeObject, x0, y0, z0) + + # Set orientation + x0 = self.params["orientation"][0] + y0 = self.params["orientation"][1] + z0 = self.params["orientation"][2] + geoshapespy.set_orientation(self.shapeObject, x0, y0, z0) + +class SphereDescriptor(ShapeDescriptor): + """ + Descriptor for a sphere + + The parameters are: + - radius [Angstroem] [default = 20 A] + - Contrast [A-2] [default = 1 A-2] + + """ + def __init__(self): + """ + Initialization + """ + ShapeDescriptor.__init__(self) + # Default parameters + self.params["type"] = "sphere" + # Radius of the sphere + self.params["radius"] = 20.0 + # Constrast parameter + self.params["contrast"] = 1.0 + + def create(self): + """ + Create an instance of the shape + @return: instance of the shape + """ + self.shapeObject = geoshapespy.new_sphere(\ + self.params["radius"]) + + ShapeDescriptor.create(self) + return self.shapeObject + +class CylinderDescriptor(ShapeDescriptor): + """ + Descriptor for a cylinder + Orientation: Default cylinder is along Y + + Parameters: + - Length [default = 40 A] + - Radius [default = 10 A] + - Contrast [default = 1 A-2] + """ + def __init__(self): + """ + Initialization + """ + ShapeDescriptor.__init__(self) + # Default parameters + self.params["type"] = "cylinder" + # Length of the cylinder + self.params["length"] = 40.0 + # Radius of the cylinder + self.params["radius"] = 10.0 + # Constrast parameter + self.params["contrast"] = 1.0 + + def create(self): + """ + Create an instance of the shape + @return: instance of the shape + """ + self.shapeObject = geoshapespy.new_cylinder(\ + self.params["radius"], self.params["length"]) + + ShapeDescriptor.create(self) + return self.shapeObject + + +class EllipsoidDescriptor(ShapeDescriptor): + """ + Descriptor for an ellipsoid + + Parameters: + - Radius_x along the x-axis [default = 30 A] + - Radius_y along the y-axis [default = 20 A] + - Radius_z along the z-axis [default = 10 A] + - contrast [default = 1 A-2] + """ + def __init__(self): + """ + Initialization + """ + ShapeDescriptor.__init__(self) + # Default parameters + self.params["type"] = "ellipsoid" + self.params["radius_x"] = 30.0 + self.params["radius_y"] = 20.0 + self.params["radius_z"] = 10.0 + self.params["contrast"] = 1.0 + + def create(self): + """ + Create an instance of the shape + @return: instance of the shape + """ + self.shapeObject = geoshapespy.new_ellipsoid(\ + self.params["radius_x"], self.params["radius_y"], + self.params["radius_z"]) + + ShapeDescriptor.create(self) + return self.shapeObject + +class HelixDescriptor(ShapeDescriptor): + """ + Descriptor for an helix + + Parameters: + -radius_helix: the radius of the helix [default = 10 A] + -radius_tube: radius of the "tube" that forms the helix [default = 3 A] + -pitch: distance between two consecutive turns of the helix [default = 34 A] + -turns: number of turns of the helix [default = 3] + -contrast: contrast parameter [default = 1 A-2] + """ + def __init__(self): + """ + Initialization + """ + ShapeDescriptor.__init__(self) + # Default parameters + self.params["type"] = "singlehelix" + self.params["radius_helix"] = 10.0 + self.params["radius_tube"] = 3.0 + self.params["pitch"] = 34.0 + self.params["turns"] = 3.0 + self.params["contrast"] = 1.0 + + def create(self): + """ + Create an instance of the shape + @return: instance of the shape + """ + self.shapeObject = geoshapespy.new_singlehelix(\ + self.params["radius_helix"], self.params["radius_tube"], + self.params["pitch"], self.params["turns"]) + + ShapeDescriptor.create(self) + return self.shapeObject + +class PDBDescriptor(ShapeDescriptor): + """ + Descriptor for a PDB set of points + + Parameter: + - file = name of the PDB file + """ + def __init__(self, filename): + """ + Initialization + @param filename: name of the PDB file to load + """ + ShapeDescriptor.__init__(self) + # Default parameters + self.params["type"] = "pdb" + self.params["file"] = filename + self.params['is_lores'] = False + + def create(self): + """ + Create an instance of the shape + @return: instance of the shape + """ + self.shapeObject = pointsmodelpy.new_pdbmodel() + pointsmodelpy.pdbmodel_add(self.shapeObject, self.params['file']) + + #ShapeDescriptor.create(self) + return self.shapeObject + +# Define a dictionary for the shape until we find +# a better way to create them +shape_dict = {'sphere':SphereDescriptor, + 'cylinder':CylinderDescriptor, + 'ellipsoid':EllipsoidDescriptor, + 'singlehelix':HelixDescriptor} + +class VolumeCanvas(BaseComponent): + """ + Class representing an empty space volume to add + geometrical object to. + + For 1D I(q) simulation, getPr() is called internally for the + first call to getIq(). + + """ + + def __init__(self): + """ + Initialization + """ + BaseComponent.__init__(self) + + ## Maximum value of q reachable + self.params['q_max'] = 0.1 + self.params['lores_density'] = 0.1 + self.params['scale'] = 1.0 + self.params['background'] = 0.0 + + self.lores_model = pointsmodelpy.new_loresmodel(self.params['lores_density']) + self.complex_model = pointsmodelpy.new_complexmodel() + self.shapes = {} + self.shapecount = 0 + self.points = None + self.npts = 0 + self.hasPr = False + + def _model_changed(self): + """ + Reset internal data members to reflect the fact that the + real-space model has changed + """ + self.hasPr = False + self.points = None + + def addObject(self, shapeDesc, id=None): + """ + Adds a real-space object to the canvas. + + @param shapeDesc: object to add to the canvas [ShapeDescriptor] + @param id: string handle for the object [string] [optional] + @return: string handle for the object + """ + # If the handle is not provided, create one + if id is None: + id = shapeDesc.params["type"]+str(self.shapecount) + + # Self the order number + shapeDesc.params['order'] = self.shapecount + # Store the shape in a dictionary entry associated + # with the handle + self.shapes[id] = shapeDesc + self.shapecount += 1 + + # model changed, need to recalculate P(r) + self._model_changed() + + return id + + + def add(self, shape, id=None): + """ + The intend of this method is to eventually be able to use it + as a factory for the canvas and unify the simulation with the + analytical solutions. For instance, if one adds a cylinder and + it is the only shape on the canvas, the analytical solution + could be called. If multiple shapes are involved, then + simulation has to be performed. + + This function is deprecated, use addObject(). + + @param shape: name of the object to add to the canvas [string] + @param id: string handle for the object [string] [optional] + @return: string handle for the object + """ + # If the handle is not provided, create one + if id is None: + id = "shape"+str(self.shapecount) + + # shapeDesc = ShapeDescriptor(shape.lower()) + if shape.lower() in shape_dict: + shapeDesc = shape_dict[shape.lower()]() + elif os.path.isfile(shape): + # A valid filename was supplier, create a PDB object + shapeDesc = PDBDescriptor(shape) + else: + raise ValueError("VolumeCanvas.add: Unknown shape %s" % shape) + + return self.addObject(shapeDesc, id) + + def delete(self, id): + """ + Delete a shape. The ID for the shape is required. + @param id: string handle for the object [string] [optional] + """ + + if id in self.shapes: + del self.shapes[id] + else: + raise KeyError("VolumeCanvas.delete: could not find shape ID") + + # model changed, need to recalculate P(r) + self._model_changed() + + + def setParam(self, name, value): + """ + Function to set the value of a parameter. + Both VolumeCanvas parameters and shape parameters + are accessible. + + Note: if shape parameters are accessed directly + from outside VolumeCanvas. The getPr method + should be called before evaluating I(q). + + TODO: implemented a check method to protect + against that. + + @param name: name of the parameter to change + @param value: value to give the parameter + """ + + # Lowercase for case insensitivity + name = name.lower() + + # Look for shape access + toks = name.split('.') + + # If a shape identifier was given, look the shape up + # in the dictionary + if len(toks): + if toks[0] in self.shapes: + # The shape was found, now look for the parameter + if toks[1] in self.shapes[toks[0]].params: + # The parameter was found, now change it + self.shapes[toks[0]].params[toks[1]] = value + self._model_changed() + else: + raise ValueError("Could not find parameter %s" % name) + else: + raise ValueError("Could not find shape %s" % toks[0]) + + else: + # If we are not accessing the parameters of a + # shape, see if the parameter is part of this object + BaseComponent.setParam(self, name, value) + self._model_changed() + + def getParam(self, name): + """ + @param name: name of the parameter to change + """ + #TODO: clean this up + + # Lowercase for case insensitivity + name = name.lower() + + # Look for sub-model access + toks = name.split('.') + if len(toks) == 1: + try: + value = self.params[toks[0]] + except KeyError: + raise ValueError("VolumeCanvas.getParam: Could not find" + " %s" % name) + if isinstance(value, ShapeDescriptor): + raise ValueError("VolumeCanvas.getParam: Cannot get parameter" + " value.") + else: + return value + + elif len(toks) == 2: + try: + shapeinstance = self.shapes[toks[0]] + except KeyError: + raise ValueError("VolumeCanvas.getParam: Could not find " + "%s" % name) + + if not toks[1] in shapeinstance.params: + raise ValueError("VolumeCanvas.getParam: Could not find " + "%s" % name) + + return shapeinstance.params[toks[1]] + + else: + raise ValueError("VolumeCanvas.getParam: Could not find %s" % name) + + def getParamList(self, shapeid=None): + """ + return a full list of all available parameters from + self.params.keys(). If a key in self.params is a instance + of ShapeDescriptor, extend the return list to: + [param1,param2,shapeid.param1,shapeid.param2.......] + + If shapeid is provided, return the list of parameters that + belongs to that shape id only : [shapeid.param1, shapeid.param2...] + """ + + param_list = [] + if shapeid is None: + for key1 in self.params: + #value1 = self.params[key1] + param_list.append(key1) + for key2 in self.shapes: + value2 = self.shapes[key2] + header = key2 + '.' + for key3 in value2.params: + fullname = header + key3 + param_list.append(fullname) + + else: + if not shapeid in self.shapes: + raise ValueError("VolumeCanvas: getParamList: Could not find " + "%s" % shapeid) + + header = shapeid + '.' + param_list = [header + param for param in self.shapes[shapeid].params] + return param_list + + def getShapeList(self): + """ + Return a list of the shapes + """ + return self.shapes.keys() + + def _addSingleShape(self, shapeDesc): + """ + create shapeobject based on shapeDesc + @param shapeDesc: shape description + """ + # Create the object model + shapeDesc.create() + + if shapeDesc.params['is_lores']: + # Add the shape to the lores_model + pointsmodelpy.lores_add(self.lores_model, + shapeDesc.shapeObject, + shapeDesc.params['contrast']) + + def _createVolumeFromList(self): + """ + Create a new lores model with all the shapes in our internal list + Whenever we change a parameter of a shape, we have to re-create + the whole thing. + + Items with higher 'order' number take precedence for regions + of space that are shared with other objects. Points in the + overlapping region belonging to objects with lower 'order' + will be ignored. + + Items are added in decreasing 'order' number. + The item with the highest 'order' will be added *first*. + [That conventions is prescribed by the realSpaceModeling module] + """ + + # Create empty model + self.lores_model = \ + pointsmodelpy.new_loresmodel(self.params['lores_density']) + + # Create empty complex model + self.complex_model = pointsmodelpy.new_complexmodel() + + # Order the object first + obj_list = [] + + for shape in self.shapes: + order = self.shapes[shape].params['order'] + # find where to place it in the list + stored = False + + for i in range(len(obj_list)): + if obj_list[i][0] > order: + obj_list.insert(i, [order, shape]) + stored = True + break + + if not stored: + obj_list.append([order, shape]) + + # Add each shape + len_list = len(obj_list) + for i in range(len_list-1, -1, -1): + shapedesc = self.shapes[obj_list[i][1]] + self._addSingleShape(shapedesc) + + return 0 + + def getPr(self): + """ + Calculate P(r) from the objects on the canvas. + This method should always be called after the shapes + on the VolumeCanvas have changed. + + @return: calculation output flag + """ + # To find a complete example of the correct call order: + # In LORES2, in actionclass.py, method CalculateAction._get_iq() + + # If there are not shapes, do nothing + if len(self.shapes) == 0: + self._model_changed() + return 0 + + # generate space filling points from shape list + self._createVolumeFromList() + + self.points = pointsmodelpy.new_point3dvec() + + pointsmodelpy.complexmodel_add(self.complex_model, + self.lores_model, "LORES") + for shape in self.shapes: + if not self.shapes[shape].params['is_lores']: + pointsmodelpy.complexmodel_add(self.complex_model, + self.shapes[shape].shapeObject, "PDB") + + #pointsmodelpy.get_lorespoints(self.lores_model, self.points) + self.npts = pointsmodelpy.get_complexpoints(self.complex_model, self.points) + + # expecting the rmax is a positive float or 0. The maximum distance. + #rmax = pointsmodelpy.get_lores_pr(self.lores_model, self.points) + + rmax = pointsmodelpy.get_complex_pr(self.complex_model, self.points) + self.hasPr = True + + return rmax + + def run(self, q=0): + """ + Returns the value of I(q) for a given q-value + @param q: q-value ([float] or [list]) ([A-1] or [[A-1], [rad]]) + @return: I(q) [float] [cm-1] + """ + # Check for 1D q length + if q.__class__.__name__ == 'int' \ + or q.__class__.__name__ == 'float': + return self.getIq(q) + # Check for 2D q-value + elif q.__class__.__name__ == 'list': + # Compute (Qx, Qy) from (Q, phi) + # Phi is in radian and Q-values are in A-1 + qx = q[0]*math.cos(q[1]) + qy = q[0]*math.sin(q[1]) + return self.getIq2D(qx, qy) + # Through an exception if it's not a + # type we recognize + else: + raise ValueError("run(q): bad type for q") + + def runXY(self, q=0): + """ + Standard run command for the canvas. + Redirects to the correct method + according to the input type. + @param q: q-value [float] or [list] [A-1] + @return: I(q) [float] [cm-1] + """ + # Check for 1D q length + if q.__class__.__name__ == 'int' \ + or q.__class__.__name__ == 'float': + return self.getIq(q) + # Check for 2D q-value + elif q.__class__.__name__ == 'list': + return self.getIq2D(q[0], q[1]) + # Through an exception if it's not a + # type we recognize + else: + raise ValueError("runXY(q): bad type for q") + + def _create_modelObject(self): + """ + Create the simulation model obejct from the list + of shapes. + + This method needs to be called each time a parameter + changes because of the way the underlying library + was (badly) written. It is impossible to change a + parameter, or remove a shape without having to + refill the space points. + + TODO: improve that. + """ + # To find a complete example of the correct call order: + # In LORES2, in actionclass.py, method CalculateAction._get_iq() + + # If there are not shapes, do nothing + if len(self.shapes) == 0: + self._model_changed() + return 0 + + # generate space filling points from shape list + self._createVolumeFromList() + + self.points = pointsmodelpy.new_point3dvec() + + pointsmodelpy.complexmodel_add(self.complex_model, + self.lores_model, "LORES") + for shape in self.shapes: + if not self.shapes[shape].params['is_lores']: + pointsmodelpy.complexmodel_add(self.complex_model, + self.shapes[shape].shapeObject, "PDB") + + #pointsmodelpy.get_lorespoints(self.lores_model, self.points) + self.npts = pointsmodelpy.get_complexpoints(self.complex_model, self.points) + + + def getIq2D(self, qx, qy): + """ + Returns simulate I(q) for given q_x and q_y values. + @param qx: q_x [A-1] + @param qy: q_y [A-1] + @return: I(q) [cm-1] + """ + + # If this is the first simulation call, we need to generate the + # space points + if self.points is None: + self._create_modelObject() + + # Protect against empty model + if self.points is None: + return 0 + + # Evalute I(q) + norm = 1.0e8/self.params['lores_density']*self.params['scale'] + return norm*pointsmodelpy.get_complex_iq_2D(self.complex_model, self.points, qx, qy)\ + + self.params['background'] + + def write_pr(self, filename): + """ + Write P(r) to an output file + @param filename: file name for P(r) output + """ + if not self.hasPr: + self.getPr() + + pointsmodelpy.outputPR(self.complex_model, filename) + + def getPrData(self): + """ + Write P(r) to an output file + @param filename: file name for P(r) output + """ + if not self.hasPr: + self.getPr() + + return pointsmodelpy.get_pr(self.complex_model) + + def getIq(self, q): + """ + Returns the value of I(q) for a given q-value + + This method should remain internal to the class + and the run() method should be used instead. + + @param q: q-value [float] + @return: I(q) [float] + """ + + if not self.hasPr: + self.getPr() + + # By dividing by the density instead of the actuall V/N, + # we have an uncertainty of +-1 on N because the number + # of points chosen for the simulation is int(density*volume). + # Propagation of error gives: + # delta(1/density^2) = 2*(1/density^2)/N + # where N is stored in self.npts + + norm = 1.0e8/self.params['lores_density']*self.params['scale'] + #return norm*pointsmodelpy.get_lores_i(self.lores_model, q) + return norm*pointsmodelpy.get_complex_i(self.complex_model, q)\ + + self.params['background'] + + def getError(self, q): + """ + Returns the error of I(q) for a given q-value + @param q: q-value [float] + @return: I(q) [float] + """ + + if not self.hasPr: + self.getPr() + + # By dividing by the density instead of the actual V/N, + # we have an uncertainty of +-1 on N because the number + # of points chosen for the simulation is int(density*volume). + # Propagation of error gives: + # delta(1/density^2) = 2*(1/density^2)/N + # where N is stored in self.npts + + norm = 1.0e8/self.params['lores_density']*self.params['scale'] + #return norm*pointsmodelpy.get_lores_i(self.lores_model, q) + return norm*pointsmodelpy.get_complex_i_error(self.complex_model, q)\ + + self.params['background'] + + def getIqError(self, q): + """ + Return the simulated value along with its estimated + error for a given q-value + + Propagation of errors is used to evaluate the + uncertainty. + + @param q: q-value [float] + @return: mean, error [float, float] + """ + val = self.getIq(q) + # Simulation error (statistical) + err = self.getError(q) + # Error on V/N + simerr = 2*val/self.npts + return val, err+simerr + + def getIq2DError(self, qx, qy): + """ + Return the simulated value along with its estimated + error for a given q-value + + Propagation of errors is used to evaluate the + uncertainty. + + @param qx: qx-value [float] + @param qy: qy-value [float] + @return: mean, error [float, float] + """ + self._create_modelObject() + + norm = 1.0e8/self.params['lores_density']*self.params['scale'] + val = norm*pointsmodelpy.get_complex_iq_2D(self.complex_model, self.points, qx, qy)\ + + self.params['background'] + + # Simulation error (statistical) + norm = 1.0e8/self.params['lores_density']*self.params['scale'] \ + * math.pow(self.npts/self.params['lores_density'], 1.0/3.0)/self.npts + err = norm*pointsmodelpy.get_complex_iq_2D_err(self.complex_model, self.points, qx, qy) + # Error on V/N + simerr = 2*val/self.npts + + # The error used for the position is over-simplified. + # The actual error was empirically found to be about + # an order of magnitude larger. + return val, 10.0*err+simerr diff --git a/sas/sascalc/realspace/__init__.py b/sas/sascalc/realspace/__init__.py new file mode 100755 index 000000000..4ed499120 --- /dev/null +++ b/sas/sascalc/realspace/__init__.py @@ -0,0 +1,79 @@ +""" + Real-Space Modeling for SAS +""" +## \mainpage Real-Space Modeling for SAS +# +# \section intro_sec Introduction +# This module provides SAS scattering intensity simulation +# based on real-space modeling. +# +# Documentation can be found here: +# http://danse.us/trac/sas/wiki/RealSpaceModeling +# +# \section install_sec Installation +# +# \subsection obtain Obtaining the Code +# +# The code is available here: +# \verbatim +#$ svn co svn://danse.us/sas/realSpaceModeling +#$ svn co svn://danse.us/sas/RealSpaceTopLayer +# \endverbatim +# +# \subsection depends External Dependencies +# None +# +# \subsection build Building the code +# The standard python package can be built with distutils. +# From the realSpaceModeling directory: +# \verbatim +#$ python setup.py install +# \endverbatim +# +# From the RealSpaceTopLayer/src directory: +# \verbatim +#$ python setup.py install +# \endverbatim +# +# \section overview_sec Package Overview +# +# \subsection class Class Diagram: +# \image html real-space-class-diagram.png +# +# \subsection behav Behavior Enumeration: +# \image html enum.png +# +# \subsection Tutorial +# To create an empty canvas: +# \verbatim +#import sas.realspace.VolumeCanvas as VolumeCanvas +# canvas = VolumeCanvas.VolumeCanvas() +# \endverbatim +# +# To set the simulation point density: +# \verbatim +# canvas.setParam('lores_density', 0.01) +# \endverbatim +# +# To add an object: +# \verbatim +# sphare = VolumeCanvas.SphereDescriptor() +# handle = canvas.addObject(sphere) +# canvas.setParam('%s.radius' % handle, 15.0) +# \endverbatim +# +# To evaluate the scattering intensity at a given q: +# \verbatim +# output, error = canvas.getIqError(q=0.1) +# output, error = canvas.getIq2DError(qx=0.1, qy=0.1) +# \endverbatim +# +# To get the value of a parameter: +# \verbatim +# canvas.getParam('scale') +# \endverbatim +# +# Examples are available as unit tests under sas.realspace.test. +# +# \section help_sec Contact Info +# Code and Documentation by Jing Zhou as part of the DANSE project. diff --git a/sas/sascalc/simulation/__init__.py b/sas/sascalc/simulation/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/sas/sascalc/simulation/analmodelpy/Make.mm b/sas/sascalc/simulation/analmodelpy/Make.mm new file mode 100755 index 000000000..c6afae425 --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/Make.mm @@ -0,0 +1,47 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +PROJECT = analmodelpy + +# directory structure + +BUILD_DIRS = \ + libanalmodelpy \ + analmodelpymodule \ + analmodelpy \ + +OTHER_DIRS = \ + tests \ + examples + +RECURSE_DIRS = $(BUILD_DIRS) $(OTHER_DIRS) + +#-------------------------------------------------------------------------- +# + +all: + BLD_ACTION="all" $(MM) recurse + +distclean:: + BLD_ACTION="distclean" $(MM) recurse + +clean:: + BLD_ACTION="clean" $(MM) recurse + +tidy:: + BLD_ACTION="tidy" $(MM) recurse + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/analmodelpy/analmodelpy/Make.mm b/sas/sascalc/simulation/analmodelpy/analmodelpy/Make.mm new file mode 100755 index 000000000..6b5732ddb --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/analmodelpy/Make.mm @@ -0,0 +1,35 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +PROJECT = analmodelpy +PACKAGE = analmodelpy + +#-------------------------------------------------------------------------- +# + +all: export + + +#-------------------------------------------------------------------------- +# +# export + +EXPORT_PYTHON_MODULES = \ + __init__.py + + +export:: export-python-modules + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/analmodelpy/analmodelpy/__init__.py b/sas/sascalc/simulation/analmodelpy/analmodelpy/__init__.py new file mode 100755 index 000000000..5c8ebf583 --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/analmodelpy/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +def copyright(): + return "analmodelpy pyre module: Copyright (c) 1998-2005 Michael A.G. Aivazis"; + + +# version +__id__ = "$Id$" + +# End of file diff --git a/sas/sascalc/simulation/analmodelpy/analmodelpymodule/Make.mm b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/Make.mm new file mode 100755 index 000000000..b5c8f35c3 --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/Make.mm @@ -0,0 +1,32 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +PROJECT = SASsimulation +PACKAGE = analmodelpymodule +MODULE = analmodelpy + +include std-pythonmodule.def +include local.def + +PROJ_CXX_SRCLIB = -lanalmodelpy + +PROJ_SRCS = \ + bindings.cc \ + exceptions.cc \ + misc.cc + + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/analmodelpy/analmodelpymodule/analmodelpymodule.cc b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/analmodelpymodule.cc new file mode 100755 index 000000000..a55a1b852 --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/analmodelpymodule.cc @@ -0,0 +1,52 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +//#include + +#include + +#include "exceptions.h" +#include "bindings.h" + + +char pyanalmodelpy_module__doc__[] = ""; + +// Initialization function for the module (*must* be called initanalmodelpy) +extern "C" +void +initanalmodelpy() +{ + // create the module and add the functions + PyObject * m = Py_InitModule4( + "analmodelpy", pyanalmodelpy_methods, + pyanalmodelpy_module__doc__, 0, PYTHON_API_VERSION); + + // get its dictionary + PyObject * d = PyModule_GetDict(m); + + // check for errors + if (PyErr_Occurred()) { + Py_FatalError("can't initialize module analmodelpy"); + } + + // install the module exceptions + pyanalmodelpy_runtimeError = PyErr_NewException("analmodelpy.runtime", 0, 0); + PyDict_SetItemString(d, "RuntimeException", pyanalmodelpy_runtimeError); + + return; +} + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/analmodelpy/analmodelpymodule/bindings.cc b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/bindings.cc new file mode 100755 index 000000000..7c5fe7c9b --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/bindings.cc @@ -0,0 +1,44 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +//#include +#include + +#include "bindings.h" + +#include "misc.h" // miscellaneous methods + +// the method table + +struct PyMethodDef pyanalmodelpy_methods[] = { + + // new analmodel + {pyanalmodelpy_new_analmodel__name__, pyanalmodelpy_new_analmodel, + METH_VARARGS, pyanalmodelpy_new_analmodel__doc__}, + + //analmodel method: CalculateIQ + {pyanalmodelpy_CalculateIQ__name__, pyanalmodelpy_CalculateIQ, + METH_VARARGS, pyanalmodelpy_CalculateIQ__doc__}, + + {pyanalmodelpy_copyright__name__, pyanalmodelpy_copyright, + METH_VARARGS, pyanalmodelpy_copyright__doc__}, + + +// Sentinel + {0, 0} +}; + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/analmodelpy/analmodelpymodule/bindings.h b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/bindings.h new file mode 100755 index 000000000..6ca3ec062 --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/bindings.h @@ -0,0 +1,26 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +#if !defined(pyanalmodelpy_bindings_h) +#define pyanalmodelpy_bindings_h + +// the method table + +extern struct PyMethodDef pyanalmodelpy_methods[]; + +#endif + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/analmodelpy/analmodelpymodule/exceptions.cc b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/exceptions.cc new file mode 100755 index 000000000..38e9df6bc --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/exceptions.cc @@ -0,0 +1,22 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +//#include +#include + +PyObject *pyanalmodelpy_runtimeError = 0; + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/analmodelpy/analmodelpymodule/exceptions.h b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/exceptions.h new file mode 100755 index 000000000..119bff582 --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/exceptions.h @@ -0,0 +1,26 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +#if !defined(pyanalmodelpy_exceptions_h) +#define pyanalmodelpy_exceptions_h + +// exceptions + +extern PyObject *pyanalmodelpy_runtimeError; + +#endif + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/analmodelpy/analmodelpymodule/local.def b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/local.def new file mode 100755 index 000000000..6a12ef933 --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/local.def @@ -0,0 +1,25 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +# C++ + + PROJ_CXX_INCLUDES = ../libanalmodelpy \ + ../../iqPy/libiqPy \ + ../../iqPy/libiqPy/tnt \ + ../../geoshapespy/libgeoshapespy \ + ../../pointsmodelpy/libpointsmodelpy + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/analmodelpy/analmodelpymodule/misc.cc b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/misc.cc new file mode 100755 index 000000000..f825ce158 --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/misc.cc @@ -0,0 +1,91 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +//#include +#include + +#include "misc.h" +#include "analytical_model.h" +#include "geo_shape.h" +#include "iq.h" + +// copyright + +char pyanalmodelpy_copyright__doc__[] = ""; +char pyanalmodelpy_copyright__name__[] = "copyright"; + +static char pyanalmodelpy_copyright_note[] = + "analmodelpy python module: Copyright (c) 1998-2005 Michael A.G. Aivazis"; + + +PyObject * pyanalmodelpy_copyright(PyObject *, PyObject *) +{ + return Py_BuildValue("s", pyanalmodelpy_copyright_note); +} + +//extern char pyanalmodelpy_GetRadius__name__[] = ""; +//extern char pyanalmodelpy_GetRadius__doc__[] = "GetRadius"; +//PyObject * pyanalmodelpy_GetRadius(PyObject *, PyObject *args){ +// double r = Sphere::GetRadius(); +// +// return PyBuildValue("d", r); +//} + + + +// analytical_model class constructor AnalyticalModel(Sphere &) + +char pyanalmodelpy_new_analmodel__doc__[] = ""; +char pyanalmodelpy_new_analmodel__name__[] = "new_analmodel"; + +PyObject * pyanalmodelpy_new_analmodel(PyObject *, PyObject *args) +{ + PyObject *pyshape = 0; + int ok = PyArg_ParseTuple(args, "O", &pyshape); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pyshape); + + GeoShape *shape = static_cast(temp); + + AnalyticalModel *newanal = new AnalyticalModel(*shape); + + return PyCObject_FromVoidPtr(newanal, NULL); +} + +//AnalyticalModel method: CalculateIQ(IQ *) +char pyanalmodelpy_CalculateIQ__doc__[] = ""; +char pyanalmodelpy_CalculateIQ__name__[] = "CalculateIQ"; + +PyObject * pyanalmodelpy_CalculateIQ(PyObject *, PyObject *args) +{ + PyObject *pyanal = 0, *pyiq = 0; + int ok = PyArg_ParseTuple(args, "OO", &pyanal, &pyiq); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pyanal); + void *temp2 = PyCObject_AsVoidPtr(pyiq); + + AnalyticalModel * thisanal = static_cast(temp); + IQ * thisiq = static_cast(temp2); + + thisanal->CalculateIQ(thisiq); + + return Py_BuildValue("i",0); + +} + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/analmodelpy/analmodelpymodule/misc.h b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/misc.h new file mode 100755 index 000000000..f8ab124c3 --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/analmodelpymodule/misc.h @@ -0,0 +1,47 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +#if !defined(pyanalmodelpy_misc_h) +#define pyanalmodelpy_misc_h + +// copyright +extern char pyanalmodelpy_copyright__name__[]; +extern char pyanalmodelpy_copyright__doc__[]; +extern "C" +PyObject * pyanalmodelpy_copyright(PyObject *, PyObject *); + +extern char pyanalmodelpy_GetRadius__name__[]; +extern char pyanalmodelpy_GetRadius__doc__[]; +extern "C" +PyObject * pyanalmodelpy_GetRadius(PyObject *, PyObject *); + + +// Analytical Model constructor +extern char pyanalmodelpy_new_analmodel__name__[]; +extern char pyanalmodelpy_new_analmodel__doc__[]; +extern "C" +PyObject * pyanalmodelpy_new_analmodel(PyObject *, PyObject *); + +// AnalyticalModel method: CalculateIQ(IQ *) +extern char pyanalmodelpy_CalculateIQ__name__[]; +extern char pyanalmodelpy_CalculateIQ__doc__[]; +extern "C" +PyObject * pyanalmodelpy_CalculateIQ(PyObject *, PyObject *); + + +#endif + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/analmodelpy/examples/Make.mm b/sas/sascalc/simulation/analmodelpy/examples/Make.mm new file mode 100755 index 000000000..11abd568a --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/examples/Make.mm @@ -0,0 +1,32 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +PROJECT = analmodelpy +PACKAGE = examples + +#-------------------------------------------------------------------------- +# + +all: clean + +release: clean + cvs release . + +update: clean + cvs update . + + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/analmodelpy/libanalmodelpy/Make.mm b/sas/sascalc/simulation/analmodelpy/libanalmodelpy/Make.mm new file mode 100755 index 000000000..bac074149 --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/libanalmodelpy/Make.mm @@ -0,0 +1,66 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +include local.def + +PROJECT = analmodelpy +PACKAGE = libanalmodelpy + +PROJ_SAR = $(BLD_LIBDIR)/$(PACKAGE).$(EXT_SAR) +PROJ_DLL = $(BLD_BINDIR)/$(PACKAGE).$(EXT_SO) +PROJ_TMPDIR = $(BLD_TMPDIR)/$(PROJECT)/$(PACKAGE) +PROJ_CLEAN += $(PROJ_SAR) $(PROJ_DLL) +PROJ_LIBRARIES = -lgeoshpespy -liqPy +#PROJ_LIBRARIES = -L/home/jingz1/dv/tools/pythia-0.8/lib -lgeoshpespy -liqPy + +PROJ_SRCS = \ + analytical_model.cc + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# build the library + +all: $(PROJ_SAR) export + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +ifeq (Win32, ${findstring Win32, $(PLATFORM_ID)}) + +# build the shared object +$(PROJ_SAR): product_dirs $(PROJ_OBJS) + $(CXX) $(LCXXFLAGS) -o $(PROJ_DLL) \ + -Wl,--out-implib=$(PROJ_SAR) $(PROJ_OBJS) + +# export +export:: export-headers export-libraries export-binaries + +else + +# build the shared object +$(PROJ_SAR): product_dirs $(PROJ_OBJS) + $(CXX) $(LCXXFLAGS) -o $(PROJ_SAR) $(PROJ_OBJS) $(PROJ_LIBRARIES) + +# export +export:: export-headers export-libraries + +endif + +EXPORT_HEADERS = \ + analytical_model.h \ + sas_model.h + +EXPORT_LIBS = $(PROJ_SAR) +EXPORT_BINS = $(PROJ_DLL) + +# version +# $Id$ + +# +# End of file diff --git a/sas/sascalc/simulation/analmodelpy/libanalmodelpy/analytical_model.cc b/sas/sascalc/simulation/analmodelpy/libanalmodelpy/analytical_model.cc new file mode 100755 index 000000000..08a716927 --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/libanalmodelpy/analytical_model.cc @@ -0,0 +1,29 @@ +/** \file analytical_model.cc */ + +#include "sphere.h" +#include "hollow_sphere.h" +#include "analytical_model.h" + +AnalyticalModel::AnalyticalModel(const GeoShape &geo_shape) +{ + switch (geo_shape.GetShapeType()){ + case SPHERE: + shape_ = new Sphere(static_cast(geo_shape)); + break; + case HOLLOWSPHERE: + shape_ = new HollowSphere(static_cast(geo_shape)); + break; + case CYLINDER: + break; + } +} + +AnalyticalModel::~AnalyticalModel() +{ + if (shape_ != NULL) delete shape_; +} + +void AnalyticalModel::CalculateIQ(IQ * iq) +{ + shape_->GetFormFactor(iq); +} diff --git a/sas/sascalc/simulation/analmodelpy/libanalmodelpy/analytical_model.h b/sas/sascalc/simulation/analmodelpy/libanalmodelpy/analytical_model.h new file mode 100755 index 000000000..2369f6463 --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/libanalmodelpy/analytical_model.h @@ -0,0 +1,23 @@ +/** \file analytical_model.h class AnalyticalModel:SASModel */ + +#ifndef ANALYTICALMODEL_H +#define ANALYTICALMODEL_H + +#include "sas_model.h" +#include "geo_shape.h" +#include "iq.h" + +class AnalyticalModel : public SASModel{ + + public: + AnalyticalModel(const GeoShape &); + ~AnalyticalModel(); + + void CalculateIQ(IQ *); + + private: + AnalyticalModel(); + GeoShape *shape_; +}; + +#endif diff --git a/sas/sascalc/simulation/analmodelpy/libanalmodelpy/local.def b/sas/sascalc/simulation/analmodelpy/libanalmodelpy/local.def new file mode 100755 index 000000000..29f02d383 --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/libanalmodelpy/local.def @@ -0,0 +1,31 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +# +# Local (project) definitions +# + +# C++ + +PROJ_CXX_INCLUDES += ../../iqPy/libiqPy/tnt \ + ../../iqPy/libiqPy \ + ../../geoshapespy/libgeoshapespy \ + ../../pointsmodelpy/libpointsmodelpy + + PROJ_CXX_FLAGS += $(CXX_SOFLAGS) + PROJ_LCXX_FLAGS += $(LCXX_SARFLAGS) + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/analmodelpy/libanalmodelpy/sas_model.h b/sas/sascalc/simulation/analmodelpy/libanalmodelpy/sas_model.h new file mode 100755 index 000000000..923c49ba4 --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/libanalmodelpy/sas_model.h @@ -0,0 +1,15 @@ +/** \file sas_model.h class SASModel virtual base class */ + +#ifndef SASMODEL_H +#define SASMODEL_H + +#include "iq.h" + +class SASModel{ + public: + + virtual void CalculateIQ(IQ *iq) = 0; + +}; + +#endif diff --git a/sas/sascalc/simulation/analmodelpy/tests/Make.mm b/sas/sascalc/simulation/analmodelpy/tests/Make.mm new file mode 100755 index 000000000..58e9a00c2 --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/tests/Make.mm @@ -0,0 +1,52 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +PROJ_CXX_INCLUDES += ../../iqPy/libiqPy \ + ../../iqPy/libiqPy/tnt \ + ../libanalmodelpy \ + ../../geoshapespy/libgeoshapespy + +PROJECT = SASsimulation +PACKAGE = tests + +PROJ_CLEAN += $(PROJ_CPPTESTS) + +PROJ_PYTESTS = testanal_model.py +PROJ_CPPTESTS = testanalytical_model +PROJ_TESTS = $(PROJ_PYTESTS) $(PROJ_CPPTESTS) +PROJ_LIBRARIES = -L$(BLD_LIBDIR) -lanalmodelpy -liqPy -lgeoshapespy + + +#-------------------------------------------------------------------------- +# + +all: $(PROJ_TESTS) + +test: + for test in $(PROJ_TESTS) ; do $${test}; done + +release: tidy + cvs release . + +update: clean + cvs update . + +#-------------------------------------------------------------------------- +# + +testanalytical_model: testanalytical_model.cc $(BLD_LIBDIR)/libanalmodelpy.$(EXT_SAR) + $(CXX) $(CXXFLAGS) $(LCXXFLAGS) -o $@ testanalytical_model.cc $(PROJ_LIBRARIES) + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/analmodelpy/tests/signon.py b/sas/sascalc/simulation/analmodelpy/tests/signon.py new file mode 100755 index 000000000..d16860ea8 --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/tests/signon.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +if __name__ == "__main__": + + import analmodelpy + from analmodelpy import analmodelpy as analmodelpymodule + + print("copyright information:") + print(" ", analmodelpy.copyright()) + print(" ", analmodelpymodule.copyright()) + + print() + print("module information:") + print(" file:", analmodelpymodule.__file__) + print(" doc:", analmodelpymodule.__doc__) + print(" contents:", dir(analmodelpymodule)) + + print() + print(analmodelpymodule.hello()) + +# version +__id__ = "$Id$" + +# End of file diff --git a/sas/sascalc/simulation/analmodelpy/tests/testanal_model.py b/sas/sascalc/simulation/analmodelpy/tests/testanal_model.py new file mode 100755 index 000000000..fe451dd2e --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/tests/testanal_model.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +from __future__ import print_function + + +if __name__ == "__main__": + + from SASsimulation import analmodelpy as analmodelpymodule + from SASsimulation import iqPy + from SASsimulation import geoshapespy + + print("copyright information:") + print(" ", analmodelpymodule.copyright()) + + print() + print("module information:") + print(" file:", analmodelpymodule.__file__) + print(" doc:", analmodelpymodule.__doc__) + print(" contents:", dir(analmodelpymodule)) + + a = geoshapespy.new_sphere(1.0) + iq = iqPy.new_iq(10,0.001, 0.3) + anal = analmodelpymodule.new_analmodel(a) + analmodelpymodule.CalculateIQ(anal,iq) + iqPy.OutputIQ(iq,"out.iq") + + + + +# version +__id__ = "$Id$" + +# End of file diff --git a/sas/sascalc/simulation/analmodelpy/tests/testanalytical_model.cc b/sas/sascalc/simulation/analmodelpy/tests/testanalytical_model.cc new file mode 100755 index 000000000..24757ffbc --- /dev/null +++ b/sas/sascalc/simulation/analmodelpy/tests/testanalytical_model.cc @@ -0,0 +1,22 @@ +#include +#include +#include "analytical_model.h" +#include "sphere.h" + +using namespace std; + +void TestAnalyticalModel() { + Sphere sphere(1.0); + AnalyticalModel am(sphere); + + IQ iq1(10,0.001, 0.3); + am.CalculateIQ(&iq1); + + for (int i = 0; i < iq1.iq_data.dim1(); ++i) + cout << iq1.iq_data[i][0]<< " " << iq1.iq_data[i][1] < +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +PROJECT = geoshapespy + +# directory structure + +BUILD_DIRS = \ + libgeoshapespy \ + geoshapespymodule \ + geoshapespy \ + +OTHER_DIRS = \ + tests \ + examples + +RECURSE_DIRS = $(BUILD_DIRS) $(OTHER_DIRS) + +#-------------------------------------------------------------------------- +# + +all: + BLD_ACTION="all" $(MM) recurse + +distclean:: + BLD_ACTION="distclean" $(MM) recurse + +clean:: + BLD_ACTION="clean" $(MM) recurse + +tidy:: + BLD_ACTION="tidy" $(MM) recurse + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/geoshapespy/examples/Make.mm b/sas/sascalc/simulation/geoshapespy/examples/Make.mm new file mode 100755 index 000000000..5697d899b --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/examples/Make.mm @@ -0,0 +1,32 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +PROJECT = geoshapespy +PACKAGE = examples + +#-------------------------------------------------------------------------- +# + +all: clean + +release: clean + cvs release . + +update: clean + cvs update . + + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/geoshapespy/geoshapespy/Make.mm b/sas/sascalc/simulation/geoshapespy/geoshapespy/Make.mm new file mode 100755 index 000000000..56ea8b3fb --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/geoshapespy/Make.mm @@ -0,0 +1,35 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +PROJECT = geoshapespy +PACKAGE = geoshapespy + +#-------------------------------------------------------------------------- +# + +all: export + + +#-------------------------------------------------------------------------- +# +# export + +EXPORT_PYTHON_MODULES = \ + __init__.py + + +export:: export-python-modules + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/geoshapespy/geoshapespy/__init__.py b/sas/sascalc/simulation/geoshapespy/geoshapespy/__init__.py new file mode 100755 index 000000000..f1ea44911 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/geoshapespy/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +def copyright(): + return "geoshapespy pyre module: Copyright (c) 1998-2005 Michael A.G. Aivazis"; + + +# version +__id__ = "$Id$" + +# End of file diff --git a/sas/sascalc/simulation/geoshapespy/geoshapespymodule/Make.mm b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/Make.mm new file mode 100755 index 000000000..a98c696e2 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/Make.mm @@ -0,0 +1,32 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +PROJECT = SASsimulation +PACKAGE = geoshapespymodule +MODULE = geoshapespy + +include std-pythonmodule.def +include local.def + +PROJ_CXX_SRCLIB = -lgeoshapespy + +PROJ_SRCS = \ + bindings.cc \ + exceptions.cc \ + misc.cc + + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/geoshapespy/geoshapespymodule/bindings.cc b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/bindings.cc new file mode 100755 index 000000000..cb50e00e2 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/bindings.cc @@ -0,0 +1,63 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +//#include +#include + +#include "bindings.h" + +#include "misc.h" // miscellaneous methods + +// the method table + +struct PyMethodDef pygeoshapespy_methods[] = { + + //geoshapes methods: set_orientation; set_center + {pygeoshapespy_set_orientation__name__, pygeoshapespy_set_orientation, + METH_VARARGS, pygeoshapespy_set_orientation__doc__}, + + {pygeoshapespy_set_center__name__, pygeoshapespy_set_center, + METH_VARARGS, pygeoshapespy_set_center__doc__}, + + // new sphere + {pyanalmodelpy_new_sphere__name__, pyanalmodelpy_new_sphere, + METH_VARARGS, pyanalmodelpy_new_sphere__doc__}, + + // new cylinder + {pyanalmodelpy_new_cylinder__name__, pyanalmodelpy_new_cylinder, + METH_VARARGS, pyanalmodelpy_new_cylinder__doc__}, + + // new ellipsoid + {pyanalmodelpy_new_ellipsoid__name__, pyanalmodelpy_new_ellipsoid, + METH_VARARGS, pyanalmodelpy_new_ellipsoid__doc__}, + + // new hollowsphere + {pyanalmodelpy_new_hollowsphere__name__, pyanalmodelpy_new_hollowsphere, + METH_VARARGS, pyanalmodelpy_new_hollowsphere__doc__}, + + // new singlehelix + {pyanalmodelpy_new_singlehelix__name__, pyanalmodelpy_new_singlehelix, + METH_VARARGS, pyanalmodelpy_new_singlehelix__doc__}, + + {pygeoshapespy_copyright__name__, pygeoshapespy_copyright, + METH_VARARGS, pygeoshapespy_copyright__doc__}, + + +// Sentinel + {0, 0} +}; + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/geoshapespy/geoshapespymodule/bindings.h b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/bindings.h new file mode 100755 index 000000000..4daaadc1d --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/bindings.h @@ -0,0 +1,26 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +#if !defined(pygeoshapespy_bindings_h) +#define pygeoshapespy_bindings_h + +// the method table + +extern struct PyMethodDef pygeoshapespy_methods[]; + +#endif + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/geoshapespy/geoshapespymodule/exceptions.cc b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/exceptions.cc new file mode 100755 index 000000000..70208076b --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/exceptions.cc @@ -0,0 +1,22 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +//#include +#include + +PyObject *pygeoshapespy_runtimeError = 0; + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/geoshapespy/geoshapespymodule/exceptions.h b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/exceptions.h new file mode 100755 index 000000000..445c7f204 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/exceptions.h @@ -0,0 +1,26 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +#if !defined(pygeoshapespy_exceptions_h) +#define pygeoshapespy_exceptions_h + +// exceptions + +extern PyObject *pygeoshapespy_runtimeError; + +#endif + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/geoshapespy/geoshapespymodule/geoshapespymodule.cc b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/geoshapespymodule.cc new file mode 100755 index 000000000..9553b4ad0 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/geoshapespymodule.cc @@ -0,0 +1,44 @@ +// -*- C++ -*- + +#include + + +#include "exceptions.h" +#include "bindings.h" +#include "myutil.h" + + +char pygeoshapespy_module__doc__[] = ""; + +// Initialization function for the module (*must* be called initgeoshapespy) +extern "C" +void +initgeoshapespy() +{ + // create the module and add the functions + PyObject * m = Py_InitModule4( + "geoshapespy", pygeoshapespy_methods, + pygeoshapespy_module__doc__, 0, PYTHON_API_VERSION); + + // get its dictionary + PyObject * d = PyModule_GetDict(m); + + // check for errors + if (PyErr_Occurred()) { + Py_FatalError("can't initialize module geoshapespy"); + } + + // install the module exceptions + pygeoshapespy_runtimeError = PyErr_NewException("geoshapespy.runtime", 0, 0); + PyDict_SetItemString(d, "RuntimeException", pygeoshapespy_runtimeError); + + // Seed the random number generator + seed_rnd(); + + return; +} + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/geoshapespy/geoshapespymodule/local.def b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/local.def new file mode 100755 index 000000000..e376f3b61 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/local.def @@ -0,0 +1,24 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +# C++ + + PROJ_CXX_INCLUDES = ../libgeoshapespy \ + ../../iqPy/libiqPy \ + ../../iqPy/libiqPy/tnt \ + ../../pointsmodelpy/libpointsmodelpy + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/geoshapespy/geoshapespymodule/misc.cc b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/misc.cc new file mode 100755 index 000000000..bd88aa1c1 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/misc.cc @@ -0,0 +1,190 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +//#include +#include + +#include "misc.h" +#include "sphere.h" +#include "cylinder.h" +#include "ellipsoid.h" +#include "hollow_sphere.h" +#include "single_helix.h" + +// copyright + +char pygeoshapespy_copyright__doc__[] = ""; +char pygeoshapespy_copyright__name__[] = "copyright"; + +static char pygeoshapespy_copyright_note[] = + "geoshapespy python module: Copyright (c) 1998-2005 Michael A.G. Aivazis"; + + +PyObject * pygeoshapespy_copyright(PyObject *, PyObject *) +{ + return Py_BuildValue("s", pygeoshapespy_copyright_note); +} + +//GeoShape methods +char pygeoshapespy_set_orientation__name__[] = "set_orientation"; +char pygeoshapespy_set_orientation__doc__[] = "Set the rotation angles"; + +PyObject * pygeoshapespy_set_orientation(PyObject *, PyObject *args){ + PyObject *pyshape = 0; + double angX=0,angY=0,angZ=0; + + int ok = PyArg_ParseTuple(args, "Oddd", &pyshape,&angX,&angY,&angZ); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pyshape); + + GeoShape *shape = static_cast(temp); + + shape->SetOrientation(angX,angY,angZ); + + return Py_BuildValue("i", 0); +} + +char pygeoshapespy_set_center__name__[] = "set_center"; +char pygeoshapespy_set_center__doc__[] = "new center for points translation"; + +PyObject * pygeoshapespy_set_center(PyObject *, PyObject *args){ + + PyObject *pyshape = 0; + double tranX=0,tranY=0,tranZ=0; + + int ok = PyArg_ParseTuple(args, "Oddd", &pyshape,&tranX,&tranY,&tranZ); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pyshape); + + GeoShape *shape = static_cast(temp); + + shape->SetCenter(tranX,tranY,tranZ); + + return Py_BuildValue("i", 0); +} + +//Sphere constructor +char pyanalmodelpy_new_sphere__name__[] = "new_sphere"; +char pyanalmodelpy_new_sphere__doc__[] = "sphere constructor"; + +PyObject * pyanalmodelpy_new_sphere(PyObject *, PyObject *args){ + double r; + int ok = PyArg_ParseTuple(args,"d",&r); + if(!ok) return 0; + + Sphere *newsph = new Sphere(r); + + return PyCObject_FromVoidPtr(newsph, PyDelSphere); +} + +static void PyDelSphere(void *ptr){ + Sphere * oldsph = static_cast(ptr); + delete oldsph; + + return; +} + +//Cylinder constructor +char pyanalmodelpy_new_cylinder__name__[] = "new_cylinder"; +char pyanalmodelpy_new_cylinder__doc__[] = "cylinder constructor"; + +PyObject * pyanalmodelpy_new_cylinder(PyObject *, PyObject *args){ + double r,h; + int ok = PyArg_ParseTuple(args,"dd",&r,&h); + if(!ok) return 0; + + Cylinder *newcyl = new Cylinder(r,h); + + return PyCObject_FromVoidPtr(newcyl, PyDelCylinder); +} + +static void PyDelCylinder(void *ptr){ + Cylinder * oldcyl = static_cast(ptr); + delete oldcyl; + + return; +} + +//Ellipsoid constructor +char pyanalmodelpy_new_ellipsoid__name__[] = "new_ellipsoid"; +char pyanalmodelpy_new_ellipsoid__doc__[] = "ellipsoid constructor"; + +PyObject * pyanalmodelpy_new_ellipsoid(PyObject *, PyObject *args){ + double rx,ry,rz; + int ok = PyArg_ParseTuple(args,"ddd",&rx,&ry,&rz); + if(!ok) return 0; + + Ellipsoid *newelli = new Ellipsoid(rx,ry,rz); + + return PyCObject_FromVoidPtr(newelli, PyDelEllipsoid); +} + +static void PyDelEllipsoid(void *ptr){ + Ellipsoid * oldelli = static_cast(ptr); + delete oldelli; + + return; +} + +//Hollow Sphere constructor & methods +char pyanalmodelpy_new_hollowsphere__name__[] = "new_hollowsphere"; +char pyanalmodelpy_new_hollowsphere__doc__[] = ""; + +PyObject * pyanalmodelpy_new_hollowsphere(PyObject *, PyObject *args) +{ + double r, th; + int ok = PyArg_ParseTuple(args,"dd",&r, &th); + if(!ok) return 0; + + HollowSphere *newhosph = new HollowSphere(r,th); + + return PyCObject_FromVoidPtr(newhosph, PyDelHollowSphere); + +} + +static void PyDelHollowSphere(void *ptr) +{ + HollowSphere * oldhosph = static_cast(ptr); + delete oldhosph; + return; +} + +//Single Helix constructor & methods +char pyanalmodelpy_new_singlehelix__name__[] = "new_singlehelix"; +char pyanalmodelpy_new_singlehelix__doc__[] = ""; + +PyObject * pyanalmodelpy_new_singlehelix(PyObject *, PyObject *args) +{ + double hr,tr,pitch,turns; + int ok = PyArg_ParseTuple(args,"dddd",&hr,&tr,&pitch,&turns); + if(!ok) return 0; + + SingleHelix *newsinhel = new SingleHelix(hr,tr,pitch,turns); + + return PyCObject_FromVoidPtr(newsinhel, PyDelSingleHelix); + +} + +static void PyDelSingleHelix(void *ptr) +{ + SingleHelix * oldsinhel = static_cast(ptr); + delete oldsinhel; + return; +} + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/geoshapespy/geoshapespymodule/misc.h b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/misc.h new file mode 100755 index 000000000..1bb0594f7 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/geoshapespymodule/misc.h @@ -0,0 +1,79 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +#if !defined(pygeoshapespy_misc_h) +#define pygeoshapespy_misc_h + +// copyright +extern char pygeoshapespy_copyright__name__[]; +extern char pygeoshapespy_copyright__doc__[]; +extern "C" +PyObject * pygeoshapespy_copyright(PyObject *, PyObject *); + +//GeoShape methods +extern char pygeoshapespy_set_orientation__name__[]; +extern char pygeoshapespy_set_orientation__doc__[]; +extern "C" +PyObject * pygeoshapespy_set_orientation(PyObject *, PyObject *); + +extern char pygeoshapespy_set_center__name__[]; +extern char pygeoshapespy_set_center__doc__[]; +extern "C" +PyObject * pygeoshapespy_set_center(PyObject *, PyObject *); + +//Sphere constructor & methods +extern char pyanalmodelpy_new_sphere__name__[]; +extern char pyanalmodelpy_new_sphere__doc__[]; +extern "C" +PyObject * pyanalmodelpy_new_sphere(PyObject *, PyObject *); + +static void PyDelSphere(void *); + +//Cylinder constructor & methods +extern char pyanalmodelpy_new_cylinder__name__[]; +extern char pyanalmodelpy_new_cylinder__doc__[]; +extern "C" +PyObject * pyanalmodelpy_new_cylinder(PyObject *, PyObject *); + +static void PyDelCylinder(void *); + +//Ellipsoid constructor & method +extern char pyanalmodelpy_new_ellipsoid__name__[]; +extern char pyanalmodelpy_new_ellipsoid__doc__[]; +extern "C" +PyObject * pyanalmodelpy_new_ellipsoid(PyObject *, PyObject *); + +static void PyDelEllipsoid(void *); + +//Hollow Sphere constructor & methods +extern char pyanalmodelpy_new_hollowsphere__name__[]; +extern char pyanalmodelpy_new_hollowsphere__doc__[]; +extern "C" +PyObject * pyanalmodelpy_new_hollowsphere(PyObject *, PyObject *); + +static void PyDelHollowSphere(void *); + +//Single Helix constructor & methods +extern char pyanalmodelpy_new_singlehelix__name__[]; +extern char pyanalmodelpy_new_singlehelix__doc__[]; +extern "C" +PyObject * pyanalmodelpy_new_singlehelix(PyObject *, PyObject *); + +static void PyDelSingleHelix(void *); + +#endif + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/Make.mm b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/Make.mm new file mode 100755 index 000000000..bab493a1c --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/Make.mm @@ -0,0 +1,79 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +include local.def + +PROJECT = geoshapespy +PACKAGE = libgeoshapespy + +PROJ_SAR = $(BLD_LIBDIR)/$(PACKAGE).$(EXT_SAR) +PROJ_DLL = $(BLD_BINDIR)/$(PACKAGE).$(EXT_SO) +PROJ_TMPDIR = $(BLD_TMPDIR)/$(PROJECT)/$(PACKAGE) +PROJ_CLEAN += $(PROJ_SAR) $(PROJ_DLL) + +PROJ_SRCS = \ + geo_shape.cc \ + sphere.cc \ + hollow_sphere.cc \ + cylinder.cc \ + ellipsoid.cc \ + single_helix.cc \ + Point3D.cc \ + myutil.cc + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# build the library + +all: $(PROJ_SAR) export + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +ifeq (Win32, ${findstring Win32, $(PLATFORM_ID)}) + +# build the shared object +$(PROJ_SAR): product_dirs $(PROJ_OBJS) + $(CXX) $(LCXXFLAGS) -o $(PROJ_DLL) \ + -Wl,--out-implib=$(PROJ_SAR) $(PROJ_OBJS) + +# export +export:: export-headers export-libraries export-binaries + +else + +# build the shared object +$(PROJ_SAR): product_dirs $(PROJ_OBJS) + $(CXX) $(LCXXFLAGS) -o $(PROJ_SAR) $(PROJ_OBJS) + +# export +export:: export-headers export-libraries + +endif + +EXPORT_HEADERS = \ + geo_shape.h \ + sphere.h \ + hollow_sphere.h \ + cylinder.h \ + ellipsoid.h \ + single_helix.h \ + Point3D.h \ + myutil.h + + +EXPORT_LIBS = $(PROJ_SAR) +EXPORT_BINS = $(PROJ_DLL) + + +# version +# $Id$ + +# +# End of file diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/Point3D.cc b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/Point3D.cc new file mode 100755 index 000000000..2ba89bf62 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/Point3D.cc @@ -0,0 +1,220 @@ +//Point3D.cpp + +#include +#include +#include +#include "Point3D.h" + +using namespace std; + +Point3D::Point3D(double a, double b, double c, double sld) +{ + x = a; + y = b; + z = c; + sld_ = sld; +} + +ostream& operator<< (ostream& os, const Point3D& p) +{ + os << "(" << p.getX() << ", " << p.getY() << ", " << p.getZ() << "," << p.getSLD() << ")"; + + return os; +} + +double Point3D::distanceToPoint(const Point3D &p) const +{ + double dx=x-p.x; + double dy=y-p.y; + double dz=z-p.z; + return sqrt(dx*dx+dy*dy+dz*dz); +} + +double Point3D::distanceToLine(const Point3D &p1, const Point3D &p2, bool * pIsOutside /* 0 */) const +{ + double u = ((x-p1.x)*(p2.x-p1.x) + (y-p1.y)*(p2.y-p1.y) + (z-p1.z)*(p2.z-p1.z))/(p1.distanceToPoint(p2)*p1.distanceToPoint(p2)); + + if(pIsOutside != 0) { + if ( u < 0 || u > 1) + *pIsOutside=true; + else + *pIsOutside=false; + } + + double interX=p1.x+u*(p2.x-p1.x); + double interY=p1.y+u*(p2.y-p1.y); + double interZ=p1.z+u*(p2.z-p1.z); + + return sqrt((x-interX)*(x-interX)+(y-interY)*(y-interY)+(z-interZ)*(z-interZ)); +} + +// p1 and p2 determine a line, they are two end points +Point3D Point3D::getInterPoint(const Point3D &p1, const Point3D &p2, bool * pIsOutside /* 0 */) const +{ + double u = ((x-p1.x)*(p2.x-p1.x) + (y-p1.y)*(p2.y-p1.y) + (z-p1.z)*(p2.z-p1.z))/(p1.distanceToPoint(p2)*p1.distanceToPoint(p2)); + + if(pIsOutside != 0) { + + if ( u < 0 || u > 1) + *pIsOutside=true; + else + *pIsOutside=false; + } + + return Point3D(p1.x+u*(p2.x-p1.x), p1.y+u*(p2.y-p1.y), p1.z+u*(p2.z-p1.z)); +} + +void Point3D::set(double x1, double y1, double z1) +{ + x = x1; + y = y1; + z = z1; +} + +double Point3D::norm() const +{ + return sqrt(x * x + y * y + z * z); +} + +double Point3D::normalize() +{ + double v = norm(); + + if(v != 0) { + x /= v; + y /= v; + z /= v; + } + + return v; +} + +Point3D Point3D::normVector() const +{ + Point3D p; + double v = norm(); + + if(v != 0) { + p.x = x / v; + p.y = y / v; + p.z = z / v; + } + + return p; +} + +double Point3D::dotProduct(const Point3D &p) const +{ + return x * p.x + y * p.y + z * p.z; +} + +Point3D Point3D::minus(const Point3D &p) const +{ + return Point3D(x - p.x, y - p.y, z - p.z); +} + +Point3D Point3D::plus(const Point3D &p) const +{ + return Point3D(x + p.x, y + p.y, z + p.z); +} + +Point3D& Point3D::operator=(const Point3D &p) +{ + x = p.x; + y = p.y; + z = p.z; + sld_ = p.sld_; + + return *this; +} + +void Point3D::scale(double s) +{ + x *= s; + y *= s; + z *= s; +} + +Point3D Point3D::multiplyProduct(const Point3D &p) +{ + return Point3D(y * p.z - z * p.y, z * p.x - x * p.z, x * p.y - y * p.x); +} + +void Point3D::Transform(const vector &orien, const vector ¢er){ + if(orien[1]) RotateY(Degree2Radian(orien[1])); + if(orien[0]) RotateX(Degree2Radian(orien[0])); + if(orien[2]) RotateZ(Degree2Radian(orien[2])); + + Translate(center[0],center[1],center[2]); +} + +void Point3D::RotateX(const double ang_x) +{ + double sinA_ = sin(ang_x); + double cosA_ = cos(ang_x); + + //x doesn't change; + double y_new= y*cosA_ - z*sinA_; + double z_new= y*sinA_ + z*cosA_; + + y = y_new; + z = z_new; +} + +void Point3D::RotateY(const double ang_y) +{ + double sinA_ = sin(ang_y); + double cosA_ = cos(ang_y); + + double x_new = z*sinA_ + x*cosA_; + //y doesn't change + double z_new = z*cosA_ - x*sinA_; + + x = x_new; + z = z_new; +} + +void Point3D::RotateZ(const double ang_z) +{ + double sinA_ = sin(ang_z); + double cosA_ = cos(ang_z); + + double x_new = x*cosA_ - y*sinA_; + double y_new = x*sinA_ + y*cosA_; + //z doesn't change + + x = x_new; + y = y_new; +} + +void Point3D::Translate(const double trans_x, const double trans_y, const double trans_z) +{ + double x_new = x + trans_x; + double y_new = y + trans_y; + double z_new = z + trans_z; + + x = x_new; + y = y_new; + z = z_new; +} + +double Point3D::Degree2Radian(const double degree) +{ + return degree/180*pi; +} + +void Point3D::TransformMatrix(const vector &rotmatrix, const vector ¢er) +{ + if (rotmatrix.size() != 9 || center.size() != 3) + throw std::runtime_error("The size for rotation matrix vector has to be 9."); + + double xold = x; + double yold = y; + double zold = z; + + x = rotmatrix[0]*xold + rotmatrix[1]*yold + rotmatrix[2]*zold; + y = rotmatrix[3]*xold + rotmatrix[4]*yold + rotmatrix[5]*zold; + z = rotmatrix[6]*xold + rotmatrix[7]*yold + rotmatrix[8]*zold; + + Translate(center[0],center[1],center[2]); +} diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/Point3D.h b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/Point3D.h new file mode 100755 index 000000000..6a4ee66aa --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/Point3D.h @@ -0,0 +1,87 @@ +//class point3D 1/29/2004 Jing +//properties: copy point coordinates +// calculate point to point distance +// shortest distance from a point to a line + + +#ifndef __POINT3D_ +#define __POINT3D_ + +#include +#include + +using namespace std; + +const double pi = 3.1415926; + +class Point3D +{ + public: + Point3D() {} + + //assign values with SLD to a point + Point3D(double a, double b, double c, double sld = 0); + + // output + friend std::ostream& operator<<(std::ostream&, const Point3D&); + + //distance to a point + double distanceToPoint(const Point3D &p) const; + + //distance to a line + double distanceToLine(const Point3D &p1, const Point3D &p2, bool *pv = 0) const; + + // get the point lying on the axis from a point + Point3D getInterPoint(const Point3D &p1, const Point3D &p2, bool *pv = 0) const; + + // normalization + Point3D normVector() const; + + // get length + double norm() const; + + double normalize(); + + // assignment operator + Point3D& operator=(const Point3D &p); + + // scale this point with s + void scale(double s); + + // multiplication product of two vectors + Point3D multiplyProduct(const Point3D &p); + + // p0 - p + Point3D minus(const Point3D &p) const; + + // p0 + p + Point3D plus(const Point3D &p) const; + + // dot product + double dotProduct(const Point3D &p) const; + + //if you do not care if the point falls into the range of line,you can directly use distanceToLine(p1,p2) + void set(double x1, double y1, double z1); + + double getX() const { return x; } + double getY() const { return y; } + double getZ() const { return z; } + double getSLD() const { return sld_; } + + //Transformation + void Transform(const vector &orien, const vector ¢er); + void TransformMatrix(const vector &rotmatrix, const vector ¢er); + private: + double x, y, z; + double sld_; + + void RotateX(const double ang_x); + void RotateY(const double ang_y); + void RotateZ(const double ang_z); + void Translate(const double trans_x, const double trans_y, const double trans_z); + double Degree2Radian(const double degree); +}; + +typedef std::vector Point3DVector; + +#endif diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/cylinder.cc b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/cylinder.cc new file mode 100755 index 000000000..b52749ace --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/cylinder.cc @@ -0,0 +1,108 @@ +/** \file Cylinder.cc */ +#include +#include +#include "cylinder.h" +#include "myutil.h" + +Cylinder::Cylinder() +{ + r_ = 0; +} + +Cylinder::Cylinder(double radius,double length) +{ + r_ = radius; + l_ = length; + topcenter_ = Point3D(0,length/2,0); + botcenter_ = Point3D(0,-length/2,0); +} + +Cylinder::~Cylinder() +{ +} + +void Cylinder::SetRadius(double r) +{ + r_ = r; +} + +void Cylinder::SetLength(double l) +{ + l_ = l; +} + +double Cylinder::GetRadius() +{ + return r_; +} + +double Cylinder::GetLength() +{ + return l_; +} + +double Cylinder::GetMaxRadius() +{ + double maxr = sqrt(4*r_*r_ + l_*l_)/2; + return maxr; +} + +ShapeType Cylinder::GetShapeType() const +{ + return CYLINDER; +} + +double Cylinder::GetVolume() +{ + double V = pi * square(r_) * l_; + return V; +} + +void Cylinder::GetFormFactor(IQ * iq) +{ + /** number of I for output, equal to the number of rows of array IQ*/ + /** to be finished */ +} + +Point3D Cylinder::GetAPoint(double sld) +{ + /** cylinder long axis is along Y to match vtk actor */ + static int max_try = 100; + for (int i = 0; i < max_try; ++i) { + double x = (ran1()-0.5) * 2 * r_; + double z = (ran1()-0.5) * 2 * r_; + double y = (ran1()-0.5) * l_; + + Point3D apoint(x,y,z,sld); + //check the cross section on xy plane within a sphere at (0,) + if (apoint.distanceToPoint(Point3D(0,y,0)) <= r_ ) + return apoint; + } + + std::cerr << "Max try " + << max_try + << " is reached while generating a point in cylinder" << std::endl; + return Point3D(0, 0, 0); +} + +bool Cylinder::IsInside(const Point3D& point) const +{ + bool isOutside = false; + double distToLine = point.distanceToLine(GetTopCenter(),GetBotCenter(),&isOutside); + + return (distToLine <= r_ && !isOutside); +} + +Point3D Cylinder::GetTopCenter() const +{ + Point3D new_center(topcenter_); + new_center.Transform(GetOrientation(),GetCenter()); + return new_center; +} + +Point3D Cylinder::GetBotCenter() const +{ + Point3D new_center(botcenter_); + new_center.Transform(GetOrientation(),GetCenter()); + return new_center; +} diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/cylinder.h b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/cylinder.h new file mode 100755 index 000000000..02e084952 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/cylinder.h @@ -0,0 +1,63 @@ +/** \file cylinder.h class Cylinder:GeoShape */ + +#ifndef Cylinder_H +#define Cylinder_H + +#include "geo_shape.h" + +/** class Cylinder, subclass of GeoShape */ +class Cylinder : public GeoShape { + public: + + /** initialize */ + Cylinder(); + + /** constructor with radius initialization */ + Cylinder(double radius,double length); + + ~Cylinder(); + + /** set parameter radius */ + void SetRadius(double r); + + /** set parameter length */ + void SetLength(double l); + + /** get the radius */ + double GetRadius(); + + /** get the length */ + double GetLength(); + + /** get the radius of the sphere to cover this shape */ + double GetMaxRadius(); + + /** get the volume */ + double GetVolume(); + + /** calculate the cylinder form factor, no scale, no background*/ + void GetFormFactor(IQ * iq); + + /** using a equation to check whether a point with XYZ lies + within the cylinder with center (0,0,0) + */ + Point3D GetAPoint(double sld); + + /** check whether a point is inside the cylinder at any position + in the 3D space + */ + bool IsInside(const Point3D& point) const; + + ShapeType GetShapeType() const; + + protected: + Point3D GetTopCenter() const; + Point3D GetBotCenter() const; + + private: + double r_, l_; + Point3D topcenter_,botcenter_; + +}; + +#endif diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/ellipsoid.cc b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/ellipsoid.cc new file mode 100755 index 000000000..08bf9bba7 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/ellipsoid.cc @@ -0,0 +1,165 @@ +/** \file Ellipsoid.cc */ +#include +#include +#include +//MS library does not define min max in algorithm +//#include "minmax.h" +#include "ellipsoid.h" +#include "myutil.h" + +using namespace std; + +Ellipsoid::Ellipsoid() +{ + rx_ = 0; + ry_ = 0; + rz_ = 0; +} + +Ellipsoid::Ellipsoid(double rx, double ry, double rz) +{ + rx_ = rx; + ry_ = ry; + rz_ = rz; + xedge_plus_ = Point3D(rx,0,0); + xedge_minus_ = Point3D(-rx,0,0); + yedge_plus_ = Point3D(0,ry,0); + yedge_minus_ = Point3D(0,-ry,0); + zedge_plus_ = Point3D(0,0,rz); + zedge_minus_ = Point3D(0,0,-rz); +} + +Ellipsoid::~Ellipsoid() +{ +} + +void Ellipsoid::SetRadii(double rx, double ry, double rz) +{ + rx_ = rx; + ry_ = ry; + rz_ = rz; +} + +double Ellipsoid::GetRadiusX() +{ + return rx_; +} + +double Ellipsoid::GetRadiusY() +{ + return ry_; +} + +double Ellipsoid::GetRadiusZ() +{ + return rz_; +} + +double Ellipsoid::GetMaxRadius() +{ + double maxr = max(max(rx_,ry_),max(ry_,rz_)); + return maxr; +} + +ShapeType Ellipsoid::GetShapeType() const +{ + return ELLIPSOID; +} + +double Ellipsoid::GetVolume() +{ + double V = (4./3.)*pi*rx_*ry_*rz_; + return V; +} + +void Ellipsoid::GetFormFactor(IQ * iq) +{ + /** number of I for output, equal to the number of rows of array IQ*/ + /** to be finished */ +} + +Point3D Ellipsoid::GetAPoint(double sld) +{ + static int max_try = 100; + for (int i = 0; i < max_try; ++i) { + double x = (ran1()-0.5) * 2 * rx_; + double y = (ran1()-0.5) * 2 * ry_; + double z = (ran1()-0.5) * 2 * rz_; + + if ((square(x/rx_) + square(y/ry_) + square(z/rz_)) <= 1){ + Point3D apoint(x,y,z,sld); + return apoint; + } + } + + std::cerr << "Max try " + << max_try + << " is reached while generating a point in ellipsoid" << std::endl; + return Point3D(0, 0, 0); +} + +bool Ellipsoid::IsInside(const Point3D& point) const +{ + //x, y, z axis are internal axis + bool isOutsideX = false; + Point3D pointOnX = point.getInterPoint(GetXaxisPlusEdge(),GetXaxisMinusEdge(),&isOutsideX); + bool isOutsideY = false; + Point3D pointOnY = point.getInterPoint(GetYaxisPlusEdge(),GetYaxisMinusEdge(),&isOutsideY); + bool isOutsideZ = false; + Point3D pointOnZ = point.getInterPoint(GetZaxisPlusEdge(),GetZaxisMinusEdge(),&isOutsideZ); + + if (isOutsideX || isOutsideY || isOutsideZ){ + //one is outside axis is true -> the point is not inside + return false; + } + else{ + Point3D pcenter = GetCenterP(); + double distX = pointOnX.distanceToPoint(pcenter); + double distY = pointOnY.distanceToPoint(pcenter); + double distZ = pointOnZ.distanceToPoint(pcenter); + return ((square(distX/rx_)+square(distY/ry_)+square(distZ/rz_)) <= 1); + } + +} + +Point3D Ellipsoid::GetXaxisPlusEdge() const +{ + Point3D new_edge(xedge_plus_); + new_edge.Transform(GetOrientation(),GetCenter()); + return new_edge; +} + +Point3D Ellipsoid::GetXaxisMinusEdge() const +{ + Point3D new_edge(xedge_minus_); + new_edge.Transform(GetOrientation(),GetCenter()); + return new_edge; +} + +Point3D Ellipsoid::GetYaxisPlusEdge() const +{ + Point3D new_edge(yedge_plus_); + new_edge.Transform(GetOrientation(),GetCenter()); + return new_edge; +} + +Point3D Ellipsoid::GetYaxisMinusEdge() const +{ + Point3D new_edge(yedge_minus_); + new_edge.Transform(GetOrientation(),GetCenter()); + return new_edge; +} + +Point3D Ellipsoid::GetZaxisPlusEdge() const +{ + Point3D new_edge(zedge_plus_); + new_edge.Transform(GetOrientation(),GetCenter()); + return new_edge; +} + +Point3D Ellipsoid::GetZaxisMinusEdge() const +{ + Point3D new_edge(zedge_minus_); + new_edge.Transform(GetOrientation(),GetCenter()); + return new_edge; +} diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/ellipsoid.h b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/ellipsoid.h new file mode 100755 index 000000000..99568a057 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/ellipsoid.h @@ -0,0 +1,65 @@ +/** \file cylinder.h class Cylinder:GeoShape */ + +#ifndef Ellipsoid_H +#define Ellipsoid_H + +#include "geo_shape.h" + +/** class Ellipsoid, subclass of GeoShape */ +class Ellipsoid : public GeoShape { + public: + + /** initialize */ + Ellipsoid(); + + /** constructor with radii initialization */ + Ellipsoid(double rx,double ry, double rz); + + ~Ellipsoid(); + + /** set parameter radii */ + void SetRadii(double rx, double ry, double rz); + + /** get the radii */ + double GetRadiusX(); + double GetRadiusY(); + double GetRadiusZ(); + + /** get the radius of the sphere to cover this shape */ + double GetMaxRadius(); + + /** get the volume */ + double GetVolume(); + + /** calculate the cylinder form factor, no scale, no background*/ + void GetFormFactor(IQ * iq); + + /** using a equation to check whether a point with XYZ lies + within the cylinder with center (0,0,0) + */ + Point3D GetAPoint(double sld); + + /** check whether a point is inside the cylinder at any position + in the 3D space + */ + bool IsInside(const Point3D& point) const; + + ShapeType GetShapeType() const; + + protected: + //get a point on three axis for function IsInside,using the edge point + Point3D GetXaxisPlusEdge() const; + Point3D GetXaxisMinusEdge() const; + Point3D GetYaxisPlusEdge() const; + Point3D GetYaxisMinusEdge() const; + Point3D GetZaxisPlusEdge() const; + Point3D GetZaxisMinusEdge() const; + + private: + double rx_,ry_,rz_; + Point3D xedge_plus_,xedge_minus_; + Point3D yedge_plus_,yedge_minus_; + Point3D zedge_plus_,zedge_minus_; +}; + +#endif diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/geo_shape.cc b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/geo_shape.cc new file mode 100755 index 000000000..677d7c783 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/geo_shape.cc @@ -0,0 +1,26 @@ +#include "geo_shape.h" + +void GeoShape::SetOrientation(double angX,double angY, double angZ){ + orientation_[0]=angX; + orientation_[1]=angY; + orientation_[2]=angZ; +} + +void GeoShape::SetCenter(double cenX,double cenY,double cenZ){ + center_[0]=cenX; + center_[1]=cenY; + center_[2]=cenZ; +} + +vector GeoShape::GetOrientation() const { + return orientation_; +} + +vector GeoShape::GetCenter() const { + return center_; +} + +Point3D GeoShape::GetCenterP() const { + Point3D p(center_[0],center_[1],center_[2]); + return p; +} diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/geo_shape.h b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/geo_shape.h new file mode 100755 index 000000000..fef7aa097 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/geo_shape.h @@ -0,0 +1,66 @@ +/*! \file a abstract class GeoShape +*/ +#ifndef GEOSHAPE_H +#define GEOSHAPE_H + +#include "iq.h" +#include "Point3D.h" +#include + +using namespace std; + +#define triple(x) ((x) * (x) * (x)) +#define square(x) ((x) * (x)) + +enum ShapeType{ SPHERE, HOLLOWSPHERE, CYLINDER , ELLIPSOID,SINGLEHELIX}; + +/** + * class GeoShape, abstract class, parent class for Sphere, Cylinder .... + */ + +class GeoShape{ + public: + + GeoShape(){ + vector init(3,0.0); + orientation_ = init; + center_ = init; + } + + virtual ~GeoShape() {} + + /** calculate the form factor for a simple shape */ + virtual void GetFormFactor(IQ * iq) = 0; + + /** Get a point that is within the simple shape*/ + virtual Point3D GetAPoint(double sld) = 0; + + /** check whether a point is inside the shape*/ + virtual bool IsInside(const Point3D& point) const = 0; + + virtual double GetVolume() = 0; + + virtual ShapeType GetShapeType() const = 0; + + /** get the radius of the sphere to cover the shape*/ + virtual double GetMaxRadius() = 0; + + void SetOrientation(double angX, double angY, double angZ); + + void SetCenter(double cenX, double cenY, double cenZ); + + vector GetOrientation() const; + + vector GetCenter() const; + + Point3D GetCenterP() const; + + private: + vector orientation_; + + protected: + vector center_; + +}; + +#endif diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/hollow_sphere.cc b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/hollow_sphere.cc new file mode 100755 index 000000000..d3b319857 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/hollow_sphere.cc @@ -0,0 +1,83 @@ +/** \file hollowsphere.cc */ +#include +#include +#include "hollow_sphere.h" + +HollowSphere::HollowSphere() +{ + ro_ = 0; + th_ = 0; +} + +HollowSphere::HollowSphere(double radius, double thickness) +{ + ro_ = radius; + th_ = thickness; +} + +double HollowSphere::GetMaxRadius() +{ + double maxr = ro_; + return maxr; +} + +void HollowSphere::GetFormFactor(IQ * iq) +{ + /** number of I for output, equal to the number of rows of array IQ*/ + int numI = iq->iq_data.dim1(); + double qmin = iq->GetQmin(); + double qmax = iq->GetQmax(); + + assert(numI > 0); + assert(qmin > 0); + assert(qmax > 0); + assert( qmax > qmin ); + + double logMin = log10(qmin); + double z = logMin; + double logMax = log10(qmax); + double delta = (logMax - logMin) / (numI-1); + + //not finished yet, need to find the part which is equal to 1 + for(int i = 0; i < numI; ++i) { + + /** temp for Q*/ + double q = pow(z,10); + + double ri_ = ro_ - th_ ; + + double bes1 = 3.0 * (sin(q*ri_) - q*ri_*cos(q*ri_)) / triple(q) / triple(ri_); + double bes2 = 3.0 * (sin(q*ro_) - q*ro_*cos(q*ro_)) / triple(q) / triple(ro_); + double bes = (triple(ro_)*bes1 - triple(ri_)*bes2)/(triple(ro_) - triple(ri_)); + /** double f is the temp for I, should be equal to one when q is 0*/ + double f = bes * bes; + + /** IQ[i][0] is Q,Q starts from qmin (non zero),q=0 handle separately IQ[i][1] is I */ + iq->iq_data[i][0]= q; + iq->iq_data[i][1]= f; + + z += delta; + } + +} + +Point3D HollowSphere::GetAPoint(double sld) +{ + return Point3D(0,0,0); +} + +double HollowSphere::GetVolume() +{ + return 0; +} + +bool HollowSphere::IsInside(const Point3D& point) const +{ + return true; +} + +ShapeType HollowSphere::GetShapeType() const +{ + return HOLLOWSPHERE; +} + diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/hollow_sphere.h b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/hollow_sphere.h new file mode 100755 index 000000000..5568f7a87 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/hollow_sphere.h @@ -0,0 +1,43 @@ +/** \file hollowsphere.h class HollowSphere:GeoShape */ + +#ifndef HOLLOWSPHERE_H +#define HOLLOWSPHERE_H + +#include "geo_shape.h" + +/** class Sphere, subclass of GeoShape */ +class HollowSphere : public GeoShape { + public: + + /** initialize */ + HollowSphere(); + + /** constructor with radius&length initialization */ + HollowSphere(double radius, double thickness); + + /** get the radius of the sphere to cover this shape */ + double GetMaxRadius(); + + /** calculate the sphere form factor, no scale, no background*/ + void GetFormFactor(IQ * iq); + + /** using a equation to check whether a point with XYZ lies + within the sphere with center (0,0,0) + */ + Point3D GetAPoint(double sld); + + /** check whether point is inside hollow sphere*/ + bool IsInside(const Point3D& point) const; + + double GetVolume(); + + //get the shape type, return sphere + ShapeType GetShapeType() const; + + private: + //outer radius, and the thickness of the shell + double ro_, th_; + +}; + +#endif diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/local.def b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/local.def new file mode 100755 index 000000000..c36373364 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/local.def @@ -0,0 +1,30 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +# +# Local (project) definitions +# + +# C++ + +PROJ_CXX_INCLUDES += ../../iqPy/libiqPy \ + ../../iqPy/libiqPy/tnt \ + ../../pointsmodelpy/libpointsmodelpy + + PROJ_CXX_FLAGS += $(CXX_SOFLAGS) + PROJ_LCXX_FLAGS += $(LCXX_SARFLAGS) -liqPy + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/minmax.h b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/minmax.h new file mode 100755 index 000000000..313e6a714 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/minmax.h @@ -0,0 +1,23 @@ +#ifndef GUARD_minmax_H +#define GUARD_minmax_H +#ifdef _MSC_VER +// needed to cope with bug in MS library: +// it fails to define min/max + +template inline T max(const T& a, const T& b) +{ + + return (a > b) ? a : b; + +} + +template inline T min(const T& a, const T& b) +{ + + return (a < b) ? a : b; + +} + +#endif + +#endif \ No newline at end of file diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/myutil.cc b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/myutil.cc new file mode 100755 index 000000000..02a1dd2ae --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/myutil.cc @@ -0,0 +1,16 @@ +/** \file myutil.cc */ +#include +#include +#include "myutil.h" + +using namespace std; + +void seed_rnd() { + srand((unsigned)time(NULL)); +} + +double ran1() +{ + return (double)rand() / (double) RAND_MAX; +} + diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/myutil.h b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/myutil.h new file mode 100755 index 000000000..c31d86d16 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/myutil.h @@ -0,0 +1,9 @@ +/** \file myutil.h */ +#ifndef _MYUTIL_H +#define _MYUTIL_H + +/** return a random float number between [0,1] */ +double ran1(); +void seed_rnd(); + +#endif diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/single_helix.cc b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/single_helix.cc new file mode 100755 index 000000000..71c6d2412 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/single_helix.cc @@ -0,0 +1,140 @@ +/** \file SingleHelix.cc */ +#include +#include +#include "single_helix.h" +#include "myutil.h" + +SingleHelix::SingleHelix() +{ + hr_ = 0; + tr_ = 0; + pitch_ = 0; + turns_ = 0; +} + +SingleHelix::SingleHelix(double helix_radius,double tube_radius,double pitch,double turns) +{ + hr_ = helix_radius; + tr_ = tube_radius; + pitch_ = pitch; + turns_ = turns; +} + +SingleHelix::~SingleHelix() +{ +} + +void SingleHelix::SetHelixRadius(double hr) +{ + hr_ = hr; +} + +void SingleHelix::SetTubeRadius(double tr) +{ + tr_ = tr; +} + +void SingleHelix::SetPitch(double pitch) +{ + pitch_ = pitch; +} + +void SingleHelix::SetTurns(double turns) +{ + turns_ = turns; +} + +double SingleHelix::GetHelixRadius() +{ + return hr_; +} + +double SingleHelix::GetTubeRadius() +{ + return tr_; +} + +double SingleHelix::GetPitch() +{ + return pitch_; +} + +double SingleHelix::GetTurns() +{ + return turns_; +} + +double SingleHelix::GetMaxRadius() +{ + // put helix into a cylinder + double r_ = hr_ + tr_; + double l_ = pitch_*turns_ + 2*tr_; + double maxr = sqrt(4*r_*r_ + l_*l_)/2; + return maxr; +} + +ShapeType SingleHelix::GetShapeType() const +{ + return SINGLEHELIX; +} + +double SingleHelix::GetVolume() +{ + double V = pi*square(tr_)*sqrt(square(2*pi*hr_)+square(turns_*pitch_)); + return V; +} + +void SingleHelix::GetFormFactor(IQ * iq) +{ + /** number of I for output, equal to the number of rows of array IQ*/ + /** to be finished */ +} + +Point3D SingleHelix::GetAPoint(double sld) +{ + int max_try = 100; + int i = 0; + double point1 = 0, point2 = 0, point3 = 0; + + do{ + i++; + if (i > max_try) + break; + zr_ = tr_*sqrt(1+square((pitch_/(2*pi*hr_)))); + point1 = (ran1()-0.5)*2*tr_; + point2 = (ran1()-0.5)*2*zr_; + point3 = (ran1()-0.5)*4*pi*turns_; + } while (((square(point1/hr_)+square(point2/zr_))>1) || (point2+pitch_*point3/(2*pi)<0)); + + double x = (point1 + hr_)*cos(point3); + double y = (point1 + hr_)*sin(point3); + //"-pitch_*turns_/2" is just to corretly center the helix + double z = point2 + pitch_*point3/(2*pi)-pitch_*turns_/2; + + Point3D apoint(x,y,z,sld); + return apoint; + + + std::cerr << "Max try " + << max_try + << " is reached while generating a point in cylinder" << std::endl; + + return Point3D(0,0,0); +} + +bool SingleHelix::IsInside(const Point3D& point) const +{ + double x = point.getX()-center_[0]; + double y = point.getY()-center_[1]; + double z = point.getZ()-center_[2]; + + double p3 = atan(y/x); + double p2 = z-(pitch_*p3)/(2*pi); + double p1 = x/cos(p3)-hr_; + + if ((square(p1/tr_)+square(p2/zr_))>1 || p2+pitch_*p3/(2*pi)) + return false; + else + return true; +} + diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/single_helix.h b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/single_helix.h new file mode 100755 index 000000000..dc6c5d885 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/single_helix.h @@ -0,0 +1,75 @@ +/** \file cylinder.h class Cylinder:GeoShape */ + +#ifndef SingleHelix_H +#define SingleHelix_H + +#include "geo_shape.h" + +/** class SingleHelix, subclass of GeoShape */ +class SingleHelix : public GeoShape { + public: + + /** initialize */ + SingleHelix(); + + /** constructor with radius initialization */ + SingleHelix(double helix_radius,double tube_radius,double pitch,double turns); + + ~SingleHelix(); + + /** set parameter helix_radius */ + void SetHelixRadius(double hr); + + /** set parameter tube_radius */ + void SetTubeRadius(double tr); + + /** set parameter pitch */ + void SetPitch(double pitch); + + /** set parameter number of turns */ + void SetTurns(double turns); + + /** get the helix_radius */ + double GetHelixRadius(); + + /** get the tube_radius */ + double GetTubeRadius(); + + /** get the pitch */ + double GetPitch(); + + /** get the number of turns */ + double GetTurns(); + + /** get the radius of the sphere to cover this shape */ + double GetMaxRadius(); + + /** get the volume */ + double GetVolume(); + + /** calculate the single helix form factor, no scale, no background*/ + void GetFormFactor(IQ * iq); + + /** using a equation to check whether a point with XYZ lies + within the cylinder with center (0,0,0) + */ + Point3D GetAPoint(double sld); + + /** check whether a point is inside the cylinder at any position + in the 3D space + */ + bool IsInside(const Point3D& point) const; + + ShapeType GetShapeType() const; + + protected: + //Point3D GetTopCenter() const; + //Point3D GetBotCenter() const; + + private: + double hr_,tr_, pitch_, turns_,zr_; + Point3D topcenter_,botcenter_; + +}; + +#endif diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/sphere.cc b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/sphere.cc new file mode 100755 index 000000000..3b7ef082f --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/sphere.cc @@ -0,0 +1,106 @@ +/** \file Sphere.cc */ +#include +#include +#include "sphere.h" +#include "myutil.h" + +Sphere::Sphere() +{ + r_ = 0; +} + +Sphere::Sphere(double radius) +{ + r_ = radius; +} + +Sphere::~Sphere() +{ +} + +void Sphere::SetRadius(double r) +{ + r_ = r; +} + +double Sphere::GetRadius() +{ + return r_; +} + +double Sphere::GetMaxRadius() +{ + double maxr = r_; + return maxr; +} + +ShapeType Sphere::GetShapeType() const +{ + return SPHERE; +} + +double Sphere::GetVolume() +{ + double V = 4 * pi / 3 * triple(r_); + return V; +} + +void Sphere::GetFormFactor(IQ * iq) +{ + /** number of I for output, equal to the number of rows of array IQ*/ + int numI = iq->iq_data.dim1(); + double qmin = iq->GetQmin(); + double qmax = iq->GetQmax(); + + assert(numI > 0); + assert(qmin > 0); + assert(qmax > 0); + assert( qmax > qmin ); + + double logMin = log10(qmin); + double z = logMin; + double logMax = log10(qmax); + double delta = (logMax - logMin) / (numI-1); + + for(int i = 0; i < numI; ++i) { + + /** temp for Q*/ + double q = pow(z, 10); + + double bes = 3.0 * (sin(q*r_) - q*r_*cos(q*r_)) / triple(q) / triple(r_); + + /** double f is the temp for I, should be equal to one when q is 0*/ + double f = bes * bes; + + /** IQ[i][0] is Q,Q starts from qmin (non zero),q=0 handle separately IQ[i][1] is IQ*/ + iq->iq_data[i][0]= q; + iq->iq_data[i][1]= f; + + z += delta; + } +} + +Point3D Sphere::GetAPoint(double sld) +{ + static int max_try = 100; + for (int i = 0; i < max_try; ++i) { + double x = (ran1()-0.5) * 2 * r_; + double y = (ran1()-0.5) * 2 * r_; + double z = (ran1()-0.5) * 2 * r_; + + Point3D apoint(x,y,z,sld); + if (apoint.distanceToPoint(Point3D(0,0,0)) <=r_ ) //dist to origin give sphere shape + return apoint; + } + + std::cerr << "Max try " + << max_try + << " is reached while generating a point in sphere" << std::endl; + return Point3D(0, 0, 0); +} + +bool Sphere::IsInside(const Point3D& point) const +{ + return point.distanceToPoint(Point3D(center_[0], center_[1], center_[2])) <= r_; + cout << "distance = " << point.distanceToPoint(Point3D(center_[0], center_[1], center_[2])) << endl; +} diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/sphere.h b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/sphere.h new file mode 100755 index 000000000..f2d5b8f54 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/sphere.h @@ -0,0 +1,52 @@ +/** \file sphere.h class Sphere:GeoShape */ + +#ifndef SPHERE_H +#define SPHERE_H + +#include "geo_shape.h" + +/** class Sphere, subclass of GeoShape */ +class Sphere : public GeoShape { + public: + + /** initialize */ + Sphere(); + + /** constructor with radius initialization */ + Sphere(double radius); + + ~Sphere(); + + /** set parameter radius */ + void SetRadius(double r); + + /** get the radius */ + double GetRadius(); + + /** Get the radius of the sphere to cover this shape */ + double GetMaxRadius(); + + /** get the volume */ + double GetVolume(); + + /** calculate the sphere form factor, no scale, no background*/ + void GetFormFactor(IQ * iq); + + /** using a equation to check whether a point with XYZ lies + within the sphere with center (0,0,0) + */ + Point3D GetAPoint(double sld); + + /** check whether a point is inside the sphere at any position + in the 3D space + */ + bool IsInside(const Point3D& point) const; + + ShapeType GetShapeType() const; + + private: + double r_; + +}; + +#endif diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/transformation.cc b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/transformation.cc new file mode 100755 index 000000000..bc9a8ea24 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/transformation.cc @@ -0,0 +1,73 @@ +#include "transformation.h" +#include +#include + +void RotateX(const double ang_x,Point3D &a_point) +{ + double sinA_ = sin(ang_x); + double cosA_ = cos(ang_x); + double pointX_ = a_point.getX(); + double pointY_ = a_point.getY(); + double pointZ_ = a_point.getZ(); + + double x_new_ = pointX_; + double y_new_ = pointY_*cosA_ - pointZ_*sinA_; + double z_new_ = pointY_*sinA_ - pointZ_*cosA_; + + a_point.set(x_new_, y_new_, z_new_); +} + +void RotateY(const double ang_y,Point3D &a_point) +{ + double sinA_ = sin(ang_y); + double cosA_ = cos(ang_y); + double pointX_ = a_point.getX(); + double pointY_ = a_point.getY(); + double pointZ_ = a_point.getZ(); + + double x_new_ = pointX_*cosA_ + pointZ_*sinA_; + double y_new_ = pointY_; + double z_new_ = -pointX_*sinA_ + pointZ_*cosA_; + + a_point.set(x_new_, y_new_, z_new_); +} + +void RotateZ(const double ang_z,Point3D &a_point) +{ + double sinA_= sin(ang_z); + double cosA_ = cos(ang_z); + double pointX_ = a_point.getX(); + double pointY_ = a_point.getY(); + double pointZ_ = a_point.getZ(); + + double x_new_ = pointX_*cosA_ - pointY_*sinA_; + double y_new_ = pointX_*sinA_+ pointY_*cosA_; + double z_new_ = pointZ_; + + a_point.set(x_new_, y_new_, z_new_); +} + +void Translate(const double trans_x, const double trans_y, const double trans_z, Point3D &a_point) +{ + double x_new_ = a_point.getX() + trans_x; + double y_new_ = a_point.getY() + trans_y; + double z_new_ = a_point.getZ() + trans_z; + + a_point.set(x_new_, y_new_, z_new_); +} + +void RotateMatrix(const vector &rotmatrix, Point3D &a_point) +{ + if (rotmatrix.size() != 9) + throw std::runtime_error("The size for rotation matrix vector has to be 9."); + + double xold = a_point.getX(); + double yold = a_point.getY(); + double zold = a_point.getZ(); + + double x_new_ = rotmatrix[0]*xold + rotmatrix[1]*yold + rotmatrix[2]*zold; + double y_new_ = rotmatrix[3]*xold + rotmatrix[4]*yold + rotmatrix[5]*zold; + double z_new_ = rotmatrix[6]*xold + rotmatrix[7]*yold + rotmatrix[8]*zold; + + a_point.set(x_new_, y_new_, z_new_); +} diff --git a/sas/sascalc/simulation/geoshapespy/libgeoshapespy/transformation.h b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/transformation.h new file mode 100755 index 000000000..1f988651b --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/libgeoshapespy/transformation.h @@ -0,0 +1,21 @@ +/** \file transformation.h: a few functions to do the transformation */ + +#ifndef TRANSFORMATION_H +//Transform operation on a Point3D +//Not in use, similar transformation are implemented as Point3D class functions + +#define TRANSFORMATION_H + +#include +#include "Point3D.h" + +void RotateX(const double ang_x,Point3D &); +void RotateY(const double ang_y,Point3D &); +void RotateZ(const double ang_z,Point3D &); +void Translate(const double trans_x, const double trans_y, const double trans_z, Point3D &); + +//rotate a point by given a rotation 3X3 matrix +//vec[0]=R00, vec[1]=R01.......vec[8]=R22 +void RotateMatrix(const vector &, Point3D &); + +#endif diff --git a/sas/sascalc/simulation/geoshapespy/tests/Make.mm b/sas/sascalc/simulation/geoshapespy/tests/Make.mm new file mode 100755 index 000000000..a4fcaec58 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/tests/Make.mm @@ -0,0 +1,59 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +PROJECT = geoshapespy +PACKAGE = tests + +PROJ_CLEAN += $(PROJ_CPPTESTS) + +PROJ_CXX_INCLUDES += ../../iqPy/libiqPy \ + ../../iqPy/libiqPy/tnt \ + ../libgeoshapespy \ + ../../pointsmodelpy/libpointsmodelpy + +PROJ_PYTESTS = +PROJ_CPPTESTS = testpoint + +PROJ_TESTS = $(PROJ_PYTESTS) $(PROJ_CPPTESTS) +PROJ_LIBRARIES = -L$(BLD_LIBDIR) -lgeoshapespy -liqPy -lpointsmodelpy + + +#-------------------------------------------------------------------------- +# + +all: $(PROJ_TESTS) + +test: + for test in $(PROJ_TESTS) ; do $${test}; done + +release: tidy + cvs release . + +update: clean + cvs update . + +#-------------------------------------------------------------------------- +# + +testshapes: testshapes.cc $(BLD_LIBDIR)/libgeoshapespy.$(EXT_SAR) + $(CXX) $(CXXFLAGS) $(LCXXFLAGS) -o $@ testshapes.cc $(PROJ_LIBRARIES) + +testpoint: testPoint.cc $(BLD_LIBDIR)/libgeoshapespy.$(EXT_SAR) + $(CXX) $(CXXFLAGS) $(LCXXFLAGS) -o $@ testPoint.cc $(PROJ_LIBRARIES) + +testtransformation: testtransformation.cc $(BLD_LIBDIR)/libgeoshapespy.$(EXT_SAR) + $(CXX) $(CXXFLAGS) $(LCXXFLAGS) -o $@ testtransformation.cc $(PROJ_LIBRARIES) + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/geoshapespy/tests/testPoint.cc b/sas/sascalc/simulation/geoshapespy/tests/testPoint.cc new file mode 100755 index 000000000..4ef7abb7f --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/tests/testPoint.cc @@ -0,0 +1,40 @@ +#include +#include +#include "Point3D.h" + + +using namespace std; + +int main(){ + + cout << "test 1:initialize a point, set orientation, set center" + << " and perform transform" < orient(3),center(3); + for (int i = 0; i!=3; ++i){ + center[i] = 0; + } + orient[0]=10; + orient[1]=0; + orient[2]=30; + + apoint.Transform(orient,center); + + cout << apoint.getX() <<" "< rotmatrix(9,1); + for (size_t i = 0; i != rotmatrix.size(); ++i){ + cout << rotmatrix[i] <<" "; + } + + Point3D pp(1,1,1); + pp.TransformMatrix(rotmatrix, center); + + cout << pp.getX() <<" "< + +int main(){ + Sphere sph1(10); + sph1.SetOrientation(10,20,30); + sph1.SetCenter(1,5,10); + + vector v1 = sph1.GetOrientation(); + cout << v1[0] < +#include +#include +#include "sphere.h" +#include "cylinder.h" +#include "hollow_sphere.h" +#include "ellipsoid.h" +#include "single_helix.h" +#include "iq.h" + +using namespace std; + +void TestGetFormFactor_sphere() { + Sphere sphere(1.0); + + IQ iq1(10,0.001, 0.3); + sphere.GetFormFactor(&iq1); + + for (int i = 0; i< iq1.iq_data.dim1(); i++) + cout << iq1.iq_data[i][0]<< " " << iq1.iq_data[i][1] < ori; + ori.push_back(10); + ori.push_back(10); + ori.push_back(20); + + vector cen; + cen.push_back(20); + cen.push_back(20); + cen.push_back(20); + + p1.Transform(ori,cen); + cout << "is the p1 still inside after orientation?" << c1.IsInside(p1) << endl; + cout << "p1 p2 distance: " < +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +from __future__ import print_function + + +if __name__ == "__main__": + + from SASsimulation import geoshapespy + + print() + print("module information:") + print(" file:", geoshapespy.__file__) + print(" doc:", geoshapespy.__doc__) + print(" contents:", dir(geoshapespy)) + + sp = geoshapespy.new_sphere(10) +# geoshapespy.set_orientation(sp,10,20,10) + cy = geoshapespy.new_cylinder(2,6) + + el = geoshapespy.new_ellipsoid(25,15,10) + + hs = geoshapespy.new_hollowsphere(10,2) + + sh = geoshapespy.new_singlehelix(10,2,30,2) + +# version +__id__ = "$Id$" + +# End of file diff --git a/sas/sascalc/simulation/geoshapespy/tests/testsphere.cc b/sas/sascalc/simulation/geoshapespy/tests/testsphere.cc new file mode 100755 index 000000000..835a97444 --- /dev/null +++ b/sas/sascalc/simulation/geoshapespy/tests/testsphere.cc @@ -0,0 +1,36 @@ +#include +#include +#include "sphere.h" +#include "iq.h" + +using namespace std; + +void TestGetAnalyticalIQ() { + Sphere sphere(1.0); + + IQ iq1(10,0.001, 0.3); + sphere.GetFormFactor(&iq1); + + for (int i = 0; i< iq1.iq_data.dim1(); i++) + cout << iq1.iq_data[i][0]<< " " << iq1.iq_data[i][1] < +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +PROJECT = SASsimulation + +# directory structure + +BUILD_DIRS = \ + libiqPy \ + iqPymodule \ + iqPy \ + +OTHER_DIRS = \ + tests \ + examples + +RECURSE_DIRS = $(BUILD_DIRS) $(OTHER_DIRS) + +#-------------------------------------------------------------------------- +# + +all: + BLD_ACTION="all" $(MM) recurse + +distclean:: + BLD_ACTION="distclean" $(MM) recurse + +clean:: + BLD_ACTION="clean" $(MM) recurse + +tidy:: + BLD_ACTION="tidy" $(MM) recurse + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/iqPy/examples/Make.mm b/sas/sascalc/simulation/iqPy/examples/Make.mm new file mode 100755 index 000000000..830c03b5c --- /dev/null +++ b/sas/sascalc/simulation/iqPy/examples/Make.mm @@ -0,0 +1,32 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +PROJECT = SASsimulation +PACKAGE = examples + +#-------------------------------------------------------------------------- +# + +all: clean + +release: clean + cvs release . + +update: clean + cvs update . + + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/iqPy/iqPy/Make.mm b/sas/sascalc/simulation/iqPy/iqPy/Make.mm new file mode 100755 index 000000000..8b8ac655f --- /dev/null +++ b/sas/sascalc/simulation/iqPy/iqPy/Make.mm @@ -0,0 +1,35 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +PROJECT = SASsimulation +PACKAGE = iqPy + +#-------------------------------------------------------------------------- +# + +all: export + + +#-------------------------------------------------------------------------- +# +# export + +EXPORT_PYTHON_MODULES = \ + __init__.py + + +export:: export-python-modules + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/iqPy/iqPy/__init__.py b/sas/sascalc/simulation/iqPy/iqPy/__init__.py new file mode 100755 index 000000000..5223ebec7 --- /dev/null +++ b/sas/sascalc/simulation/iqPy/iqPy/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +def copyright(): + return "iqPy pyre module: Copyright (c) 1998-2005 Michael A.G. Aivazis"; + + +# version +__id__ = "$Id$" + +# End of file diff --git a/sas/sascalc/simulation/iqPy/iqPymodule/Make.mm b/sas/sascalc/simulation/iqPy/iqPymodule/Make.mm new file mode 100755 index 000000000..d2dd983cb --- /dev/null +++ b/sas/sascalc/simulation/iqPy/iqPymodule/Make.mm @@ -0,0 +1,31 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +PROJECT = SASsimulation +PACKAGE = iqPymodule +MODULE = iqPy + +include std-pythonmodule.def +include local.def + +PROJ_CXX_SRCLIB = -liqPy +PROJ_SRCS = \ + bindings.cc \ + exceptions.cc \ + misc.cc + + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/iqPy/iqPymodule/bindings.cc b/sas/sascalc/simulation/iqPy/iqPymodule/bindings.cc new file mode 100755 index 000000000..cfa3ab956 --- /dev/null +++ b/sas/sascalc/simulation/iqPy/iqPymodule/bindings.cc @@ -0,0 +1,44 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +//#include +#include + +#include "bindings.h" + +#include "misc.h" // miscellaneous methods + +// the method table + +struct PyMethodDef pyiqPy_methods[] = { + + // new_iq + {pyiqPy_new_iq__name__, pyiqPy_new_iq, + METH_VARARGS, "iq(int numI,double qmin, double qmax)->new IQ object"}, + + //OutputIQ + {pyiqPy_OutputIQ__name__, pyiqPy_OutputIQ, + METH_VARARGS, pyiqPy_OutputIQ__doc__}, + + {pyiqPy_copyright__name__, pyiqPy_copyright, + METH_VARARGS, pyiqPy_copyright__doc__}, + + +// Sentinel + {0, 0} +}; + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/iqPy/iqPymodule/bindings.h b/sas/sascalc/simulation/iqPy/iqPymodule/bindings.h new file mode 100755 index 000000000..0f27d3b2b --- /dev/null +++ b/sas/sascalc/simulation/iqPy/iqPymodule/bindings.h @@ -0,0 +1,26 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +#if !defined(pyiqPy_bindings_h) +#define pyiqPy_bindings_h + +// the method table + +extern struct PyMethodDef pyiqPy_methods[]; + +#endif + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/iqPy/iqPymodule/exceptions.cc b/sas/sascalc/simulation/iqPy/iqPymodule/exceptions.cc new file mode 100755 index 000000000..30fd7566a --- /dev/null +++ b/sas/sascalc/simulation/iqPy/iqPymodule/exceptions.cc @@ -0,0 +1,22 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +//#include +#include + +PyObject *pyiqPy_runtimeError = 0; + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/iqPy/iqPymodule/exceptions.h b/sas/sascalc/simulation/iqPy/iqPymodule/exceptions.h new file mode 100755 index 000000000..77a4f382d --- /dev/null +++ b/sas/sascalc/simulation/iqPy/iqPymodule/exceptions.h @@ -0,0 +1,26 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +#if !defined(pyiqPy_exceptions_h) +#define pyiqPy_exceptions_h + +// exceptions + +extern PyObject *pyiqPy_runtimeError; + +#endif + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/iqPy/iqPymodule/iqPymodule.cc b/sas/sascalc/simulation/iqPy/iqPymodule/iqPymodule.cc new file mode 100755 index 000000000..8f68df2bd --- /dev/null +++ b/sas/sascalc/simulation/iqPy/iqPymodule/iqPymodule.cc @@ -0,0 +1,52 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +//#include + +#include + +#include "exceptions.h" +#include "bindings.h" + + +char pyiqPy_module__doc__[] = ""; + +// Initialization function for the module (*must* be called initiqPy) +extern "C" +void +initiqPy() +{ + // create the module and add the functions + PyObject * m = Py_InitModule4( + "iqPy", pyiqPy_methods, + pyiqPy_module__doc__, 0, PYTHON_API_VERSION); + + // get its dictionary + PyObject * d = PyModule_GetDict(m); + + // check for errors + if (PyErr_Occurred()) { + Py_FatalError("can't initialize module iqPy"); + } + + // install the module exceptions + pyiqPy_runtimeError = PyErr_NewException("iqPy.runtime", 0, 0); + PyDict_SetItemString(d, "RuntimeException", pyiqPy_runtimeError); + + return; +} + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/iqPy/iqPymodule/local.def b/sas/sascalc/simulation/iqPy/iqPymodule/local.def new file mode 100755 index 000000000..7bfd75dcb --- /dev/null +++ b/sas/sascalc/simulation/iqPy/iqPymodule/local.def @@ -0,0 +1,21 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +# C++ + + PROJ_CXX_INCLUDES = ../libiqPy/tnt \ + ../libiqPy +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/iqPy/iqPymodule/misc.cc b/sas/sascalc/simulation/iqPy/iqPymodule/misc.cc new file mode 100755 index 000000000..a10f9d8be --- /dev/null +++ b/sas/sascalc/simulation/iqPy/iqPymodule/misc.cc @@ -0,0 +1,82 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +//#include +#include + +#include "misc.h" +#include "iq.h" + + +// copyright + +char pyiqPy_copyright__doc__[] = ""; +char pyiqPy_copyright__name__[] = "copyright"; + +static char pyiqPy_copyright_note[] = + "iqPy python module: Copyright (c) 1998-2005 Michael A.G. Aivazis"; + + +PyObject * pyiqPy_copyright(PyObject *, PyObject *) +{ + return Py_BuildValue("s", pyiqPy_copyright_note); +} + +// iq(int numI,double qmin,double qmax) +char pyiqPy_new_iq__doc__[] = "wrap class iq in C++"; +char pyiqPy_new_iq__name__[] = "new_iq"; + +PyObject * pyiqPy_new_iq(PyObject *, PyObject *args) +{ + int py_numI; + double py_qmin, py_qmax; + int ok = PyArg_ParseTuple(args,"idd",&py_numI, &py_qmin, &py_qmax); + if(!ok) return 0; + + IQ *iq = new IQ(py_numI,py_qmin,py_qmax); + + return PyCObject_FromVoidPtr(iq, NULL); +} + +//output iq to file +extern char pyiqPy_OutputIQ__name__[] = "OutputIQ"; +extern char pyiqPy_OutputIQ__doc__[] = ""; + +PyObject * pyiqPy_OutputIQ(PyObject *, PyObject *args){ + PyObject *pyiq = 0; + char *outfile; + int ok = PyArg_ParseTuple(args,"Os", &pyiq, &outfile); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pyiq); + + IQ * thisiq = static_cast(temp); + + thisiq->OutputIQ(outfile); + + return Py_BuildValue("i",0); +} + +//release the iq object +static void PyDeliq(void *ptr) +{ + std::cout<<"Called PyDeliq()\n"; //Good to see once + IQ * oldiq = static_cast(ptr); + delete oldiq; + return; +} + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/iqPy/iqPymodule/misc.h b/sas/sascalc/simulation/iqPy/iqPymodule/misc.h new file mode 100755 index 000000000..0c0478a18 --- /dev/null +++ b/sas/sascalc/simulation/iqPy/iqPymodule/misc.h @@ -0,0 +1,43 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +#if !defined(pyiqPy_misc_h) +#define pyiqPy_misc_h + +// copyright +extern char pyiqPy_copyright__name__[]; +extern char pyiqPy_copyright__doc__[]; +extern "C" +PyObject * pyiqPy_copyright(PyObject *, PyObject *); + +// iq constructor, iq(int, double, double) +extern char pyiqPy_new_iq__name__[]; +extern char pyiqPy_new_iq__doc__[]; +extern "C" +PyObject * pyiqPy_new_iq(PyObject *, PyObject *); + +//output iq to file +extern char pyiqPy_OutputIQ__name__[]; +extern char pyiqPy_OutputIQ__doc__[]; +extern "C" +PyObject * pyiqPy_OutputIQ(PyObject *, PyObject *); + +//release the iq object +static void PyDeliq(void *ptr); + +#endif + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/iqPy/libiqPy/Make.mm b/sas/sascalc/simulation/iqPy/libiqPy/Make.mm new file mode 100755 index 000000000..bc1c53d7d --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/Make.mm @@ -0,0 +1,64 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +include local.def + +PROJECT = iqPy +PACKAGE = libiqPy + +PROJ_SAR = $(BLD_LIBDIR)/$(PACKAGE).$(EXT_SAR) +PROJ_DLL = $(BLD_BINDIR)/$(PACKAGE).$(EXT_SO) +PROJ_TMPDIR = $(BLD_TMPDIR)/$(PROJECT)/$(PACKAGE) +PROJ_CLEAN += $(PROJ_SAR) $(PROJ_DLL) + +PROJ_SRCS = \ + iq.cc + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# build the library + +all: $(PROJ_SAR) export + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +ifeq (Win32, ${findstring Win32, $(PLATFORM_ID)}) + +# build the shared object +$(PROJ_SAR): product_dirs $(PROJ_OBJS) + $(CXX) $(LCXXFLAGS) -o $(PROJ_DLL) \ + -Wl,--out-implib=$(PROJ_SAR) $(PROJ_OBJS) + +# export +export:: export-headers export-libraries export-binaries + +else + +# build the shared object +$(PROJ_SAR): product_dirs $(PROJ_OBJS) + $(CXX) $(LCXXFLAGS) -o $(PROJ_SAR) $(PROJ_OBJS) + +# export +export:: export-headers export-libraries + +endif + +EXPORT_HEADERS = \ + iq.h + +EXPORT_LIBS = $(PROJ_SAR) +EXPORT_BINS = $(PROJ_DLL) + + +# version +# $Id$ + +# +# End of file diff --git a/sas/sascalc/simulation/iqPy/libiqPy/iq.cc b/sas/sascalc/simulation/iqPy/libiqPy/iq.cc new file mode 100755 index 000000000..276b1be13 --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/iq.cc @@ -0,0 +1,80 @@ +/** \file iq.cc */ + +#include "iq.h" +#include +#include + +using namespace std; + +IQ::IQ(int numI){ + + numI_ = numI; + + Array2D iq1(numI, 2, 0.0); + + iq_data = iq1; +} + +IQ::IQ(int numI,double qmin, double qmax){ + + numI_ = numI; + qmin_ = qmin; + qmax_ = qmax; + + Array2D iq1(numI, 2, 0.0); + + iq_data = iq1; +} + +void IQ::SetQmin(double qmin){ + qmin_ = qmin; + +} + +void IQ::SetQmax(double qmax){ + qmax_ = qmax; +} + +void IQ::SetContrast(double delrho){ + delrho_ = delrho; +} + +void IQ::SetVolFrac(double vol_frac){ + vol_frac_ = vol_frac; +} + +double IQ::GetQmin(){ + return qmin_; +} + +double IQ::GetQmax(){ + return qmax_; +} + +double IQ::GetContrast(){ + return delrho_; +} + +double IQ::GetVolFrac(){ + return vol_frac_; +} + +int IQ::GetNumI(){ + return numI_; +} + +void IQ::OutputIQ(string fiq){ + ofstream outfile(fiq.c_str()); + if (!outfile) { + cerr << "error: unable to open output file: " + << outfile << endl; + exit(1); + } + for (int i = 0; i < iq_data.dim1(); ++i){ + outfile << iq_data[i][0] << " " << iq_data[i][1] << endl; + } + // fprintf(fp,"%15lf%15lf%15lf\n", (j+1)*rstep, cor[j]/cormax, "0"); + + + +} diff --git a/sas/sascalc/simulation/iqPy/libiqPy/iq.h b/sas/sascalc/simulation/iqPy/libiqPy/iq.h new file mode 100755 index 000000000..264451fff --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/iq.h @@ -0,0 +1,45 @@ +/** \file iq.h class IQ */ + +#ifndef IQ_H +#define IQ_H + +#include +/**tnt: template numerical toolkit, http://math.nist.gov/tnt/ */ +#include "tnt/tnt.h" +using namespace TNT; + +class IQ{ + + public: + IQ(int numI); + IQ(int numI,double qmin, double qmax); + + void SetQmin(double qmin); + void SetQmax(double qmax); + void SetContrast(double delrho); + void SetVolFrac(double vol_frac); + void SetIQArray(Array2D iq_array); + + double GetQmin(); + double GetQmax(); + double GetContrast(); + double GetVolFrac(); + int GetNumI(); + + void OutputIQ(std::string fiq); + + Array2D iq_data; + + private: + IQ(); + double qmin_; + double qmax_; + double delrho_; + int numI_; + double vol_frac_; + + +}; + + +#endif diff --git a/sas/sascalc/simulation/iqPy/libiqPy/local.def b/sas/sascalc/simulation/iqPy/libiqPy/local.def new file mode 100755 index 000000000..66c0028b9 --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/local.def @@ -0,0 +1,27 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +# +# Local (project) definitions +# + +# C++ + +PROJ_CXX_INCLUDES += tnt + PROJ_CXX_FLAGS += $(CXX_SOFLAGS) + PROJ_LCXX_FLAGS += $(LCXX_SARFLAGS) + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt.h b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt.h new file mode 100755 index 000000000..92463e08a --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt.h @@ -0,0 +1,64 @@ +/* +* +* Template Numerical Toolkit (TNT): Linear Algebra Module +* +* Mathematical and Computational Sciences Division +* National Institute of Technology, +* Gaithersburg, MD USA +* +* +* This software was developed at the National Institute of Standards and +* Technology (NIST) by employees of the Federal Government in the course +* of their official duties. Pursuant to title 17 Section 105 of the +* United States Code, this software is not subject to copyright protection +* and is in the public domain. NIST assumes no responsibility whatsoever for +* its use by other parties, and makes no guarantees, expressed or implied, +* about its quality, reliability, or any other characteristic. +* +*/ + + +#ifndef TNT_H +#define TNT_H + + + +//--------------------------------------------------------------------- +// Define this macro if you want TNT to track some of the out-of-bounds +// indexing. This can encur a small run-time overhead, but is recommended +// while developing code. It can be turned off for production runs. +// +// #define TNT_BOUNDS_CHECK +//--------------------------------------------------------------------- +// + +//#define TNT_BOUNDS_CHECK + + + +#include "tnt_version.h" +#include "tnt_math_utils.h" +#include "tnt_array1d.h" +#include "tnt_array2d.h" +#include "tnt_array3d.h" +#include "tnt_array1d_utils.h" +#include "tnt_array2d_utils.h" +#include "tnt_array3d_utils.h" + +#include "tnt_fortran_array1d.h" +#include "tnt_fortran_array2d.h" +#include "tnt_fortran_array3d.h" +#include "tnt_fortran_array1d_utils.h" +#include "tnt_fortran_array2d_utils.h" +#include "tnt_fortran_array3d_utils.h" + +#include "tnt_sparse_matrix_csr.h" + +#include "tnt_stopwatch.h" +#include "tnt_subscript.h" +#include "tnt_vec.h" +#include "tnt_cmat.h" + + +#endif +// TNT_H diff --git a/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_array1d.h b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_array1d.h new file mode 100755 index 000000000..858df5798 --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_array1d.h @@ -0,0 +1,278 @@ +/* +* +* Template Numerical Toolkit (TNT) +* +* Mathematical and Computational Sciences Division +* National Institute of Technology, +* Gaithersburg, MD USA +* +* +* This software was developed at the National Institute of Standards and +* Technology (NIST) by employees of the Federal Government in the course +* of their official duties. Pursuant to title 17 Section 105 of the +* United States Code, this software is not subject to copyright protection +* and is in the public domain. NIST assumes no responsibility whatsoever for +* its use by other parties, and makes no guarantees, expressed or implied, +* about its quality, reliability, or any other characteristic. +* +*/ + + + +#ifndef TNT_ARRAY1D_H +#define TNT_ARRAY1D_H + +//#include +#include + +#ifdef TNT_BOUNDS_CHECK +#include +#endif + + +#include "tnt_i_refvec.h" + +namespace TNT +{ + +template +class Array1D +{ + + private: + + /* ... */ + i_refvec v_; + int n_; + T* data_; /* this normally points to v_.begin(), but + * could also point to a portion (subvector) + * of v_. + */ + + void copy_(T* p, const T* q, int len) const; + void set_(T* begin, T* end, const T& val); + + + public: + + typedef T value_type; + + + Array1D(); + explicit Array1D(int n); + Array1D(int n, const T &a); + Array1D(int n, T *a); + inline Array1D(const Array1D &A); + inline operator T*(); + inline operator const T*(); + inline Array1D & operator=(const T &a); + inline Array1D & operator=(const Array1D &A); + inline Array1D & ref(const Array1D &A); + Array1D copy() const; + Array1D & inject(const Array1D & A); + inline T& operator[](int i); + inline const T& operator[](int i) const; + inline int dim1() const; + inline int dim() const; + ~Array1D(); + + + /* ... extended interface ... */ + + inline int ref_count() const; + inline Array1D subarray(int i0, int i1); + +}; + + + + +template +Array1D::Array1D() : v_(), n_(0), data_(0) {} + +template +Array1D::Array1D(const Array1D &A) : v_(A.v_), n_(A.n_), + data_(A.data_) +{ +#ifdef TNT_DEBUG + std::cout << "Created Array1D(const Array1D &A) \n"; +#endif + +} + + +template +Array1D::Array1D(int n) : v_(n), n_(n), data_(v_.begin()) +{ +#ifdef TNT_DEBUG + std::cout << "Created Array1D(int n) \n"; +#endif +} + +template +Array1D::Array1D(int n, const T &val) : v_(n), n_(n), data_(v_.begin()) +{ +#ifdef TNT_DEBUG + std::cout << "Created Array1D(int n, const T& val) \n"; +#endif + set_(data_, data_+ n, val); + +} + +template +Array1D::Array1D(int n, T *a) : v_(a), n_(n) , data_(v_.begin()) +{ +#ifdef TNT_DEBUG + std::cout << "Created Array1D(int n, T* a) \n"; +#endif +} + +template +inline Array1D::operator T*() +{ + return &(v_[0]); +} + + +template +inline Array1D::operator const T*() +{ + return &(v_[0]); +} + + + +template +inline T& Array1D::operator[](int i) +{ +#ifdef TNT_BOUNDS_CHECK + assert(i>= 0); + assert(i < n_); +#endif + return data_[i]; +} + +template +inline const T& Array1D::operator[](int i) const +{ +#ifdef TNT_BOUNDS_CHECK + assert(i>= 0); + assert(i < n_); +#endif + return data_[i]; +} + + + + +template +Array1D & Array1D::operator=(const T &a) +{ + set_(data_, data_+n_, a); + return *this; +} + +template +Array1D Array1D::copy() const +{ + Array1D A( n_); + copy_(A.data_, data_, n_); + + return A; +} + + +template +Array1D & Array1D::inject(const Array1D &A) +{ + if (A.n_ == n_) + copy_(data_, A.data_, n_); + + return *this; +} + + + + + +template +Array1D & Array1D::ref(const Array1D &A) +{ + if (this != &A) + { + v_ = A.v_; /* operator= handles the reference counting. */ + n_ = A.n_; + data_ = A.data_; + + } + return *this; +} + +template +Array1D & Array1D::operator=(const Array1D &A) +{ + return ref(A); +} + +template +inline int Array1D::dim1() const { return n_; } + +template +inline int Array1D::dim() const { return n_; } + +template +Array1D::~Array1D() {} + + +/* ............................ exented interface ......................*/ + +template +inline int Array1D::ref_count() const +{ + return v_.ref_count(); +} + +template +inline Array1D Array1D::subarray(int i0, int i1) +{ + if ((i0 > 0) && (i1 < n_) || (i0 <= i1)) + { + Array1D X(*this); /* create a new instance of this array. */ + X.n_ = i1-i0+1; + X.data_ += i0; + + return X; + } + else + { + return Array1D(); + } +} + + +/* private internal functions */ + + +template +void Array1D::set_(T* begin, T* end, const T& a) +{ + for (T* p=begin; p +void Array1D::copy_(T* p, const T* q, int len) const +{ + T *end = p + len; + while (p +#include + +namespace TNT +{ + + +template +std::ostream& operator<<(std::ostream &s, const Array1D &A) +{ + int N=A.dim1(); + +#ifdef TNT_DEBUG + s << "addr: " << (void *) &A[0] << "\n"; +#endif + s << N << "\n"; + for (int j=0; j +std::istream& operator>>(std::istream &s, Array1D &A) +{ + int N; + s >> N; + + Array1D B(N); + for (int i=0; i> B[i]; + A = B; + return s; +} + + + +template +Array1D operator+(const Array1D &A, const Array1D &B) +{ + int n = A.dim1(); + + if (B.dim1() != n ) + return Array1D(); + + else + { + Array1D C(n); + + for (int i=0; i +Array1D operator-(const Array1D &A, const Array1D &B) +{ + int n = A.dim1(); + + if (B.dim1() != n ) + return Array1D(); + + else + { + Array1D C(n); + + for (int i=0; i +Array1D operator*(const Array1D &A, const Array1D &B) +{ + int n = A.dim1(); + + if (B.dim1() != n ) + return Array1D(); + + else + { + Array1D C(n); + + for (int i=0; i +Array1D operator/(const Array1D &A, const Array1D &B) +{ + int n = A.dim1(); + + if (B.dim1() != n ) + return Array1D(); + + else + { + Array1D C(n); + + for (int i=0; i +Array1D& operator+=(Array1D &A, const Array1D &B) +{ + int n = A.dim1(); + + if (B.dim1() == n) + { + for (int i=0; i +Array1D& operator-=(Array1D &A, const Array1D &B) +{ + int n = A.dim1(); + + if (B.dim1() == n) + { + for (int i=0; i +Array1D& operator*=(Array1D &A, const Array1D &B) +{ + int n = A.dim1(); + + if (B.dim1() == n) + { + for (int i=0; i +Array1D& operator/=(Array1D &A, const Array1D &B) +{ + int n = A.dim1(); + + if (B.dim1() == n) + { + for (int i=0; i +#include +#ifdef TNT_BOUNDS_CHECK +#include +#endif + +#include "tnt_array1d.h" + +namespace TNT +{ + +template +class Array2D +{ + + + private: + + + + Array1D data_; + Array1D v_; + int m_; + int n_; + + public: + + typedef T value_type; + Array2D(); + Array2D(int m, int n); + Array2D(int m, int n, T *a); + Array2D(int m, int n, const T &a); + inline Array2D(const Array2D &A); + inline operator T**(); + inline operator const T**(); + inline Array2D & operator=(const T &a); + inline Array2D & operator=(const Array2D &A); + inline Array2D & ref(const Array2D &A); + Array2D copy() const; + Array2D & inject(const Array2D & A); + inline T* operator[](int i); + inline const T* operator[](int i) const; + inline int dim1() const; + inline int dim2() const; + ~Array2D(); + + /* extended interface (not part of the standard) */ + + + inline int ref_count(); + inline int ref_count_data(); + inline int ref_count_dim1(); + Array2D subarray(int i0, int i1, int j0, int j1); + +}; + + +template +Array2D::Array2D() : data_(), v_(), m_(0), n_(0) {} + +template +Array2D::Array2D(const Array2D &A) : data_(A.data_), v_(A.v_), + m_(A.m_), n_(A.n_) {} + + + + +template +Array2D::Array2D(int m, int n) : data_(m*n), v_(m), m_(m), n_(n) +{ + if (m>0 && n>0) + { + T* p = &(data_[0]); + for (int i=0; i +Array2D::Array2D(int m, int n, const T &val) : data_(m*n), v_(m), + m_(m), n_(n) +{ + if (m>0 && n>0) + { + data_ = val; + T* p = &(data_[0]); + for (int i=0; i +Array2D::Array2D(int m, int n, T *a) : data_(m*n, a), v_(m), m_(m), n_(n) +{ + if (m>0 && n>0) + { + T* p = &(data_[0]); + + for (int i=0; i +inline T* Array2D::operator[](int i) +{ +#ifdef TNT_BOUNDS_CHECK + assert(i >= 0); + assert(i < m_); +#endif + +return v_[i]; + +} + + +template +inline const T* Array2D::operator[](int i) const +{ +#ifdef TNT_BOUNDS_CHECK + assert(i >= 0); + assert(i < m_); +#endif + +return v_[i]; + +} + +template +Array2D & Array2D::operator=(const T &a) +{ + /* non-optimzied, but will work with subarrays in future verions */ + + for (int i=0; i +Array2D Array2D::copy() const +{ + Array2D A(m_, n_); + + for (int i=0; i +Array2D & Array2D::inject(const Array2D &A) +{ + if (A.m_ == m_ && A.n_ == n_) + { + for (int i=0; i +Array2D & Array2D::ref(const Array2D &A) +{ + if (this != &A) + { + v_ = A.v_; + data_ = A.data_; + m_ = A.m_; + n_ = A.n_; + + } + return *this; +} + + + +template +Array2D & Array2D::operator=(const Array2D &A) +{ + return ref(A); +} + +template +inline int Array2D::dim1() const { return m_; } + +template +inline int Array2D::dim2() const { return n_; } + + +template +Array2D::~Array2D() {} + + + + +template +inline Array2D::operator T**() +{ + return &(v_[0]); +} +template +inline Array2D::operator const T**() +{ + return &(v_[0]); +} + +/* ............... extended interface ............... */ +/** + Create a new view to a subarray defined by the boundaries + [i0][i0] and [i1][j1]. The size of the subarray is + (i1-i0) by (j1-j0). If either of these lengths are zero + or negative, the subarray view is null. + +*/ +template +Array2D Array2D::subarray(int i0, int i1, int j0, int j1) +{ + Array2D A; + int m = i1-i0+1; + int n = j1-j0+1; + + /* if either length is zero or negative, this is an invalide + subarray. return a null view. + */ + if (m<1 || n<1) + return A; + + A.data_ = data_; + A.m_ = m; + A.n_ = n; + A.v_ = Array1D(m); + T* p = &(data_[0]) + i0 * n_ + j0; + for (int i=0; i +inline int Array2D::ref_count() +{ + return ref_count_data(); +} + + + +template +inline int Array2D::ref_count_data() +{ + return data_.ref_count(); +} + +template +inline int Array2D::ref_count_dim1() +{ + return v_.ref_count(); +} + + + + +} /* namespace TNT */ + +#endif +/* TNT_ARRAY2D_H */ + diff --git a/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_array2d_utils.h b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_array2d_utils.h new file mode 100755 index 000000000..7041ed378 --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_array2d_utils.h @@ -0,0 +1,287 @@ +/* +* +* Template Numerical Toolkit (TNT) +* +* Mathematical and Computational Sciences Division +* National Institute of Technology, +* Gaithersburg, MD USA +* +* +* This software was developed at the National Institute of Standards and +* Technology (NIST) by employees of the Federal Government in the course +* of their official duties. Pursuant to title 17 Section 105 of the +* United States Code, this software is not subject to copyright protection +* and is in the public domain. NIST assumes no responsibility whatsoever for +* its use by other parties, and makes no guarantees, expressed or implied, +* about its quality, reliability, or any other characteristic. +* +*/ + + +#ifndef TNT_ARRAY2D_UTILS_H +#define TNT_ARRAY2D_UTILS_H + +#include +#include + +namespace TNT +{ + + +template +std::ostream& operator<<(std::ostream &s, const Array2D &A) +{ + int M=A.dim1(); + int N=A.dim2(); + + s << M << " " << N << "\n"; + + for (int i=0; i +std::istream& operator>>(std::istream &s, Array2D &A) +{ + + int M, N; + + s >> M >> N; + + Array2D B(M,N); + + for (int i=0; i> B[i][j]; + } + + A = B; + return s; +} + + +template +Array2D operator+(const Array2D &A, const Array2D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + + if (B.dim1() != m || B.dim2() != n ) + return Array2D(); + + else + { + Array2D C(m,n); + + for (int i=0; i +Array2D operator-(const Array2D &A, const Array2D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + + if (B.dim1() != m || B.dim2() != n ) + return Array2D(); + + else + { + Array2D C(m,n); + + for (int i=0; i +Array2D operator*(const Array2D &A, const Array2D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + + if (B.dim1() != m || B.dim2() != n ) + return Array2D(); + + else + { + Array2D C(m,n); + + for (int i=0; i +Array2D operator/(const Array2D &A, const Array2D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + + if (B.dim1() != m || B.dim2() != n ) + return Array2D(); + + else + { + Array2D C(m,n); + + for (int i=0; i +Array2D& operator+=(Array2D &A, const Array2D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + + if (B.dim1() == m || B.dim2() == n ) + { + for (int i=0; i +Array2D& operator-=(Array2D &A, const Array2D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + + if (B.dim1() == m || B.dim2() == n ) + { + for (int i=0; i +Array2D& operator*=(Array2D &A, const Array2D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + + if (B.dim1() == m || B.dim2() == n ) + { + for (int i=0; i +Array2D& operator/=(Array2D &A, const Array2D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + + if (B.dim1() == m || B.dim2() == n ) + { + for (int i=0; i +Array2D matmult(const Array2D &A, const Array2D &B) +{ + if (A.dim2() != B.dim1()) + return Array2D(); + + int M = A.dim1(); + int N = A.dim2(); + int K = B.dim2(); + + Array2D C(M,K); + + for (int i=0; i +#include +#ifdef TNT_BOUNDS_CHECK +#include +#endif + +#include "tnt_array1d.h" +#include "tnt_array2d.h" + +namespace TNT +{ + +template +class Array3D +{ + + + private: + Array1D data_; + Array2D v_; + int m_; + int n_; + int g_; + + + public: + + typedef T value_type; + + Array3D(); + Array3D(int m, int n, int g); + Array3D(int m, int n, int g, T val); + Array3D(int m, int n, int g, T *a); + + inline operator T***(); + inline operator const T***(); + inline Array3D(const Array3D &A); + inline Array3D & operator=(const T &a); + inline Array3D & operator=(const Array3D &A); + inline Array3D & ref(const Array3D &A); + Array3D copy() const; + Array3D & inject(const Array3D & A); + + inline T** operator[](int i); + inline const T* const * operator[](int i) const; + inline int dim1() const; + inline int dim2() const; + inline int dim3() const; + ~Array3D(); + + /* extended interface */ + + inline int ref_count(){ return data_.ref_count(); } + Array3D subarray(int i0, int i1, int j0, int j1, + int k0, int k1); +}; + +template +Array3D::Array3D() : data_(), v_(), m_(0), n_(0) {} + +template +Array3D::Array3D(const Array3D &A) : data_(A.data_), + v_(A.v_), m_(A.m_), n_(A.n_), g_(A.g_) +{ +} + + + +template +Array3D::Array3D(int m, int n, int g) : data_(m*n*g), v_(m,n), + m_(m), n_(n), g_(g) +{ + + if (m>0 && n>0 && g>0) + { + T* p = & (data_[0]); + int ng = n_*g_; + + for (int i=0; i +Array3D::Array3D(int m, int n, int g, T val) : data_(m*n*g, val), + v_(m,n), m_(m), n_(n), g_(g) +{ + if (m>0 && n>0 && g>0) + { + + T* p = & (data_[0]); + int ng = n_*g_; + + for (int i=0; i +Array3D::Array3D(int m, int n, int g, T* a) : + data_(m*n*g, a), v_(m,n), m_(m), n_(n), g_(g) +{ + + if (m>0 && n>0 && g>0) + { + T* p = & (data_[0]); + int ng = n_*g_; + + for (int i=0; i +inline T** Array3D::operator[](int i) +{ +#ifdef TNT_BOUNDS_CHECK + assert(i >= 0); + assert(i < m_); +#endif + +return v_[i]; + +} + +template +inline const T* const * Array3D::operator[](int i) const +{ return v_[i]; } + +template +Array3D & Array3D::operator=(const T &a) +{ + for (int i=0; i +Array3D Array3D::copy() const +{ + Array3D A(m_, n_, g_); + for (int i=0; i +Array3D & Array3D::inject(const Array3D &A) +{ + if (A.m_ == m_ && A.n_ == n_ && A.g_ == g_) + + for (int i=0; i +Array3D & Array3D::ref(const Array3D &A) +{ + if (this != &A) + { + m_ = A.m_; + n_ = A.n_; + g_ = A.g_; + v_ = A.v_; + data_ = A.data_; + } + return *this; +} + +template +Array3D & Array3D::operator=(const Array3D &A) +{ + return ref(A); +} + + +template +inline int Array3D::dim1() const { return m_; } + +template +inline int Array3D::dim2() const { return n_; } + +template +inline int Array3D::dim3() const { return g_; } + + + +template +Array3D::~Array3D() {} + +template +inline Array3D::operator T***() +{ + return v_; +} + + +template +inline Array3D::operator const T***() +{ + return v_; +} + +/* extended interface */ +template +Array3D Array3D::subarray(int i0, int i1, int j0, + int j1, int k0, int k1) +{ + + /* check that ranges are valid. */ + if (!( 0 <= i0 && i0 <= i1 && i1 < m_ && + 0 <= j0 && j0 <= j1 && j1 < n_ && + 0 <= k0 && k0 <= k1 && k1 < g_)) + return Array3D(); /* null array */ + + + Array3D A; + A.data_ = data_; + A.m_ = i1-i0+1; + A.n_ = j1-j0+1; + A.g_ = k1-k0+1; + A.v_ = Array2D(A.m_,A.n_); + T* p = &(data_[0]) + i0*n_*g_ + j0*g_ + k0; + + for (int i=0; i +#include + +namespace TNT +{ + + +template +std::ostream& operator<<(std::ostream &s, const Array3D &A) +{ + int M=A.dim1(); + int N=A.dim2(); + int K=A.dim3(); + + s << M << " " << N << " " << K << "\n"; + + for (int i=0; i +std::istream& operator>>(std::istream &s, Array3D &A) +{ + + int M, N, K; + + s >> M >> N >> K; + + Array3D B(M,N,K); + + for (int i=0; i> B[i][j][k]; + + A = B; + return s; +} + + + +template +Array3D operator+(const Array3D &A, const Array3D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + int p = A.dim3(); + + if (B.dim1() != m || B.dim2() != n || B.dim3() != p ) + return Array3D(); + + else + { + Array3D C(m,n,p); + + for (int i=0; i +Array3D operator-(const Array3D &A, const Array3D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + int p = A.dim3(); + + if (B.dim1() != m || B.dim2() != n || B.dim3() != p ) + return Array3D(); + + else + { + Array3D C(m,n,p); + + for (int i=0; i +Array3D operator*(const Array3D &A, const Array3D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + int p = A.dim3(); + + if (B.dim1() != m || B.dim2() != n || B.dim3() != p ) + return Array3D(); + + else + { + Array3D C(m,n,p); + + for (int i=0; i +Array3D operator/(const Array3D &A, const Array3D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + int p = A.dim3(); + + if (B.dim1() != m || B.dim2() != n || B.dim3() != p ) + return Array3D(); + + else + { + Array3D C(m,n,p); + + for (int i=0; i +Array3D& operator+=(Array3D &A, const Array3D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + int p = A.dim3(); + + if (B.dim1() == m && B.dim2() == n && B.dim3() == p ) + { + for (int i=0; i +Array3D& operator-=(Array3D &A, const Array3D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + int p = A.dim3(); + + if (B.dim1() == m && B.dim2() == n && B.dim3() == p ) + { + for (int i=0; i +Array3D& operator*=(Array3D &A, const Array3D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + int p = A.dim3(); + + if (B.dim1() == m && B.dim2() == n && B.dim3() == p ) + { + for (int i=0; i +Array3D& operator/=(Array3D &A, const Array3D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + int p = A.dim3(); + + if (B.dim1() == m && B.dim2() == n && B.dim3() == p ) + { + for (int i=0; i +#include +#include +#include + +namespace TNT +{ + + +template +class Matrix +{ + + + public: + + typedef Subscript size_type; + typedef T value_type; + typedef T element_type; + typedef T* pointer; + typedef T* iterator; + typedef T& reference; + typedef const T* const_iterator; + typedef const T& const_reference; + + Subscript lbound() const { return 1;} + + protected: + Subscript m_; + Subscript n_; + Subscript mn_; // total size + T* v_; + T** row_; + T* vm1_ ; // these point to the same data, but are 1-based + T** rowm1_; + + // internal helper function to create the array + // of row pointers + + void initialize(Subscript M, Subscript N) + { + mn_ = M*N; + m_ = M; + n_ = N; + + v_ = new T[mn_]; + row_ = new T*[M]; + rowm1_ = new T*[M]; + + assert(v_ != NULL); + assert(row_ != NULL); + assert(rowm1_ != NULL); + + T* p = v_; + vm1_ = v_ - 1; + for (Subscript i=0; i &A) + { + initialize(A.m_, A.n_); + copy(A.v_); + } + + Matrix(Subscript M, Subscript N, const T& value = T()) + { + initialize(M,N); + set(value); + } + + Matrix(Subscript M, Subscript N, const T* v) + { + initialize(M,N); + copy(v); + } + + Matrix(Subscript M, Subscript N, const char *s) + { + initialize(M,N); + //std::istrstream ins(s); + std::istringstream ins(s); + + Subscript i, j; + + for (i=0; i> row_[i][j]; + } + + // destructor + // + ~Matrix() + { + destroy(); + } + + + // reallocating + // + Matrix& newsize(Subscript M, Subscript N) + { + if (num_rows() == M && num_cols() == N) + return *this; + + destroy(); + initialize(M,N); + + return *this; + } + + + + + // assignments + // + Matrix& operator=(const Matrix &A) + { + if (v_ == A.v_) + return *this; + + if (m_ == A.m_ && n_ == A.n_) // no need to re-alloc + copy(A.v_); + + else + { + destroy(); + initialize(A.m_, A.n_); + copy(A.v_); + } + + return *this; + } + + Matrix& operator=(const T& scalar) + { + set(scalar); + return *this; + } + + + Subscript dim(Subscript d) const + { +#ifdef TNT_BOUNDS_CHECK + assert( d >= 1); + assert( d <= 2); +#endif + return (d==1) ? m_ : ((d==2) ? n_ : 0); + } + + Subscript num_rows() const { return m_; } + Subscript num_cols() const { return n_; } + + + + + inline T* operator[](Subscript i) + { +#ifdef TNT_BOUNDS_CHECK + assert(0<=i); + assert(i < m_) ; +#endif + return row_[i]; + } + + inline const T* operator[](Subscript i) const + { +#ifdef TNT_BOUNDS_CHECK + assert(0<=i); + assert(i < m_) ; +#endif + return row_[i]; + } + + inline reference operator()(Subscript i) + { +#ifdef TNT_BOUNDS_CHECK + assert(1<=i); + assert(i <= mn_) ; +#endif + return vm1_[i]; + } + + inline const_reference operator()(Subscript i) const + { +#ifdef TNT_BOUNDS_CHECK + assert(1<=i); + assert(i <= mn_) ; +#endif + return vm1_[i]; + } + + + + inline reference operator()(Subscript i, Subscript j) + { +#ifdef TNT_BOUNDS_CHECK + assert(1<=i); + assert(i <= m_) ; + assert(1<=j); + assert(j <= n_); +#endif + return rowm1_[i][j]; + } + + + + inline const_reference operator() (Subscript i, Subscript j) const + { +#ifdef TNT_BOUNDS_CHECK + assert(1<=i); + assert(i <= m_) ; + assert(1<=j); + assert(j <= n_); +#endif + return rowm1_[i][j]; + } + + + + +}; + + +/* *************************** I/O ********************************/ + +template +std::ostream& operator<<(std::ostream &s, const Matrix &A) +{ + Subscript M=A.num_rows(); + Subscript N=A.num_cols(); + + s << M << " " << N << "\n"; + + for (Subscript i=0; i +std::istream& operator>>(std::istream &s, Matrix &A) +{ + + Subscript M, N; + + s >> M >> N; + + if ( !(M == A.num_rows() && N == A.num_cols() )) + { + A.newsize(M,N); + } + + + for (Subscript i=0; i> A[i][j]; + } + + + return s; +} + +// *******************[ basic matrix algorithms ]*************************** + + +template +Matrix operator+(const Matrix &A, + const Matrix &B) +{ + Subscript M = A.num_rows(); + Subscript N = A.num_cols(); + + assert(M==B.num_rows()); + assert(N==B.num_cols()); + + Matrix tmp(M,N); + Subscript i,j; + + for (i=0; i +Matrix operator-(const Matrix &A, + const Matrix &B) +{ + Subscript M = A.num_rows(); + Subscript N = A.num_cols(); + + assert(M==B.num_rows()); + assert(N==B.num_cols()); + + Matrix tmp(M,N); + Subscript i,j; + + for (i=0; i +Matrix mult_element(const Matrix &A, + const Matrix &B) +{ + Subscript M = A.num_rows(); + Subscript N = A.num_cols(); + + assert(M==B.num_rows()); + assert(N==B.num_cols()); + + Matrix tmp(M,N); + Subscript i,j; + + for (i=0; i +Matrix transpose(const Matrix &A) +{ + Subscript M = A.num_rows(); + Subscript N = A.num_cols(); + + Matrix S(N,M); + Subscript i, j; + + for (i=0; i +inline Matrix matmult(const Matrix &A, + const Matrix &B) +{ + +#ifdef TNT_BOUNDS_CHECK + assert(A.num_cols() == B.num_rows()); +#endif + + Subscript M = A.num_rows(); + Subscript N = A.num_cols(); + Subscript K = B.num_cols(); + + Matrix tmp(M,K); + T sum; + + for (Subscript i=0; i +inline Matrix operator*(const Matrix &A, + const Matrix &B) +{ + return matmult(A,B); +} + +template +inline int matmult(Matrix& C, const Matrix &A, + const Matrix &B) +{ + + assert(A.num_cols() == B.num_rows()); + + Subscript M = A.num_rows(); + Subscript N = A.num_cols(); + Subscript K = B.num_cols(); + + C.newsize(M,K); + + T sum; + + const T* row_i; + const T* col_k; + + for (Subscript i=0; i +Vector matmult(const Matrix &A, const Vector &x) +{ + +#ifdef TNT_BOUNDS_CHECK + assert(A.num_cols() == x.dim()); +#endif + + Subscript M = A.num_rows(); + Subscript N = A.num_cols(); + + Vector tmp(M); + T sum; + + for (Subscript i=0; i +inline Vector operator*(const Matrix &A, const Vector &x) +{ + return matmult(A,x); +} + +} // namespace TNT + +#endif +// CMAT_H diff --git a/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_fortran_array1d.h b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_fortran_array1d.h new file mode 100755 index 000000000..ad3bba0c0 --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_fortran_array1d.h @@ -0,0 +1,267 @@ +/* +* +* Template Numerical Toolkit (TNT) +* +* Mathematical and Computational Sciences Division +* National Institute of Technology, +* Gaithersburg, MD USA +* +* +* This software was developed at the National Institute of Standards and +* Technology (NIST) by employees of the Federal Government in the course +* of their official duties. Pursuant to title 17 Section 105 of the +* United States Code, this software is not subject to copyright protection +* and is in the public domain. NIST assumes no responsibility whatsoever for +* its use by other parties, and makes no guarantees, expressed or implied, +* about its quality, reliability, or any other characteristic. +* +*/ + + + +#ifndef TNT_FORTRAN_ARRAY1D_H +#define TNT_FORTRAN_ARRAY1D_H + +#include +#include + +#ifdef TNT_BOUNDS_CHECK +#include +#endif + + +#include "tnt_i_refvec.h" + +namespace TNT +{ + +template +class Fortran_Array1D +{ + + private: + + i_refvec v_; + int n_; + T* data_; /* this normally points to v_.begin(), but + * could also point to a portion (subvector) + * of v_. + */ + + void initialize_(int n); + void copy_(T* p, const T* q, int len) const; + void set_(T* begin, T* end, const T& val); + + + public: + + typedef T value_type; + + + Fortran_Array1D(); + explicit Fortran_Array1D(int n); + Fortran_Array1D(int n, const T &a); + Fortran_Array1D(int n, T *a); + inline Fortran_Array1D(const Fortran_Array1D &A); + inline Fortran_Array1D & operator=(const T &a); + inline Fortran_Array1D & operator=(const Fortran_Array1D &A); + inline Fortran_Array1D & ref(const Fortran_Array1D &A); + Fortran_Array1D copy() const; + Fortran_Array1D & inject(const Fortran_Array1D & A); + inline T& operator()(int i); + inline const T& operator()(int i) const; + inline int dim1() const; + inline int dim() const; + ~Fortran_Array1D(); + + + /* ... extended interface ... */ + + inline int ref_count() const; + inline Fortran_Array1D subarray(int i0, int i1); + +}; + + + + +template +Fortran_Array1D::Fortran_Array1D() : v_(), n_(0), data_(0) {} + +template +Fortran_Array1D::Fortran_Array1D(const Fortran_Array1D &A) : v_(A.v_), n_(A.n_), + data_(A.data_) +{ +#ifdef TNT_DEBUG + std::cout << "Created Fortran_Array1D(const Fortran_Array1D &A) \n"; +#endif + +} + + +template +Fortran_Array1D::Fortran_Array1D(int n) : v_(n), n_(n), data_(v_.begin()) +{ +#ifdef TNT_DEBUG + std::cout << "Created Fortran_Array1D(int n) \n"; +#endif +} + +template +Fortran_Array1D::Fortran_Array1D(int n, const T &val) : v_(n), n_(n), data_(v_.begin()) +{ +#ifdef TNT_DEBUG + std::cout << "Created Fortran_Array1D(int n, const T& val) \n"; +#endif + set_(data_, data_+ n, val); + +} + +template +Fortran_Array1D::Fortran_Array1D(int n, T *a) : v_(a), n_(n) , data_(v_.begin()) +{ +#ifdef TNT_DEBUG + std::cout << "Created Fortran_Array1D(int n, T* a) \n"; +#endif +} + +template +inline T& Fortran_Array1D::operator()(int i) +{ +#ifdef TNT_BOUNDS_CHECK + assert(i>= 1); + assert(i <= n_); +#endif + return data_[i-1]; +} + +template +inline const T& Fortran_Array1D::operator()(int i) const +{ +#ifdef TNT_BOUNDS_CHECK + assert(i>= 1); + assert(i <= n_); +#endif + return data_[i-1]; +} + + + + +template +Fortran_Array1D & Fortran_Array1D::operator=(const T &a) +{ + set_(data_, data_+n_, a); + return *this; +} + +template +Fortran_Array1D Fortran_Array1D::copy() const +{ + Fortran_Array1D A( n_); + copy_(A.data_, data_, n_); + + return A; +} + + +template +Fortran_Array1D & Fortran_Array1D::inject(const Fortran_Array1D &A) +{ + if (A.n_ == n_) + copy_(data_, A.data_, n_); + + return *this; +} + + + + + +template +Fortran_Array1D & Fortran_Array1D::ref(const Fortran_Array1D &A) +{ + if (this != &A) + { + v_ = A.v_; /* operator= handles the reference counting. */ + n_ = A.n_; + data_ = A.data_; + + } + return *this; +} + +template +Fortran_Array1D & Fortran_Array1D::operator=(const Fortran_Array1D &A) +{ + return ref(A); +} + +template +inline int Fortran_Array1D::dim1() const { return n_; } + +template +inline int Fortran_Array1D::dim() const { return n_; } + +template +Fortran_Array1D::~Fortran_Array1D() {} + + +/* ............................ exented interface ......................*/ + +template +inline int Fortran_Array1D::ref_count() const +{ + return v_.ref_count(); +} + +template +inline Fortran_Array1D Fortran_Array1D::subarray(int i0, int i1) +{ +#ifdef TNT_DEBUG + std::cout << "entered subarray. \n"; +#endif + if ((i0 > 0) && (i1 < n_) || (i0 <= i1)) + { + Fortran_Array1D X(*this); /* create a new instance of this array. */ + X.n_ = i1-i0+1; + X.data_ += i0; + + return X; + } + else + { +#ifdef TNT_DEBUG + std::cout << "subarray: null return.\n"; +#endif + return Fortran_Array1D(); + } +} + + +/* private internal functions */ + + +template +void Fortran_Array1D::set_(T* begin, T* end, const T& a) +{ + for (T* p=begin; p +void Fortran_Array1D::copy_(T* p, const T* q, int len) const +{ + T *end = p + len; + while (p + +namespace TNT +{ + + +/** + Write an array to a character outstream. Output format is one that can + be read back in via the in-stream operator: one integer + denoting the array dimension (n), followed by n elements, + one per line. + +*/ +template +std::ostream& operator<<(std::ostream &s, const Fortran_Array1D &A) +{ + int N=A.dim1(); + + s << N << "\n"; + for (int j=1; j<=N; j++) + { + s << A(j) << "\n"; + } + s << "\n"; + + return s; +} + +/** + Read an array from a character stream. Input format + is one integer, denoting the dimension (n), followed + by n whitespace-separated elments. Newlines are ignored + +

+ Note: the array being read into references new memory + storage. If the intent is to fill an existing conformant + array, use cin >> B; A.inject(B) ); + instead or read the elements in one-a-time by hand. + + @param s the charater to read from (typically std::in) + @param A the array to read into. +*/ +template +std::istream& operator>>(std::istream &s, Fortran_Array1D &A) +{ + int N; + s >> N; + + Fortran_Array1D B(N); + for (int i=1; i<=N; i++) + s >> B(i); + A = B; + return s; +} + + +template +Fortran_Array1D operator+(const Fortran_Array1D &A, const Fortran_Array1D &B) +{ + int n = A.dim1(); + + if (B.dim1() != n ) + return Fortran_Array1D(); + + else + { + Fortran_Array1D C(n); + + for (int i=1; i<=n; i++) + { + C(i) = A(i) + B(i); + } + return C; + } +} + + + +template +Fortran_Array1D operator-(const Fortran_Array1D &A, const Fortran_Array1D &B) +{ + int n = A.dim1(); + + if (B.dim1() != n ) + return Fortran_Array1D(); + + else + { + Fortran_Array1D C(n); + + for (int i=1; i<=n; i++) + { + C(i) = A(i) - B(i); + } + return C; + } +} + + +template +Fortran_Array1D operator*(const Fortran_Array1D &A, const Fortran_Array1D &B) +{ + int n = A.dim1(); + + if (B.dim1() != n ) + return Fortran_Array1D(); + + else + { + Fortran_Array1D C(n); + + for (int i=1; i<=n; i++) + { + C(i) = A(i) * B(i); + } + return C; + } +} + + +template +Fortran_Array1D operator/(const Fortran_Array1D &A, const Fortran_Array1D &B) +{ + int n = A.dim1(); + + if (B.dim1() != n ) + return Fortran_Array1D(); + + else + { + Fortran_Array1D C(n); + + for (int i=1; i<=n; i++) + { + C(i) = A(i) / B(i); + } + return C; + } +} + + + + + + + + + +template +Fortran_Array1D& operator+=(Fortran_Array1D &A, const Fortran_Array1D &B) +{ + int n = A.dim1(); + + if (B.dim1() == n) + { + for (int i=1; i<=n; i++) + { + A(i) += B(i); + } + } + return A; +} + + + + +template +Fortran_Array1D& operator-=(Fortran_Array1D &A, const Fortran_Array1D &B) +{ + int n = A.dim1(); + + if (B.dim1() == n) + { + for (int i=1; i<=n; i++) + { + A(i) -= B(i); + } + } + return A; +} + + + +template +Fortran_Array1D& operator*=(Fortran_Array1D &A, const Fortran_Array1D &B) +{ + int n = A.dim1(); + + if (B.dim1() == n) + { + for (int i=1; i<=n; i++) + { + A(i) *= B(i); + } + } + return A; +} + + + + +template +Fortran_Array1D& operator/=(Fortran_Array1D &A, const Fortran_Array1D &B) +{ + int n = A.dim1(); + + if (B.dim1() == n) + { + for (int i=1; i<=n; i++) + { + A(i) /= B(i); + } + } + return A; +} + + +} // namespace TNT + +#endif diff --git a/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_fortran_array2d.h b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_fortran_array2d.h new file mode 100755 index 000000000..f3075366d --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_fortran_array2d.h @@ -0,0 +1,225 @@ +/* +* +* Template Numerical Toolkit (TNT): Two-dimensional Fortran numerical array +* +* Mathematical and Computational Sciences Division +* National Institute of Technology, +* Gaithersburg, MD USA +* +* +* This software was developed at the National Institute of Standards and +* Technology (NIST) by employees of the Federal Government in the course +* of their official duties. Pursuant to title 17 Section 105 of the +* United States Code, this software is not subject to copyright protection +* and is in the public domain. NIST assumes no responsibility whatsoever for +* its use by other parties, and makes no guarantees, expressed or implied, +* about its quality, reliability, or any other characteristic. +* +*/ + + + +#ifndef TNT_FORTRAN_ARRAY2D_H +#define TNT_FORTRAN_ARRAY2D_H + +#include +#include + +#ifdef TNT_BOUNDS_CHECK +#include +#endif + +#include "tnt_i_refvec.h" + +namespace TNT +{ + +template +class Fortran_Array2D +{ + + + private: + i_refvec v_; + int m_; + int n_; + T* data_; + + + void initialize_(int n); + void copy_(T* p, const T* q, int len); + void set_(T* begin, T* end, const T& val); + + public: + + typedef T value_type; + + Fortran_Array2D(); + Fortran_Array2D(int m, int n); + Fortran_Array2D(int m, int n, T *a); + Fortran_Array2D(int m, int n, const T &a); + inline Fortran_Array2D(const Fortran_Array2D &A); + inline Fortran_Array2D & operator=(const T &a); + inline Fortran_Array2D & operator=(const Fortran_Array2D &A); + inline Fortran_Array2D & ref(const Fortran_Array2D &A); + Fortran_Array2D copy() const; + Fortran_Array2D & inject(const Fortran_Array2D & A); + inline T& operator()(int i, int j); + inline const T& operator()(int i, int j) const ; + inline int dim1() const; + inline int dim2() const; + ~Fortran_Array2D(); + + /* extended interface */ + + inline int ref_count() const; + +}; + +template +Fortran_Array2D::Fortran_Array2D() : v_(), m_(0), n_(0), data_(0) {} + + +template +Fortran_Array2D::Fortran_Array2D(const Fortran_Array2D &A) : v_(A.v_), + m_(A.m_), n_(A.n_), data_(A.data_) {} + + + +template +Fortran_Array2D::Fortran_Array2D(int m, int n) : v_(m*n), m_(m), n_(n), + data_(v_.begin()) {} + +template +Fortran_Array2D::Fortran_Array2D(int m, int n, const T &val) : + v_(m*n), m_(m), n_(n), data_(v_.begin()) +{ + set_(data_, data_+m*n, val); +} + + +template +Fortran_Array2D::Fortran_Array2D(int m, int n, T *a) : v_(a), + m_(m), n_(n), data_(v_.begin()) {} + + + + +template +inline T& Fortran_Array2D::operator()(int i, int j) +{ +#ifdef TNT_BOUNDS_CHECK + assert(i >= 1); + assert(i <= m_); + assert(j >= 1); + assert(j <= n_); +#endif + + return v_[ (j-1)*m_ + (i-1) ]; + +} + +template +inline const T& Fortran_Array2D::operator()(int i, int j) const +{ +#ifdef TNT_BOUNDS_CHECK + assert(i >= 1); + assert(i <= m_); + assert(j >= 1); + assert(j <= n_); +#endif + + return v_[ (j-1)*m_ + (i-1) ]; + +} + + +template +Fortran_Array2D & Fortran_Array2D::operator=(const T &a) +{ + set_(data_, data_+m_*n_, a); + return *this; +} + +template +Fortran_Array2D Fortran_Array2D::copy() const +{ + + Fortran_Array2D B(m_,n_); + + B.inject(*this); + return B; +} + + +template +Fortran_Array2D & Fortran_Array2D::inject(const Fortran_Array2D &A) +{ + if (m_ == A.m_ && n_ == A.n_) + copy_(data_, A.data_, m_*n_); + + return *this; +} + + + +template +Fortran_Array2D & Fortran_Array2D::ref(const Fortran_Array2D &A) +{ + if (this != &A) + { + v_ = A.v_; + m_ = A.m_; + n_ = A.n_; + data_ = A.data_; + } + return *this; +} + +template +Fortran_Array2D & Fortran_Array2D::operator=(const Fortran_Array2D &A) +{ + return ref(A); +} + +template +inline int Fortran_Array2D::dim1() const { return m_; } + +template +inline int Fortran_Array2D::dim2() const { return n_; } + + +template +Fortran_Array2D::~Fortran_Array2D() +{ +} + +template +inline int Fortran_Array2D::ref_count() const { return v_.ref_count(); } + + + + +template +void Fortran_Array2D::set_(T* begin, T* end, const T& a) +{ + for (T* p=begin; p +void Fortran_Array2D::copy_(T* p, const T* q, int len) +{ + T *end = p + len; + while (p + +namespace TNT +{ + + +template +std::ostream& operator<<(std::ostream &s, const Fortran_Array2D &A) +{ + int M=A.dim1(); + int N=A.dim2(); + + s << M << " " << N << "\n"; + + for (int i=1; i<=M; i++) + { + for (int j=1; j<=N; j++) + { + s << A(i,j) << " "; + } + s << "\n"; + } + + + return s; +} + +template +std::istream& operator>>(std::istream &s, Fortran_Array2D &A) +{ + + int M, N; + + s >> M >> N; + + Fortran_Array2D B(M,N); + + for (int i=1; i<=M; i++) + for (int j=1; j<=N; j++) + { + s >> B(i,j); + } + + A = B; + return s; +} + + + + +template +Fortran_Array2D operator+(const Fortran_Array2D &A, const Fortran_Array2D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + + if (B.dim1() != m || B.dim2() != n ) + return Fortran_Array2D(); + + else + { + Fortran_Array2D C(m,n); + + for (int i=1; i<=m; i++) + { + for (int j=1; j<=n; j++) + C(i,j) = A(i,j) + B(i,j); + } + return C; + } +} + +template +Fortran_Array2D operator-(const Fortran_Array2D &A, const Fortran_Array2D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + + if (B.dim1() != m || B.dim2() != n ) + return Fortran_Array2D(); + + else + { + Fortran_Array2D C(m,n); + + for (int i=1; i<=m; i++) + { + for (int j=1; j<=n; j++) + C(i,j) = A(i,j) - B(i,j); + } + return C; + } +} + + +template +Fortran_Array2D operator*(const Fortran_Array2D &A, const Fortran_Array2D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + + if (B.dim1() != m || B.dim2() != n ) + return Fortran_Array2D(); + + else + { + Fortran_Array2D C(m,n); + + for (int i=1; i<=m; i++) + { + for (int j=1; j<=n; j++) + C(i,j) = A(i,j) * B(i,j); + } + return C; + } +} + + +template +Fortran_Array2D operator/(const Fortran_Array2D &A, const Fortran_Array2D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + + if (B.dim1() != m || B.dim2() != n ) + return Fortran_Array2D(); + + else + { + Fortran_Array2D C(m,n); + + for (int i=1; i<=m; i++) + { + for (int j=1; j<=n; j++) + C(i,j) = A(i,j) / B(i,j); + } + return C; + } +} + + + +template +Fortran_Array2D& operator+=(Fortran_Array2D &A, const Fortran_Array2D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + + if (B.dim1() == m || B.dim2() == n ) + { + for (int i=1; i<=m; i++) + { + for (int j=1; j<=n; j++) + A(i,j) += B(i,j); + } + } + return A; +} + +template +Fortran_Array2D& operator-=(Fortran_Array2D &A, const Fortran_Array2D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + + if (B.dim1() == m || B.dim2() == n ) + { + for (int i=1; i<=m; i++) + { + for (int j=1; j<=n; j++) + A(i,j) -= B(i,j); + } + } + return A; +} + +template +Fortran_Array2D& operator*=(Fortran_Array2D &A, const Fortran_Array2D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + + if (B.dim1() == m || B.dim2() == n ) + { + for (int i=1; i<=m; i++) + { + for (int j=1; j<=n; j++) + A(i,j) *= B(i,j); + } + } + return A; +} + +template +Fortran_Array2D& operator/=(Fortran_Array2D &A, const Fortran_Array2D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + + if (B.dim1() == m || B.dim2() == n ) + { + for (int i=1; i<=m; i++) + { + for (int j=1; j<=n; j++) + A(i,j) /= B(i,j); + } + } + return A; +} + +} // namespace TNT + +#endif diff --git a/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_fortran_array3d.h b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_fortran_array3d.h new file mode 100755 index 000000000..e51affba4 --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_fortran_array3d.h @@ -0,0 +1,223 @@ +/* +* +* Template Numerical Toolkit (TNT): Three-dimensional Fortran numerical array +* +* Mathematical and Computational Sciences Division +* National Institute of Technology, +* Gaithersburg, MD USA +* +* +* This software was developed at the National Institute of Standards and +* Technology (NIST) by employees of the Federal Government in the course +* of their official duties. Pursuant to title 17 Section 105 of the +* United States Code, this software is not subject to copyright protection +* and is in the public domain. NIST assumes no responsibility whatsoever for +* its use by other parties, and makes no guarantees, expressed or implied, +* about its quality, reliability, or any other characteristic. +* +*/ + + + +#ifndef TNT_FORTRAN_ARRAY3D_H +#define TNT_FORTRAN_ARRAY3D_H + +#include +#include +#ifdef TNT_BOUNDS_CHECK +#include +#endif +#include "tnt_i_refvec.h" + +namespace TNT +{ + +template +class Fortran_Array3D +{ + + + private: + + + i_refvec v_; + int m_; + int n_; + int k_; + T* data_; + + public: + + typedef T value_type; + + Fortran_Array3D(); + Fortran_Array3D(int m, int n, int k); + Fortran_Array3D(int m, int n, int k, T *a); + Fortran_Array3D(int m, int n, int k, const T &a); + inline Fortran_Array3D(const Fortran_Array3D &A); + inline Fortran_Array3D & operator=(const T &a); + inline Fortran_Array3D & operator=(const Fortran_Array3D &A); + inline Fortran_Array3D & ref(const Fortran_Array3D &A); + Fortran_Array3D copy() const; + Fortran_Array3D & inject(const Fortran_Array3D & A); + inline T& operator()(int i, int j, int k); + inline const T& operator()(int i, int j, int k) const ; + inline int dim1() const; + inline int dim2() const; + inline int dim3() const; + inline int ref_count() const; + ~Fortran_Array3D(); + + +}; + +template +Fortran_Array3D::Fortran_Array3D() : v_(), m_(0), n_(0), k_(0), data_(0) {} + + +template +Fortran_Array3D::Fortran_Array3D(const Fortran_Array3D &A) : + v_(A.v_), m_(A.m_), n_(A.n_), k_(A.k_), data_(A.data_) {} + + + +template +Fortran_Array3D::Fortran_Array3D(int m, int n, int k) : + v_(m*n*k), m_(m), n_(n), k_(k), data_(v_.begin()) {} + + + +template +Fortran_Array3D::Fortran_Array3D(int m, int n, int k, const T &val) : + v_(m*n*k), m_(m), n_(n), k_(k), data_(v_.begin()) +{ + for (T* p = data_; p < data_ + m*n*k; p++) + *p = val; +} + +template +Fortran_Array3D::Fortran_Array3D(int m, int n, int k, T *a) : + v_(a), m_(m), n_(n), k_(k), data_(v_.begin()) {} + + + + +template +inline T& Fortran_Array3D::operator()(int i, int j, int k) +{ +#ifdef TNT_BOUNDS_CHECK + assert(i >= 1); + assert(i <= m_); + assert(j >= 1); + assert(j <= n_); + assert(k >= 1); + assert(k <= k_); +#endif + + return data_[(k-1)*m_*n_ + (j-1) * m_ + i-1]; + +} + +template +inline const T& Fortran_Array3D::operator()(int i, int j, int k) const +{ +#ifdef TNT_BOUNDS_CHECK + assert(i >= 1); + assert(i <= m_); + assert(j >= 1); + assert(j <= n_); + assert(k >= 1); + assert(k <= k_); +#endif + + return data_[(k-1)*m_*n_ + (j-1) * m_ + i-1]; +} + + +template +Fortran_Array3D & Fortran_Array3D::operator=(const T &a) +{ + + T *end = data_ + m_*n_*k_; + + for (T *p=data_; p != end; *p++ = a); + + return *this; +} + +template +Fortran_Array3D Fortran_Array3D::copy() const +{ + + Fortran_Array3D B(m_, n_, k_); + B.inject(*this); + return B; + +} + + +template +Fortran_Array3D & Fortran_Array3D::inject(const Fortran_Array3D &A) +{ + + if (m_ == A.m_ && n_ == A.n_ && k_ == A.k_) + { + T *p = data_; + T *end = data_ + m_*n_*k_; + const T* q = A.data_; + for (; p < end; *p++ = *q++); + } + return *this; +} + + + + +template +Fortran_Array3D & Fortran_Array3D::ref(const Fortran_Array3D &A) +{ + + if (this != &A) + { + v_ = A.v_; + m_ = A.m_; + n_ = A.n_; + k_ = A.k_; + data_ = A.data_; + } + return *this; +} + +template +Fortran_Array3D & Fortran_Array3D::operator=(const Fortran_Array3D &A) +{ + return ref(A); +} + +template +inline int Fortran_Array3D::dim1() const { return m_; } + +template +inline int Fortran_Array3D::dim2() const { return n_; } + +template +inline int Fortran_Array3D::dim3() const { return k_; } + + +template +inline int Fortran_Array3D::ref_count() const +{ + return v_.ref_count(); +} + +template +Fortran_Array3D::~Fortran_Array3D() +{ +} + + +} /* namespace TNT */ + +#endif +/* TNT_FORTRAN_ARRAY3D_H */ + diff --git a/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_fortran_array3d_utils.h b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_fortran_array3d_utils.h new file mode 100755 index 000000000..a13a275dc --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_fortran_array3d_utils.h @@ -0,0 +1,249 @@ +/* +* +* Template Numerical Toolkit (TNT) +* +* Mathematical and Computational Sciences Division +* National Institute of Technology, +* Gaithersburg, MD USA +* +* +* This software was developed at the National Institute of Standards and +* Technology (NIST) by employees of the Federal Government in the course +* of their official duties. Pursuant to title 17 Section 105 of the +* United States Code, this software is not subject to copyright protection +* and is in the public domain. NIST assumes no responsibility whatsoever for +* its use by other parties, and makes no guarantees, expressed or implied, +* about its quality, reliability, or any other characteristic. +* +*/ + + +#ifndef TNT_FORTRAN_ARRAY3D_UTILS_H +#define TNT_FORTRAN_ARRAY3D_UTILS_H + +#include +#include + +namespace TNT +{ + + +template +std::ostream& operator<<(std::ostream &s, const Fortran_Array3D &A) +{ + int M=A.dim1(); + int N=A.dim2(); + int K=A.dim3(); + + s << M << " " << N << " " << K << "\n"; + + for (int i=1; i<=M; i++) + { + for (int j=1; j<=N; j++) + { + for (int k=1; k<=K; k++) + s << A(i,j,k) << " "; + s << "\n"; + } + s << "\n"; + } + + + return s; +} + +template +std::istream& operator>>(std::istream &s, Fortran_Array3D &A) +{ + + int M, N, K; + + s >> M >> N >> K; + + Fortran_Array3D B(M,N,K); + + for (int i=1; i<=M; i++) + for (int j=1; j<=N; j++) + for (int k=1; k<=K; k++) + s >> B(i,j,k); + + A = B; + return s; +} + + +template +Fortran_Array3D operator+(const Fortran_Array3D &A, const Fortran_Array3D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + int p = A.dim3(); + + if (B.dim1() != m || B.dim2() != n || B.dim3() != p ) + return Fortran_Array3D(); + + else + { + Fortran_Array3D C(m,n,p); + + for (int i=1; i<=m; i++) + for (int j=1; j<=n; j++) + for (int k=1; k<=p; k++) + C(i,j,k) = A(i,j,k)+ B(i,j,k); + + return C; + } +} + + +template +Fortran_Array3D operator-(const Fortran_Array3D &A, const Fortran_Array3D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + int p = A.dim3(); + + if (B.dim1() != m || B.dim2() != n || B.dim3() != p ) + return Fortran_Array3D(); + + else + { + Fortran_Array3D C(m,n,p); + + for (int i=1; i<=m; i++) + for (int j=1; j<=n; j++) + for (int k=1; k<=p; k++) + C(i,j,k) = A(i,j,k)- B(i,j,k); + + return C; + } +} + + +template +Fortran_Array3D operator*(const Fortran_Array3D &A, const Fortran_Array3D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + int p = A.dim3(); + + if (B.dim1() != m || B.dim2() != n || B.dim3() != p ) + return Fortran_Array3D(); + + else + { + Fortran_Array3D C(m,n,p); + + for (int i=1; i<=m; i++) + for (int j=1; j<=n; j++) + for (int k=1; k<=p; k++) + C(i,j,k) = A(i,j,k)* B(i,j,k); + + return C; + } +} + + +template +Fortran_Array3D operator/(const Fortran_Array3D &A, const Fortran_Array3D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + int p = A.dim3(); + + if (B.dim1() != m || B.dim2() != n || B.dim3() != p ) + return Fortran_Array3D(); + + else + { + Fortran_Array3D C(m,n,p); + + for (int i=1; i<=m; i++) + for (int j=1; j<=n; j++) + for (int k=1; k<=p; k++) + C(i,j,k) = A(i,j,k)/ B(i,j,k); + + return C; + } +} + + +template +Fortran_Array3D& operator+=(Fortran_Array3D &A, const Fortran_Array3D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + int p = A.dim3(); + + if (B.dim1() == m && B.dim2() == n && B.dim3() == p ) + { + for (int i=1; i<=m; i++) + for (int j=1; j<=n; j++) + for (int k=1; k<=p; k++) + A(i,j,k) += B(i,j,k); + } + + return A; +} + + +template +Fortran_Array3D& operator-=(Fortran_Array3D &A, const Fortran_Array3D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + int p = A.dim3(); + + if (B.dim1() == m && B.dim2() == n && B.dim3() == p ) + { + for (int i=1; i<=m; i++) + for (int j=1; j<=n; j++) + for (int k=1; k<=p; k++) + A(i,j,k) -= B(i,j,k); + } + + return A; +} + + +template +Fortran_Array3D& operator*=(Fortran_Array3D &A, const Fortran_Array3D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + int p = A.dim3(); + + if (B.dim1() == m && B.dim2() == n && B.dim3() == p ) + { + for (int i=1; i<=m; i++) + for (int j=1; j<=n; j++) + for (int k=1; k<=p; k++) + A(i,j,k) *= B(i,j,k); + } + + return A; +} + + +template +Fortran_Array3D& operator/=(Fortran_Array3D &A, const Fortran_Array3D &B) +{ + int m = A.dim1(); + int n = A.dim2(); + int p = A.dim3(); + + if (B.dim1() == m && B.dim2() == n && B.dim3() == p ) + { + for (int i=1; i<=m; i++) + for (int j=1; j<=n; j++) + for (int k=1; k<=p; k++) + A(i,j,k) /= B(i,j,k); + } + + return A; +} + + +} // namespace TNT + +#endif diff --git a/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_i_refvec.h b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_i_refvec.h new file mode 100755 index 000000000..5a67eb578 --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_i_refvec.h @@ -0,0 +1,243 @@ +/* +* +* Template Numerical Toolkit (TNT) +* +* Mathematical and Computational Sciences Division +* National Institute of Technology, +* Gaithersburg, MD USA +* +* +* This software was developed at the National Institute of Standards and +* Technology (NIST) by employees of the Federal Government in the course +* of their official duties. Pursuant to title 17 Section 105 of the +* United States Code, this software is not subject to copyright protection +* and is in the public domain. NIST assumes no responsibility whatsoever for +* its use by other parties, and makes no guarantees, expressed or implied, +* about its quality, reliability, or any other characteristic. +* +*/ + + + +#ifndef TNT_I_REFVEC_H +#define TNT_I_REFVEC_H + +#include +#include + +#ifdef TNT_BOUNDS_CHECK +#include +#endif + +#ifndef NULL +#define NULL 0 +#endif + +namespace TNT +{ +/* + Internal representation of ref-counted array. The TNT + arrays all use this building block. + +

+ If an array block is created by TNT, then every time + an assignment is made, the left-hand-side reference + is decreased by one, and the right-hand-side refernce + count is increased by one. If the array block was + external to TNT, the refernce count is a NULL pointer + regardless of how many references are made, since the + memory is not freed by TNT. + + + +*/ +template +class i_refvec +{ + + + private: + T* data_; + int *ref_count_; + + + public: + + i_refvec(); + explicit i_refvec(int n); + inline i_refvec(T* data); + inline i_refvec(const i_refvec &v); + inline T* begin(); + inline const T* begin() const; + inline T& operator[](int i); + inline const T& operator[](int i) const; + inline i_refvec & operator=(const i_refvec &V); + void copy_(T* p, const T* q, const T* e); + void set_(T* p, const T* b, const T* e); + inline int ref_count() const; + inline int is_null() const; + inline void destroy(); + ~i_refvec(); + +}; + +template +void i_refvec::copy_(T* p, const T* q, const T* e) +{ + for (T* t=p; q +i_refvec::i_refvec() : data_(NULL), ref_count_(NULL) {} + +/** + In case n is 0 or negative, it does NOT call new. +*/ +template +i_refvec::i_refvec(int n) : data_(NULL), ref_count_(NULL) +{ + if (n >= 1) + { +#ifdef TNT_DEBUG + std::cout << "new data storage.\n"; +#endif + data_ = new T[n]; + ref_count_ = new int; + *ref_count_ = 1; + } +} + +template +inline i_refvec::i_refvec(const i_refvec &V): data_(V.data_), + ref_count_(V.ref_count_) +{ + if (V.ref_count_ != NULL) + (*(V.ref_count_))++; +} + + +template +i_refvec::i_refvec(T* data) : data_(data), ref_count_(NULL) {} + +template +inline T* i_refvec::begin() +{ + return data_; +} + +template +inline const T& i_refvec::operator[](int i) const +{ + return data_[i]; +} + +template +inline T& i_refvec::operator[](int i) +{ + return data_[i]; +} + + +template +inline const T* i_refvec::begin() const +{ + return data_; +} + + + +template +i_refvec & i_refvec::operator=(const i_refvec &V) +{ + if (this == &V) + return *this; + + + if (ref_count_ != NULL) + { + (*ref_count_) --; + if ((*ref_count_) == 0) + destroy(); + } + + data_ = V.data_; + ref_count_ = V.ref_count_; + + if (V.ref_count_ != NULL) + (*(V.ref_count_))++; + + return *this; +} + +template +void i_refvec::destroy() +{ + if (ref_count_ != NULL) + { +#ifdef TNT_DEBUG + std::cout << "destorying data... \n"; +#endif + delete ref_count_; + +#ifdef TNT_DEBUG + std::cout << "deleted ref_count_ ...\n"; +#endif + if (data_ != NULL) + delete []data_; +#ifdef TNT_DEBUG + std::cout << "deleted data_[] ...\n"; +#endif + data_ = NULL; + } +} + +/* +* return 1 is vector is empty, 0 otherwise +* +* if is_null() is false and ref_count() is 0, then +* +*/ +template +int i_refvec::is_null() const +{ + return (data_ == NULL ? 1 : 0); +} + +/* +* returns -1 if data is external, +* returns 0 if a is NULL array, +* otherwise returns the positive number of vectors sharing +* this data space. +*/ +template +int i_refvec::ref_count() const +{ + if (data_ == NULL) + return 0; + else + return (ref_count_ != NULL ? *ref_count_ : -1) ; +} + +template +i_refvec::~i_refvec() +{ + if (ref_count_ != NULL) + { + (*ref_count_)--; + + if (*ref_count_ == 0) + destroy(); + } +} + + +} /* namespace TNT */ + + + + + +#endif +/* TNT_I_REFVEC_H */ + diff --git a/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_math_utils.h b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_math_utils.h new file mode 100755 index 000000000..f9c1c91ee --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_math_utils.h @@ -0,0 +1,34 @@ +#ifndef MATH_UTILS_H +#define MATH_UTILS_H + +/* needed for fabs, sqrt() below */ +#include + + + +namespace TNT +{ +/** + @returns hypotenuse of real (non-complex) scalars a and b by + avoiding underflow/overflow + using (a * sqrt( 1 + (b/a) * (b/a))), rather than + sqrt(a*a + b*b). +*/ +template +Real hypot(const Real &a, const Real &b) +{ + + if (a== 0) + return abs(b); + else + { + Real c = b/a; + return fabs(a) * sqrt(1 + c*c); + } +} +} /* TNT namespace */ + + + +#endif +/* MATH_UTILS_H */ diff --git a/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_sparse_matrix_csr.h b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_sparse_matrix_csr.h new file mode 100755 index 000000000..0d4fde1c2 --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_sparse_matrix_csr.h @@ -0,0 +1,103 @@ +/* +* +* Template Numerical Toolkit (TNT) +* +* Mathematical and Computational Sciences Division +* National Institute of Technology, +* Gaithersburg, MD USA +* +* +* This software was developed at the National Institute of Standards and +* Technology (NIST) by employees of the Federal Government in the course +* of their official duties. Pursuant to title 17 Section 105 of the +* United States Code, this software is not subject to copyright protection +* and is in the public domain. NIST assumes no responsibility whatsoever for +* its use by other parties, and makes no guarantees, expressed or implied, +* about its quality, reliability, or any other characteristic. +* +*/ + + +#ifndef TNT_SPARSE_MATRIX_CSR_H +#define TNT_SPARSE_MATRIX_CSR_H + +#include "tnt_array1d.h" + +namespace TNT +{ + + +/** + Read-only view of a sparse matrix in compressed-row storage + format. Neither array elements (nonzeros) nor sparsity + structure can be modified. If modifications are required, + create a new view. + +

+ Index values begin at 0. + +

+ Storage requirements: An (m x n) matrix with + nz nonzeros requires no more than ((T+I)*nz + M*I) + bytes, where T is the size of data elements and + I is the size of integers. + + +*/ +template +class Sparse_Matrix_CompRow { + +private: + Array1D val_; // data values (nz_ elements) + Array1D rowptr_; // row_ptr (dim_[0]+1 elements) + Array1D colind_; // col_ind (nz_ elements) + + int dim1_; // number of rows + int dim2_; // number of cols + +public: + + Sparse_Matrix_CompRow(const Sparse_Matrix_CompRow &S); + Sparse_Matrix_CompRow(int M, int N, int nz, const T *val, + const int *r, const int *c); + + + + inline const T& val(int i) const { return val_[i]; } + inline const int& row_ptr(int i) const { return rowptr_[i]; } + inline const int& col_ind(int i) const { return colind_[i];} + + inline int dim1() const {return dim1_;} + inline int dim2() const {return dim2_;} + int NumNonzeros() const {return val_.dim1();} + + + Sparse_Matrix_CompRow& operator=( + const Sparse_Matrix_CompRow &R); + + + +}; + +/** + Construct a read-only view of existing sparse matrix in + compressed-row storage format. + + @param M the number of rows of sparse matrix + @param N the number of columns of sparse matrix + @param nz the number of nonzeros + @param val a contiguous list of nonzero values + @param r row-pointers: r[i] denotes the begining position of row i + (i.e. the ith row begins at val[row[i]]). + @param c column-indices: c[i] denotes the column location of val[i] +*/ +template +Sparse_Matrix_CompRow::Sparse_Matrix_CompRow(int M, int N, int nz, + const T *val, const int *r, const int *c) : val_(nz,val), + rowptr_(M, r), colind_(nz, c), dim1_(M), dim2_(N) {} + + +} +// namespace TNT + +#endif diff --git a/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_stopwatch.h b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_stopwatch.h new file mode 100755 index 000000000..8dc5d23ac --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_stopwatch.h @@ -0,0 +1,95 @@ +/* +* +* Mathematical and Computational Sciences Division +* National Institute of Technology, +* Gaithersburg, MD USA +* +* +* This software was developed at the National Institute of Standards and +* Technology (NIST) by employees of the Federal Government in the course +* of their official duties. Pursuant to title 17 Section 105 of the +* United States Code, this software is not subject to copyright protection +* and is in the public domain. NIST assumes no responsibility whatsoever for +* its use by other parties, and makes no guarantees, expressed or implied, +* about its quality, reliability, or any other characteristic. +* +*/ + + + +#ifndef STOPWATCH_H +#define STOPWATCH_H + +// for clock() and CLOCKS_PER_SEC +#include + + +namespace TNT +{ + +inline static double seconds(void) +{ + const double secs_per_tick = 1.0 / CLOCKS_PER_SEC; + return ( (double) clock() ) * secs_per_tick; +} + +class Stopwatch { + private: + int running_; + double start_time_; + double total_; + + public: + inline Stopwatch(); + inline void start(); + inline double stop(); + inline double read(); + inline void resume(); + inline int running(); +}; + +inline Stopwatch::Stopwatch() : running_(0), start_time_(0.0), total_(0.0) {} + +void Stopwatch::start() +{ + running_ = 1; + total_ = 0.0; + start_time_ = seconds(); +} + +double Stopwatch::stop() +{ + if (running_) + { + total_ += (seconds() - start_time_); + running_ = 0; + } + return total_; +} + +inline void Stopwatch::resume() +{ + if (!running_) + { + start_time_ = seconds(); + running_ = 1; + } +} + + +inline double Stopwatch::read() +{ + if (running_) + { + stop(); + resume(); + } + return total_; +} + + +} /* TNT namespace */ +#endif + + + diff --git a/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_subscript.h b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_subscript.h new file mode 100755 index 000000000..d8fe1200e --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_subscript.h @@ -0,0 +1,54 @@ +/* +* +* Template Numerical Toolkit (TNT) +* +* Mathematical and Computational Sciences Division +* National Institute of Technology, +* Gaithersburg, MD USA +* +* +* This software was developed at the National Institute of Standards and +* Technology (NIST) by employees of the Federal Government in the course +* of their official duties. Pursuant to title 17 Section 105 of the +* United States Code, this software is not subject to copyright protection +* and is in the public domain. NIST assumes no responsibility whatsoever for +* its use by other parties, and makes no guarantees, expressed or implied, +* about its quality, reliability, or any other characteristic. +* +*/ + + +#ifndef TNT_SUBSCRPT_H +#define TNT_SUBSCRPT_H + + +//--------------------------------------------------------------------- +// This definition describes the default TNT data type used for +// indexing into TNT matrices and vectors. The data type should +// be wide enough to index into large arrays. It defaults to an +// "int", but can be overriden at compile time redefining TNT_SUBSCRIPT_TYPE, +// e.g. +// +// c++ -DTNT_SUBSCRIPT_TYPE='unsigned int' ... +// +//--------------------------------------------------------------------- +// + +#ifndef TNT_SUBSCRIPT_TYPE +#define TNT_SUBSCRIPT_TYPE int +#endif + +namespace TNT +{ + typedef TNT_SUBSCRIPT_TYPE Subscript; +} /* namespace TNT */ + + +// () indexing in TNT means 1-offset, i.e. x(1) and A(1,1) are the +// first elements. This offset is left as a macro for future +// purposes, but should not be changed in the current release. +// +// +#define TNT_BASE_OFFSET (1) + +#endif diff --git a/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_vec.h b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_vec.h new file mode 100755 index 000000000..3455d79cc --- /dev/null +++ b/sas/sascalc/simulation/iqPy/libiqPy/tnt/tnt_vec.h @@ -0,0 +1,404 @@ +/* +* +* Template Numerical Toolkit (TNT) +* +* Mathematical and Computational Sciences Division +* National Institute of Technology, +* Gaithersburg, MD USA +* +* +* This software was developed at the National Institute of Standards and +* Technology (NIST) by employees of the Federal Government in the course +* of their official duties. Pursuant to title 17 Section 105 of the +* United States Code, this software is not subject to copyright protection +* and is in the public domain. NIST assumes no responsibility whatsoever for +* its use by other parties, and makes no guarantees, expressed or implied, +* about its quality, reliability, or any other characteristic. +* +*/ + + + +#ifndef TNT_VEC_H +#define TNT_VEC_H + +#include "tnt_subscript.h" +#include +#include +#include +#include + +namespace TNT +{ + +/** + [Deprecatred] Value-based vector class from pre-1.0 + TNT version. Kept here for backward compatiblity, but should + use the newer TNT::Array1D classes instead. + +*/ + +template +class Vector +{ + + + public: + + typedef Subscript size_type; + typedef T value_type; + typedef T element_type; + typedef T* pointer; + typedef T* iterator; + typedef T& reference; + typedef const T* const_iterator; + typedef const T& const_reference; + + Subscript lbound() const { return 1;} + + protected: + T* v_; + T* vm1_; // pointer adjustment for optimzied 1-offset indexing + Subscript n_; + + // internal helper function to create the array + // of row pointers + + void initialize(Subscript N) + { + // adjust pointers so that they are 1-offset: + // v_[] is the internal contiguous array, it is still 0-offset + // + assert(v_ == NULL); + v_ = new T[N]; + assert(v_ != NULL); + vm1_ = v_-1; + n_ = N; + } + + void copy(const T* v) + { + Subscript N = n_; + Subscript i; + +#ifdef TNT_UNROLL_LOOPS + Subscript Nmod4 = N & 3; + Subscript N4 = N - Nmod4; + + for (i=0; i &A) : v_(0), vm1_(0), n_(0) + { + initialize(A.n_); + copy(A.v_); + } + + Vector(Subscript N, const T& value = T()) : v_(0), vm1_(0), n_(0) + { + initialize(N); + set(value); + } + + Vector(Subscript N, const T* v) : v_(0), vm1_(0), n_(0) + { + initialize(N); + copy(v); + } + + Vector(Subscript N, char *s) : v_(0), vm1_(0), n_(0) + { + initialize(N); + std::istringstream ins(s); + + Subscript i; + + for (i=0; i> v_[i]; + } + + + // methods + // + Vector& newsize(Subscript N) + { + if (n_ == N) return *this; + + destroy(); + initialize(N); + + return *this; + } + + + // assignments + // + Vector& operator=(const Vector &A) + { + if (v_ == A.v_) + return *this; + + if (n_ == A.n_) // no need to re-alloc + copy(A.v_); + + else + { + destroy(); + initialize(A.n_); + copy(A.v_); + } + + return *this; + } + + Vector& operator=(const T& scalar) + { + set(scalar); + return *this; + } + + inline Subscript dim() const + { + return n_; + } + + inline Subscript size() const + { + return n_; + } + + + inline reference operator()(Subscript i) + { +#ifdef TNT_BOUNDS_CHECK + assert(1<=i); + assert(i <= n_) ; +#endif + return vm1_[i]; + } + + inline const_reference operator() (Subscript i) const + { +#ifdef TNT_BOUNDS_CHECK + assert(1<=i); + assert(i <= n_) ; +#endif + return vm1_[i]; + } + + inline reference operator[](Subscript i) + { +#ifdef TNT_BOUNDS_CHECK + assert(0<=i); + assert(i < n_) ; +#endif + return v_[i]; + } + + inline const_reference operator[](Subscript i) const + { +#ifdef TNT_BOUNDS_CHECK + assert(0<=i); + + + + + + + assert(i < n_) ; +#endif + return v_[i]; + } + + + +}; + + +/* *************************** I/O ********************************/ + +template +std::ostream& operator<<(std::ostream &s, const Vector &A) +{ + Subscript N=A.dim(); + + s << N << "\n"; + + for (Subscript i=0; i +std::istream & operator>>(std::istream &s, Vector &A) +{ + + Subscript N; + + s >> N; + + if ( !(N == A.size() )) + { + A.newsize(N); + } + + + for (Subscript i=0; i> A[i]; + + + return s; +} + +// *******************[ basic matrix algorithms ]*************************** + + +template +Vector operator+(const Vector &A, + const Vector &B) +{ + Subscript N = A.dim(); + + assert(N==B.dim()); + + Vector tmp(N); + Subscript i; + + for (i=0; i +Vector operator-(const Vector &A, + const Vector &B) +{ + Subscript N = A.dim(); + + assert(N==B.dim()); + + Vector tmp(N); + Subscript i; + + for (i=0; i +Vector operator*(const Vector &A, + const Vector &B) +{ + Subscript N = A.dim(); + + assert(N==B.dim()); + + Vector tmp(N); + Subscript i; + + for (i=0; i +T dot_prod(const Vector &A, const Vector &B) +{ + Subscript N = A.dim(); + assert(N == B.dim()); + + Subscript i; + T sum = 0; + + for (i=0; i +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +PROJECT = iqPy +PACKAGE = tests + +PROJ_CLEAN += $(PROJ_CPPTESTS) + +PROJ_PYTESTS = testiq.py +PROJ_CPPTESTS = testiq +PROJ_TESTS = $(PROJ_PYTESTS) $(PROJ_CPPTESTS) +PROJ_LIBRARIES = -L$(BLD_LIBDIR) -liqPy +PROJ_CXX_INCLUDES = ../libiqPy/tnt \ + ../libiqPy + +#-------------------------------------------------------------------------- +# + +all: $(PROJ_TESTS) + +test: + for test in $(PROJ_TESTS) ; do $${test}; done + +release: tidy + cvs release . + +update: clean + cvs update . + +#-------------------------------------------------------------------------- +# + +testiq: testiq.cc $(BLD_LIBDIR)/libiqPy.$(EXT_SAR) + $(CXX) $(CXXFLAGS) $(LCXXFLAGS) -o $@ testiq.cc $(PROJ_LIBRARIES) + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/iqPy/tests/signon.py b/sas/sascalc/simulation/iqPy/tests/signon.py new file mode 100755 index 000000000..10541433c --- /dev/null +++ b/sas/sascalc/simulation/iqPy/tests/signon.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +from __future__ import print_function + + +if __name__ == "__main__": + + import iqPy + from iqPy import iqPy as iqPymodule + + print("copyright information:") + print(" ", iqPy.copyright()) + print(" ", iqPymodule.copyright()) + + print() + print("module information:") + print(" file:", iqPymodule.__file__) + print(" doc:", iqPymodule.__doc__) + print(" contents:", dir(iqPymodule)) + + print() + +# version +__id__ = "$Id$" + +# End of file diff --git a/sas/sascalc/simulation/iqPy/tests/testiq.cc b/sas/sascalc/simulation/iqPy/tests/testiq.cc new file mode 100755 index 000000000..fc707ab0d --- /dev/null +++ b/sas/sascalc/simulation/iqPy/tests/testiq.cc @@ -0,0 +1,34 @@ +#include "iq.h" +#include + +using namespace std; + +int main(){ + + cout << " Generating a empty iq 2D array with size of 10" << endl; + IQ iq1(10); + + for (int i = 0; i< iq1.iq_data.dim1(); i++) + cout << iq1.iq_data[i][0]<< " " << iq1.iq_data[i][0] < +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +if __name__ == "__main__": + + from SASsimulation import iqPy + iqPy.new_iq(10,0.01,0.4) + + print("pass.") + +# version +__id__ = "$Id$" + +# End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/Make.mm b/sas/sascalc/simulation/pointsmodelpy/Make.mm new file mode 100755 index 000000000..59bd44356 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/Make.mm @@ -0,0 +1,47 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +PROJECT = pointsmodelpy + +# directory structure + +BUILD_DIRS = \ + libpointsmodelpy \ + pointsmodelpymodule \ + pointsmodelpy \ + +OTHER_DIRS = \ + tests \ + examples + +RECURSE_DIRS = $(BUILD_DIRS) $(OTHER_DIRS) + +#-------------------------------------------------------------------------- +# + +all: + BLD_ACTION="all" $(MM) recurse + +distclean:: + BLD_ACTION="distclean" $(MM) recurse + +clean:: + BLD_ACTION="clean" $(MM) recurse + +tidy:: + BLD_ACTION="tidy" $(MM) recurse + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/examples/Make.mm b/sas/sascalc/simulation/pointsmodelpy/examples/Make.mm new file mode 100755 index 000000000..e51e71bec --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/examples/Make.mm @@ -0,0 +1,32 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +PROJECT = pointsmodelpy +PACKAGE = examples + +#-------------------------------------------------------------------------- +# + +all: clean + +release: clean + cvs release . + +update: clean + cvs update . + + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/Make.mm b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/Make.mm new file mode 100755 index 000000000..8b93407aa --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/Make.mm @@ -0,0 +1,69 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +include local.def + +PROJECT = pointsmodelpy +PACKAGE = libpointsmodelpy + +PROJ_SAR = $(BLD_LIBDIR)/$(PACKAGE).$(EXT_SAR) +PROJ_DLL = $(BLD_BINDIR)/$(PACKAGE).$(EXT_SO) +PROJ_TMPDIR = $(BLD_TMPDIR)/$(PROJECT)/$(PACKAGE) +PROJ_CLEAN += $(PROJ_SAR) $(PROJ_DLL) + +PROJ_SRCS = \ + lores_model.cc \ + points_model.cc \ + pdb_model.cc \ + complex_model.cc + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# build the library + +all: $(PROJ_SAR) export + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +ifeq (Win32, ${findstring Win32, $(PLATFORM_ID)}) + +# build the shared object +$(PROJ_SAR): product_dirs $(PROJ_OBJS) + $(CXX) $(LCXXFLAGS) -o $(PROJ_DLL) \ + -Wl,--out-implib=$(PROJ_SAR) $(PROJ_OBJS) + +# export +export:: export-headers export-libraries export-binaries + +else + +# build the shared object +$(PROJ_SAR): product_dirs $(PROJ_OBJS) + $(CXX) $(LCXXFLAGS) -o $(PROJ_SAR) $(PROJ_OBJS) -L$(BLD_LIBDIR) + +# export +export:: export-headers export-libraries + +endif + +EXPORT_HEADERS = \ + lores_model.h \ + points_model.h \ + complex_model.h + +EXPORT_LIBS = $(PROJ_SAR) +EXPORT_BINS = $(PROJ_DLL) + + +# version +# $Id$ + +# +# End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/complex_model.cc b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/complex_model.cc new file mode 100755 index 000000000..8a3645d5c --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/complex_model.cc @@ -0,0 +1,97 @@ +/** \file complex_model.cc */ + +#include "complex_model.h" +#include +#include + +using namespace std; + +ComplexModel::ComplexModel(){ + +} + +void ComplexModel::Add(PointsModel *pm) +{ + models_.push_back(pm); +} + +double ComplexModel::GetDimBound() +{ + //Get the vector of centers of pointsmodel instances + //and a vector of individual boundary + vector::iterator itr; + vector veccenter; + vector bounds; + + for(itr = models_.begin(); itr != models_.end(); ++itr){ + if ((*itr)->GetDimBound() != 0){ + Point3D pp((*itr)->GetCenter()[0],(*itr)->GetCenter()[1],(*itr)->GetCenter()[2]); + veccenter.push_back(pp); + bounds.push_back((*itr)->GetDimBound()); + } + } + + //max bound + double maxbound = *max_element(bounds.begin(),bounds.end()); + + //max distance + vector vecdist; + size_t num = veccenter.size(); + + if (num > 1){ + for (size_t i = 0; i != num; ++i){ + for (size_t j= 1; j != num; ++j){ + double dist = veccenter[i].distanceToPoint(veccenter[j]); + vecdist.push_back(dist); + } + } + } + else{ + vecdist.push_back(maxbound); + } + + double maxcenterdist = *max_element(vecdist.begin(),vecdist.end()); + + double finalbound = maxbound + maxcenterdist; + + return finalbound; +} + +vector ComplexModel::GetCenter() +{ + double sumx = 0, sumy = 0, sumz = 0; + size_t num = 0; + + vector::iterator itr; + for(itr = models_.begin(); itr != models_.end(); ++itr){ + sumx += (*itr)->GetCenter()[0]; + sumy += (*itr)->GetCenter()[1]; + sumz += (*itr)->GetCenter()[2]; + ++num; + } + + vector v(3); + v[0] = sumx/num; + v[1] = sumy/num; + v[2] = sumz/num; + center_ = v; + return center_; +} + +int ComplexModel::GetPoints(Point3DVector &vp) +{ + if (vp.size() != 0){ + throw runtime_error("GetPoints(Point3DVector &VP):VP has to be empty"); + } + + vector::iterator itr; + + for(itr = models_.begin(); itr != models_.end(); ++itr){ + vector temp; + (*itr)->GetPoints(temp); + if (temp.size() != 0){ + vp.insert(vp.end(),temp.begin(),temp.end()); + } + } + return vp.size(); +} diff --git a/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/complex_model.h b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/complex_model.h new file mode 100755 index 000000000..db1a64ed7 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/complex_model.h @@ -0,0 +1,36 @@ +/** \file complex_model.h child class of PointsModel */ + +#ifndef COMPLEX_MODEL_H +#define COMPLEX_MODEL_H + +#include "points_model.h" +#include +#include + +/** + * Class ComplexModel : container class for LORESModel & PDBModel + * The main functionality is to merge points from instances of + * LORESModel & PDBModel + */ + +class ComplexModel : public PointsModel { + public: + ComplexModel(); + + //add PointsModel instance + void Add(PointsModel *); + + //Parse all coordinates from ATOM section + //of the PDB file into vector of points + int GetPoints(Point3DVector &); + + //Get distance boundary for the pointsmodel instances + double GetDimBound(); + + vector GetCenter(); + + private: + vector models_; +}; + +#endif diff --git a/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/local.def b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/local.def new file mode 100755 index 000000000..e2782e7c7 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/local.def @@ -0,0 +1,32 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +# +# Local (project) definitions +# + +# C++ + +PROJ_CXX_INCLUDES += ../../iqPy/libiqPy \ + ../../iqPy/libiqPy/tnt \ + ../../geoshapespy/libgeoshapespy \ + ../../analmodelpy/libanalmodelpy \ + ./ + + PROJ_CXX_FLAGS += $(CXX_SOFLAGS) + PROJ_LCXX_FLAGS += $(LCXX_SARFLAGS) -liqPy -lgeoshapespy + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/lores_model.cc b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/lores_model.cc new file mode 100755 index 000000000..0afde6f2e --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/lores_model.cc @@ -0,0 +1,237 @@ +/** \file lores_model.cc */ + +#include +#include +#include "lores_model.h" +#include "sphere.h" +#include "hollow_sphere.h" +#include "cylinder.h" +#include "ellipsoid.h" +#include "single_helix.h" +#include "Point3D.h" + +LORESModel::LORESModel(double density) +{ + density_ = density; +} + +GeoShape* LORESModel::GetGeoShape(GeoShape& geo_shape) +{ + GeoShape* shape = NULL; + + switch (geo_shape.GetShapeType()){ + case SPHERE: + shape = new Sphere(static_cast(geo_shape)); + break; + case HOLLOWSPHERE: + shape = new HollowSphere(static_cast(geo_shape)); + break; + case CYLINDER: + shape = new Cylinder(static_cast(geo_shape)); + break; + case ELLIPSOID: + shape = new Ellipsoid(static_cast(geo_shape)); + break; + case SINGLEHELIX: + shape = new SingleHelix(static_cast(geo_shape)); + break; + + } + + return shape; +} + +LORESModel::~LORESModel() +{ + for (RealSpaceShapeCollection::iterator it = shapes_.begin(); + it != shapes_.end(); ++it) { + delete *it; + } +} + +void LORESModel::SetDensity(double density) +{ + density_ = density; +} + +double LORESModel::GetDensity() +{ + return density_; +} + +// Add a shape into LORES Model +void LORESModel::Add(GeoShape& geo_shape, + double sld) { + GeoShape* shape = GetGeoShape(geo_shape); + assert(shape != NULL); + + RealSpaceShape* real_shape = new RealSpaceShape(shape); + FillPoints(real_shape, sld); + shapes_.push_back(real_shape); +} + +// Delete ith shape in shapes_ list +// If we have 3 shapes in our model, the index starts +// from 0, which means we need to call Delete(0) to delete +// the first shape, and call Delete(1) to delete the second +// shape, etc.. +void LORESModel::Delete(size_t i) { + if (i >= shapes_.size()) { + std::cerr << "Delete shapes out of scope" << std::endl; + return; + } + + RealSpaceShape* real_shape = shapes_[i]; + if (i + 1 < shapes_.size()) { + // if it is not the last shape, we have to distribute its points + // to the shapes after i-th shape. + const Point3DVector& points = real_shape->points; + for (Point3DVector::const_iterator pit = points.begin(); + pit != points.end(); ++pit) + DistributePoint(*pit, i + 1); + } + + shapes_.erase(shapes_.begin() + i); + delete real_shape; +} + +// Get the points in the realspaceshapecollection +int LORESModel::GetPoints(Point3DVector &pp) +{ + if (pp.size() != 0){ + throw runtime_error("GetPoints(Point3DVector &VP):VP has to be empty"); + } + + if (shapes_.size() != 0){ + for (size_t j = 0; j points.begin(),shapes_[j]->points.end()); + } + } + return pp.size(); +} + +//Write points to a file, mainly for testing right now +void LORESModel::WritePoints2File(Point3DVector &vp){ + ofstream outfile("test.coor"); + for(size_t i=0; i vec_center; + for (size_t m = 0; m < shapes_.size(); ++m){ + assert(shapes_[m]->shape != NULL); + vector center(3); + center = shapes_[m]->shape->GetCenter(); + Point3D p_center; + p_center.set(center[0],center[1],center[2]); + vec_center.push_back(p_center); + } + size_t vecsize = vec_center.size(); + + //get the maximum distance among centers + double max_cen_dist; + if (vecsize == 1){ + max_cen_dist = 0; + } + else{ + vector vecdist; + for (size_t m1=0; m1 maxradii; + for (size_t n = 0; n < shapes_.size(); ++n){ + assert(shapes_[n]->shape != NULL); + double maxradius = shapes_[n]->shape->GetMaxRadius(); + maxradii.push_back(maxradius); + } + double max_maxradius = *max_element(maxradii.begin(),maxradii.end()); + + return 2*(max_cen_dist/2 + max_maxradius); +} + + +// Distribute points of the shape we are going to delete +void LORESModel::DistributePoint(const Point3D& point, size_t i) +{ + for (size_t k = i; k < shapes_.size(); ++k) { + assert(shapes_[k]->shape != NULL); + if (shapes_[k]->shape->IsInside(point)) { + shapes_[k]->points.push_back(point); + return; + } + } +} + +void LORESModel::FillPoints(RealSpaceShape* real_shape, double sld) +{ + assert(real_shape != NULL); + + GeoShape* shape = real_shape->shape; + assert(shape != NULL); + + int npoints = static_cast(density_ * shape->GetVolume()); + + for (int i = 0; i < npoints; ++i){ + Point3D apoint = shape->GetAPoint(sld); + apoint.Transform(shape->GetOrientation(),shape->GetCenter()); + if (!IsInside(apoint)){ + real_shape->points.push_back(apoint); + } + } + +} + +bool LORESModel::IsInside(const Point3D& point) +{ + for (RealSpaceShapeCollection::const_iterator it = shapes_.begin(); + it != shapes_.end(); ++it) { + const GeoShape* shape = (*it)->shape; + assert(shape != NULL); + + if (shape->IsInside(point)) { + return true; + } + } + + return false; +} + +vector LORESModel::GetCenter() +{ + //get the vector of centers from the list of shapes + size_t numshapes = 0; + double sumx = 0, sumy = 0, sumz = 0; + for (size_t m = 0; m < shapes_.size(); ++m){ + assert(shapes_[m]->shape != NULL); + vector center(3); + center = shapes_[m]->shape->GetCenter(); + + sumx += center[0]; + sumy += center[1]; + sumz += center[2]; + + ++numshapes; + } + + vector shapescenter(3); + shapescenter[0]= sumx/numshapes; + shapescenter[1]= sumy/numshapes; + shapescenter[2]= sumz/numshapes; + + center_ = shapescenter; + return center_; +} diff --git a/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/lores_model.h b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/lores_model.h new file mode 100755 index 000000000..8d5bd4861 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/lores_model.h @@ -0,0 +1,73 @@ +/** \file lores_model.h child class of PointsModel*/ + +#ifndef LORESMODEL_H +#define LORESMODEL_H + +#include +#include "points_model.h" +#include "geo_shape.h" +//temporary +#include + +/** + * Class LORESModel, low resolution models + */ + +class LORESModel : public PointsModel { + private: + typedef vector PointsVector; + + struct RealSpaceShape { + GeoShape* shape; + PointsVector points; + + RealSpaceShape(GeoShape* s = NULL) : shape(s) { + } + + ~RealSpaceShape() { + if (shape) delete shape; + } + }; + + typedef deque RealSpaceShapeCollection; + + public: + LORESModel(double density = 1.0); + ~LORESModel(); + + // Change density + void SetDensity(double density); + double GetDensity(); + + // Add new shape + void Add(GeoShape& geo_shape, double sld = 1); + + // Delete ith shape at shapes_ + void Delete(size_t i); + + int GetPoints(Point3DVector &); + //Write points to a file, mainly for testing right now + void WritePoints2File(Point3DVector &); + + //get the maximum possible dimension + double GetDimBound(); + + //will be used in determining the maximum distance for + //P(r) calculation for a complex model (merge several + //pointsmodel instance together + vector GetCenter(); + + protected: + GeoShape* GetGeoShape(GeoShape& geo_shape); + void FillPoints(RealSpaceShape* real_shape, double sld); + bool IsInside(const Point3D& point); + void DistributePoint(const Point3D& point, size_t i); + + private: + RealSpaceShapeCollection shapes_; + int npoints_; + double density_; + +}; + +#endif diff --git a/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/pdb_model.cc b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/pdb_model.cc new file mode 100755 index 000000000..34696c921 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/pdb_model.cc @@ -0,0 +1,123 @@ +/** \file pdb_model.cc */ + +#include "pdb_model.h" +#include +#include +#include +#include +#include + +using namespace std; + +typedef map::const_iterator mymapitr; + +PDBModel::PDBModel(){ + res_sld_["ALA"] = 1.645; res_sld_["ARG"] = 3.466; + res_sld_["ASN"] = 3.456; res_sld_["ASP"] = 3.845; + res_sld_["CYS"] = 1.930; res_sld_["GLU"] = 3.762; + res_sld_["GLN"] = 3.373; res_sld_["GLY"] = 1.728; + res_sld_["HIS"] = 4.959; res_sld_["LLE"] = 1.396; + res_sld_["LEU"] = 1.396; res_sld_["LYS"] = 1.586; + res_sld_["MET"] = 1.763; res_sld_["PHE"] = 4.139; + res_sld_["PRO"] = 2.227; res_sld_["SER"] = 2.225; + res_sld_["THR"] = 2.142; res_sld_["TRP"] = 6.035; + res_sld_["TYR"] = 4.719; res_sld_["VAL"] = 1.479; + + x_max_ = -10000; x_min_ = 10000; + y_max_ = -10000; y_min_ = 10000; + z_max_ = -10000; z_min_ = 10000; +} + +void PDBModel::AddPDB(const string &pdbfile) +{ + pdbnames_.push_back(pdbfile); +} + +void PDBModel::AddPDB(const char* pdbfile) +{ + string sname(pdbfile); + pdbnames_.push_back(sname); +} + +int PDBModel::GetPoints(Point3DVector &vp) +{ + if (vp.size() != 0){ + throw runtime_error("PDBModel::GetPoints(Point3DVector &VP):VP has to be empty"); + } + vector::iterator itr; + for (itr = pdbnames_.begin(); itr!=pdbnames_.end(); ++itr){ + ifstream infile(itr->c_str()); + if (!infile){ + cerr << "error: unable to open input file: " + << infile << endl; + } + string line; + while (getline(infile,line)){ + size_t len = line.size(); + string header; + + //for a line with size >=4, retrieve the header + //if the header == "ATOM"" + //make sure the length of the line >=54, then + //retrieve the corresponding coordinates,convert str to double + //then assign them to a point3d,which will be append to vp with a SLD + if (len >= 4){ + for (size_t i=0; i !=4; ++i) header += line[i]; + if (header.compare("ATOM") == 0 && len >53){ + string strx,stry,strz; + double x = 0, y = 0, z = 0; + for (size_t j = 30; j!=38; ++j){ + strx += line[j]; + stry += line[j+8]; + strz += line[j+16]; + } + x = atof(strx.c_str()); + y = atof(stry.c_str()); + z = atof(strz.c_str()); + + //find residue type, and assign SLD + string resname; + for (size_t k = 17; k!=20; ++k) resname +=line[k]; + mymapitr itr = res_sld_.find("ALA"); + + //apoint(x,y,z,sld) + Point3D apoint(x,y,z,itr->second); + vp.push_back(apoint); + + //save x y z' max & min to calculate the size boundary + x_max_ = x > x_max_ ? x : x_max_; + x_min_ = x < x_min_ ? x : x_min_; + y_max_ = y > y_max_ ? y : y_max_; + y_min_ = y < y_min_ ? y : y_min_; + z_max_ = z > z_max_ ? z : z_max_; + z_min_ = z < z_min_ ? z : z_min_; + } + } + } + + infile.close(); + } + return vp.size(); +} + +double PDBModel::GetDimBound() +{ + if (pdbnames_.size() == 0 ) + return 0; + + //cout < PDBModel::GetCenter() +{ + + vector cen(3); + cen[0]= (x_max_+x_min_)/2; + cen[1]= (y_max_+y_min_)/2; + cen[2]= (z_max_+z_min_)/2; + center_ = cen; + + return center_; +} diff --git a/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/pdb_model.h b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/pdb_model.h new file mode 100755 index 000000000..a6e2f6ac6 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/pdb_model.h @@ -0,0 +1,43 @@ +/** \file pdb_model.h child class of PointsModel */ + +#ifndef PDB_MODEL_H +#define PDB_MODEL_H + +#include "points_model.h" +#include +#include + +/** + * Class PDBModel, PDB models + */ + +class PDBModel : public PointsModel { + public: + PDBModel(); + + //add pdb file + void AddPDB(const string &); + void AddPDB(const char*); + + //Parse all coordinates from ATOM section + //of the PDB file into vector of points + int GetPoints(Point3DVector &); + + //will be used in determining the maximum distance for + //P(r) calculation for a complex model (merge several + //pointsmodel instance together + vector GetCenter(); + + //get the maximum possible dimension= dist between XYZmax XYZmin + //May improve to directly get the dimension from unit cell + //if the technique is crystallography + double GetDimBound(); + + private: + vector pdbnames_; + map res_sld_; + double x_max_,x_min_,y_max_,y_min_,z_max_,z_min_; +}; + +#endif + diff --git a/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/points_model.cc b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/points_model.cc new file mode 100755 index 000000000..4b014e186 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/points_model.cc @@ -0,0 +1,461 @@ +/** \file points_model.cc */ + +#include +#include +#include +#include +//#include +#include +#include "points_model.h" +#include "Point3D.h" + +PointsModel::PointsModel() +{ + r_grids_num_ = 2000; + rmax_ = 0; + cormax_ = 0; + rstep_ = 0; +} + +void PointsModel::CalculateIQ(IQ *iq) +{ + //fourier transform of the returned Array2D from ddFunction() + int nIpoints = iq->GetNumI(); + double qstep = (iq->GetQmax()) / (nIpoints-1); + vector fint(nIpoints, 0); + + //I(0) is calculated seperately + int num_rstep = pr_.dim1(); + + for (int k = 1; k from ddFunction() + int num_rstep = pr_.dim1(); + + double r =0; + double debeye = 0; + double fadd = 0; + double Irelative = 0; + + //I(0) is calculated seperately + if (q == 0){ + //I(0) + double Izero = 0; + for (int i = 0; i < num_rstep; ++i) + Izero += pr_[i][1]; + Irelative = Izero; + } + else { + for (int i = 1; i < num_rstep; ++i){ + r = i*rstep_; //r should start from 1* rstep + double qr = q*r; + debeye = sin(qr)/qr; + fadd = pr_[i][1]*debeye; + Irelative = Irelative + fadd; + } + } + return Irelative; +} + +double PointsModel::CalculateIQError(double q) +{ + //fourier transform of the returned Array2D from ddFunction() + int num_rstep = pr_.dim1(); + + double r =0; + double debeye = 0; + double fadd = 0; + double Irelative = 0; + + //I(0) is calculated seperately + for (int i = 1; i < num_rstep; ++i){ + r = i*rstep_; //r should start from 1* rstep + double qr = q*r; + debeye = sin(qr)/qr; + fadd = fabs(pr_[i][2])*debeye*debeye + + rstep_*rstep_/4.0/r/r*(cos(qr)*cos(qr) + debeye*debeye); + Irelative = Irelative + fadd; + } + return sqrt(Irelative); +} + +//pass in a vector of points, and calculate the P(r) +double PointsModel::DistDistribution(const vector &vp) +{ + //get r axis:0,rstep,2rstep,3rstep......d_bound + int sizeofpr = r_grids_num_ + 1; //+1 just for overflow prevention + + double d_bound = GetDimBound(); + rstep_ = CalculateRstep(r_grids_num_,d_bound); + + Array2D pr(sizeofpr, 3); //third column is left for error for the future + pr = 0; + + for (int i = 1; i != sizeofpr; ++i) + pr[i][0] = pr[i-1][0] + rstep_ ; //column 1: distance + + int size = vp.size(); + + for (int i1 = 0; i1 < size - 1; ++i1) { + for (int i2 = i1 + 1; i2 < size; ++i2) { + //dist_.push_back(vp[i1].distanceToPoint(vp[i2])); + //product_sld_.push_back(vp[i1].getSLD() * vp[i2].getSLD()); + double a_dist = vp[i1].distanceToPoint(vp[i2]); + double its_sld = vp[i1].getSLD() * vp[i2].getSLD(); + + //save maximum distance + if (a_dist>rmax_) { + rmax_ = a_dist; + } + //insert into pr array + int l = int(floor(a_dist/rstep_)); + + //cout << "i1,i2,l,a_dist"<= sizeofpr) { + cerr << "one distance is out of range: " << l < cormax_) cormax_ = pr[l][1]; + } + } + } + + //normalize Pr + for (int j = 0; j != sizeofpr; ++j){ //final column2 for P(r) + //pr[j][1] = pr[j][1]/cormax_; + + // 'Size' is the number of space points, without double counting (excluding + // overlapping regions between shapes). The volume of the combined shape + // is given by V = size * (sum of all sub-volumes) / (Total number of points) + // V = size / (lores_density) + + // - To transform the integral to a sum, we need to give a weight + // to each entry equal to the average space volume of a point (w = V/N = 1/lores_density). + // The final output, I(q), should therefore be multiplied by V*V/N*N. + // Since we will be interested in P(r)/V, we only need to multiply by 1/N*(V/N) = 1/N/lores_density. + // We don't have access to lores_density from this class; we will therefore apply + // this correction externally. + // + // - Since the loop goes through half the points, multiply by 2. + // TODO: have access to lores_density from this class. + // + pr[j][1] = 2.0*pr[j][1]/size; + pr[j][2] = 4.0*pr[j][2]/size/size; + } + pr_ = pr; + + return rmax_; +} + +Array2D PointsModel::GetPr() +{ + return pr_; +} + + +double PointsModel::CalculateRstep(int num_grids, double rmax) +{ + assert(num_grids > 0); + + double rstep; + rstep = rmax / num_grids; + + return rstep; +} + +void PointsModel::OutputPR(const string &fpr){ + ofstream outfile(fpr.c_str()); + if (!outfile) { + cerr << "error: unable to open output file: " + << outfile << endl; + exit(1); + } + + double sum = 0.0; + double r_stepsize = 1.0; + if (pr_.dim1()>2) r_stepsize = pr_[1][0] - pr_[0][0]; + + for (int i = 0; i < pr_.dim1(); ++i){ + sum += pr_[i][1]*r_stepsize; + } + + for (int i = 0; i < pr_.dim1(); ++i){ + if (pr_[i][1]==0) continue; + outfile << pr_[i][0] << " " << (pr_[i][1]/sum) << endl; + } +} + +void PointsModel::OutputPDB(const vector &vp,const char *fpr){ + FILE *outfile=NULL; + outfile = fopen(fpr,"w+"); + if (!outfile) { + cerr << "error: unable to open output file: " + << outfile << endl; + exit(1); + } + int size = vp.size(); + int index = 0; + for (int i = 0; i < size; ++i){ + ++index; + fprintf(outfile,"ATOM%7d C%24.3lf%8.3lf%8.3lf%6.3lf\n", \ + index,vp[i].getX(),vp[i].getY(),vp[i].getZ(),vp[i].getSLD()); + } + fclose(outfile); +} + +PointsModel::~PointsModel() +{ +} + +void PointsModel::DistDistributionXY(const vector &vp) +{ + //the max box get from 3D should be more than enough for 2D,but doesn't hurt + double d_bound = GetDimBound(); + + //using 1A for rstep, so the total bins is the max distance for the object + int sizeofpr = ceil(d_bound) + 1; //+1 just for overflow prevention + rstep_ = 1; + + Array2D pr_xy(sizeofpr,sizeofpr); //2D histogram + + //the max frequency in the correlation histogram + double cormax_xy_ = 0; + + //initialization + pr_xy = 0; + + for (int i = 1; i != sizeofpr; ++i){ + pr_xy[i][0] = pr_xy[i-1][0] + rstep_ ; //column 1: distance + } + + int size = vp.size(); + + for (int i1 = 0; i1 < size - 1; ++i1) { + for (int i2 = i1 + 1; i2 < size; ++i2) { + int jx = int(floor(fabs(vp[i1].getX()-vp[i2].getX())/rstep_)); + int jy = int(floor(fabs(vp[i1].getY()-vp[i2].getY())/rstep_)); + //the sld for the pair of points + double its_sld = vp[i1].getSLD()*vp[i2].getSLD(); + + //overflow check + if ((jx >= sizeofpr) || (jy >= sizeofpr)) + { + cerr << "one distance is out of range: " < cormax_xy_ ) cormax_xy_ = pr_xy[jx][jy]; + } + } + } + + //normalize Pr_xy + for (int m = 0; m != sizeofpr; ++m){ //final column2 for P(r) + for (int n = 0; n != sizeofpr; ++n){ + pr_xy[m][n] = pr_xy[m][n]/cormax_xy_; + //cout << "m n:"<GetNumI(); + double qstep = (iq->GetQmax()) / (nIpoints-1); + vector fint(nIpoints, 0); + double Izero = 0; + + //number of bins on x and y axis + int size_r = pr_xy_.dim1(); + //rstep is set to one, otherwise should be cos(phi)*rstep + double cosphi = cos(phi); + double sinphi = sin(phi); + + for(int k = 1; k != nIpoints; ++k){ + double q = k * qstep; + double tmp = cos(q*(cosphi+sinphi)); + + for(int i=0; i!=size_r; ++i){ + for(int j = 0; j!=size_r; ++j){ + fint[k] += pr_xy_[i][j]*tmp; + } + } + } + + for(int i=0; i!=size_r; ++i){ + for(int j = 0; j!=size_r; ++j){ + Izero += pr_xy_[i][j]; + } + } + fint[0] = Izero; + + //assign I(Q) with normalization + for(int j = 0; j < nIpoints; ++j){ + (*iq).iq_data[j][0] = j * qstep; + (*iq).iq_data[j][1] = fint[j] / Izero; + } +} + +vector PointsModel::GetCenter() +{ + vector vp(3,0); + return vp; +} + +double PointsModel::CalculateIQ_2D(double qx, double qy) +{ + //for each (Qx,Qy) on 2D detector, calculate I + double q = sqrt(qx*qx+qy*qy); + double I = 0; + + double cosphi = qx/q; + double sinphi = qy/q; + double tmp = cos(q*(cosphi+sinphi)); + + //loop through P(r) on xy plane + int size_r = pr_xy_.dim1(); + for(int i=-size_r+1; i!=size_r; ++i){ + for(int j = -size_r+1; j!=size_r; ++j){ + //rstep is set to one, left out from calculation + I += pr_xy_[abs(i)][abs(j)]*cos(q*(cosphi*i+sinphi*j)); + } + } + + //return I, without normalization + return I; +} + +/* + * 2D simulation for oriented systems + * The beam direction is assumed to be in the z direction. + * + * @param points: vector of space points + * @param qx: qx [A-1] + * @param qy: qy [A-1] + * @return: I(qx, qy) for the system described by the space points [cm-1] + * + */ +double PointsModel::CalculateIQ_2D(const vector&points, double qx, double qy){ + /* + * TODO: the vector of points should really be part of the class + * This is a design flaw inherited from the original programmer. + */ + + int size = points.size(); + + double cos_term = 0; + double sin_term = 0; + for (int i = 0; i < size; i++) { + //the sld for the pair of points + + double phase = qx*points[i].getX() + qy*points[i].getY(); + + cos_term += cos(phase) * points[i].getSLD(); + sin_term += sin(phase) * points[i].getSLD(); + + } + + // P(q) = 1/V I(q) = (V/N)^2 (1/V) (cos_term^2 + sin_term^2) + // We divide by N here and we will multiply by the density later. + + return (cos_term*cos_term + sin_term*sin_term)/size; +} + +double PointsModel::CalculateIQ_2D_Error(const vector&points, double qx, double qy){ + + int size = points.size(); + + double delta_x, delta_y; + double q_t2 = qx*qx + qy*qy; + double cos_term = 0; + double sin_term = 0; + double cos_err = 0; + double sin_err = 0; + + // Estimate the error on the position of each point + // in x or y as V^(1/3)/N + + for (int i = 0; i < size; i++) { + + + //the sld for the pair of points + + double phase = qx*points[i].getX() + qy*points[i].getY(); + double sld_fac = points[i].getSLD() * points[i].getSLD(); + + cos_term += cos(phase) * points[i].getSLD(); + sin_term += sin(phase) * points[i].getSLD(); + + sin_err += cos(phase) * cos(phase) * sld_fac; + cos_err += sin(phase) * sin(phase) * sld_fac; + + } + + // P(q) = 1/V I(q) = (V/N)^2 (1/V) (cos_term^2 + sin_term^2) + // We divide by N here and we will multiply by the density later. + + // We will need to multiply this error by V^(1/3)/N. + // We don't have access to V from within this class. + return 2*sqrt(cos_term*cos_term*cos_err*cos_err + sin_term*sin_term*sin_err*sin_err)/size; +} + diff --git a/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/points_model.h b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/points_model.h new file mode 100755 index 000000000..daab8e0c0 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/libpointsmodelpy/points_model.h @@ -0,0 +1,61 @@ +/** \file points_model.h child class of SASmodel*/ + +#ifndef POINTSMODEL_H +#define POINTSMODEL_H + +#include +#include "tnt/tnt.h" +#include +#include "Point3D.h" +#include "iq.h" +#include "sas_model.h" + +using namespace std; + +class PointsModel : public SASModel{ + public: + PointsModel(); + virtual ~PointsModel(); + + void CalculateIQ(IQ *iq); + double CalculateIQ(double q); + double CalculateIQError(double q); + + // Old lengthy 2D simulation (unchecked) + void CalculateIQ_2D(IQ *iq,double phi); + double CalculateIQ_2D(double qx, double qy); + + // Fast 2D simulation + double CalculateIQ_2D(const vector&, double qx, double qy); + double CalculateIQ_2D_Error(const vector&, double qx, double qy); + + //given a set of points, calculate distance correlation + //function, and return the max dist + double DistDistribution(const vector&); + void DistDistributionXY(const vector&); + + Array2D GetPr(); + void OutputPR(const std::string &); + void OutputPR_XY(const std::string &); + + virtual int GetPoints(Point3DVector &) = 0; + void OutputPDB(const vector&,const char*); + + //will be used in calculating P(r), the maximum distance + virtual double GetDimBound() = 0; + //will be used to determin the maximum distance for + //several pointsmodel instances + virtual vector GetCenter(); + + protected: + double CalculateRstep(int num_points, double rmax); + double rmax_, rstep_,cormax_,cormax_xy_; + int r_grids_num_; + vector center_; + + private: + Array2D pr_,pr_xy_; + vector product_sld_; +}; + +#endif diff --git a/sas/sascalc/simulation/pointsmodelpy/pointsmodelpy/Make.mm b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpy/Make.mm new file mode 100755 index 000000000..ba5da6472 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpy/Make.mm @@ -0,0 +1,35 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +PROJECT = pointsmodelpy +PACKAGE = pointsmodelpy + +#-------------------------------------------------------------------------- +# + +all: export + + +#-------------------------------------------------------------------------- +# +# export + +EXPORT_PYTHON_MODULES = \ + __init__.py + + +export:: export-python-modules + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/pointsmodelpy/__init__.py b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpy/__init__.py new file mode 100755 index 000000000..0c1f1db8b --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpy/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +def copyright(): + return "pointsmodelpy pyre module: Copyright (c) 1998-2005 Michael A.G. Aivazis"; + + +# version +__id__ = "$Id$" + +# End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/Make.mm b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/Make.mm new file mode 100755 index 000000000..980f1cd98 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/Make.mm @@ -0,0 +1,32 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +PROJECT = SASsimulation +PACKAGE = pointsmodelpymodule +MODULE = pointsmodelpy + +include std-pythonmodule.def +include local.def + +PROJ_CXX_SRCLIB = -lpointsmodelpy -lgeoshapespy -liqPy + +PROJ_SRCS = \ + bindings.cc \ + exceptions.cc \ + misc.cc + + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/bindings.cc b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/bindings.cc new file mode 100755 index 000000000..981f445b2 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/bindings.cc @@ -0,0 +1,166 @@ +// -*- C++ -*- +// @copyright: University of Tennessee, for the DANSE project + +//#include +#include + +#include "bindings.h" + +#include "misc.h" // miscellaneous methods + +// the method table + +struct PyMethodDef pypointsmodelpy_methods[] = { + + // new_loresmodel + {pypointsmodelpy_new_loresmodel__name__, pypointsmodelpy_new_loresmodel, + METH_VARARGS, pypointsmodelpy_new_loresmodel__doc__}, + + // LORESModel method:add(geoshape,sld) + {pypointsmodelpy_lores_add__name__, pypointsmodelpy_lores_add, + METH_VARARGS, pypointsmodelpy_lores_add__doc__}, + + // LORESModel method:GetPoints(vector &) + {pypointsmodelpy_get_lorespoints__name__, pypointsmodelpy_get_lorespoints, + METH_VARARGS, pypointsmodelpy_get_lorespoints__doc__}, + + // new_pdbmodel + {pypointsmodelpy_new_pdbmodel__name__, pypointsmodelpy_new_pdbmodel, + METH_VARARGS, pypointsmodelpy_new_pdbmodel__doc__}, + + // PDBModel method: AddPDB(const char*) + {pypointsmodelpy_pdbmodel_add__name__, pypointsmodelpy_pdbmodel_add, + METH_VARARGS, pypointsmodelpy_pdbmodel_add__doc__}, + + // PDBModel method: GetPoints(Point3DVector &) + {pypointsmodelpy_get_pdbpoints__name__, pypointsmodelpy_get_pdbpoints, + METH_VARARGS, pypointsmodelpy_get_pdbpoints__doc__}, + + // new_complexmodel + {pypointsmodelpy_new_complexmodel__name__, pypointsmodelpy_new_complexmodel, + METH_VARARGS, pypointsmodelpy_new_complexmodel__doc__}, + + // ComplexModel method: Add(PointsModel *) + {pypointsmodelpy_complexmodel_add__name__, pypointsmodelpy_complexmodel_add, + METH_VARARGS, pypointsmodelpy_complexmodel_add__doc__}, + + // ComplexModel method: GetPoints(Point3DVector &) + {pypointsmodelpy_get_complexpoints__name__, pypointsmodelpy_get_complexpoints, + METH_VARARGS, pypointsmodelpy_get_complexpoints__doc__}, + + //new_point3dvec + {pypointsmodelpy_new_point3dvec__name__, pypointsmodelpy_new_point3dvec, + METH_VARARGS, pypointsmodelpy_new_point3dvec__doc__}, + + //fillpoints + //{pypointsmodelpy_fillpoints__name__, pypointsmodelpy_fillpoints, + //METH_VARARGS, pypointsmodelpy_fillpoints__doc__}, + + //distdistribution calculation for LORES model + {pypointsmodelpy_get_lores_pr__name__, pypointsmodelpy_get_lores_pr, + METH_VARARGS, pypointsmodelpy_get_lores_pr__doc__}, + + //distdistribution 2D (on xy plane) + {pypointsmodelpy_distdistribution_xy__name__, pypointsmodelpy_distdistribution_xy, + METH_VARARGS, pypointsmodelpy_distdistribution_xy__doc__}, + + //distdistribution calculation for PDB model + {pypointsmodelpy_get_pdb_pr__name__, pypointsmodelpy_get_pdb_pr, + METH_VARARGS, pypointsmodelpy_get_pdb_pr__doc__}, + + //distdistribution calculation on XY plane for PDB model + {pypointsmodelpy_get_pdb_pr_xy__name__, pypointsmodelpy_get_pdb_pr_xy, + METH_VARARGS, pypointsmodelpy_get_pdb_pr_xy__doc__}, + + //distdistribution calculation for Complex model + {pypointsmodelpy_get_complex_pr__name__, pypointsmodelpy_get_complex_pr, + METH_VARARGS, pypointsmodelpy_get_complex_pr__doc__}, + + //calculateIQ + {pypointsmodelpy_get_lores_iq__name__, pypointsmodelpy_get_lores_iq, + METH_VARARGS, pypointsmodelpy_get_lores_iq__doc__}, + + //calculateI(single Q) + {pypointsmodelpy_get_lores_i__name__, pypointsmodelpy_get_lores_i, + METH_VARARGS, pypointsmodelpy_get_lores_i__doc__}, + + //calculateI(single Q) + {pypointsmodelpy_get_complex_i__name__, pypointsmodelpy_get_complex_i, + METH_VARARGS, pypointsmodelpy_get_complex_i__doc__}, + + //calculateI(single Q) + {pypointsmodelpy_get_complex_i_error__name__, pypointsmodelpy_get_complex_i_error, + METH_VARARGS, pypointsmodelpy_get_complex_i_error__doc__}, + + //calculateIQ 2D + {pypointsmodelpy_calculateIQ_2D__name__, pypointsmodelpy_calculateIQ_2D, + METH_VARARGS, pypointsmodelpy_calculateIQ_2D__doc__}, + + //calculateIQ 2D(points, Qx,Qy) + {pypointsmodelpy_calculateI_Qvxy__name__, pypointsmodelpy_calculateI_Qvxy, + METH_VARARGS, pypointsmodelpy_calculateI_Qvxy__doc__}, + + //calculateIQ 2D(Qx,Qy) + {pypointsmodelpy_calculateI_Qxy__name__, pypointsmodelpy_calculateI_Qxy, + METH_VARARGS, pypointsmodelpy_calculateI_Qxy__doc__}, + + //PDBModel calculateIQ + {pypointsmodelpy_get_pdb_iq__name__, pypointsmodelpy_get_pdb_iq, + METH_VARARGS, pypointsmodelpy_get_pdb_iq__doc__}, + + //PDBModel calculateIQ(Qx,Qy) + {pypointsmodelpy_get_pdb_Iqxy__name__, pypointsmodelpy_get_pdb_Iqxy, + METH_VARARGS, pypointsmodelpy_get_pdb_Iqxy__doc__}, + + //PDBModel calculateIQ(pts,Qx,Qy) + {pypointsmodelpy_get_pdb_Iqvxy__name__, pypointsmodelpy_get_pdb_Iqvxy, + METH_VARARGS, pypointsmodelpy_get_pdb_Iqvxy__doc__}, + + //ComplexModel calculateIQ(pts,Qx,Qy) + {pypointsmodelpy_get_complex_Iqxy__name__, pypointsmodelpy_get_complex_Iqxy, + METH_VARARGS, pypointsmodelpy_get_complex_Iqxy__doc__}, + + //ComplexModel calculateIQ_2D_Error(pts,Qx,Qy) + {pypointsmodelpy_get_complex_Iqxy_err__name__, pypointsmodelpy_get_complex_Iqxy_err, + METH_VARARGS, pypointsmodelpy_get_complex_Iqxy_err__doc__}, + + //ComplexModel calculateIQ + {pypointsmodelpy_get_complex_iq__name__, pypointsmodelpy_get_complex_iq, + METH_VARARGS, pypointsmodelpy_get_complex_iq__doc__}, + + //outputPR + {pypointsmodelpy_outputPR__name__, pypointsmodelpy_outputPR, + METH_VARARGS, pypointsmodelpy_outputPR__doc__}, + + //getPR + {pypointsmodelpy_getPR__name__, pypointsmodelpy_getPR, + METH_VARARGS, pypointsmodelpy_getPR__doc__}, + + //outputPR_xy + {pypointsmodelpy_outputPR_xy__name__, pypointsmodelpy_outputPR_xy, + METH_VARARGS, pypointsmodelpy_outputPR_xy__doc__}, + + //PDBModel outputPR + {pypointsmodelpy_save_pdb_pr__name__, pypointsmodelpy_save_pdb_pr, + METH_VARARGS, pypointsmodelpy_save_pdb_pr__doc__}, + + //ComplexModel outputPR + {pypointsmodelpy_save_complex_pr__name__, pypointsmodelpy_save_complex_pr, + METH_VARARGS, pypointsmodelpy_save_complex_pr__doc__}, + + //outputPDB + {pypointsmodelpy_outputPDB__name__, pypointsmodelpy_outputPDB, + METH_VARARGS, pypointsmodelpy_outputPDB__doc__}, + + {pypointsmodelpy_copyright__name__, pypointsmodelpy_copyright, + METH_VARARGS, pypointsmodelpy_copyright__doc__}, + + +// Sentinel + {0, 0} +}; + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/bindings.h b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/bindings.h new file mode 100755 index 000000000..785b0cace --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/bindings.h @@ -0,0 +1,26 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +#if !defined(pypointsmodelpy_bindings_h) +#define pypointsmodelpy_bindings_h + +// the method table + +extern struct PyMethodDef pypointsmodelpy_methods[]; + +#endif + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/exceptions.cc b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/exceptions.cc new file mode 100755 index 000000000..5af3d780a --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/exceptions.cc @@ -0,0 +1,22 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +//#include +#include + +PyObject *pypointsmodelpy_runtimeError = 0; + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/exceptions.h b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/exceptions.h new file mode 100755 index 000000000..1cbed0a5e --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/exceptions.h @@ -0,0 +1,26 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +#if !defined(pypointsmodelpy_exceptions_h) +#define pypointsmodelpy_exceptions_h + +// exceptions + +extern PyObject *pypointsmodelpy_runtimeError; + +#endif + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/local.def b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/local.def new file mode 100755 index 000000000..181a06cb1 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/local.def @@ -0,0 +1,26 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# + +# C++ + + PROJ_CXX_INCLUDES = .. \ + ../libpointsmodelpy \ + ../../analmodelpy/libanalmodelpy \ + ../../geoshapespy/libgeoshapespy \ + ../../iqPy/libiqPy \ + ../../iqPy/libiqPy/tnt + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/misc.cc b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/misc.cc new file mode 100755 index 000000000..91ed99e93 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/misc.cc @@ -0,0 +1,806 @@ +// -*- C++ -*- +#include + +#include +#include +#include +#include "Point3D.h" +#include "misc.h" +#include "lores_model.h" +#include "pdb_model.h" +#include "complex_model.h" +#include "geo_shape.h" +#include "iq.h" + +// copyright + +char pypointsmodelpy_copyright__doc__[] = ""; +char pypointsmodelpy_copyright__name__[] = "copyright"; + +static char pypointsmodelpy_copyright_note[] = + "pointsmodelpy python module: Copyright (c) 2007 University of Tennessee"; + + +PyObject * pypointsmodelpy_copyright(PyObject *, PyObject *) +{ + return Py_BuildValue("s", pypointsmodelpy_copyright_note); +} + +// new_loresmodel +//wrapper for LORESModel constructor LORESModel(double density) + +char pypointsmodelpy_new_loresmodel__doc__[] = "Low-resolution shapes:real space geometric complex models"; +char pypointsmodelpy_new_loresmodel__name__[] = "new_loresmodel"; + +PyObject * pypointsmodelpy_new_loresmodel(PyObject *, PyObject *args) +{ + double density = 0; + + int ok = PyArg_ParseTuple(args, "d",&density); + if(!ok) return NULL; + + LORESModel *newlores = new LORESModel(density); + return PyCObject_FromVoidPtr(newlores, PyDelLores); +} + +void PyDelLores(void *ptr){ + LORESModel * oldlores = static_cast(ptr); + delete oldlores; + return; +} + +//LORESModel methods add(GeoShape &, double sld) +char pypointsmodelpy_lores_add__name__[] = "lores_add"; +char pypointsmodelpy_lores_add__doc__[] = "loresmodel method:add(Geoshape &,sld)"; + +PyObject * pypointsmodelpy_lores_add(PyObject *, PyObject *args){ + double sld = 1; + PyObject *pyloresmodel = 0, *pyshape = 0; + int ok = PyArg_ParseTuple(args, "OOd", &pyloresmodel, &pyshape, &sld); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pyloresmodel); + void *temp2 = PyCObject_AsVoidPtr(pyshape); + + LORESModel * thislores = static_cast(temp); + GeoShape * thisshape = static_cast(temp2); + + thislores->Add(*thisshape, sld); + + return Py_BuildValue("i", 0); +} + +//LORESModel methods GetPoints(vector &) +char pypointsmodelpy_get_lorespoints__name__[] = "get_lorespoints"; +char pypointsmodelpy_get_lorespoints__doc__[] = "get the points from the lores model"; + +PyObject * pypointsmodelpy_get_lorespoints(PyObject *, PyObject *args){ + PyObject *pyloresmodel = 0, *pypoint3dvec = 0; + int ok = PyArg_ParseTuple(args, "OO", &pyloresmodel, &pypoint3dvec); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pyloresmodel); + void *temp2 = PyCObject_AsVoidPtr(pypoint3dvec); + + LORESModel * thislores = static_cast(temp); + vector * thisvec = static_cast *>(temp2); + + int npts = thislores->GetPoints(*thisvec); + //temporary + thislores->WritePoints2File(*thisvec); + return Py_BuildValue("i", npts); +} + +// new_pdbmodel +//wrapper for PDBModel constructor PDBModel() + +char pypointsmodelpy_new_pdbmodel__doc__[] = "PDB model: contain atomic coordinate from PDB file & Scattering length density"; +char pypointsmodelpy_new_pdbmodel__name__[] = "new_pdbmodel"; + +PyObject * pypointsmodelpy_new_pdbmodel(PyObject *, PyObject *args) +{ + PDBModel *newpdb = new PDBModel(); + return PyCObject_FromVoidPtr(newpdb, PyDelPDB); +} + +void PyDelPDB(void *ptr){ + PDBModel * oldpdb = static_cast(ptr); + delete oldpdb; + return; +} + +//PDBModel methods AddPDB(char * pdbfile) +char pypointsmodelpy_pdbmodel_add__name__[] = "pdbmodel_add"; +char pypointsmodelpy_pdbmodel_add__doc__[] = "Add a structure from PDB"; + +PyObject * pypointsmodelpy_pdbmodel_add(PyObject *, PyObject *args){ + PyObject *pypdbmodel = 0; + char * pdbfile; + + int ok = PyArg_ParseTuple(args, "Os", &pypdbmodel, &pdbfile); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pypdbmodel); + + PDBModel * thispdb = static_cast(temp); + + thispdb->AddPDB(pdbfile); + + return Py_BuildValue("i", 0); +} + +//PDBModel methods GetPoints(Point3DVector &) +char pypointsmodelpy_get_pdbpoints__name__[] = "get_pdbpoints"; +char pypointsmodelpy_get_pdbpoints__doc__[] = "Get atomic points from pdb with SLD"; + +PyObject * pypointsmodelpy_get_pdbpoints(PyObject *, PyObject *args){ + PyObject *pypdbmodel = 0, *pypoint3dvec = 0; + int ok = PyArg_ParseTuple(args, "OO", &pypdbmodel, &pypoint3dvec); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pypdbmodel); + void *temp2 = PyCObject_AsVoidPtr(pypoint3dvec); + + PDBModel * thispdb = static_cast(temp); + vector * thisvec = static_cast *>(temp2); + + int npts = thispdb->GetPoints(*thisvec); + + return Py_BuildValue("i", npts); +} + +// new_complexmodel +//wrapper for ComplexModel constructor ComplexModel() + +char pypointsmodelpy_new_complexmodel__doc__[] = "COMPLEX model: contain LORES and PDB models"; +char pypointsmodelpy_new_complexmodel__name__[] = "new_complexmodel"; + +PyObject * pypointsmodelpy_new_complexmodel(PyObject *, PyObject *args) +{ + ComplexModel *newcomplex = new ComplexModel(); + return PyCObject_FromVoidPtr(newcomplex, PyDelComplex); +} + +void PyDelComplex(void *ptr){ + ComplexModel * oldcomplex = static_cast(ptr); + delete oldcomplex; + return; +} + +//ComplexModel methods Add(PointsModel *) +char pypointsmodelpy_complexmodel_add__name__[] = "complexmodel_add"; +char pypointsmodelpy_complexmodel_add__doc__[] = "Add LORES model or PDB Model,type has to be specified (either PDB or LORES)"; + +PyObject * pypointsmodelpy_complexmodel_add(PyObject *, PyObject *args){ + PyObject *pycomplexmodel = 0, *pymodel = 0; + char * modeltype; + + int ok = PyArg_ParseTuple(args, "OOs", &pycomplexmodel,&pymodel, &modeltype); + if(!ok) return NULL; + + void *temp2 = PyCObject_AsVoidPtr(pycomplexmodel); + ComplexModel *thiscomplex = static_cast(temp2); + + void *temp = PyCObject_AsVoidPtr(pymodel); + if (strcmp(modeltype,"LORES") == 0){ + LORESModel * thislores = static_cast(temp); + thiscomplex->Add(thislores); + } + else if (strcmp(modeltype,"PDB") == 0){ + PDBModel * thispdb = static_cast(temp); + thiscomplex->Add(thispdb); + } + else{ + throw runtime_error("The model type is either PDB or LORES"); + } + + return Py_BuildValue("i", 0); +} + +//ComplexModel methods GetPoints(Point3DVector &) +char pypointsmodelpy_get_complexpoints__name__[] = "get_complexpoints"; +char pypointsmodelpy_get_complexpoints__doc__[] = "Get points from complex model (container for LORES & PDB model)"; + +PyObject * pypointsmodelpy_get_complexpoints(PyObject *, PyObject *args){ + PyObject *pycomplexmodel = 0, *pypoint3dvec = 0; + int ok = PyArg_ParseTuple(args, "OO", &pycomplexmodel, &pypoint3dvec); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pycomplexmodel); + void *temp2 = PyCObject_AsVoidPtr(pypoint3dvec); + + ComplexModel * thiscomplex = static_cast(temp); + vector * thisvec = static_cast *>(temp2); + + int npts = thiscomplex->GetPoints(*thisvec); + + return Py_BuildValue("i", npts); +} + +//create a new vector that holds of class Point3D objects +char pypointsmodelpy_new_point3dvec__doc__[] = ""; +char pypointsmodelpy_new_point3dvec__name__[] = "new_point3dvec"; + +PyObject * pypointsmodelpy_new_point3dvec(PyObject *, PyObject *args) +{ + PyObject *pyvec = 0; + + vector *newvec = new vector(); + + return PyCObject_FromVoidPtr(newvec, PyDelPoint3DVec); +} + +void PyDelPoint3DVec(void *ptr) +{ + vector * oldvec = static_cast *>(ptr); + delete oldvec; + return; + +} + +//LORESModel method distribution(point3dvec) +char pypointsmodelpy_get_lores_pr__name__[] = "get_lores_pr"; +char pypointsmodelpy_get_lores_pr__doc__[] = "calculate distance distribution function"; + +PyObject * pypointsmodelpy_get_lores_pr(PyObject *, PyObject *args) +{ + PyObject *pymodel = 0, *pypoint3dvec = 0; + int ok = PyArg_ParseTuple(args, "OO", &pymodel, &pypoint3dvec); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pymodel); + void *temp2 = PyCObject_AsVoidPtr(pypoint3dvec); + vector * thisvec = static_cast *>(temp2); + + LORESModel * thislores = static_cast(temp); + double rmax = thislores->DistDistribution(*thisvec); + + return Py_BuildValue("d", rmax); +} + +//LORESModel method distribution_xy(point3dvec) +char pypointsmodelpy_distdistribution_xy__name__[] = "distdistribution_xy"; +char pypointsmodelpy_distdistribution_xy__doc__[] = "calculate distance distribution function on XY plane"; + +PyObject * pypointsmodelpy_distdistribution_xy(PyObject *, PyObject *args) +{ + PyObject *pymodel = 0, *pypoint3dvec = 0; + int ok = PyArg_ParseTuple(args, "OO", &pymodel, &pypoint3dvec); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pymodel); + void *temp2 = PyCObject_AsVoidPtr(pypoint3dvec); + + LORESModel * thislores = static_cast(temp); + + Py_BEGIN_ALLOW_THREADS + vector * thisvec = static_cast *>(temp2); + thislores->DistDistributionXY(*thisvec); + Py_END_ALLOW_THREADS + + return Py_BuildValue("i", 0); +} + +//PDBModel method distribution_xy(point3dvec) +char pypointsmodelpy_get_pdb_pr_xy__name__[] = "get_pdb_pr_xy"; +char pypointsmodelpy_get_pdb_pr_xy__doc__[] = "calculate distance distribution function on XY plane"; + +PyObject * pypointsmodelpy_get_pdb_pr_xy(PyObject *, PyObject *args) +{ + PyObject *pymodel = 0, *pypoint3dvec = 0; + int ok = PyArg_ParseTuple(args, "OO", &pymodel, &pypoint3dvec); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pymodel); + void *temp2 = PyCObject_AsVoidPtr(pypoint3dvec); + vector * thisvec = static_cast *>(temp2); + + PDBModel * thispdb = static_cast(temp); + thispdb->DistDistributionXY(*thisvec); + + return Py_BuildValue("i", 0); +} + +//PDBModel method distribution(point3dvec) +char pypointsmodelpy_get_pdb_pr__name__[] = "get_pdb_pr"; +char pypointsmodelpy_get_pdb_pr__doc__[] = "calculate distance distribution function"; + +PyObject * pypointsmodelpy_get_pdb_pr(PyObject *, PyObject *args) +{ + PyObject *pymodel = 0, *pypoint3dvec = 0; + int ok = PyArg_ParseTuple(args, "OO", &pymodel, &pypoint3dvec); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pymodel); + void *temp2 = PyCObject_AsVoidPtr(pypoint3dvec); + vector * thisvec = static_cast *>(temp2); + + Py_BEGIN_ALLOW_THREADS + PDBModel * thispdb = static_cast(temp); + thispdb->DistDistribution(*thisvec); + Py_END_ALLOW_THREADS + + return Py_BuildValue("i", 0); +} + +//ComplexModel method distribution(point3dvec) +char pypointsmodelpy_get_complex_pr__name__[] = "get_complex_pr"; +char pypointsmodelpy_get_complex_pr__doc__[] = "calculate distance distribution function"; + +PyObject * pypointsmodelpy_get_complex_pr(PyObject *, PyObject *args) +{ + PyObject *pymodel = 0, *pypoint3dvec = 0; + int ok = PyArg_ParseTuple(args, "OO", &pymodel, &pypoint3dvec); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pymodel); + void *temp2 = PyCObject_AsVoidPtr(pypoint3dvec); + vector * thisvec = static_cast *>(temp2); + + ComplexModel * thiscomplex = static_cast(temp); + Py_BEGIN_ALLOW_THREADS + thiscomplex->DistDistribution(*thisvec); + Py_END_ALLOW_THREADS + return Py_BuildValue("i", 0); +} + +//LORESModel method CalculateIQ(iq) +char pypointsmodelpy_get_lores_iq__name__[] = "get_lores_iq"; +char pypointsmodelpy_get_lores_iq__doc__[] = "calculate scattering intensity"; + +PyObject * pypointsmodelpy_get_lores_iq(PyObject *, PyObject *args) +{ + PyObject *pylores = 0, *pyiq = 0; + int ok = PyArg_ParseTuple(args, "OO", &pylores, &pyiq); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pylores); + void *temp2 = PyCObject_AsVoidPtr(pyiq); + + LORESModel * thislores = static_cast(temp); + IQ * thisiq = static_cast(temp2); + + Py_BEGIN_ALLOW_THREADS + thislores->CalculateIQ(thisiq); + Py_END_ALLOW_THREADS + + return Py_BuildValue("i",0); +} + +//LORESModel method CalculateIQ(q) +char pypointsmodelpy_get_lores_i__name__[] = "get_lores_i"; +char pypointsmodelpy_get_lores_i__doc__[] = "calculate averaged scattering intensity from a single q"; + +PyObject * pypointsmodelpy_get_lores_i(PyObject *, PyObject *args) +{ + PyObject *pylores = 0; + double q = 0; + int ok = PyArg_ParseTuple(args, "Od", &pylores, &q); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pylores); + + LORESModel * thislores = static_cast(temp); + + double I = 0.0; + Py_BEGIN_ALLOW_THREADS + I = thislores->CalculateIQ(q); + Py_END_ALLOW_THREADS + + return Py_BuildValue("d",I); +} + +// method calculateIQ_2D(iq) +char pypointsmodelpy_calculateIQ_2D__name__[] = "calculateIQ_2D"; +char pypointsmodelpy_calculateIQ_2D__doc__[] = "calculate scattering intensity"; + +PyObject * pypointsmodelpy_calculateIQ_2D(PyObject *, PyObject *args) +{ + PyObject *pylores = 0, *pyiq = 0; + double theta = 0; + int ok = PyArg_ParseTuple(args, "OOd", &pylores, &pyiq,&theta); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pylores); + void *temp2 = PyCObject_AsVoidPtr(pyiq); + + LORESModel * thislores = static_cast(temp); + IQ * thisiq = static_cast(temp2); + + Py_BEGIN_ALLOW_THREADS + thislores->CalculateIQ_2D(thisiq,theta); + Py_END_ALLOW_THREADS + + return Py_BuildValue("i",0); +} + +// method calculateI_Qxy(Qx,Qy) +char pypointsmodelpy_calculateI_Qxy__name__[] = "calculateI_Qxy"; +char pypointsmodelpy_calculateI_Qxy__doc__[] = "calculate scattering intensity on a 2D pixel"; + +PyObject * pypointsmodelpy_calculateI_Qxy(PyObject *, PyObject *args) +{ + PyObject *pylores = 0; + double qx = 0, qy = 0; + double I = 0; + + int ok = PyArg_ParseTuple(args, "Odd", &pylores, &qx,&qy); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pylores); + LORESModel * thislores = static_cast(temp); + + Py_BEGIN_ALLOW_THREADS + I = thislores->CalculateIQ_2D(qx,qy); + Py_END_ALLOW_THREADS + + return Py_BuildValue("d",I); +} + +// method calculateI_Qxy(poitns, Qx,Qy) +char pypointsmodelpy_calculateI_Qvxy__name__[] = "calculateI_Qvxy"; +char pypointsmodelpy_calculateI_Qvxy__doc__[] = "calculate scattering intensity on a 2D pixel"; + +PyObject * pypointsmodelpy_calculateI_Qvxy(PyObject *, PyObject *args) +{ + PyObject *pylores = 0, *pypoint3dvec = 0; + double qx = 0, qy = 0; + double I = 0; + + int ok = PyArg_ParseTuple(args, "OOdd", &pylores, &pypoint3dvec, &qx,&qy); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pylores); + LORESModel * thislores = static_cast(temp); + void *temp2 = PyCObject_AsVoidPtr(pypoint3dvec); + vector * thisvec = static_cast *>(temp2); + + Py_BEGIN_ALLOW_THREADS + I = thislores->CalculateIQ_2D(*thisvec, qx,qy); + Py_END_ALLOW_THREADS + + return Py_BuildValue("d",I); +} + +// PDBModel method calculateIQ(iq) +char pypointsmodelpy_get_pdb_iq__name__[] = "get_pdb_iq"; +char pypointsmodelpy_get_pdb_iq__doc__[] = "calculate scattering intensity for PDB model"; + +PyObject * pypointsmodelpy_get_pdb_iq(PyObject *, PyObject *args) +{ + PyObject *pymodel = 0, *pyiq = 0; + int ok = PyArg_ParseTuple(args, "OO", &pymodel, &pyiq); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pymodel); + void *temp2 = PyCObject_AsVoidPtr(pyiq); + + PDBModel * thispdb = static_cast(temp); + IQ * thisiq = static_cast(temp2); + + Py_BEGIN_ALLOW_THREADS + thispdb->CalculateIQ(thisiq); + Py_END_ALLOW_THREADS + + return Py_BuildValue("i",0); +} + +// PDBModel method calculateIQ_2D(qx,qy) +char pypointsmodelpy_get_pdb_Iqxy__name__[] = "get_pdb_Iqxy"; +char pypointsmodelpy_get_pdb_Iqxy__doc__[] = "calculate scattering intensity by a given (qx,qy) for PDB model"; + +PyObject * pypointsmodelpy_get_pdb_Iqxy(PyObject *, PyObject *args) +{ + PyObject *pypdb = 0; + double qx = 0, qy = 0; + double I = 0; + + int ok = PyArg_ParseTuple(args, "Odd", &pypdb, &qx,&qy); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pypdb); + PDBModel * thispdb = static_cast(temp); + + Py_BEGIN_ALLOW_THREADS + I = thispdb->CalculateIQ_2D(qx,qy); + Py_END_ALLOW_THREADS + + return Py_BuildValue("d",I); +} + +// PDBModel method calculateIQ_2Dv(points,qx,qy) +char pypointsmodelpy_get_pdb_Iqvxy__name__[] = "get_pdb_Iqvxy"; +char pypointsmodelpy_get_pdb_Iqvxy__doc__[] = "calculate scattering intensity by a given (qx,qy) for PDB model"; + +PyObject * pypointsmodelpy_get_pdb_Iqvxy(PyObject *, PyObject *args) +{ + PyObject *pypdb = 0, *pypoint3dvec = 0; + double qx = 0, qy = 0; + double I = 0; + + int ok = PyArg_ParseTuple(args, "OOdd", &pypdb, &pypoint3dvec, &qx,&qy); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pypdb); + PDBModel * thispdb = static_cast(temp); + void *temp2 = PyCObject_AsVoidPtr(pypoint3dvec); + vector * thisvec = static_cast *>(temp2); + + Py_BEGIN_ALLOW_THREADS + I = thispdb->CalculateIQ_2D(*thisvec,qx,qy); + Py_END_ALLOW_THREADS + + return Py_BuildValue("d",I); +} + +// ComplexModel method calculateIQ(iq) +char pypointsmodelpy_get_complex_iq__name__[] = "get_complex_iq"; +char pypointsmodelpy_get_complex_iq__doc__[] = "calculate scattering intensity for COMPLEX model"; + +PyObject * pypointsmodelpy_get_complex_iq(PyObject *, PyObject *args) +{ + PyObject *pymodel = 0, *pyiq = 0; + int ok = PyArg_ParseTuple(args, "OO", &pymodel, &pyiq); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pymodel); + void *temp2 = PyCObject_AsVoidPtr(pyiq); + + ComplexModel * thiscomplex = static_cast(temp); + IQ * thisiq = static_cast(temp2); + + Py_BEGIN_ALLOW_THREADS + thiscomplex->CalculateIQ(thisiq); + Py_END_ALLOW_THREADS + + return Py_BuildValue("i",0); +} + +//LORESModel method CalculateIQ_2D(points,qx,qy) +char pypointsmodelpy_get_complex_Iqxy__name__[] = "get_complex_iq_2D"; +char pypointsmodelpy_get_complex_Iqxy__doc__[] = "calculate averaged scattering intensity from a single q"; + +PyObject * pypointsmodelpy_get_complex_Iqxy(PyObject *, PyObject *args) +{ + PyObject *pylores = 0, *pypoint3dvec = 0; + double qx = 0, qy = 0; + int ok = PyArg_ParseTuple(args, "OOdd", &pylores, &pypoint3dvec, &qx, &qy); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pylores); + ComplexModel * thiscomplex = static_cast(temp); + void *temp2 = PyCObject_AsVoidPtr(pypoint3dvec); + vector * thisvec = static_cast *>(temp2); + + double I = 0.0; + Py_BEGIN_ALLOW_THREADS + I = thiscomplex->CalculateIQ_2D(*thisvec,qx,qy); + Py_END_ALLOW_THREADS + + return Py_BuildValue("d",I); +} + +//LORESModel method CalculateIQ_2D_Error(points,qx,qy) +char pypointsmodelpy_get_complex_Iqxy_err__name__[] = "get_complex_iq_2D_err"; +char pypointsmodelpy_get_complex_Iqxy_err__doc__[] = "calculate averaged scattering intensity from a single q"; + +PyObject * pypointsmodelpy_get_complex_Iqxy_err(PyObject *, PyObject *args) +{ + PyObject *pylores = 0, *pypoint3dvec = 0; + double qx = 0, qy = 0; + int ok = PyArg_ParseTuple(args, "OOdd", &pylores, &pypoint3dvec, &qx, &qy); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pylores); + ComplexModel * thiscomplex = static_cast(temp); + void *temp2 = PyCObject_AsVoidPtr(pypoint3dvec); + vector * thisvec = static_cast *>(temp2); + + double I = 0.0; + Py_BEGIN_ALLOW_THREADS + I = thiscomplex->CalculateIQ_2D_Error(*thisvec,qx,qy); + Py_END_ALLOW_THREADS + + return Py_BuildValue("d",I); +} + +//LORESModel method CalculateIQ(q) +char pypointsmodelpy_get_complex_i__name__[] = "get_complex_i"; +char pypointsmodelpy_get_complex_i__doc__[] = "calculate averaged scattering intensity from a single q"; + +PyObject * pypointsmodelpy_get_complex_i(PyObject *, PyObject *args) +{ + PyObject *pylores = 0; + double q = 0; + int ok = PyArg_ParseTuple(args, "Od", &pylores, &q); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pylores); + + ComplexModel * thiscomplex = static_cast(temp); + + double I = 0.0; + Py_BEGIN_ALLOW_THREADS + I = thiscomplex->CalculateIQ(q); + Py_END_ALLOW_THREADS + + return Py_BuildValue("d",I); +} + +char pypointsmodelpy_get_complex_i_error__name__[] = "get_complex_i_error"; +char pypointsmodelpy_get_complex_i_error__doc__[] = "calculate error on averaged scattering intensity from a single q"; + +PyObject * pypointsmodelpy_get_complex_i_error(PyObject *, PyObject *args) +{ + PyObject *pylores = 0; + double q = 0; + int ok = PyArg_ParseTuple(args, "Od", &pylores, &q); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pylores); + + ComplexModel * thiscomplex = static_cast(temp); + + double I = 0.0; + Py_BEGIN_ALLOW_THREADS + I = thiscomplex->CalculateIQError(q); + Py_END_ALLOW_THREADS + + return Py_BuildValue("d",I); +} + + + + +//method outputPR(string filename) +char pypointsmodelpy_outputPR__name__[] = "outputPR"; +char pypointsmodelpy_outputPR__doc__[] = "print out P(R) to a file"; + +PyObject * pypointsmodelpy_outputPR(PyObject *, PyObject *args) +{ + PyObject *pymodel = 0; + char *outfile; + int ok = PyArg_ParseTuple(args, "Os", &pymodel, &outfile); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pymodel); + + LORESModel * thislores = static_cast(temp); + + thislores->OutputPR(outfile); + + return Py_BuildValue("i", 0); +} + + +//method get_pr() +char pypointsmodelpy_getPR__name__[] = "get_pr"; +char pypointsmodelpy_getPR__doc__[] = "Return P(r) as a list of points"; + +PyObject * pypointsmodelpy_getPR(PyObject *, PyObject *args) +{ + PyObject *pymodel = 0; + char *outfile; + int ok = PyArg_ParseTuple(args, "O", &pymodel); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pymodel); + + LORESModel * thislores = static_cast(temp); + + // Get the P(r) array + Array2D pr_ = thislores->GetPr(); + + // Create two lists to store the r and P(r) values + PyObject* r_list = PyList_New(0); + PyObject* pr_list = PyList_New(0); + + double sum = 0.0; + double r_stepsize = 1.0; + if (pr_.dim1()>2) r_stepsize = pr_[1][0] - pr_[0][0]; + + for (int i = 0; i < pr_.dim1(); ++i){ + sum += pr_[i][1]*r_stepsize; + } + + for (int i = 0; i < pr_.dim1(); ++i){ + if (pr_[i][1]==0) continue; + int r_append = PyList_Append(r_list, Py_BuildValue("d", pr_[i][0])); + int pr_append = PyList_Append(pr_list, Py_BuildValue("d", pr_[i][1]/sum)); + if (r_append+pr_append<0) return NULL; + } + + return Py_BuildValue("OO", r_list, pr_list); +} + + + +//method outputPR_xy(string filename) +char pypointsmodelpy_outputPR_xy__name__[] = "outputPR_xy"; +char pypointsmodelpy_outputPR_xy__doc__[] = "print out P(R) to a file"; + +PyObject * pypointsmodelpy_outputPR_xy(PyObject *, PyObject *args) +{ + PyObject *pyloresmodel = 0; + char *outfile; + int ok = PyArg_ParseTuple(args, "Os", &pyloresmodel, &outfile); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pyloresmodel); + + LORESModel * thislores = static_cast(temp); + + thislores->OutputPR_XY(outfile); + + return Py_BuildValue("i", 0); +} + +//PDBModel method outputPR(string filename) +char pypointsmodelpy_save_pdb_pr__name__[] = "save_pdb_pr"; +char pypointsmodelpy_save_pdb_pr__doc__[] = "print out P(R) to a file"; + +PyObject * pypointsmodelpy_save_pdb_pr(PyObject *, PyObject *args) +{ + PyObject *pymodel = 0; + char *outfile; + int ok = PyArg_ParseTuple(args, "Os", &pymodel, &outfile); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pymodel); + + PDBModel * thispdb = static_cast(temp); + + thispdb->OutputPR(outfile); + + return Py_BuildValue("i", 0); +} + +//ComplexModel method outputPR(string filename) +char pypointsmodelpy_save_complex_pr__name__[] = "save_complex_pr"; +char pypointsmodelpy_save_complex_pr__doc__[] = "print out P(R) to a file"; + +PyObject * pypointsmodelpy_save_complex_pr(PyObject *, PyObject *args) +{ + PyObject *pymodel = 0; + char *outfile; + int ok = PyArg_ParseTuple(args, "Os", &pymodel, &outfile); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pymodel); + + ComplexModel * thiscomplex = static_cast(temp); + + thiscomplex->OutputPR(outfile); + + return Py_BuildValue("i", 0); +} + + +//method outputPDB(string filename) +char pypointsmodelpy_outputPDB__name__[] = "outputPDB"; +char pypointsmodelpy_outputPDB__doc__[] = "save the monte-carlo distributed points of the geomodel into a PDB format file.\ + a .pdb extension will be automatically added"; + +PyObject * pypointsmodelpy_outputPDB(PyObject *, PyObject *args) +{ + PyObject *pyloresmodel = 0, *pypoint3dvec=0; + char *outfile; + int ok = PyArg_ParseTuple(args, "OOs", &pyloresmodel, &pypoint3dvec,&outfile); + if(!ok) return NULL; + + void *temp = PyCObject_AsVoidPtr(pyloresmodel); + + LORESModel * thislores = static_cast(temp); + + void *temp2 = PyCObject_AsVoidPtr(pypoint3dvec); + vector * thisvec = static_cast *>(temp2); + + thislores->OutputPDB(*thisvec,outfile); + + return Py_BuildValue("i", 0); +} + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/misc.h b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/misc.h new file mode 100755 index 000000000..2f0175eea --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/misc.h @@ -0,0 +1,238 @@ +#if !defined(pypointsmodelpy_misc_h) +#define pypointsmodelpy_misc_h + +// @copyright: University of Tennessee, for the DANSE project +extern char pypointsmodelpy_copyright__name__[]; +extern char pypointsmodelpy_copyright__doc__[]; +extern "C" +PyObject * pypointsmodelpy_copyright(PyObject *, PyObject *); + +// LORESModel constructor LORESModel(double density) +extern char pypointsmodelpy_new_loresmodel__name__[]; +extern char pypointsmodelpy_new_loresmodel__doc__[]; +extern "C" +PyObject * pypointsmodelpy_new_loresmodel(PyObject *, PyObject *); + +//Clean LORESModel constructor memory usage +static void PyDelLores(void *); + +//LORESModel methods add(GeoShapes &, double sld) +extern char pypointsmodelpy_lores_add__name__[]; +extern char pypointsmodelpy_lores_add__doc__[]; +extern "C" +PyObject * pypointsmodelpy_lores_add(PyObject *, PyObject *); + +//LORESModel methods GetPoints(vector &) +extern char pypointsmodelpy_get_lorespoints__name__[]; +extern char pypointsmodelpy_get_lorespoints__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_lorespoints(PyObject *, PyObject *); + +//PDBModel constructor PDBModel() +extern char pypointsmodelpy_new_pdbmodel__name__[]; +extern char pypointsmodelpy_new_pdbmodel__doc__[]; +extern "C" +PyObject * pypointsmodelpy_new_pdbmodel(PyObject *, PyObject *); + +//Clean PDBModel constructor memory usage +static void PyDelPDB(void *); + +//PDBModel method AddPDB(string) +extern char pypointsmodelpy_pdbmodel_add__name__[]; +extern char pypointsmodelpy_pdbmodel_add__doc__[]; +extern "C" +PyObject * pypointsmodelpy_pdbmodel_add(PyObject *, PyObject *); + +//PDBModel method GetPoints(Point3DVector &) +extern char pypointsmodelpy_get_pdbpoints__name__[]; +extern char pypointsmodelpy_get_pdbpoints__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_pdbpoints(PyObject *, PyObject *); + +//ComplexModel constructor ComplexModel() +extern char pypointsmodelpy_new_complexmodel__name__[]; +extern char pypointsmodelpy_new_complexmodel__doc__[]; +extern "C" +PyObject * pypointsmodelpy_new_complexmodel(PyObject *, PyObject *); + +//Clean Complexodel constructor memory usage +static void PyDelComplex(void *); + +//ComplexModel method AddComplex(string) +extern char pypointsmodelpy_complexmodel_add__name__[]; +extern char pypointsmodelpy_complexmodel_add__doc__[]; +extern "C" +PyObject * pypointsmodelpy_complexmodel_add(PyObject *, PyObject *); + +//ComplexModel method GetPoints(Point3DVector &) +extern char pypointsmodelpy_get_complexpoints__name__[]; +extern char pypointsmodelpy_get_complexpoints__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_complexpoints(PyObject *, PyObject *); + +//generate a new vector of points3d +extern char pypointsmodelpy_new_point3dvec__name__[]; +extern char pypointsmodelpy_new_point3dvec__doc__[]; +extern "C" +PyObject * pypointsmodelpy_new_point3dvec(PyObject *, PyObject *); + +//clean new_point3dvec +static void PyDelPoint3DVec(void *); + +// method FillPoints(loresmodel, point3dvec) +//extern char pypointsmodelpy_fillpoints__name__[]; +//extern char pypointsmodelpy_fillpoints__doc__[]; +//extern "C" +//PyObject * pypointsmodelpy_fillpoints(PyObject *, PyObject *); + +// LORESModel method distdistribution(point3dvec) +extern char pypointsmodelpy_get_lores_pr__name__[]; +extern char pypointsmodelpy_get_lores_pr__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_lores_pr(PyObject *, PyObject *); + +// method distdistribution_xy(point3dvec) +extern char pypointsmodelpy_distdistribution_xy__name__[]; +extern char pypointsmodelpy_distdistribution_xy__doc__[]; +extern "C" +PyObject * pypointsmodelpy_distdistribution_xy(PyObject *, PyObject *); + +// PDBModel method distdistribution(point3dvec) +extern char pypointsmodelpy_get_pdb_pr__name__[]; +extern char pypointsmodelpy_get_pdb_pr__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_pdb_pr(PyObject *, PyObject *); + +// PDBModel method distdistribution_xy(point3dvec) +extern char pypointsmodelpy_get_pdb_pr_xy__name__[]; +extern char pypointsmodelpy_get_pdb_pr_xy__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_pdb_pr_xy(PyObject *, PyObject *); + +// ComplexModel method distdistribution(point3dvec) +extern char pypointsmodelpy_get_complex_pr__name__[]; +extern char pypointsmodelpy_get_complex_pr__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_complex_pr(PyObject *, PyObject *); + +// LORESModel method calculateIQ(iq) +extern char pypointsmodelpy_get_lores_iq__name__[]; +extern char pypointsmodelpy_get_lores_iq__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_lores_iq(PyObject *, PyObject *); + +// LORESModel method CalculateIQ(q) +extern char pypointsmodelpy_get_lores_i__name__[]; +extern char pypointsmodelpy_get_lores_i__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_lores_i(PyObject *, PyObject *); + +// ComplexModel method CalculateIQ(q) +extern char pypointsmodelpy_get_complex_i__name__[]; +extern char pypointsmodelpy_get_complex_i__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_complex_i(PyObject *, PyObject *); + +// ComplexModel method CalculateIQError(q) +extern char pypointsmodelpy_get_complex_i_error__name__[]; +extern char pypointsmodelpy_get_complex_i_error__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_complex_i_error(PyObject *, PyObject *); + +// method calculateIQ_2D(iq,theta) +extern char pypointsmodelpy_calculateIQ_2D__name__[]; +extern char pypointsmodelpy_calculateIQ_2D__doc__[]; +extern "C" +PyObject * pypointsmodelpy_calculateIQ_2D(PyObject *, PyObject *); + +// method calculateI_Qxy(Qx,Qy) +extern char pypointsmodelpy_calculateI_Qxy__name__[]; +extern char pypointsmodelpy_calculateI_Qxy__doc__[]; +extern "C" +PyObject * pypointsmodelpy_calculateI_Qxy(PyObject *, PyObject *); + +// method calculateI_Qvxy(points,Qx,Qy) +extern char pypointsmodelpy_calculateI_Qvxy__name__[]; +extern char pypointsmodelpy_calculateI_Qvxy__doc__[]; +extern "C" +PyObject * pypointsmodelpy_calculateI_Qvxy(PyObject *, PyObject *); + +// PDBModel method calculateIQ(iq) +extern char pypointsmodelpy_get_pdb_iq__name__[]; +extern char pypointsmodelpy_get_pdb_iq__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_pdb_iq(PyObject *, PyObject *); + +// PDBModel method calculateIQ_2D(qx,qy) +extern char pypointsmodelpy_get_pdb_Iqxy__name__[]; +extern char pypointsmodelpy_get_pdb_Iqxy__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_pdb_Iqxy(PyObject *, PyObject *); + +// PDBModel method calculateIQ_2D(pts,qx,qy) +extern char pypointsmodelpy_get_pdb_Iqvxy__name__[]; +extern char pypointsmodelpy_get_pdb_Iqvxy__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_pdb_Iqvxy(PyObject *, PyObject *); + +// ComplexModel method calculateIQ_2D(pts,qx,qy) +extern char pypointsmodelpy_get_complex_Iqxy__name__[]; +extern char pypointsmodelpy_get_complex_Iqxy__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_complex_Iqxy(PyObject *, PyObject *); + +// ComplexModel method calculateIQ_2D_Error(pts,qx,qy) +extern char pypointsmodelpy_get_complex_Iqxy_err__name__[]; +extern char pypointsmodelpy_get_complex_Iqxy_err__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_complex_Iqxy_err(PyObject *, PyObject *); + +// ComplexModel method calculateIQ(iq) +extern char pypointsmodelpy_get_complex_iq__name__[]; +extern char pypointsmodelpy_get_complex_iq__doc__[]; +extern "C" +PyObject * pypointsmodelpy_get_complex_iq(PyObject *, PyObject *); + +// method outputPR +extern char pypointsmodelpy_outputPR__name__[]; +extern char pypointsmodelpy_outputPR__doc__[]; +extern "C" +PyObject * pypointsmodelpy_outputPR(PyObject *, PyObject *); + +//method get_pr() +extern char pypointsmodelpy_getPR__name__[]; +extern char pypointsmodelpy_getPR__doc__[]; +extern "C" +PyObject * pypointsmodelpy_getPR(PyObject *, PyObject *); + + +// method outputPR_xy +extern char pypointsmodelpy_outputPR_xy__name__[]; +extern char pypointsmodelpy_outputPR_xy__doc__[]; +extern "C" +PyObject * pypointsmodelpy_outputPR_xy(PyObject *, PyObject *); + +// PDBModel method outputPR +extern char pypointsmodelpy_save_pdb_pr__name__[]; +extern char pypointsmodelpy_save_pdb_pr__doc__[]; +extern "C" +PyObject * pypointsmodelpy_save_pdb_pr(PyObject *, PyObject *); + +// ComplexModel method outputPR +extern char pypointsmodelpy_save_complex_pr__name__[]; +extern char pypointsmodelpy_save_complex_pr__doc__[]; +extern "C" +PyObject * pypointsmodelpy_save_complex_pr(PyObject *, PyObject *); + +// method outputPDB +extern char pypointsmodelpy_outputPDB__name__[]; +extern char pypointsmodelpy_outputPDB__doc__[]; +extern "C" +PyObject * pypointsmodelpy_outputPDB(PyObject *, PyObject *); + +#endif + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/pointsmodelpymodule.cc b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/pointsmodelpymodule.cc new file mode 100755 index 000000000..49599b5df --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/pointsmodelpymodule/pointsmodelpymodule.cc @@ -0,0 +1,52 @@ +// -*- C++ -*- +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// +// Michael A.G. Aivazis +// California Institute of Technology +// (C) 1998-2005 All Rights Reserved +// +// +// +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// + +//#include + +#include + +#include "exceptions.h" +#include "bindings.h" + + +char pypointsmodelpy_module__doc__[] = ""; + +// Initialization function for the module (*must* be called initpointsmodelpy) +extern "C" +void +initpointsmodelpy() +{ + // create the module and add the functions + PyObject * m = Py_InitModule4( + "pointsmodelpy", pypointsmodelpy_methods, + pypointsmodelpy_module__doc__, 0, PYTHON_API_VERSION); + + // get its dictionary + PyObject * d = PyModule_GetDict(m); + + // check for errors + if (PyErr_Occurred()) { + Py_FatalError("can't initialize module pointsmodelpy"); + } + + // install the module exceptions + pypointsmodelpy_runtimeError = PyErr_NewException("pointsmodelpy.runtime", 0, 0); + PyDict_SetItemString(d, "RuntimeException", pypointsmodelpy_runtimeError); + + return; +} + +// version +// $Id$ + +// End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/tests/Make.mm b/sas/sascalc/simulation/pointsmodelpy/tests/Make.mm new file mode 100755 index 000000000..d79164e19 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/tests/Make.mm @@ -0,0 +1,61 @@ +# -*- Makefile -*- +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +PROJECT = pointsmodelpy +PACKAGE = tests + +PROJ_CLEAN += $(PROJ_CPPTESTS) + +PROJ_CXX_INCLUDES += ../../iqPy/libiqPy \ + ../../iqPy/libiqPy/tnt \ + ../../geoshapespy/libgeoshapespy \ + ../../analmodelpy/libanalmodelpy \ + ../libpointsmodelpy + +PROJ_PYTESTS = signon.py +PROJ_CPPTESTS = testlores \ + testpdb \ + testcomplexmodel +PROJ_TESTS = $(PROJ_PYTESTS) $(PROJ_CPPTESTS) +PROJ_LIBRARIES = -L$(BLD_LIBDIR) -lpointsmodelpy -lgeoshapespy -liqPy + + +#-------------------------------------------------------------------------- +# + +all: $(PROJ_TESTS) + +test: + for test in $(PROJ_TESTS) ; do $${test}; done + +release: tidy + cvs release . + +update: clean + cvs update . + +#-------------------------------------------------------------------------- +# + +testlores: testlores.cc $(BLD_LIBDIR)/libpointsmodelpy.$(EXT_SAR) + $(CXX) $(CXXFLAGS) $(LCXXFLAGS) -g -o $@ testlores.cc $(PROJ_LIBRARIES) + +testpdb: testpdb.cc $(BLD_LIBDIR)/libpointsmodelpy.$(EXT_SAR) + $(CXX) $(CXXFLAGS) $(LCXXFLAGS) -g -o $@ testpdb.cc $(PROJ_LIBRARIES) + +testcomplexmodel: testcomplexmodel.cc $(BLD_LIBDIR)/libpointsmodelpy.$(EXT_SAR) + $(CXX) $(CXXFLAGS) $(LCXXFLAGS) -g -o $@ testcomplexmodel.cc $(PROJ_LIBRARIES) + +# version +# $Id$ + +# End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/tests/signon.py b/sas/sascalc/simulation/pointsmodelpy/tests/signon.py new file mode 100755 index 000000000..332eb949c --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/tests/signon.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Michael A.G. Aivazis +# California Institute of Technology +# (C) 1998-2005 All Rights Reserved +# +# +# +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +from __future__ import print_function + + +if __name__ == "__main__": + + import pointsmodelpy + from pointsmodelpy import pointsmodelpy as pointsmodelpymodule + + print("copyright information:") + print(" ", pointsmodelpy.copyright()) + print(" ", pointsmodelpymodule.copyright()) + + print() + print("module information:") + print(" file:", pointsmodelpymodule.__file__) + print(" doc:", pointsmodelpymodule.__doc__) + print(" contents:", dir(pointsmodelpymodule)) + + print() + print(pointsmodelpymodule.hello()) + +# version +__id__ = "$Id$" + +# End of file diff --git a/sas/sascalc/simulation/pointsmodelpy/tests/test.pdb b/sas/sascalc/simulation/pointsmodelpy/tests/test.pdb new file mode 100755 index 000000000..50d49add5 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/tests/test.pdb @@ -0,0 +1,10 @@ +REMARK FILENAME="if0.pdb" +REMARK TOPH19.pep -MACRO for protein sequence +REMARK DATE:17-Aug-02 11:05:39 created by user: +ATOM 1 CB SER 1 -1.115 -11.617 -11.368 +ATOM 2 OG SER 1 0.202 -11.845 -10.894 1.00 54.67 +ATOM 3 C SER 1 -2.076 -12.837 -9.369 1.00 44.10 +ATOM 4 O SER 1 -3.108 -12.947 -8.704 1.00 43.77 +ATOM 5 N SER 1 -3.465 -12.399 -11.363 1.00 83.41 +ATOM 6 CA SER 1 -2.108 -12.709 -10.900 1.00 57.67 +END diff --git a/sas/sascalc/simulation/pointsmodelpy/tests/test2.pdb b/sas/sascalc/simulation/pointsmodelpy/tests/test2.pdb new file mode 100755 index 000000000..f68ed2fac --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/tests/test2.pdb @@ -0,0 +1,7 @@ +ATOM 1 CB SER 1 -1.115 -11.617 -11.368 1.00 52.49 +ATOM 2 OG SER 1 0.202 -11.845 -10.894 1.00 54.67 +ATOM 3 C SER 1 -2.076 -12.837 -9.369 1.00 44.10 +ATOM 4 O SER 1 -3.108 -12.947 -8.704 1.00 43.77 +ATOM 5 N SER 1 -3.465 -12.399 -11.363 1.00 83.41 +ATOM 6 CA SER 1 -2.108 -12.709 -10.900 1.00 57.67 +END diff --git a/sas/sascalc/simulation/pointsmodelpy/tests/test2dui.py b/sas/sascalc/simulation/pointsmodelpy/tests/test2dui.py new file mode 100755 index 000000000..6d076d79e --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/tests/test2dui.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python +""" +Demonstration of drawing a 2D image plot using the "hot" colormap +""" + +#-------------------------------------------------------------------------------- +# Imports: +#-------------------------------------------------------------------------------- +from __future__ import print_function + +import wx + +from enthought.traits import Any, Instance +from enthought.enable.wx import Window +from enthought.pyface import ApplicationWindow, GUI +from enthought.util.numerix import pi, concatenate, array, zeros, ones, \ + arange, resize, ravel +from enthought.util.numerix import Float as NumericFloat +from math import sqrt, sin + +from enthought.chaco.plot_component import PlotComponent +from enthought.chaco.plot_axis import PlotAxis +from enthought.chaco.plot_canvas import PlotCanvas +from enthought.chaco.plot_group import PlotGroup +from enthought.chaco.image_plot_value import ImageData, CmapImagePlotValue +from enthought.chaco.colormap import LinearColormap +from enthought.chaco.colormap_legend import ColormapLegend +from enthought.chaco.default_colormaps import hot, gray +from enthought.chaco.demo.demo_base import PlotApplicationWindow + + +class ImagePlotApplicationWindow( PlotApplicationWindow ): + + ########################################################################### + # PlotApplicationWindow interface. + ########################################################################### + + def _create_plot( self ): + """ Create the plot to be displayed. """ + + # Create the image data and the index values + #value_grid = zeros((100,100), NumericFloat) + from testlores2d import get2d_2 + value_grid = get2d_2() + #self._compute_function(value_grid) + index_vals = (arange(value_grid.shape[0]), arange(value_grid.shape[1])) + + data = ImageData(value_grid, index_vals) + print(value_grid, index_vals) + + # Create the index axes + xaxis = PlotAxis(tick_visible=False, grid_visible=False) + # bound_low = index_vals[0][0], bound_high = index_vals[0][-1]) + yaxis = PlotAxis(tick_visible=False, grid_visible=False) + #bound_low = index_vals[1][0], bound_high = index_vals[1][-1]) + xaxis.visible = False + yaxis.visible = False + + # Create the value axis (i.e. colormap) + cmap = hot(0,1) + + # Create the Image PlotValue +# image = CmapImagePlotValue(data, cmap, axis_index = xaxis, axis = yaxis, type='image') + image = CmapImagePlotValue(data, cmap,type='image') + image.weight = 10 + + cmap_legend = ColormapLegend(cmap, margin_width=31, margin_height=31) + cmap_legend.weight = 0.4 + + group = PlotGroup(cmap_legend, image, orientation='horizontal') + + return group + + ########################################################################### + # Private interface. + ########################################################################### + + def _compute_function(self, ary): + "Fills in ary with the sin(r)/r function" + + width, height = ary.shape + for i in range(width): + for j in range(height): + x = i - width / 2.0 + x = x / (width/2.0) * 15 + y = j - height / 2.0 + y = y / (height/2.0) * 15 + + radius = sqrt(x*x + y*y) + if radius == 0.0: + ary[i,j] = 1 + else: + ary[i,j] = sin(radius) / radius + + return + +def main(): + + # Create the GUI (this does NOT start the GUI event loop). + gui = GUI() + + # Screen size: + screen_width = gui.system_metrics.screen_width or 1024 + screen_height = gui.system_metrics.screen_height or 768 + + # Create and open the main window. + window = ImagePlotApplicationWindow( title = "Plot" ) + #window.plot_item = object + window.size = ( 2 * screen_width / 3, 2 * screen_height / 3 ) + window.open() + + # Start the GUI event loop. + gui.start_event_loop() + + +#=============================================================================== +# Program start-up: +#=============================================================================== + +if __name__ == '__main__': + main() diff --git a/sas/sascalc/simulation/pointsmodelpy/tests/testcomplexmodel.cc b/sas/sascalc/simulation/pointsmodelpy/tests/testcomplexmodel.cc new file mode 100755 index 000000000..27f180349 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/tests/testcomplexmodel.cc @@ -0,0 +1,41 @@ +#include "complex_model.h" +#include "lores_model.h" +#include "sphere.h" +#include "pdb_model.h" +#include +#include "Point3D.h" +#include + +int main(){ + + //Get a few points from a lores model + vector vplores; + LORESModel lm(0.1); + Sphere s1(10); + lm.Add(s1,1); + lm.GetPoints(vplores); + + //get a few points from a pdb model + vector vppdb; + string pdbfile("ff0.pdb"); + PDBModel p1; + p1.AddPDB(pdbfile); + p1.GetPoints(vppdb); + //should be error,argument vector has to be empty + //p1.GetPoints(vppdb); + + //merge points + vector vptotal; + ComplexModel cm; + cm.Add(&lm); + cm.Add(&p1); + cout << "boundary"<< cm.GetDimBound() < +#include +#include +#include +#include "lores_model.h" +#include "sphere.h" +#include "cylinder.h" +#include "ellipsoid.h" +#include "Point3D.h" + +using namespace std; + +void test_calculateIQ(LORESModel &lm); + +void WritePointsCoor(vector &vp){ + ofstream outfile("testcc.coor"); + for(size_t i=0; i vp; + lm.GetPoints(vp); + WritePointsCoor(vp); + cout << "vp size:" <::iterator iter = vp.begin(); + // iter != vp.end(); ++iter){ + // cout << *iter << endl; + //} + + lm.DistDistribution(vp); + Array2D pr(lm.GetPr()); + //for(int i = 0; i< pr.dim1(); ++i) + // cout << pr[i][0] << " " << pr[i][1] << " " << pr[i][2] << endl; + lm.OutputPR("test.pr"); + cout << "pass ddfunction, and print out the pr file" < vp; + lm.GetPoints(vp); + + lm.DistDistributionXY(vp); + lm.OutputPR_XY("test2d.pr"); + + IQ iq(10,0.001,0.3); + lm.CalculateIQ_2D(&iq,10); + + iq.OutputIQ("test2d.iq"); +} + +void test_lores2d_qxqy(){ + + LORESModel lm(0.1); + + Cylinder c1(5,20); + c1.SetCenter(0,0,0); + c1.SetOrientation(10,20,30); + lm.Add(c1,1.0); + + vector vp; + lm.GetPoints(vp); + + lm.DistDistributionXY(vp); + + double aI = lm.CalculateIQ_2D(0.1,0.2); + + cout << " a single I is: "< center(3); + center = lm.GetCenter(); + + cout << "center should be (0,0,0) after adding two spheres:"< vp; + lm.GetPoints(vp); + + lm.DistDistribution(vp); + + double result = lm.CalculateIQ(0.1); + cout << "The I(0.1) is: " << result << endl; +} + +int main(){ + + printf("this\n"); + cout << "Start" << endl; + //test_lores(); + cout <<"testing DistDistributionXY"< Imax: + Imax = value_grid[i][j] + + for i in range(width): + for j in range(height): + value_grid[i][j] = value_grid[i][j]/Imax + + value_grid[50,50] = 1 + return value_grid + +def get2d_2(): + from math import pi + from Numeric import arange,zeros + from enthought.util.numerix import Float,zeros + from sasModeling.file2array import readfile2array + from sasModeling.pointsmodelpy import pointsmodelpy + from sasModeling.geoshapespy import geoshapespy + + lm = pointsmodelpy.new_loresmodel(0.1) + cyn = geoshapespy.new_cylinder(5,20) + geoshapespy.set_orientation(cyn,0,0,90) + pointsmodelpy.lores_add(lm,cyn,1.0) + + vp = pointsmodelpy.new_point3dvec() + pointsmodelpy.get_lorespoints(lm,vp) + + pointsmodelpy.distdistribution_xy(lm,vp) + + value_grid = zeros((100,100),Float) + width, height = value_grid.shape + print(width,height) + + I = pointsmodelpy.calculateI_Qxy(lm,0.00001,0.000002) + print(I) + + Imax = 0 + for i in range(width): + for j in range(height): + qx = float(i-50)/200.0 + qy = float(j-50)/200.0 + value_grid[i,j] = pointsmodelpy.calculateI_Qxy(lm,qx,qy) + if value_grid[i][j] > Imax: + Imax = value_grid[i][j] + + for i in range(width): + for j in range(height): + value_grid[i][j] = value_grid[i][j]/Imax + + value_grid[50,50] = 1 + return value_grid + +if __name__ == "__main__": + + print("start to test lores 2D") +# test_lores2d(10) + value_grid = get2d_2() + print(value_grid) + print("pass") diff --git a/sas/sascalc/simulation/pointsmodelpy/tests/testnegativepr.py b/sas/sascalc/simulation/pointsmodelpy/tests/testnegativepr.py new file mode 100755 index 000000000..3476e989c --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/tests/testnegativepr.py @@ -0,0 +1,16 @@ +if __name__ == "__main__": + from sasModeling.pointsmodelpy import pointsmodelpy + from sasModeling.iqPy import iqPy + from sasModeling.geoshapespy import geoshapespy + + a = geoshapespy.new_sphere(10) + lm = pointsmodelpy.new_loresmodel(0.0005) + pointsmodelpy.lores_add(lm,a,1.0) + b = geoshapespy.new_sphere(20) + geoshapespy.set_center(b,20,20,20) + pointsmodelpy.lores_add(lm,b,-1.0) + + vp = pointsmodelpy.new_point3dvec() + pointsmodelpy.get_lorespoints(lm,vp) + + pointsmodelpy.get_lores_pr(lm,vp) diff --git a/sas/sascalc/simulation/pointsmodelpy/tests/testpdb.cc b/sas/sascalc/simulation/pointsmodelpy/tests/testpdb.cc new file mode 100755 index 000000000..b4d93d2d8 --- /dev/null +++ b/sas/sascalc/simulation/pointsmodelpy/tests/testpdb.cc @@ -0,0 +1,42 @@ +#include "pdb_model.h" +#include "Point3D.h" +#include +#include +#include "iq.h" + +typedef vector PointsVector; + +int main(){ + + //Test PDB with adding one normal pdb file + PointsVector vp; + IQ iq(100,0.001,0.3); + + string name("test.pdb"); + + //initialize a PDBModel with a pdb file name + PDBModel mpdb; + mpdb.AddPDB(name); + //parse coordinates in pdb file into a vector + //of points + mpdb.GetPoints(vp); + //output the vector of points into a pseudo pdb + //file, for testing only + mpdb.OutputPDB(vp,"testpdb.pdb"); + mpdb.DistDistribution(vp); + mpdb.OutputPR("test.pr"); + mpdb.CalculateIQ(&iq); + iq.OutputIQ("test.iq"); + cout << "check test.pr and test.iq files" < vp2; + string second("test2.pdb"); + mpdb.AddPDB(second); + mpdb.GetPoints(vp2); + mpdb.OutputPDB(vp2,"testpdb2.pdb"); + cout <<"check whether testpdb2.pdb has 12 points" <=0.17,<1.3'], + install_requires=['docutils', 'scipy>=0.17,<1.3', 'bumps', 'sasmodels', 'lxml'], zip_safe=False, cmdclass={ 'externals': InstallExternals,