Skip to content
This repository has been archived by the owner on Sep 1, 2023. It is now read-only.

Commit

Permalink
Fix softmax overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
suzusuzu committed Mar 13, 2018
1 parent d72f0ef commit 0032fbf
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions src/nupic/algorithms/SDRClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,15 @@ void SDRClassifier::infer_(const vector<UInt> &patternNZ,
add(likelihoods->begin(), likelihoods->end(), weights.begin(bit),
weights.begin(bit + 1));
}
// compute softmax of raw scores
// TODO: fix potential overflow problem by shifting scores by their
// maximal value across buckets
Real64 maxLikelihoods = *max_element(likelihoods->begin(), likelihoods->end());
for (auto likelihood : *likelihoods) {
likelihood -= maxLikelihoods;
}
range_exp(1.0, *likelihoods);
normalize(*likelihoods, 1.0, 1.0);
Real64 sumLikelihoods = accumulate(likelihoods->begin(), likelihoods->end(), 0);
for (auto likelihood : *likelihoods) {
likelihood /= sumLikelihoods;
}
}
}

Expand All @@ -222,8 +226,15 @@ vector<Real64> SDRClassifier::calculateError_(const vector<UInt> &bucketIdxList,
add(likelihoods.begin(), likelihoods.end(), weights.begin(bit),
weights.begin(bit + 1));
}
Real64 maxLikelihoods = *max_element(likelihoods.begin(), likelihoods.end());
for (auto likelihood : likelihoods) {
likelihood -= maxLikelihoods;
}
range_exp(1.0, likelihoods);
normalize(likelihoods, 1.0, 1.0);
Real64 sumLikelihoods = accumulate(likelihoods.begin(), likelihoods.end(), 0);
for (auto likelihood : likelihoods) {
likelihood /= sumLikelihoods;
}

// compute target likelihoods
vector<Real64> targetDistribution(maxBucketIdx_ + 1, 0.0);
Expand Down

0 comments on commit 0032fbf

Please sign in to comment.