Skip to content

Commit

Permalink
Prevent degenerate variance values using abs + eps
Browse files Browse the repository at this point in the history
  • Loading branch information
Leif Johnson committed Jul 28, 2015
1 parent d0130fb commit 1f28080
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions theanets/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,14 +282,15 @@ class GaussianLogLikelihood(Loss):

__extra_registration_keys__ = ['GLL']

def __init__(self, mean_name='mean', covar_name='covar', **kwargs):
def __init__(self, mean_name='mean', covar_name='covar', covar_eps=1e-3, **kwargs):
super(GaussianLogLikelihood, self).__init__(**kwargs)
self.mean_name = mean_name
if ':' not in self.mean_name:
self.mean_name += ':out'
self.covar_name = covar_name
if ':' not in self.covar_name:
self.covar_name += ':out'
self.covar_eps = covar_eps

def __call__(self, outputs):
'''Construct the computation graph for this loss function.
Expand Down Expand Up @@ -323,7 +324,7 @@ def __call__(self, outputs):
# strange in the code below.
mean = outputs[self.mean_name]
covar = outputs[self.covar_name]
prec = 1 / TT.switch(TT.eq(covar, 0), 1.0, covar) # prevent nans!
prec = 1 / (abs(covar) + self.covar_eps) # prevent nans!
eta = mean * prec
logpi = TT.cast(mean.shape[-1] * np.log(2 * np.pi), 'float32')
logdet = TT.log(prec.sum(axis=-1))
Expand Down

0 comments on commit 1f28080

Please sign in to comment.