In [1]:
from theano import tensor as T
from theano import function

import numpy as np

### known_grads

`known_grads` is a dictionary that specifies the gradient of the cost w.r.t. to some variables, instead of the ones automatically derived from the cost. The cost can even be omitted if a seperating set is known. 

In [2]:
x = T.vector('x')
y = x ** 3 + x ** 2
z = 6 * y + x

# normal gradient
g1 = T.grad(z.mean(), x)

# dz/dx = dz/dy * dy/dx + d(x)/dx = (2 * x) * (3 * (x ** 2) + 2 * x)
g2 = T.grad(z.mean(), x, known_grads={y : 2 * x})

f1 = function([x], [y, z, g1])
f2 = function([x], [y, z, g2])

x_val = np.arange(2, dtype='float32')
print f1(x_val)
print f2(x_val)

[array([ 0.,  2.], dtype=float32), array([  0.,  13.], dtype=float32), array([  0.5,  15.5], dtype=float32)]
[array([ 0.,  2.], dtype=float32), array([  0.,  13.], dtype=float32), array([  0.5,  10.5], dtype=float32)]


### compute_test_value
In normal execution, the exception is only raised when it tries to execute the call to function. Thus, it's hard to see which line cause the error. `compute_test_value` allows you to plug in a test value in order to detect the first invalid expression.

In [6]:
from theano import config

config.compute_test_value = 'raise'

a = T.vector()
a.tag.test_value = np.ones((3,)).astype(a.dtype)
b = T.log(a)
c = T.nnet.sigmoid(b)
d = T.sqrt(c)
e = T.concatenate((d, c), axis=0)
f = b * c * d
g = e + f # the first bad line
h = g / c
fn = function([a], h)
fn(np.ones((3,)).astype(a.dtype))

ValueError: Input dimension mis-match. (input[0].shape[0] = 6, input[1].shape[0] = 3)

In [7]:
config.compute_test_value = 'off'

### mode - controlling compilation

In [11]:
import theano
from theano.compile import Mode

class NaNVariableError(Exception):
    pass

def get_nan_detection_mode():
    """
    Returns a theano Mode that detects if any nan value occurs in the
    evaluation of a theano function.
    """
    
    class NaNDetectionMode(Mode):
        def __init__(self):
            def do_nan_check(var):
                if np.isnan(np.sum(var)):
                    raise NaNVariableError()
                    
            def nan_check(i, node, fn):
                inputs = fn.inputs
                do_nan_check(inputs)
                fn()
                outputs = fn.outputs
                do_nan_check(outputs)
                
            wrap_linker = theano.gof.WrapLinkerMany(
                [theano.gof.OpWiseCLinker()],
                [nan_check])
            super(NaNDetectionMode, self).__init__(wrap_linker,
                optimizer='fast_run')
    return NaNDetectionMode()

x = T.scalar('x')
z = 2. * x
mode = get_nan_detection_mode()
f = function([x], z, mode=mode)

try:
    f(np.NAN)
except NaNVariableError:
    print "Got it!"
    
try:
    f(1.)
except NaNVariableError:
    print "Shouldn't go this way"

Got it!


### disrupted optimizations
Be careful not to break the optimizations while debugging. In the following example, the `Print` opt disrupts the optimization and the code does not use `softmax` op (thus, no normalizaion), which causes NaN.

In [15]:
from theano.printing import Print

X = T.matrix()
p_tilde = T.exp(X)
# try the code without Print
p_tilde = Print('p_tilde', attrs=['min', 'max'])(p_tilde)
denom = p_tilde.sum(axis=1, keepdims=True)
p = p_tilde / denom

f = function([X], p)

X = -1000. * np.ones((2, 2)).astype(X.dtype)

output = f(X)

assert np.allclose(output, 0.5 * np.ones((2, 2)))

p_tilde min = 0.0
p_tilde max = 0.0


AssertionError: 