diff --git a/torchsparse/utils/helpers.py b/torchsparse/utils/helpers.py index 0830295..76b6552 100644 --- a/torchsparse/utils/helpers.py +++ b/torchsparse/utils/helpers.py @@ -256,3 +256,10 @@ def make_tuple(inputs, dimension=3): elif isinstance(inputs, tuple): assert len(inputs) == dimension, 'Input length and dimension mismatch' return inputs + elif isinstance(inputs, torch.Tensor): + inputs = inputs.squeeze() + shape = inputs.shape + assert len(shape) == 1 and shape[0] == dimension, 'Input length and dimension mismatch' + if inputs.is_cuda: + inputs = inputs.cpu() + return tuple((t.item() for t in inputs))