diff --git a/Source/SGDLib/SimpleOutputWriter.h b/Source/SGDLib/SimpleOutputWriter.h index bdd3ecdec9e6..ceb534b3710e 100644 --- a/Source/SGDLib/SimpleOutputWriter.h +++ b/Source/SGDLib/SimpleOutputWriter.h @@ -46,6 +46,8 @@ class SimpleOutputWriter { return logP < rhs.logP; } + bool realValues = false; + unordered_map>> nameToParentNodeValues; unordered_map>> nameToNodeValues; }; typedef shared_ptr> ComputationNodePtr; @@ -285,9 +287,9 @@ class SimpleOutputWriter for (size_t i = 0; i < m_nodesToCache.size(); i++) { vector v; - oneSeq.nameToNodeValues[m_nodesToCache[i]] = make_shared>(deviceId, L"test"); + oneSeq.nameToNodeValues[m_nodesToCache[i]] = make_shared>(deviceId, m_nodesToCache[i]); + fprintf(stderr, "Scratch %ls %p \n", m_nodesToCache[i].c_str(), &oneSeq.nameToNodeValues[m_nodesToCache[i]]); } - return oneSeq; } Sequence newSeq(Sequence& a, DEVICEID_TYPE deviceId) @@ -300,23 +302,54 @@ class SimpleOutputWriter oneSeq.processlength = a.processlength; oneSeq.decodeoutput = make_shared>(a.decodeoutput->GetNumRows(), (size_t) 1, a.decodeoutput->GetDeviceId()); oneSeq.decodeoutput->SetValue(*(a.decodeoutput)); + unordered_map>>::iterator it; for (it = a.nameToNodeValues.begin(); it != a.nameToNodeValues.end(); it++) { + if (oneSeq.processlength > 0) + { + if (it->second->Value().GetNumElements() > 0 && a.realValues) + oneSeq.nameToParentNodeValues[it->first] = it->second; + else + oneSeq.nameToParentNodeValues[it->first] = a.nameToParentNodeValues[it->first]; + /*size_t ab = oneSeq.nameToParentNodeValues[it->first]->Value().GetNumElements(); + if (ab > 0) + fprintf(stderr, "test %ls %zu", it->first.c_str(), ab);*/ + } auto itin = m_nameToPastValueNodeCache.find(it->first); if (itin != m_nameToPastValueNodeCache.end() && m_nameToPastValueNodeCache[it->first].size() > 0) { //vector>> mCache = m_nameToPastValueNodeCache[it->first]; oneSeq.nameToNodeValues[it->first] = m_nameToPastValueNodeCache[it->first].back(); m_nameToPastValueNodeCache[it->first].pop_back(); + /*size_t ab = oneSeq.nameToNodeValues[it->first]->Value().GetNumElements(); + if (ab > 0) + fprintf(stderr, "test %ls %zu", it->first.c_str(), ab);*/ } else { oneSeq.nameToNodeValues[it->first] = make_shared>(deviceId, it->first); } - - it->second->CopyTo(oneSeq.nameToNodeValues[it->first], it->first, CopyNodeFlags::copyNodeAll); + fprintf(stderr, "newSeq %ls %p \n", it->first.c_str(), &oneSeq.nameToNodeValues[it->first]); } + + //unordered_map>>::iterator it; + //for (it = a.nameToNodeValues.begin(); it != a.nameToNodeValues.end(); it++) + //{ + // auto itin = m_nameToPastValueNodeCache.find(it->first); + // if (itin != m_nameToPastValueNodeCache.end() && m_nameToPastValueNodeCache[it->first].size() > 0) + // { + // //vector>> mCache = m_nameToPastValueNodeCache[it->first]; + // oneSeq.nameToNodeValues[it->first] = m_nameToPastValueNodeCache[it->first].back(); + // m_nameToPastValueNodeCache[it->first].pop_back(); + // } + // else + // { + // oneSeq.nameToNodeValues[it->first] = make_shared>(deviceId, it->first); + // } + + // it->second->CopyTo(oneSeq.nameToNodeValues[it->first], it->first, CopyNodeFlags::copyNodeAll); + //} return oneSeq; } @@ -329,6 +362,7 @@ class SimpleOutputWriter if (itin == m_nameToPastValueNodeCache.end()) m_nameToPastValueNodeCache[it->first] = vector>>(); m_nameToPastValueNodeCache[it->first].push_back(oneSeq.nameToNodeValues[it->first]); + fprintf(stderr, "deleteSeq %ls %p \n", it->first.c_str(), &m_nameToPastValueNodeCache[it->first]); } oneSeq.decodeoutput->ReleaseMemory(); vector().swap(oneSeq.labelseq); @@ -376,6 +410,41 @@ class SimpleOutputWriter } return nodes; } + void prepareSequence(Sequence& s) + { + if (s.nameToNodeValues.size() > 0) + { + unordered_map>>::iterator it; + for (it = s.nameToParentNodeValues.begin(); it != s.nameToParentNodeValues.end(); it++) + { + it->second->CopyTo(s.nameToNodeValues[it->first], it->first, CopyNodeFlags::copyNodeAll); + fprintf(stderr, "prepareSequence %ls %p \n", it->first.c_str(), & s.nameToNodeValues[it->first]); + /*size_t ab = s.nameToNodeValues[it->first]->Value().GetNumElements(); + ab = it->second->Value().GetNumElements(); + if (ab > 0) + fprintf(stderr, "test %ls %zu", it->first.c_str(), ab);*/ + } + } + s.realValues = true; + + //unordered_map>>::iterator it; + //for (it = a.nameToNodeValues.begin(); it != a.nameToNodeValues.end(); it++) + //{ + // auto itin = m_nameToPastValueNodeCache.find(it->first); + // if (itin != m_nameToPastValueNodeCache.end() && m_nameToPastValueNodeCache[it->first].size() > 0) + // { + // //vector>> mCache = m_nameToPastValueNodeCache[it->first]; + // oneSeq.nameToNodeValues[it->first] = m_nameToPastValueNodeCache[it->first].back(); + // m_nameToPastValueNodeCache[it->first].pop_back(); + // } + // else + // { + // oneSeq.nameToNodeValues[it->first] = make_shared>(deviceId, it->first); + // } + + // it->second->CopyTo(oneSeq.nameToNodeValues[it->first], it->first, CopyNodeFlags::copyNodeAll); + //} + } void forward_decode(Sequence& oneSeq, StreamMinibatchInputs decodeinputMatrices, DEVICEID_TYPE deviceID, const std::vector& decodeOutputNodes, const std::vector& decodeinputNodes, size_t vocabSize, size_t plength) @@ -385,6 +454,16 @@ class SimpleOutputWriter LogicError("Current implementation assumes 1 step difference"); if (plength != oneSeq.processlength) { + m_logIndex = m_logIndex + 1; + wstring fileName = L"D:\\users\\vadimma\\cntk_3\\new_opt" + std::to_wstring(m_logIndex) + L".txt"; + std::ofstream out(fileName, std::ios::out); + out << fixed; + out.precision(3); + for (size_t li = 0; li < oneSeq.labelseq.size(); li++) + out << oneSeq.labelseq[li] << " "; + + out << "\n"; + Matrix lmin(deviceID); lmin.Resize(vocabSize, 1); @@ -400,20 +479,36 @@ class SimpleOutputWriter else { lminput->second.pMBLayout->AddSequence(NEW_SEQUENCE_ID, 0, SentinelValueIndicatingUnspecifedSequenceBeginIdx, 1); + for (size_t i = 0; i < m_nodesToCache.size(); i++) + { + auto nodePtr = m_net->GetNodeFromName(m_nodesToCache[i]); + if (oneSeq.nameToNodeValues[m_nodesToCache[i]]->Value().GetNumElements() > 0) + { + oneSeq.nameToNodeValues[m_nodesToCache[i]]->CopyTo(nodePtr, m_nodesToCache[i], CopyNodeFlags::copyNodeAll); //copyNodeInputLinks + } + + shared_ptr> pLearnableNode = dynamic_pointer_cast>(nodePtr); + + pLearnableNode = dynamic_pointer_cast>(nodePtr); + Matrix& mat2 = pLearnableNode->Value(); + + for (size_t m_i = 0; m_i < mat2.GetNumRows(); m_i++) + { + for (size_t j = 0; j < mat2.GetNumCols(); j++) + { + out << mat2(m_i, j); + } + } + out << string("\n"); + } } ComputationNetwork::BumpEvalTimeStamp(decodeinputNodes); - bool shallowCopy = false; - for (size_t i = 0; i < m_nodesToCache.size(); i++) + /* if (oneSeq.realValues) { - auto nodePtr = m_net->GetNodeFromName(m_nodesToCache[i]); - if (oneSeq.nameToNodeValues[m_nodesToCache[i]]->Value().GetNumElements() > 0) - { - oneSeq.nameToNodeValues[m_nodesToCache[i]]->CopyTo(nodePtr, m_nodesToCache[i], CopyNodeFlags::copyNodeInputLinks); - shallowCopy = true; - } - } + + }*/ ComputationNetwork::BumpEvalTimeStamp(decodeinputNodes); DataReaderHelpers::NotifyChangedNodes(m_net, decodeinputMatrices); @@ -425,11 +520,48 @@ class SimpleOutputWriter for (size_t i = 0; i < m_nodesToCache.size(); i++) { auto nodePtr = m_net->GetNodeFromName(m_nodesToCache[i]); - if (shallowCopy) - nodePtr->CopyTo(oneSeq.nameToNodeValues[m_nodesToCache[i]], m_nodesToCache[i], CopyNodeFlags::copyNodeInputLinks); - else + + if (plength == 1) + { nodePtr->CopyTo(oneSeq.nameToNodeValues[m_nodesToCache[i]], m_nodesToCache[i], CopyNodeFlags::copyNodeAll); + /*size_t ab = oneSeq.nameToNodeValues[m_nodesToCache[i]]->Value().GetNumElements(); + if (ab > 0) + fprintf(stderr, "test %ls %zu", m_nodesToCache[i].c_str(), ab);*/ + } + else + { + nodePtr->CopyTo(oneSeq.nameToNodeValues[m_nodesToCache[i]], m_nodesToCache[i], CopyNodeFlags::copyNodeAll); //copyNodeInputLinks + } + + shared_ptr> pLearnableNode = dynamic_pointer_cast>(nodePtr); + + Matrix& mat = pLearnableNode->Value(); + + for (size_t m_i = 0; m_i < mat.GetNumRows(); m_i++) + { + for (size_t j = 0; j < mat.GetNumCols(); j++) + { + out << mat(m_i, j); + } + } + out << string("\n"); + + /*if (shallowCopy) + + else + nodePtr->CopyTo(oneSeq.nameToNodeValues[m_nodesToCache[i]], m_nodesToCache[i], CopyNodeFlags::copyNodeAll);*/ } + //oneSeq.realValues = true; + for (size_t m_i = 0; m_i < oneSeq.decodeoutput->GetNumRows(); m_i++) + { + for (size_t j = 0; j < oneSeq.decodeoutput->GetNumCols(); j++) + { + out << (*oneSeq.decodeoutput)(m_i, j); + } + } + out << string("\n"); + + out.close(); lmin.ReleaseMemory(); } @@ -658,6 +790,7 @@ class SimpleOutputWriter Sequence tempSeq = newSeq(*maxSeq, deviceid); deleteSeq(*maxSeq); CurSequences.erase(maxSeq); + prepareSequence(tempSeq); forward_decode(tempSeq, decodeinputMatrices, deviceid, decodeOutputNodes, decodeinputNodes, vocabSize, tempSeq.labelseq.size()); forwardmerged(tempSeq, t, sumofENandDE, encodeOutput, decodeOutput, PlusNode, PlusTransNode, Plusnodes, Plustransnodes); @@ -665,7 +798,6 @@ class SimpleOutputWriter int iLabel; for (iLabel = 0; iLabel < expandBeam; iLabel++) { - Sequence seqK = newSeq(tempSeq, deviceid); ElemType newlogP = topN[iLabel].second + tempSeq.logP; seqK.logP = newlogP; @@ -691,6 +823,10 @@ class SimpleOutputWriter extendSeq(seqK, topN[iLabel].first, newlogP); CurSequences.push_back(seqK); + + auto maxSeq1 = std::max_element(CurSequences.begin(), CurSequences.end()); + Sequence tempSeq1 = newSeq(*maxSeq1, deviceid); + //fprintf(stderr, "test %zu", tempSeq1.length); } vector>().swap(topN); //delete topN; @@ -962,6 +1098,7 @@ class SimpleOutputWriter } private: + int m_logIndex = 0; ComputationNetworkPtr m_net; std::vector m_nodesToCache; int m_verbosity;