In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
# three channel image size 32x32
image = torch.rand(1, 3, 32, 32, 32)
k_size = (3,3,3)
s_size = (1,1,1)

In [3]:
# Pytorch Conv2D
conv = nn.Conv3d(3, 3, k_size, s_size, bias=False)
print(conv.weight.shape)
expected = conv(image)
print(expected.shape)
print(expected.data.type())

torch.Size([3, 3, 3, 3, 3])
torch.Size([1, 3, 30, 30, 30])
torch.FloatTensor


In [73]:
# Manual Conv3D without padding and no bias - adding a padding should not be difficult

# image unfold - dimension, size and step

patches = image.unfold(2, k_size[0], s_size[0]).unfold(3, k_size[1], s_size[1]).unfold(4, k_size[2], s_size[2])
patches = patches.reshape(1, 3, -1, k_size[0], k_size[1], k_size[2])

# now we get tensor with the shape of (N, C, No. windows, D, H, W)
patches = patches.permute(0, 2, 1, 3, 4, 5)

print(patches.unsqueeze(2).shape)
print(conv.weight.unsqueeze(0).unsqueeze(1).shape)

result = (patches.unsqueeze(2) * conv.weight.unsqueeze(0).unsqueeze(1)).sum([3,4,5,6])

result = result.permute(0, 2, 1)

# out_d = out_h = out_w = int(result.size(2) ** (1/3))
result = result.view(1, 3, 30, 30, 30)

print(result.shape)


# def conv2d(feat, k_size=(3,3), s_size=(3,3), weights=None):
#     # extract all patches from feat
    

torch.Size([1, 27000, 1, 3, 3, 3, 3])
torch.Size([1, 1, 3, 3, 3, 3, 3])
torch.Size([1, 3, 30, 30, 30])


In [9]:
import time

def th_generate_grid(batch_size, input_depth, input_height, input_width):
    grid = np.meshgrid(
        range(input_depth), range(input_height), range(input_width), indexing='ij'
    )
    grid = np.stack(grid, axis=-1)
    grid = grid.reshape(-1, 3)
    
    # similar to unsqueeze
    grid = np.expand_dims(grid, 0)
    grid = np.tile(grid, [batch_size, 1, 1])
    grid = torch.from_numpy(grid)
    return grid

start_time = time.time()
test = th_generate_grid(5, 32, 32, 32)
print(test.shape)
print("--- %s seconds ---" % (time.time() - start_time))


torch.Size([5, 32768, 3])
--- 0.005708932876586914 seconds ---


In [242]:
import itertools
start_time = time.time()
a = np.array(np.meshgrid(range(32), range(32), range(32), indexing='ij')).T.reshape(-1,3)
a = np.tile(a, [5, 1, 1])

print("--- %s seconds ---" % (time.time() - start_time))

print((a == test).all())

--- 0.0024862289428710938 seconds ---
False


In [230]:
coords = torch.from_numpy(a).type(torch.FloatTensor)

def flatten(a): return a.reshape(a.numel())

def repeat(a, repeats): return flatten(torch.transpose(a.repeat(repeats, 1), 0, 1))

idx = repeat(torch.arange(0, 5), 10)


In [4]:
from deformable_conv_2d import ConvOffset2D

image = torch.rand(1, 3, 32, 32)
def_conv = ConvOffset2D(3)
result = def_conv(image)
print(result.shape)

indices  torch.Size([3072, 3])
inds  tensor([   0,    1,    2,  ..., 1052, 1053, 1055])
indices  torch.Size([3072, 3])
inds  tensor([  33,   34,   35,  ..., 1085, 1086, 1087])
indices  torch.Size([3072, 3])
inds  tensor([  32,   33,   34,  ..., 1084, 1085, 1087])
indices  torch.Size([3072, 3])
inds  tensor([   1,    2,    3,  ..., 1053, 1054, 1055])
torch.Size([1, 3, 32, 32])


In [None]:
import time
import itertools
from deformable_conv_3d import ConvOffset3D, deform_conv3d

image = torch.rand(1, 16, 64, 64, 64)
def_conv3 = deform_conv3d(16, 32, (3,3,3))

start_time = time.time()
result = def_conv3(image)
print("--- %s seconds ---" % (time.time() - start_time))
# print(result)