diff --git a/Code/Mantid/Framework/Crystal/inc/MantidCrystal/ConnectedComponentLabeling.h b/Code/Mantid/Framework/Crystal/inc/MantidCrystal/ConnectedComponentLabeling.h index 830be6e8d635..21075b2ae3b1 100644 --- a/Code/Mantid/Framework/Crystal/inc/MantidCrystal/ConnectedComponentLabeling.h +++ b/Code/Mantid/Framework/Crystal/inc/MantidCrystal/ConnectedComponentLabeling.h @@ -3,12 +3,25 @@ #include "MantidKernel/System.h" #include "MantidAPI/IMDHistoWorkspace.h" +#include "MantidCrystal/DisjointElement.h" #include +#include +#include namespace Mantid { namespace Crystal { + namespace ConnectedComponentMappingTypes + { + typedef boost::tuple SignalErrorSQPair; + typedef std::map LabelIdIntensityMap; + typedef std::map PositionToLabelIdMap; + typedef std::vector VecIndexes; + typedef std::vector VecElements; + typedef std::set SetIds; + } + class BackgroundStrategy; /** ConnectedComponentLabelling : Implements connected component labeling on MDHistoWorkspaces. @@ -35,15 +48,33 @@ namespace Crystal */ class DLLExport ConnectedComponentLabeling { + public: + /// Constructor ConnectedComponentLabeling(const size_t&id = 1, const bool runMultiThreaded=true); + /// Getter for the start label id size_t getStartLabelId() const; + /// Setter for the label id void startLabelingId(const size_t& id); + + /// Execute and return clusters boost::shared_ptr execute(Mantid::API::IMDHistoWorkspace_sptr ws, BackgroundStrategy * const strategy) const; + + /// Execute and return clusters, as well as maps to integrated label values + boost::shared_ptr executeAndIntegrate( + Mantid::API::IMDHistoWorkspace_sptr ws, BackgroundStrategy * const strategy, ConnectedComponentMappingTypes::LabelIdIntensityMap& labelMap, + ConnectedComponentMappingTypes::PositionToLabelIdMap& positionLabelMap) const; + + /// Destructor virtual ~ConnectedComponentLabeling(); private: + /// Get the number of threads to use. int getNThreads() const; + /// Calculate the disjoint element tree across the image. + void calculateDisjointTree(Mantid::API::IMDHistoWorkspace_sptr ws, BackgroundStrategy * const strategy, std::vector& neighbourElements) const; + /// Start labeling index size_t m_startId; + /// Run multithreaded const bool m_runMultiThreaded; }; diff --git a/Code/Mantid/Framework/Crystal/src/ConnectedComponentLabeling.cpp b/Code/Mantid/Framework/Crystal/src/ConnectedComponentLabeling.cpp index 394ae4a2f9ae..a7506c1b3409 100644 --- a/Code/Mantid/Framework/Crystal/src/ConnectedComponentLabeling.cpp +++ b/Code/Mantid/Framework/Crystal/src/ConnectedComponentLabeling.cpp @@ -1,4 +1,5 @@ #include "MantidKernel/MultiThreaded.h" +#include "MantidKernel/V3D.h" #include "MantidAPI/IMDHistoWorkspace.h" #include "MantidAPI/AlgorithmManager.h" #include "MantidAPI/FrameworkManager.h" @@ -7,10 +8,14 @@ #include "MantidCrystal/BackgroundStrategy.h" #include "MantidCrystal/DisjointElement.h" #include +#include +#include #include #include using namespace Mantid::API; +using namespace Mantid::Kernel; +using namespace Mantid::Crystal::ConnectedComponentMappingTypes; namespace Mantid { @@ -18,10 +23,6 @@ namespace Mantid { namespace { - typedef std::vector VecIndexes; - typedef std::vector VecElements; - typedef std::set SetIds; - size_t calculateMaxNeighbours(IMDHistoWorkspace const * const ws) { const size_t ndims = ws->getNumDims(); @@ -33,6 +34,22 @@ namespace Mantid maxNeighbours -= 1; return maxNeighbours; } + + boost::shared_ptr cloneInputWorkspace(IMDHistoWorkspace_sptr& inWS) + { + auto alg = AlgorithmManager::Instance().createUnmanaged("CloneWorkspace"); + alg->initialize(); + alg->setChild(true); + alg->setProperty("InputWorkspace", inWS); + alg->setPropertyValue("OutputWorkspace", "out_ws"); + alg->execute(); + Mantid::API::IMDHistoWorkspace_sptr outWS; + { + Mantid::API::Workspace_sptr temp = alg->getProperty("OutputWorkspace"); + outWS = boost::dynamic_pointer_cast(temp); + } + return outWS; + } } //---------------------------------------------------------------------------------------------- @@ -73,25 +90,9 @@ namespace Mantid return m_runMultiThreaded ? API::FrameworkManager::Instance().getNumOMPThreads() : 1; } - boost::shared_ptr ConnectedComponentLabeling::execute( - IMDHistoWorkspace_sptr ws, BackgroundStrategy * const strategy) const + void ConnectedComponentLabeling::calculateDisjointTree(IMDHistoWorkspace_sptr ws, BackgroundStrategy * const strategy, VecElements& neighbourElements) const { - auto alg = AlgorithmManager::Instance().createUnmanaged("CloneWorkspace"); - alg->initialize(); - alg->setChild(true); - alg->setProperty("InputWorkspace", ws); - alg->setPropertyValue("OutputWorkspace", "out_ws"); - alg->execute(); - - Mantid::API::IMDHistoWorkspace_sptr out_ws; - { - Mantid::API::Workspace_sptr temp = alg->getProperty("OutputWorkspace"); - out_ws = boost::dynamic_pointer_cast(temp); - } - - VecElements neighbourElements(ws->getNPoints()); - VecIndexes allNonBackgroundIndexes; allNonBackgroundIndexes.reserve(ws->getNPoints()); @@ -202,6 +203,18 @@ namespace Mantid } } + } + + boost::shared_ptr ConnectedComponentLabeling::execute( + IMDHistoWorkspace_sptr ws, BackgroundStrategy * const strategy) const + { + VecElements neighbourElements(ws->getNPoints()); + + // Perform the bulk of the connected component analysis, but don't collapse the elements yet. + calculateDisjointTree(ws, strategy, neighbourElements); + + // Create the output workspace from the input workspace + IMDHistoWorkspace_sptr outWS = cloneInputWorkspace(ws); // Set each pixel to the root of each disjointed element. PARALLEL_FOR_NO_WSP_CHECK() @@ -210,16 +223,66 @@ namespace Mantid //std::cout << "Element\t" << i << " Id: \t" << neighbourElements[i].getId() << " This location:\t"<< &neighbourElements[i] << " Root location:\t" << neighbourElements[i].getParent() << " Root Id:\t" << neighbourElements[i].getRoot() << std::endl; if(!neighbourElements[i].isEmpty()) { - out_ws->setSignalAt(i, neighbourElements[i].getRoot()); + outWS->setSignalAt(i, neighbourElements[i].getRoot()); + } + else + { + outWS->setSignalAt(i, 0); + } + outWS->setErrorSquaredAt(i, 0); + } + + return outWS; + } + + boost::shared_ptr ConnectedComponentLabeling::executeAndIntegrate( + IMDHistoWorkspace_sptr ws, BackgroundStrategy * const strategy, LabelIdIntensityMap& labelMap, + PositionToLabelIdMap& positionLabelMap) const + { + VecElements neighbourElements(ws->getNPoints()); + + // Perform the bulk of the connected component analysis, but don't collapse the elements yet. + calculateDisjointTree(ws, strategy, neighbourElements); + + // Create the output workspace from the input workspace + IMDHistoWorkspace_sptr outWS = cloneInputWorkspace(ws); + + // Set each pixel to the root of each disjointed element. + for (size_t i = 0; i < neighbourElements.size(); ++i) + { + if(!neighbourElements[i].isEmpty()) + { + const double& signal = ws->getSignalAt(i); // Intensity value at index + double errorSQ = ws->getErrorAt(i); + errorSQ *=errorSQ; // Error squared at index + const size_t& labelId = neighbourElements[i].getRoot(); + // Set the output cluster workspace signal value + outWS->setSignalAt(i, labelId); + + if(labelMap.find(labelId) != labelMap.end()) // Have we already started integrating over this label + { + SignalErrorSQPair current = labelMap[labelId]; + // Sum labels. This is integration! + labelMap[labelId] = SignalErrorSQPair(current.get<0>() + signal, current.get<1>() + errorSQ); + } + else // This label is unknown to us. + { + labelMap[labelId] = SignalErrorSQPair(signal, errorSQ); + + const VMD& center = ws->getCenter(i); + V3D temp(center[0], center[1], center[2]); + positionLabelMap[temp] = labelId; //Record charcteristic position of the cluster. + } + outWS->setSignalAt(i, neighbourElements[i].getRoot()); } else { - out_ws->setSignalAt(i, 0); + outWS->setSignalAt(i, 0); + outWS->setErrorSquaredAt(i, 0); } - out_ws->setErrorSquaredAt(i, 0); } - return out_ws; + return outWS; } } // namespace Crystal diff --git a/Code/Mantid/Framework/Crystal/src/IntegratePeaksUsingClusters.cpp b/Code/Mantid/Framework/Crystal/src/IntegratePeaksUsingClusters.cpp index 0a500d33973e..c5e314c0287f 100644 --- a/Code/Mantid/Framework/Crystal/src/IntegratePeaksUsingClusters.cpp +++ b/Code/Mantid/Framework/Crystal/src/IntegratePeaksUsingClusters.cpp @@ -26,12 +26,11 @@ Uses connected component analysis to integrate peaks in an PeaksWorkspace over a using namespace Mantid::API; using namespace Mantid::Kernel; using namespace Mantid::DataObjects; +using namespace Mantid::Crystal::ConnectedComponentMappingTypes; namespace { - typedef boost::tuple SignalErrorSQPair; - typedef std::map LabelIdIntensityMap; - typedef std::map PositionToLabelIdMap; + class IsNearPeak { @@ -145,44 +144,10 @@ namespace Mantid PeakBackground background(peakWS, radiusEstimate, threshold, NoNormalization, mdCoordinates); //HardThresholdBackground background(threshold, normalization); - ConnectedComponentLabeling analysis; - IMDHistoWorkspace_sptr clusters = analysis.execute(mdWS, &background); - - /* - Note that the following may be acheived better inside the clustering utility at the same time as the - cluster workspace is populated. - - Accumulate intesity values for each peak cluster and key by label_id - */ + ConnectedComponentLabeling analysis; LabelIdIntensityMap labelMap; PositionToLabelIdMap positionMap; - - // Go through the output workspace and perform the integration. by summing labels. - for(size_t i = 0; i < clusters->getNPoints(); ++i) - { - const size_t& label_id = static_cast(clusters->getSignalAt(i)); - - const double& signal = mdWS->getSignalAt(i); - double errorSQ = mdWS->getErrorAt(i); - errorSQ *=errorSQ; - if(label_id >= analysis.getStartLabelId()) - { - if(labelMap.find(label_id) != labelMap.end()) - { - SignalErrorSQPair current = labelMap[label_id]; - // Sum labels. - labelMap[label_id] = SignalErrorSQPair(current.get<0>() + signal, current.get<1>() + errorSQ); - } - else - { - labelMap[label_id] = SignalErrorSQPair(signal, errorSQ); - - const VMD& center = mdWS->getCenter(i); - V3D temp(center[0], center[1], center[2]); - positionMap[temp] = label_id; //Record charcteristic position of the cluster. - } - } - } + IMDHistoWorkspace_sptr clusters = analysis.executeAndIntegrate(mdWS, &background, labelMap, positionMap); // Link integrated values up with peaks. PARALLEL_FOR1(peakWS)