In [1]:
import torch

In [3]:
weights_map = torch.zeros((3, 5, 8))
weights_map

tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]]])

In [64]:
weights_map[2, 1, 6] = 1  # 94
weights_map[0, 3, 5] = 1  # 29
weights_map[1, 4, 7] = 1  # 79
weights_map[0, 0, 0] = 1  # 94

print(torch.nonzero(weights_map))
print(torch.nonzero(weights_map.flatten()))

height, width, depth = weights_map.shape

# x * (W * D) + y * D + z
# elem_(-3) * (dim_(-2) * dim(-1)) + elem_(-2) * dim(-1) + elem_(-1)

0 * (width * depth) + 0 * depth + 0

tensor([[0, 0, 0],
        [0, 3, 5],
        [1, 4, 7],
        [2, 1, 6]])
tensor([[ 0],
        [29],
        [79],
        [94]])


0

In [68]:
# Pick multinomial in whole tensor
# flatten = view(-1)
K = torch.multinomial(weights_map.flatten(), 4, replacement=False)

# Write a function that converts 1D indices to 3D indices
def get_indices_3d(indices, shape):
    indices_3d = torch.zeros((len(indices), len(shape)), dtype=torch.long)
    for i in range(len(indices)):
        for j in range(len(shape)):
            indices_3d[i, j] = indices[i] % shape[j]
            indices[i] = indices[i] // shape[j]
    return indices_3d

index_3d = (K // (weights_map.size(1) * weights_map.size(2)),
            (K // weights_map.size(2)) % weights_map.size(1),
            K % weights_map.size(2))

print(torch.stack(index_3d, dim=1))

tensor([[0, 0, 0],
        [0, 3, 5],
        [1, 4, 7],
        [2, 1, 6]])


In [69]:
num_samples = 4
samples = torch.multinomial(weights_map.flatten(), num_samples, replacement=False)

# Convert 1D indices to 3D indices
# [num_images, H, W]
batch_idx = torch.stack((K // (weights_map.size(1) * weights_map.size(2)),
                        (K // weights_map.size(2)) % weights_map.size(1),
                        K % weights_map.size(2)), 
                        dim=1)

print(batch_idx)

tensor([[0, 0, 0],
        [0, 3, 5],
        [1, 4, 7],
        [2, 1, 6]])


In [40]:
weights_map = torch.zeros((400, 540, 960))
# Set 10% of values of weights_map to 1
weights_map[torch.rand(weights_map.shape) < 0.1] = 1

num_samples = 400

In [33]:
%%timeit -n 1 -r 1

# about 1s for 400 samples over 1000 540p images.

indices = torch.zeros((num_samples, 3), dtype=torch.long)
sampled_pixels = 0

pixels_per_iter = 4
nb_iters = int(num_samples / pixels_per_iter)

# Old sampling
for i in range(nb_iters):
    weights_map_i = weights_map[i]
    samples = torch.multinomial(weights_map_i.flatten(), pixels_per_iter, replacement=False)
    h, w = torch.div(samples, 960, rounding_mode="floor"), samples % 960
    indices[sampled_pixels : (sampled_pixels + pixels_per_iter), 0] = i
    indices[sampled_pixels : (sampled_pixels + pixels_per_iter), 1] = h
    indices[sampled_pixels : (sampled_pixels + pixels_per_iter), 2] = w
    sampled_pixels += pixels_per_iter

print(indices[-10:, :])


tensor([[ 97, 429, 225],
        [ 97,  84,  93],
        [ 98, 327, 832],
        [ 98,  84, 861],
        [ 98,  26, 680],
        [ 98, 200, 496],
        [ 99, 451, 874],
        [ 99, 232, 299],
        [ 99, 437, 599],
        [ 99, 181, 446]])
936 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [41]:
#%%timeit -n 1 -r 1

# New sampling
# NB: Multinomial is for max 2^24 values = 16777216
# 1000*960*540 = about 2^29 (> 28, < 29)
num_images, height, width = weights_map.shape
images_per_iter = 20

indices = torch.zeros((num_samples, 3), dtype=torch.long)

for i in range(0, num_images, images_per_iter):
    weights_map_i = weights_map[i : i + images_per_iter]
    samples = torch.multinomial(weights_map_i.flatten(), int(num_samples / images_per_iter), replacement=False)
    indices = torch.stack((samples // (weights_map_i.size(1) * weights_map_i.size(2)),
                            (samples // weights_map_i.size(2)) % weights_map_i.size(1),
                            samples % weights_map_i.size(2)), 
                            dim=1)
    indices[:, 0] += i

# 3.6 seconds, not worth it.

print(indices.shape)

#print(indices[-10:, :])

torch.Size([20, 3])
