Skip to content

Commit

Permalink
Refs #7860 Refactor code with numpy and slicing.
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Jackson committed Apr 14, 2014
1 parent bf80fa0 commit 01ee9ec
Showing 1 changed file with 58 additions and 68 deletions.
Expand Up @@ -5,11 +5,15 @@
*WIKI*"""

from mantid import config, logger, mtd
from mantid import logger, mtd
from mantid.api import PythonAlgorithm, AlgorithmFactory, WorkspaceProperty
from mantid.kernel import StringListValidator, StringMandatoryValidator, Direction
from mantid.kernel import Direction
from mantid.simpleapi import *
import sys, platform, math, os.path, numpy as np

import math
import os.path
import numpy as np


class Symmetrise(PythonAlgorithm):

Expand All @@ -35,80 +39,52 @@ def PyExec(self):
StartTime('Symmetrise')
self._setup()
num_spectra, npt = CheckHistZero(self._sample)

sample_x = mtd[self._sample].readX(0)

if math.fabs(self._x_cut) < 1e-5:
raise ValueError('XCut point is Zero')

delta_x = sample_x[1]-sample_x[0]
# diff = np.absolute(sample_x - self._x_cut)
# ineg = np.where(diff < delta_x)[0]

for n in range(npt):
x = sample_x[n]-self._x_cut
if math.fabs(x) < delta_x:
ineg = n

if ineg <= 0:
error = 'Negative point('+str(ineg)+') < 0'
logger.notice('ERROR *** ' + error)
sys.exit(error)

if ineg >= npt:
error = type + 'Negative point('+str(ineg)+') > '+str(npt)
logger.notice('ERROR *** ' + error)
sys.exit(error)

for n in range(npt):
x = sample_x[n]+sample_x[ineg]
if math.fabs(x) < delta_x:
ipos = n

if ipos <= 0:
error = 'Positive point('+str(ipos)+') < 0'
logger.notice('ERROR *** ' + error)
sys.exit(error)

if ipos >= npt:
error = type + 'Positive point('+str(ipos)+') > '+str(npt)
logger.notice('ERROR *** ' + error)
sys.exit(error)

negative_diff = np.absolute(sample_x - self._x_cut)
ineg = np.where(negative_diff < delta_x)[0][-1]
self._check_bounds(ineg, npt, label='Negative')

positive_diff = np.absolute(sample_x + sample_x[ineg])
ipos = np.where(positive_diff < delta_x)[0][-1]
self._check_bounds(ipos, npt, label='Positive')

ncut = npt-ipos+1

if self._verbose:
logger.notice('No. points = '+str(npt))
logger.notice('Negative : at i ='+str(ineg)+' ; x = '+str(sample_x[ineg]))
logger.notice('Positive : at i ='+str(ipos)+' ; x = '+str(sample_x[ipos]))
logger.notice('Copy points = '+str(xcut))
logger.notice('No. points = %d' % npt)
logger.notice('Negative : at i =%d; x = %f' % (ineg, sample_x[ineg]))
logger.notice('Positive : at i =%d; x = %f' % (ipos, sample_x[ipos]))
logger.notice('Copy points = %d' % ncut)

for m in range(num_spectra):
sample_x = mtd[self._sample].readX(m)
Yin = mtd[self._sample].readY(m)
Ein = mtd[self._sample].readE(m)
Xout = []
Yout = []
Eout = []
for n in range(0,ncut):
icut = npt-n-1
Xout.append(-sample_x[icut])
Yout.append(Yin[icut])
Eout.append(Ein[icut])
for n in range(ncut,npt):
Xout.append(sample_x[n])
Yout.append(Yin[n])
Eout.append(Ein[n])

if m == 0:
CreateWorkspace(OutputWorkspace=self._output_workspace, DataX=Xout, DataY=Yout, DataE=Eout,
Nspec=1, UnitX='DeltaE')
else:
CreateWorkspace(OutputWorkspace='__tmp', DataX=Xout, DataY=Yout, DataE=Eout,
Nspec=1, UnitX='DeltaE')
ConjoinWorkspaces(InputWorkspace1=self._output_workspace, InputWorkspace2='__tmp',CheckOverlapping=False)
CloneWorkspace(InputWorkspace=self._sample, OutputWorkspace=self._output_workspace)

for m in xrange(num_spectra):
x = mtd[self._sample].readX(m)
y = mtd[self._output_workspace].readY(m)
e = mtd[self._output_workspace].readE(m)

x_out = np.zeros(x.size)
y_out = np.zeros(y.size)
e_out = np.zeros(e.size)

x_out[:ncut] = -x[npt:npt-ncut:-1]
y_out[:ncut] = y[npt:npt-ncut-1:-1]
e_out[:ncut] = e[npt:npt-ncut-1:-1]

x_out[ncut:] = x[ncut:]
y_out[ncut:] = y[ncut:]
e_out[ncut:] = e[ncut:]

mtd[self._output_workspace].setX(m, np.asarray(x_out))
mtd[self._output_workspace].setY(m, np.asarray(y_out))
mtd[self._output_workspace].setE(m, np.asarray(e_out))

if self._save:
workdir = getDefaultWorkingDirectory()
file_path = os.path.join(workdir,self._output_workspace+'.nxs')
Expand All @@ -123,7 +99,6 @@ def PyExec(self):
self.setProperty("OutputWorkspace", self._output_workspace)
EndTime('Symmetrise')


def _setup(self):
"""
Get the algorithm properties.
Expand All @@ -135,14 +110,29 @@ def _setup(self):
self._plot = self.getProperty('Plot').value
self._save = self.getProperty('Save').value

self._output_workspace = self.getPropertyValue('OutputWorkspace')
self._output_workspace = self.getPropertyValue('OutputWorkspace')

def _check_bounds(self, index, num_pts, label=''):
"""
Check if the index falls within the bounds of the x range.
Throws a ValueError if the x point falls outside of the range.
@param index - value of the index within the x range.
@param num_pts - total number of points in the range.
@param label - label to call the point if an error is thrown.
"""
if index <= 0:
raise ValueError('%s point %d < 0' % (label, index))
elif index >= num_pts:
raise ValueError('%s point %d > %d' % (label, index, num_pts))

def _plotSymmetrise(self):
"""
Plot the first spectrum of the input and output workspace together
"""
from IndirectImport import import_mantidplot
mp = import_mantidplot()
tot_plot = mp.plotSpectrum([self._output_workspace, self._sample],0)
mp.plotSpectrum([self._output_workspace, self._sample],0)

AlgorithmFactory.subscribe(Symmetrise) # Register algorithm with Mantid
# Register algorithm with Mantid
AlgorithmFactory.subscribe(Symmetrise)

0 comments on commit 01ee9ec

Please sign in to comment.