From 2f9a48c71dc0a6097498cb7e90ac3b151ab536dd Mon Sep 17 00:00:00 2001 From: Frank Seide Date: Fri, 5 Feb 2016 11:06:20 -0800 Subject: [PATCH] InputNodes() now skips inputs that are only reachable through PreComputeNodes that have already been computed, addressing Issue #65; cleaned up some unnecessary NULL checks before delete --- Source/CNTK/NetworkDescriptionLanguage.h | 6 - .../ComputationNetwork.cpp | 108 ++++++++---------- .../ComputationNetwork.h | 6 +- .../ComputationNetworkLib/ComputationNode.h | 16 ++- .../ComputationNetworkLib/PreComputeNodes.h | 6 +- Source/Readers/Kaldi2Reader/HTKMLFReader.cpp | 84 +++----------- .../KaldiSequenceTrainingDerivative.cpp | 12 +- Source/SGDLib/SGD.cpp | 21 +--- 8 files changed, 91 insertions(+), 168 deletions(-) diff --git a/Source/CNTK/NetworkDescriptionLanguage.h b/Source/CNTK/NetworkDescriptionLanguage.h index f3dbbdaedd35..7d70651208af 100644 --- a/Source/CNTK/NetworkDescriptionLanguage.h +++ b/Source/CNTK/NetworkDescriptionLanguage.h @@ -470,9 +470,7 @@ class NDLScript : public ConfigParser { // need to free all the child nodes attached to this script node for (NDLNode* node : m_children) - { delete node; - } m_children.clear(); } @@ -576,14 +574,10 @@ class NDLScript : public ConfigParser { for (NDLNode* node : m_children) - { delete node; - } m_children.clear(); for (NDLNode* node : m_script) - { delete node; - } m_script.clear(); m_symbols.clear(); diff --git a/Source/ComputationNetworkLib/ComputationNetwork.cpp b/Source/ComputationNetworkLib/ComputationNetwork.cpp index e0bdfd1cf2f7..211fba46125e 100644 --- a/Source/ComputationNetworkLib/ComputationNetwork.cpp +++ b/Source/ComputationNetworkLib/ComputationNetwork.cpp @@ -383,82 +383,74 @@ bool ComputationNetwork::IsTypicalCriterionNode(ComputationNodeBasePtr nodePtr) return false; } -template -void ComputationNetwork::GetNodesRequiringX(list& nodesRequiringX, const ComputationNodeBasePtr& rootNode, bool checkComputed) +// return list of nodes that require precomputation and not precomputed yet +list ComputationNetwork::GetNodesRequiringPreComputation(const ComputationNodeBasePtr& rootNode, bool checkComputed) { - if (!rootNode) // find nodes from all available nodes + list nodes; + for (const auto& node : GetEvalOrder(rootNode)) { - for (const auto& nodep : m_nameToNodeMap) + auto pcnode = dynamic_pointer_cast(node); + if (pcnode) { - auto node = dynamic_pointer_cast(nodep.second); - if (node) - { - assert(node->RequiresPreCompute()); - if (!checkComputed || !node->HasComputed()) - nodesRequiringX.push_back(node); - } + assert(node->RequiresPreCompute()); + if (!checkComputed || !pcnode->HasComputed()) + nodes.push_back(node); } } - else // or for calculating a specific node - { - for (const auto& nodei : GetEvalOrder(rootNode)) - { - auto node = dynamic_pointer_cast(nodei); - if (node) - { - assert(node->RequiresPreCompute()); - if (!checkComputed || !node->HasComputed()) - nodesRequiringX.push_back(node); - } - } - } - nodesRequiringX.unique(); -} - -// return list of nodes that require precomputation and not precomputed yet -list ComputationNetwork::GetNodesRequiringPreComputation(const ComputationNodeBasePtr& rootNode, bool checkComputed) -{ - list nodesRequiringX; - GetNodesRequiringX>(nodesRequiringX, rootNode, checkComputed); - GetNodesRequiringX>(nodesRequiringX, rootNode, checkComputed); - return nodesRequiringX; + return nodes; } // create the m_inputValues[] and m_learnableParameters[] lists +// This enumerates all leaves reachable from rootNode. +// Leaves are: +// - inputs +// - learnable parameters +// It does not traverse disabled ones, i.e. +// - inputs that are only reachable through PrecomputeNodes that have completed computation +// - learnable parameters that are constants void ComputationNetwork::CollectInputAndLearnableParameters(const ComputationNodeBasePtr& rootNode) { assert(m_inputValues.find(rootNode) == m_inputValues.end()); // this function must only be called once assert(m_learnableParameters.find(rootNode) == m_learnableParameters.end()); - const list& nodes = GetEvalOrder(rootNode); + // gather the lists + set visited; + list inputs, learnableParameters; + if (rootNode) + CollectInputAndLearnableParametersRec(rootNode, visited, inputs, learnableParameters); + else + for (const auto& root : m_allRoots) + CollectInputAndLearnableParametersRec(root, visited, inputs, learnableParameters); - // collect input values for given root - // Note: This will not return nodes that are reached through a PrecomputeNode that has already been computed. - list inputs; - for (const auto& node : nodes) + // sort learnable parameters by name so that we get consistent order when load it from saved file + learnableParameters.sort([](const ComputationNodeBasePtr& a, const ComputationNodeBasePtr& b) { - if (node->OperationName() == OperationNameOf(InputValue) || node->OperationName() == OperationNameOf(SparseInputValue)) - inputs.push_back(node); - } - m_inputValues[rootNode] = inputs; + return a->NodeName() < b->NodeName(); + }); + + m_inputValues[rootNode] = move(inputs); + m_learnableParameters[rootNode] = move(learnableParameters); +} - // instead of collecting the nodes themselves, collect the names (they will be sorted below) - list learnableParameterNames; - for (auto nodeIter = nodes.begin(); nodeIter != nodes.end(); nodeIter++) +void ComputationNetwork::CollectInputAndLearnableParametersRec(const ComputationNodeBasePtr& node, set& visited, list& inputs, list& learnableParameters) +{ + if (visited.find(node) != visited.end()) // allready got this one + return; + else if (node->OperationName() == OperationNameOf(InputValue) || node->OperationName() == OperationNameOf(SparseInputValue)) + inputs.push_back(node); + else if (node->OperationName() == OperationNameOf(LearnableParameter) && node->IsParameterUpdateRequired()) + learnableParameters.push_back(node); + else { - ComputationNodeBasePtr node = *nodeIter; - if (node->OperationName() == OperationNameOf(LearnableParameter) && node->IsParameterUpdateRequired()) - learnableParameterNames.push_back(node->NodeName()); + // PreComputeNodes that are already done should not be traversed + auto pcnode = dynamic_pointer_cast(node); + if (pcnode && pcnode->HasComputed()) + return; + // recurse + visited.insert(node); + for (const auto & input : node->GetInputs()) + CollectInputAndLearnableParametersRec(input, visited, inputs, learnableParameters); } - - // sort names so that we get consistent order when load it from saved file - learnableParameterNames.sort(); - - // now collect the actual nodes in the sort order of their node names - list learnableParameters; - for (const auto& nodeNameIter : learnableParameterNames) - learnableParameters.push_back(GetNodeFromName(nodeNameIter)); - m_learnableParameters[rootNode] = move(learnableParameters); } template diff --git a/Source/ComputationNetworkLib/ComputationNetwork.h b/Source/ComputationNetworkLib/ComputationNetwork.h index da383c9bb74f..2d32eecc995e 100644 --- a/Source/ComputationNetworkLib/ComputationNetwork.h +++ b/Source/ComputationNetworkLib/ComputationNetwork.h @@ -26,6 +26,7 @@ #include #include #include +#include namespace Microsoft { namespace MSR { namespace CNTK { @@ -164,6 +165,7 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb private: void DetermineSetOfAllRoots(); void CollectInputAndLearnableParameters(const ComputationNodeBasePtr& rootNode); + void CollectInputAndLearnableParametersRec(const ComputationNodeBasePtr& node, set& visited, list& inputs, list& learnableParameters); bool IsCompiled() const { return m_isCompiled; } void VerifyIsCompiled(const char* where) const; public: @@ -506,10 +508,6 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb return nodesWithType; } -private: - template - void GetNodesRequiringX(std::list& nodesRequirePreComputation, const ComputationNodeBasePtr& rootNode, bool checkComputed); - public: // return list of nodes that require precomputation and not precomputed yet std::list GetNodesRequiringPreComputation(const ComputationNodeBasePtr& rootNode = nullptr, bool checkComputed = true); diff --git a/Source/ComputationNetworkLib/ComputationNode.h b/Source/ComputationNetworkLib/ComputationNode.h index 99c6c230112c..5918e7de77b1 100644 --- a/Source/ComputationNetworkLib/ComputationNode.h +++ b/Source/ComputationNetworkLib/ComputationNode.h @@ -1673,11 +1673,25 @@ class LateAttachingNode : public N, public ILateAttachingNode }; // ======================================================================= -// IRecurrentNode -- helper wrapper class for ComputationNodes that can be recurrent +// IRecurrentNode -- interface implemented by ComputationNodes that can be recurrent // ======================================================================= struct IRecurrentNode { virtual int GetRecurrenceSteppingDirection() const = 0; }; +// ======================================================================= +// PreComputedNodeBase -- interface implemented by ComputationNodes that precompute +// TODO: We can use this interface in more places. +// ======================================================================= + +struct IPreComputeNode +{ + // check whether node has already undergone precomputation + virtual bool HasComputed() const = 0; + // call this with 'false' at start and with 'true' at end + // This is used for resetting and updating from accumulators. + virtual void MarkComputed(const bool hasComputed) = 0; +}; + // ======================================================================= // helper macro to ease access to base members in presence of C++ two-phase name lookup // ======================================================================= diff --git a/Source/ComputationNetworkLib/PreComputeNodes.h b/Source/ComputationNetworkLib/PreComputeNodes.h index 4956680a68c3..e9ef10e8b3b4 100644 --- a/Source/ComputationNetworkLib/PreComputeNodes.h +++ b/Source/ComputationNetworkLib/PreComputeNodes.h @@ -25,7 +25,7 @@ namespace Microsoft { namespace MSR { namespace CNTK { // ----------------------------------------------------------------------- template -class PreComputedNodeBase : public ComputationNodeNonLooping /*ComputationNode*/ +class PreComputedNodeBase : public ComputationNodeNonLooping /*ComputationNode*/, public IPreComputeNode { typedef ComputationNodeNonLooping Base; UsingComputationNodeMembers; @@ -40,14 +40,14 @@ class PreComputedNodeBase : public ComputationNodeNonLooping /*ComputationNode*/ // interface through which this node is operated on are these two functions // check whether node has already undergone precomputation - virtual bool HasComputed() const + virtual bool /*IPreComputeNode::*/ HasComputed() const override { return m_hasComputed; } // call this with 'false' at start and with 'true' at end // This is used for resetting and updating from accumulators. - virtual void MarkComputed(const bool hasComputed) + virtual void /*IPreComputeNode::*/ MarkComputed(const bool hasComputed) override { m_hasComputed = hasComputed; CreateMatrixIfNull(m_value); diff --git a/Source/Readers/Kaldi2Reader/HTKMLFReader.cpp b/Source/Readers/Kaldi2Reader/HTKMLFReader.cpp index 6936c32734a2..e2c1e8ac349b 100644 --- a/Source/Readers/Kaldi2Reader/HTKMLFReader.cpp +++ b/Source/Readers/Kaldi2Reader/HTKMLFReader.cpp @@ -673,87 +673,29 @@ void HTKMLFReader::PrepareForWriting(const ConfigRecordType& readerCon template HTKMLFReader::~HTKMLFReader() { - if (m_mbiter != NULL) - { - delete m_mbiter; - m_mbiter = NULL; - } - if (m_frameSource != NULL) - { - delete m_frameSource; - m_frameSource = NULL; - } - if (m_lattices != NULL) - { - delete m_lattices; - m_lattices = NULL; - } - if (m_seqTrainDeriv != NULL) - { - delete m_seqTrainDeriv; - m_seqTrainDeriv = NULL; - } - if (m_uttDerivBuffer != NULL) - { - delete m_uttDerivBuffer; - m_uttDerivBuffer = NULL; - } + delete m_mbiter; + delete m_frameSource; + delete m_lattices; + delete m_seqTrainDeriv; + delete m_uttDerivBuffer; - if (!m_featuresBufferMultiIO.empty()) - { - foreach_index (i, m_featuresBufferMultiIO) - { - if (m_featuresBufferMultiIO[i] != NULL) - { - delete[] m_featuresBufferMultiIO[i]; - m_featuresBufferMultiIO[i] = NULL; - } - } - } + foreach_index(i, m_featuresBufferMultiIO) + delete[] m_featuresBufferMultiIO[i]; - if (!m_labelsBufferMultiIO.empty()) - { - foreach_index (i, m_labelsBufferMultiIO) - { - if (m_labelsBufferMultiIO[i] != NULL) - { - delete[] m_labelsBufferMultiIO[i]; - m_labelsBufferMultiIO[i] = NULL; - } - } - } + foreach_index(i, m_labelsBufferMultiIO) + delete[] m_labelsBufferMultiIO[i]; for (size_t i = 0; i < m_numberOfuttsPerMinibatch; i++) { - if (m_featuresBufferMultiUtt[i] != NULL) - { - delete[] m_featuresBufferMultiUtt[i]; - m_featuresBufferMultiUtt[i] = NULL; - } - if (m_labelsBufferMultiUtt[i] != NULL) - { - delete[] m_labelsBufferMultiUtt[i]; - m_labelsBufferMultiUtt[i] = NULL; - } + delete[] m_featuresBufferMultiUtt[i]; + delete[] m_labelsBufferMultiUtt[i]; } foreach_index (i, m_trainingOrTestingFeatureSections) - { - if (m_trainingOrTestingFeatureSections[i] != NULL) - { - delete m_trainingOrTestingFeatureSections[i]; - m_trainingOrTestingFeatureSections[i] = NULL; - } - } + delete m_trainingOrTestingFeatureSections[i]; foreach_index (i, m_writingFeatureSections) - { - if (m_writingFeatureSections[i] != NULL) - { - delete m_writingFeatureSections[i]; - m_writingFeatureSections[i] = NULL; - } - } + delete m_writingFeatureSections[i]; } // StartMinibatchLoop - Startup a minibatch loop diff --git a/Source/Readers/Kaldi2Reader/KaldiSequenceTrainingDerivative.cpp b/Source/Readers/Kaldi2Reader/KaldiSequenceTrainingDerivative.cpp index 218eab22229f..357dd2b65900 100644 --- a/Source/Readers/Kaldi2Reader/KaldiSequenceTrainingDerivative.cpp +++ b/Source/Readers/Kaldi2Reader/KaldiSequenceTrainingDerivative.cpp @@ -41,16 +41,8 @@ KaldiSequenceTrainingDerivative::KaldiSequenceTrainingDerivative( template KaldiSequenceTrainingDerivative::~KaldiSequenceTrainingDerivative() { - if (m_denlatReader != NULL) - { - delete m_denlatReader; - m_denlatReader = NULL; - } - if (m_aliReader != NULL) - { - delete m_aliReader; - m_aliReader = NULL; - } + delete m_denlatReader; + delete m_aliReader; } template diff --git a/Source/SGDLib/SGD.cpp b/Source/SGDLib/SGD.cpp index f537f66b9e4a..0141a2f72a63 100644 --- a/Source/SGDLib/SGD.cpp +++ b/Source/SGDLib/SGD.cpp @@ -5,7 +5,6 @@ #include "Basics.h" #include "SGD.h" #include "NonlinearityNodes.h" // for DropoutNode -#include "PreComputeNodes.h" // for PrecomputeNode #include "SpecialPurposeNodes.h" // for SequenceWithSoftmaxNode #include "DataReaderHelpers.h" #include "MatrixQuantizerImpl.h" @@ -1284,11 +1283,8 @@ bool SGD::PreCompute(ComputationNetworkPtr net, } fprintf(stderr, "\nPrecomputing --> %lu PreCompute nodes found.\n\n", nodes.size()); - for (auto nodeIter = nodes.begin(); nodeIter != nodes.end(); nodeIter++) - { - auto node = static_pointer_cast>(*nodeIter); + for (const auto & node : nodes) fprintf(stderr, "\tNodeName: %ls\n", (node->NodeName()).c_str()); - } // compute // trainSetDataReader->StartMinibatchLoop(m_mbSize[0], 0 , requestDataSize); @@ -1302,11 +1298,8 @@ bool SGD::PreCompute(ComputationNetworkPtr net, net->StartEvaluateMinibatchLoop(nodes); // initialize - for (auto nodeIter = nodes.begin(); nodeIter != nodes.end(); nodeIter++) - { - auto node = static_pointer_cast>(*nodeIter); - node->MarkComputed(false /*begin accumulating*/); - } + for (auto & node : nodes) + dynamic_pointer_cast(node)->MarkComputed(false /*begin accumulating*/); const size_t numIterationsBeforePrintingProgress = 100; size_t numItersSinceLastPrintOfProgress = 0; @@ -1332,11 +1325,9 @@ bool SGD::PreCompute(ComputationNetworkPtr net, } // finalize - for (auto nodeIter = nodes.begin(); nodeIter != nodes.end(); nodeIter++) - { - auto node = static_pointer_cast>(*nodeIter); - node->MarkComputed(true /*done accumulating*/); - } + for (auto & node : nodes) + dynamic_pointer_cast(node)->MarkComputed(true /*done accumulating*/); + fprintf(stderr, "\nPrecomputing --> Completed.\n\n"); return true;