Skip to content

lengstrom/tensorguard

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

47 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

tensorguard

Pretty runtime typechecking for PyTorch and Numpy tensors!

Install

git clone git@github.com:lengstrom/tensorguard.git
pip install -e tensorguard

Example usage

As a decorator:

from tensorguard import tensorguard, Tensor as T
import torch as ch

@tensorguard
def inference(x: T(['bs', 3, 224, 224], 'float16', 'cpu'), y: T(['bs'], 'int64')):
    pass

# make examples with wrong dtype
x = ch.ones(128, 3, 224, 224, dtype=ch.float32)
# make labels with wrong batch size
y = ch.ones(256)

# checks happen at runtime with @tensorguard decorator
inference(x, y)

As a standalone assertion:

from tensorguard import tensorcheck
x = ch.randn(4, 4).to(dtype=ch.float32)
x_expected = Tensor([4, 4])

# check one at once
tensorcheck(x, x_expected)

# or multiple...
tensorcheck([x, y], [x_expected, y_expected])

Not specifying or setting a field to None yields a wildcard type; by default, every field is None. You can also check that the tensor type is either 'numpy' or 'pytorch'!

tensorcheck(x, Tensor([4, None], library='numpy', device=None))

Related work

TODOs:

  • use a different color for each individual error found in the runtime type checking

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages