Skip to content

Commit

Permalink
fix test logic and make py2.7 compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed Aug 28, 2017
1 parent 880fbaf commit cca90a4
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tests/test_pytorch_np.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from tensorboardX import x2num
import torch
import numpy as np
shapes = [(3, 10, 10), (1, ), (1, 2, 3, 4, 5)]
tensors = [torch.rand(3, 10, 10), torch.rand(1), torch.rand(1, 2, 3, 4, 5)]

def test_pytorch_np():

for shape in shapes:
for tensor in tensors:
# regular tensor
assert(isinstance(x2num.makenp(torch.Tensor(*shape)), np.ndarray))
assert isinstance(x2num.makenp(tensor), np.ndarray)

# CUDA tensor
assert(isinstance(x2num.makenp(torch.Tensor(*shape).cuda()), np.ndarray))
if torch.cuda.device_count()>0:
assert isinstance(x2num.makenp(tensor.cuda()), np.ndarray)

# regular variable
assert(isinstance(x2num.makenp(torch.autograd.variable.Variable(torch.Tensor((*shape)))), np.ndarray))
assert isinstance(x2num.makenp(torch.autograd.variable.Variable(tensor)), np.ndarray)

# CUDA variable
assert(isinstance(x2num.makenp(torch.autograd.variable.Variable(torch.Tensor((*shape))).cuda()), np.ndarray))
if torch.cuda.device_count()>0:
assert isinstance(x2num.makenp(torch.autograd.variable.Variable(tensor)).cuda(), np.ndarray)

# python primitive type
assert(isinstance(x2num.makenp(0), np.ndarray))
Expand Down

0 comments on commit cca90a4

Please sign in to comment.