Skip to content

bodin-e/tensorcheck

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

18 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

tensorcheck

tensorcheck is a (very small) library for validating tensors when using TensorFlow or PyTorch. The library is intended as a demonstration of a suggested approach rather than a library in its own right, although contributions or improvement suggestions to the library are welcome. Authors: Erik Bodin, Andrew Lawrence

(Update). Now contributed to TensorFlow, see: https://www.tensorflow.org/api_docs/python/tf/debugging/assert_shapes

Example

import tensorflow as tf
import torch
import numpy as np
from tensorcheck.static import assert_shapes

def model(x, y, param, other_param):
    assert_shapes({
        ("x", x): ('N', 'Q'),
        ("y", y): ('N', 'D'),
        ("param", param): 'Q',
        ("other param", other_param): 2,
    })
    ...

# asserts true
m = model(
    x=tf.ones([10, 2]),
    y=tf.ones([10, 1]),
    param=tf.ones([2]),
    other_param=tf.ones([2])
)

# asserts true
m = model(
    x=torch.ones([10, 2]),
    y=torch.ones([10, 1]),
    param=torch.ones([2]),
    other_param=torch.ones([2])
)

# asserts true
m = model(
    x=tf.placeholder(shape=[None, 3], dtype=tf.float32),
    y=tf.placeholder(shape=[None, 1], dtype=tf.float32),
    param=tf.ones([3]),
    other_param=tf.ones([2])
)

# asserts false
m = model(
    x=tf.constant(np.ones([10, 2])),
    y=tf.constant(np.ones([1, 10])),
    param=tf.ones([2]),
    other_param=tf.ones([2])
)
# => AssertionError: Tensor 'y' dim 0 was of size 1 but was expected to be 10 as declared by 'x' dim 0

# asserts false
m = model(
    x=tf.constant(np.ones([10, 2])),
    y=tf.constant(np.ones([10, 1])),
    param=tf.ones([2]),
    other_param=tf.ones([3])
)
# => AssertionError: Tensor 'other param' dim 0 was of size 3 but was expected to be 2 as declared directly

# asserts false
m = model(
    x=tf.ones([10, 2]),
    y=tf.ones([10, 1]),
    param=tf.ones([2, 1]),
    other_param=tf.ones([2])
)
# => AssertionError: Tensor 'param' was declared to have 1 tensor dim(s) but had 2.

# asserts false
m = model(
    x=tf.ones([10, 2]),
    y=tf.ones([10, 1]),
    param=tf.ones([1]),
    other_param=tf.ones([2])
)
# => AssertionError: Tensor 'param' dim 0 was of size 1 but was expected to be 2 as declared by 'x' dim 1
...

Releases

No releases published

Packages

No packages published

Languages