Skip to content

Commit

Permalink
hierarchical cache
Browse files Browse the repository at this point in the history
  • Loading branch information
vmazalov committed Sep 10, 2019
1 parent 9e84d1a commit f84a069
Showing 1 changed file with 154 additions and 17 deletions.
171 changes: 154 additions & 17 deletions Source/SGDLib/SimpleOutputWriter.h
Expand Up @@ -46,6 +46,8 @@ class SimpleOutputWriter
{
return logP < rhs.logP;
}
bool realValues = false;
unordered_map<wstring, shared_ptr<PastValueNode<ElemType>>> nameToParentNodeValues;
unordered_map<wstring, shared_ptr<PastValueNode<ElemType>>> nameToNodeValues;
};
typedef shared_ptr<ComputationNode<ElemType>> ComputationNodePtr;
Expand Down Expand Up @@ -285,9 +287,9 @@ class SimpleOutputWriter
for (size_t i = 0; i < m_nodesToCache.size(); i++)
{
vector<ElemType> v;
oneSeq.nameToNodeValues[m_nodesToCache[i]] = make_shared<PastValueNode<ElemType>>(deviceId, L"test");
oneSeq.nameToNodeValues[m_nodesToCache[i]] = make_shared<PastValueNode<ElemType>>(deviceId, m_nodesToCache[i]);
fprintf(stderr, "Scratch %ls %p \n", m_nodesToCache[i].c_str(), &oneSeq.nameToNodeValues[m_nodesToCache[i]]);
}

return oneSeq;
}
Sequence newSeq(Sequence& a, DEVICEID_TYPE deviceId)
Expand All @@ -300,23 +302,54 @@ class SimpleOutputWriter
oneSeq.processlength = a.processlength;
oneSeq.decodeoutput = make_shared<Matrix<ElemType>>(a.decodeoutput->GetNumRows(), (size_t) 1, a.decodeoutput->GetDeviceId());
oneSeq.decodeoutput->SetValue(*(a.decodeoutput));

unordered_map<wstring, shared_ptr<PastValueNode<ElemType>>>::iterator it;
for (it = a.nameToNodeValues.begin(); it != a.nameToNodeValues.end(); it++)
{
if (oneSeq.processlength > 0)
{
if (it->second->Value().GetNumElements() > 0 && a.realValues)
oneSeq.nameToParentNodeValues[it->first] = it->second;
else
oneSeq.nameToParentNodeValues[it->first] = a.nameToParentNodeValues[it->first];
/*size_t ab = oneSeq.nameToParentNodeValues[it->first]->Value().GetNumElements();
if (ab > 0)
fprintf(stderr, "test %ls %zu", it->first.c_str(), ab);*/
}
auto itin = m_nameToPastValueNodeCache.find(it->first);
if (itin != m_nameToPastValueNodeCache.end() && m_nameToPastValueNodeCache[it->first].size() > 0)
{
//vector<shared_ptr<PastValueNode<ElemType>>> mCache = m_nameToPastValueNodeCache[it->first];
oneSeq.nameToNodeValues[it->first] = m_nameToPastValueNodeCache[it->first].back();
m_nameToPastValueNodeCache[it->first].pop_back();
/*size_t ab = oneSeq.nameToNodeValues[it->first]->Value().GetNumElements();
if (ab > 0)
fprintf(stderr, "test %ls %zu", it->first.c_str(), ab);*/
}
else
{
oneSeq.nameToNodeValues[it->first] = make_shared<PastValueNode<ElemType>>(deviceId, it->first);
}

it->second->CopyTo(oneSeq.nameToNodeValues[it->first], it->first, CopyNodeFlags::copyNodeAll);
fprintf(stderr, "newSeq %ls %p \n", it->first.c_str(), &oneSeq.nameToNodeValues[it->first]);
}

//unordered_map<wstring, shared_ptr<PastValueNode<ElemType>>>::iterator it;
//for (it = a.nameToNodeValues.begin(); it != a.nameToNodeValues.end(); it++)
//{
// auto itin = m_nameToPastValueNodeCache.find(it->first);
// if (itin != m_nameToPastValueNodeCache.end() && m_nameToPastValueNodeCache[it->first].size() > 0)
// {
// //vector<shared_ptr<PastValueNode<ElemType>>> mCache = m_nameToPastValueNodeCache[it->first];
// oneSeq.nameToNodeValues[it->first] = m_nameToPastValueNodeCache[it->first].back();
// m_nameToPastValueNodeCache[it->first].pop_back();
// }
// else
// {
// oneSeq.nameToNodeValues[it->first] = make_shared<PastValueNode<ElemType>>(deviceId, it->first);
// }

// it->second->CopyTo(oneSeq.nameToNodeValues[it->first], it->first, CopyNodeFlags::copyNodeAll);
//}
return oneSeq;
}

Expand All @@ -329,6 +362,7 @@ class SimpleOutputWriter
if (itin == m_nameToPastValueNodeCache.end())
m_nameToPastValueNodeCache[it->first] = vector<shared_ptr<PastValueNode<ElemType>>>();
m_nameToPastValueNodeCache[it->first].push_back(oneSeq.nameToNodeValues[it->first]);
fprintf(stderr, "deleteSeq %ls %p \n", it->first.c_str(), &m_nameToPastValueNodeCache[it->first]);
}
oneSeq.decodeoutput->ReleaseMemory();
vector<size_t>().swap(oneSeq.labelseq);
Expand Down Expand Up @@ -376,6 +410,41 @@ class SimpleOutputWriter
}
return nodes;
}
void prepareSequence(Sequence& s)
{
if (s.nameToNodeValues.size() > 0)
{
unordered_map<wstring, shared_ptr<PastValueNode<ElemType>>>::iterator it;
for (it = s.nameToParentNodeValues.begin(); it != s.nameToParentNodeValues.end(); it++)
{
it->second->CopyTo(s.nameToNodeValues[it->first], it->first, CopyNodeFlags::copyNodeAll);
fprintf(stderr, "prepareSequence %ls %p \n", it->first.c_str(), & s.nameToNodeValues[it->first]);
/*size_t ab = s.nameToNodeValues[it->first]->Value().GetNumElements();
ab = it->second->Value().GetNumElements();
if (ab > 0)
fprintf(stderr, "test %ls %zu", it->first.c_str(), ab);*/
}
}
s.realValues = true;

//unordered_map<wstring, shared_ptr<PastValueNode<ElemType>>>::iterator it;
//for (it = a.nameToNodeValues.begin(); it != a.nameToNodeValues.end(); it++)
//{
// auto itin = m_nameToPastValueNodeCache.find(it->first);
// if (itin != m_nameToPastValueNodeCache.end() && m_nameToPastValueNodeCache[it->first].size() > 0)
// {
// //vector<shared_ptr<PastValueNode<ElemType>>> mCache = m_nameToPastValueNodeCache[it->first];
// oneSeq.nameToNodeValues[it->first] = m_nameToPastValueNodeCache[it->first].back();
// m_nameToPastValueNodeCache[it->first].pop_back();
// }
// else
// {
// oneSeq.nameToNodeValues[it->first] = make_shared<PastValueNode<ElemType>>(deviceId, it->first);
// }

// it->second->CopyTo(oneSeq.nameToNodeValues[it->first], it->first, CopyNodeFlags::copyNodeAll);
//}
}

void forward_decode(Sequence& oneSeq, StreamMinibatchInputs decodeinputMatrices, DEVICEID_TYPE deviceID, const std::vector<ComputationNodeBasePtr>& decodeOutputNodes,
const std::vector<ComputationNodeBasePtr>& decodeinputNodes, size_t vocabSize, size_t plength)
Expand All @@ -385,6 +454,16 @@ class SimpleOutputWriter
LogicError("Current implementation assumes 1 step difference");
if (plength != oneSeq.processlength)
{
m_logIndex = m_logIndex + 1;
wstring fileName = L"D:\\users\\vadimma\\cntk_3\\new_opt" + std::to_wstring(m_logIndex) + L".txt";
std::ofstream out(fileName, std::ios::out);
out << fixed;
out.precision(3);
for (size_t li = 0; li < oneSeq.labelseq.size(); li++)
out << oneSeq.labelseq[li] << " ";

out << "\n";

Matrix<ElemType> lmin(deviceID);

lmin.Resize(vocabSize, 1);
Expand All @@ -400,20 +479,36 @@ class SimpleOutputWriter
else
{
lminput->second.pMBLayout->AddSequence(NEW_SEQUENCE_ID, 0, SentinelValueIndicatingUnspecifedSequenceBeginIdx, 1);
for (size_t i = 0; i < m_nodesToCache.size(); i++)
{
auto nodePtr = m_net->GetNodeFromName(m_nodesToCache[i]);
if (oneSeq.nameToNodeValues[m_nodesToCache[i]]->Value().GetNumElements() > 0)
{
oneSeq.nameToNodeValues[m_nodesToCache[i]]->CopyTo(nodePtr, m_nodesToCache[i], CopyNodeFlags::copyNodeAll); //copyNodeInputLinks
}

shared_ptr<ComputationNode<ElemType>> pLearnableNode = dynamic_pointer_cast<ComputationNode<ElemType>>(nodePtr);

pLearnableNode = dynamic_pointer_cast<ComputationNode<ElemType>>(nodePtr);
Matrix<ElemType>& mat2 = pLearnableNode->Value();

for (size_t m_i = 0; m_i < mat2.GetNumRows(); m_i++)
{
for (size_t j = 0; j < mat2.GetNumCols(); j++)
{
out << mat2(m_i, j);
}
}
out << string("\n");
}
}

ComputationNetwork::BumpEvalTimeStamp(decodeinputNodes);

bool shallowCopy = false;
for (size_t i = 0; i < m_nodesToCache.size(); i++)
/* if (oneSeq.realValues)
{
auto nodePtr = m_net->GetNodeFromName(m_nodesToCache[i]);
if (oneSeq.nameToNodeValues[m_nodesToCache[i]]->Value().GetNumElements() > 0)
{
oneSeq.nameToNodeValues[m_nodesToCache[i]]->CopyTo(nodePtr, m_nodesToCache[i], CopyNodeFlags::copyNodeInputLinks);
shallowCopy = true;
}
}
}*/

ComputationNetwork::BumpEvalTimeStamp(decodeinputNodes);
DataReaderHelpers::NotifyChangedNodes<ElemType>(m_net, decodeinputMatrices);
Expand All @@ -425,11 +520,48 @@ class SimpleOutputWriter
for (size_t i = 0; i < m_nodesToCache.size(); i++)
{
auto nodePtr = m_net->GetNodeFromName(m_nodesToCache[i]);
if (shallowCopy)
nodePtr->CopyTo(oneSeq.nameToNodeValues[m_nodesToCache[i]], m_nodesToCache[i], CopyNodeFlags::copyNodeInputLinks);
else

if (plength == 1)
{
nodePtr->CopyTo(oneSeq.nameToNodeValues[m_nodesToCache[i]], m_nodesToCache[i], CopyNodeFlags::copyNodeAll);
/*size_t ab = oneSeq.nameToNodeValues[m_nodesToCache[i]]->Value().GetNumElements();
if (ab > 0)
fprintf(stderr, "test %ls %zu", m_nodesToCache[i].c_str(), ab);*/
}
else
{
nodePtr->CopyTo(oneSeq.nameToNodeValues[m_nodesToCache[i]], m_nodesToCache[i], CopyNodeFlags::copyNodeAll); //copyNodeInputLinks
}

shared_ptr<ComputationNode<ElemType>> pLearnableNode = dynamic_pointer_cast<ComputationNode<ElemType>>(nodePtr);

Matrix<ElemType>& mat = pLearnableNode->Value();

for (size_t m_i = 0; m_i < mat.GetNumRows(); m_i++)
{
for (size_t j = 0; j < mat.GetNumCols(); j++)
{
out << mat(m_i, j);
}
}
out << string("\n");

/*if (shallowCopy)
else
nodePtr->CopyTo(oneSeq.nameToNodeValues[m_nodesToCache[i]], m_nodesToCache[i], CopyNodeFlags::copyNodeAll);*/
}
//oneSeq.realValues = true;
for (size_t m_i = 0; m_i < oneSeq.decodeoutput->GetNumRows(); m_i++)
{
for (size_t j = 0; j < oneSeq.decodeoutput->GetNumCols(); j++)
{
out << (*oneSeq.decodeoutput)(m_i, j);
}
}
out << string("\n");

out.close();

lmin.ReleaseMemory();
}
Expand Down Expand Up @@ -658,14 +790,14 @@ class SimpleOutputWriter
Sequence tempSeq = newSeq(*maxSeq, deviceid);
deleteSeq(*maxSeq);
CurSequences.erase(maxSeq);
prepareSequence(tempSeq);
forward_decode(tempSeq, decodeinputMatrices, deviceid, decodeOutputNodes, decodeinputNodes, vocabSize, tempSeq.labelseq.size());
forwardmerged(tempSeq, t, sumofENandDE, encodeOutput, decodeOutput, PlusNode, PlusTransNode, Plusnodes, Plustransnodes);

vector<pair<size_t, ElemType>> topN = getTopN(decodeOutput, expandBeam);
int iLabel;
for (iLabel = 0; iLabel < expandBeam; iLabel++)
{

Sequence seqK = newSeq(tempSeq, deviceid);
ElemType newlogP = topN[iLabel].second + tempSeq.logP;
seqK.logP = newlogP;
Expand All @@ -691,6 +823,10 @@ class SimpleOutputWriter
extendSeq(seqK, topN[iLabel].first, newlogP);

CurSequences.push_back(seqK);

auto maxSeq1 = std::max_element(CurSequences.begin(), CurSequences.end());
Sequence tempSeq1 = newSeq(*maxSeq1, deviceid);
//fprintf(stderr, "test %zu", tempSeq1.length);
}
vector<pair<size_t, ElemType>>().swap(topN);
//delete topN;
Expand Down Expand Up @@ -962,6 +1098,7 @@ class SimpleOutputWriter
}

private:
int m_logIndex = 0;
ComputationNetworkPtr m_net;
std::vector<wstring> m_nodesToCache;
int m_verbosity;
Expand Down

0 comments on commit f84a069

Please sign in to comment.