Skip to content

Commit

Permalink
LSTM model: Applied the new IndirectContext improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
GotthardtZ committed Jan 16, 2021
1 parent ae7a5b5 commit 4f1348d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
10 changes: 6 additions & 4 deletions lstm/LstmModel.hpp
Expand Up @@ -18,7 +18,7 @@ class LstmModel {
const Shared* const shared;
std::valarray<float> probs;
APM apm1, apm2, apm3;
IndirectContext<std::uint32_t> iCtx;
IndirectContext<std::uint16_t> iCtx;
std::size_t top, mid, bot;
std::uint8_t expected;
public:
Expand All @@ -31,11 +31,13 @@ class LstmModel {
float const gradient_clip) :
shared(sh),
probs(1.f / Size, Size),
apm1{ sh, 0x10000u, 24 }, apm2{ sh, 0x800u, 24 }, apm3{ sh, 0x20000u, 24 },
iCtx{ 11, 2 },
apm1{ sh, 0x10000u, 24 }, apm2{ sh, 0x800u, 24 }, apm3{ sh, 1024, 24 },
iCtx{ 11, 1, 9 },
top(Size - 1), mid(0), bot(0),
expected(0)
{}
{
iCtx.reset();
}
virtual ~LstmModel() = default;
virtual void mix(Mixer& m) = 0;
};
Expand Down
9 changes: 4 additions & 5 deletions lstm/SimdLstmModel.hpp
Expand Up @@ -105,18 +105,17 @@ class SIMDLstmModel :
}
}

this->iCtx += 2+y, this->iCtx = (bpos << 8) | this->expected;
std::uint32_t mask = 0u, i = 0u;
for (std::uint32_t ctx = this->iCtx(); ctx > 0u; mask |= (ctx & 1u) << i, i++, ctx >>= 2);
mask |= 1u << i;
this->iCtx += y;
this->iCtx = (bpos << 8) | this->expected;
std::uint32_t ctx = this->iCtx();

int const p = min(max(std::lround(prediction * 4096.0f), 1), 4095);
m.promote(stretch(p)/2);
m.add(stretch(p));
m.add((p - 2048) >> 2);
int const pr1 = this->apm1.p(p, (c0 << 8) | (this->shared->State.misses & 0xFF), 0xFF);
int const pr2 = this->apm2.p(p, (bpos << 8) | this->expected, 0xFF);
int const pr3 = this->apm3.p(pr2, mask, 0xFF);
int const pr3 = this->apm3.p(pr2, ctx, 0xFF);
m.add(stretch(pr1) >> 1);
m.add(stretch(pr2) >> 1);
m.add(stretch(pr3) >> 1);
Expand Down

0 comments on commit 4f1348d

Please sign in to comment.