In [1]:
import torch
import nvidia.nvcomp as nvcomp

def encode(codec, x):
    assert x.dtype == torch.uint8, "Input tensor must be of type torch.uint8"
    shape = x.shape
    x_flat = x.contiguous().view(-1)
    nvcomp_array = nvcomp.as_array(x_flat)
    c = codec.encode(nvcomp_array)
    return c, shape

def decode(codec, c, shape):
    decompressed_array = codec.decode(c)
    xhat_flat = torch.utils.dlpack.from_dlpack(decompressed_array.to_dlpack())
    xhat = xhat_flat.view(shape)
    return xhat

In [2]:
x = torch.ones((1, 3, 256, 256), dtype=torch.uint8, device='cuda')
codec = nvcomp.Codec(algorithm="bitcomp", data_type='|u1', algorithm_type=0)

In [3]:
c, shape = encode(codec, x)
print(f"Ratio: {x.numel()/c.buffer_size}")
xhat = decode(codec, c, shape)
assert torch.equal(x, xhat)

Ratio: 327.68
