Skip to content

Commit

Permalink
fixing numerical issues in celerite grad
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Nov 25, 2018
1 parent 62aea8c commit c283f7e
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions exoplanet/theano_ops/celerite/include/celerite.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ int factor (
for (int n = 1; n < N; ++n) {
// Update S = diag(P) * (S + d*W*W.T) * diag(P)
S_.noalias() += d(n-1) * W.row(n-1).transpose() * W.row(n-1);
S_.array() *= (P.row(n-1).transpose() * P.row(n-1)).array();
S_ = P.row(n-1).asDiagonal() * S_;
//S_.array() *= (P.row(n-1).transpose() * P.row(n-1)).array();
for (int j = 0; j < J; ++j)
for (int k = 0; k < J; ++k)
S(n, j*J+k) = S_(j, k);
S_ *= P.row(n-1).asDiagonal();

// Update d = a - U * S * U.T
tmp = U.row(n) * S_;
Expand Down Expand Up @@ -126,11 +128,10 @@ void factor_grad (

// Step 6
ba(n) -= W.row(n) * bV.row(n).transpose();
bU.row(n).noalias() = -(bV.row(n) + 2.0 * ba(n) * U.row(n)) * S_;
bU.row(n).noalias() = -(bV.row(n) + 2.0 * ba(n) * U.row(n)) * S_ * P.row(n-1).asDiagonal();
bS.noalias() -= U.row(n).transpose() * (bV.row(n) + ba(n) * U.row(n));

// Step 4
S_ *= P.row(n-1).asDiagonal().inverse();
bP.row(n-1).noalias() = (bS * S_ + S_.transpose() * bS).diagonal();

// Step 3
Expand Down Expand Up @@ -162,12 +163,11 @@ void solve (

for (int n = 1; n < N; ++n) {
F_.noalias() += W.row(n-1).transpose() * Z.row(n-1);
F_ = P.row(n-1).asDiagonal() * F_;
Z.row(n).noalias() -= U.row(n) * F_;

for (int j = 0; j < J; ++j)
for (int k = 0; k < nrhs; ++k)
F(n, j*nrhs+k) = F_(j, k);
F_ = P.row(n-1).asDiagonal() * F_;
Z.row(n).noalias() -= U.row(n) * F_;
}

Z.array().colwise() /= d.array();
Expand All @@ -176,12 +176,11 @@ void solve (
G.row(N-1).setZero();
for (int n = N-2; n >= 0; --n) {
F_.noalias() += U.row(n+1).transpose() * Z.row(n+1);
F_ = P.row(n).asDiagonal() * F_;
Z.row(n).noalias() -= W.row(n) * F_;

for (int j = 0; j < J; ++j)
for (int k = 0; k < nrhs; ++k)
G(n, j*nrhs+k) = F_(j, k);
F_ = P.row(n).asDiagonal() * F_;
Z.row(n).noalias() -= W.row(n) * F_;
}
}

Expand Down Expand Up @@ -214,14 +213,14 @@ void solve_grad (
F_(j, k) = G(n, j*nrhs+k);

// Grad of: Z.row(n).noalias() -= W.row(n) * G;
bW.row(n).noalias() -= bY.row(n) * F_.transpose();
bW.row(n).noalias() -= bY.row(n) * (P.row(n).asDiagonal() * F_).transpose();
bF.noalias() -= W.row(n).transpose() * bY.row(n);

// Inverse of: Z.row(n).noalias() -= W.row(n) * G;
Z_.row(n).noalias() += W.row(n) * F_;
Z_.row(n).noalias() += W.row(n) * (P.row(n).asDiagonal() * F_);

// Grad of: g = P.row(n).asDiagonal() * G;
bP.row(n).noalias() += P.row(n).asDiagonal().inverse() * (F_ * bF.transpose()).diagonal();
bP.row(n).noalias() += (F_ * bF.transpose()).diagonal();
bF = P.row(n).asDiagonal() * bF;

// Grad of: g.noalias() += U.row(n+1).transpose() * Z.row(n+1);
Expand All @@ -242,11 +241,11 @@ void solve_grad (
F_(j, k) = F(n, j*nrhs+k);

// Grad of: Z.row(n).noalias() -= U.row(n) * f;
bU.row(n).noalias() -= bY.row(n) * F_.transpose();
bU.row(n).noalias() -= bY.row(n) * (P.row(n-1).asDiagonal() * F_).transpose();
bF.noalias() -= U.row(n).transpose() * bY.row(n);

// Grad of: F = P.row(n-1).asDiagonal() * F;
bP.row(n-1).noalias() += P.row(n-1).asDiagonal().inverse() * (F_ * bF.transpose()).diagonal();
bP.row(n-1).noalias() += (F_ * bF.transpose()).diagonal();
bF = P.row(n-1).asDiagonal() * bF;

// Grad of: F.noalias() += W.row(n-1).transpose() * Z.row(n-1);
Expand Down

0 comments on commit c283f7e

Please sign in to comment.