In [2]:
import torch
import numpy as np
import time

In [3]:
in_image = torch.randn(2,5,5)
nn = torch.nn.Conv2d(2, 3, 1, padding=0, stride=1)

out_image = nn(in_image)

out_channels = nn.out_channels
in_channels = nn.in_channels
in_height = in_image.size(1)
in_width = in_image.size(2)
kernel_size = nn.kernel_size[0]
out_height = out_image.size(1)
out_width = out_image.size(2)
weights = nn.weight.detach()
stride = nn.stride[0]
padding = nn.padding[0]
bias = nn.bias
print(weights.size())
print("Output size: ", out_image.size())

torch.Size([3, 2, 1, 1])
Output size:  torch.Size([3, 5, 5])


In [4]:
def run_their_implementation():
    now=time.time()
    in_width_p = in_width + padding * 2
    in_height_p = in_height + padding * 2

    size_p = in_height_p * in_width_p
    in_dim = size_p * in_channels
    out_dim = out_height * out_width * out_channels
    res = torch.zeros((out_dim, in_dim))

    # build row fillers
    len_rows = (in_channels - 1) * size_p + (kernel_size - 1) * in_width_p + kernel_size
    channels = torch.zeros((out_channels, len_rows))

    for i_out in range(out_channels):
        for i_in in range(in_channels):
            i_p = i_in * size_p
            for k in range(kernel_size):
                start = i_p + k * in_width_p
                end = start + kernel_size
                channels[i_out, start:end] = weights[i_out, i_in, k]

        for i_out_height in range(out_height):
            for i_out_width in range(out_width):
                start = i_out_height * stride * in_width_p + i_out_width * stride
                end = start + len_rows
                output = i_out * out_height * out_width + i_out_height * out_width + i_out_width
                res[output, start:end] = channels[i_out]

    # remove padding
    padding_rows = []
    for i_in in range(in_channels):
        for i_in_height in range(in_height_p):
            for i_in_width in range(in_width_p):
                if i_in_width < padding or i_in_width >= padding + in_width:
                    padding_rows.append(i_in * size_p + i_in_height * in_width_p + i_in_width)

            if i_in_height < padding or i_in_height >= padding + in_height:
                start = i_in * size_p + i_in_height * in_width_p
                end = start + in_width_p
                padding_rows = padding_rows + list(range(start, end))

    padding_rows = list(np.unique(np.array(padding_rows)))  # delete duplicates

    lc = torch.from_numpy(np.delete(res.numpy(), padding_rows, axis=1)).detach()

    if bias is None:
        ret_bias = torch.zeros(out_width * out_height * out_channels)
    else:
        ret_bias = torch.repeat_interleave(bias, out_width * out_height)
    print("Their implementation took: ", time.time()-now)
    return lc

In [114]:
def run_my_implementation():
    now=time.time()

    final_matrix = torch.zeros((in_width * in_height * in_channels, out_width * out_height * out_channels))

    flattened_weights = weights.flatten()

    # Check if kernel size is odd
    is_even_kernel = kernel_size % 2 == 0
    if is_even_kernel:
        offset = int(kernel_size / 2)
    else:
        offset = int((kernel_size - 1) / 2)
    
    kernel_size_kadenz = kernel_size * kernel_size
    output_image_kadenz = out_width * out_height
    final_matrix_col_indices = torch.zeros(out_width * out_height * out_channels)
    final_matrix_weight_indices = torch.zeros(weights.flatten().size())

    for channel_in_idx in range(in_channels):
        # Only iterate over the ones where the kernel fits
        top_bound_height = in_height - offset + padding
        top_bound_width = in_width - offset + padding
        if is_even_kernel:
            top_bound_height += 1
            top_bound_width += 1
        
        for col_in in range(offset-padding, top_bound_height, stride):
            for row_in in range(offset-padding, top_bound_width, stride):
                # Indices of the output
                col_out = int((col_in - offset + padding)/stride)
                row_out = int((row_in - offset + padding)/stride)


                # Final matrix column indices is the relative position inside the output image
                final_matrix_col_indices.fill_(0)  # Resetting the tensor to zero
                first_idx = (col_out) * out_width + row_out
                num_elements_to_fill = (len(final_matrix_col_indices) - first_idx) // (output_image_kadenz) + 1
                fill_indices = first_idx + torch.arange(num_elements_to_fill) * (output_image_kadenz)
                # Clamp fill_indices to the length of final_matrix_col_indices to avoid index out of bounds
                fill_indices = fill_indices[fill_indices < len(final_matrix_col_indices)]
                # Fill in the tensor
                final_matrix_col_indices[fill_indices] = 1

                
                kernel_count_idx = 0
                top_col_bound = col_in + offset + 1
                top_row_bound = row_in + offset + 1
                if is_even_kernel:
                    top_col_bound -= 1
                    top_row_bound -= 1
                for kernel_col_idx in range(col_in - offset, top_col_bound):
                    for kernel_row_idx in range(row_in - offset, top_row_bound):
                        if kernel_col_idx < 0 or kernel_row_idx < 0 or kernel_col_idx >= in_height or kernel_row_idx >= in_width:
                            kernel_count_idx += 1
                            continue
                
                        # Final matrix row index is the relative position inside the input image
                        final_matrix_row_idx = (kernel_col_idx * in_width + kernel_row_idx) + (channel_in_idx * in_width * in_height)

                        # Final matrix weight indices is the relative position inside the kernel -> kernel_count_idx
                        # It has to be across the channels with kadenz of the kernel size
                        final_matrix_weight_indices.fill_(0)  # Resetting the tensor to zero
                        num_elements_to_fill = (len(final_matrix_weight_indices) - kernel_count_idx) // (kernel_size_kadenz) + 1
                        fill_weight_indices = kernel_count_idx + torch.arange(num_elements_to_fill) * (kernel_size_kadenz)
                        # Clamp fill_weight_indices to the length of final_matrix_weight_indices to avoid index out of bounds
                        fill_weight_indices = fill_weight_indices[fill_weight_indices < len(final_matrix_weight_indices)]
                        # Fill in the tensor
                        final_matrix_weight_indices[fill_weight_indices] = 1

                        weight_mask = final_matrix_weight_indices == 1
                        # Select every in_channel'th element
                        weight_mask = torch.arange(len(weight_mask))[weight_mask]
                        weight_mask = weight_mask[channel_in_idx::in_channels]
                        # Fill the matrix
                        final_matrix[final_matrix_row_idx, final_matrix_col_indices == 1] = flattened_weights[weight_mask]

                        kernel_count_idx += 1
    
    print("My implementation took: ", time.time()-now)
    return final_matrix

In [112]:
in_image = torch.randn(16,14,14)
nn = torch.nn.Conv2d(16, 64, 4, padding=1, stride=2)

out_image = nn(in_image)
out_channels = nn.out_channels
in_channels = nn.in_channels
in_height = in_image.size(1)
in_width = in_image.size(2)
kernel_size = nn.kernel_size[0]
out_height = out_image.size(1)
out_width = out_image.size(2)
weights = nn.weight.detach()
stride = nn.stride[0]
padding = nn.padding[0]
bias = nn.bias
print(weights.size())
print("Output size: ", out_image.size())
print("Matrix size: ", in_width * in_height * in_channels, out_width * out_height * out_channels)

torch.Size([2, 2, 3, 3])
Output size:  torch.Size([2, 2, 2])
Matrix size:  32 8


In [113]:
lc = run_their_implementation()
lc = lc.transpose(0,1)
final_matrix = run_my_implementation()
torch.allclose(lc, final_matrix, atol=1e-06, rtol=1e-06)

# Find which rows are different
# for i in range(len(lc)):
#     if not torch.allclose(lc[i], final_matrix[i], atol=1e-06, rtol=1e-06):
#         print("Row ", i, " is different")
#         print("LC: ", lc[i])
#         print("Final: ", final_matrix[i])

Their implementation took:  0.001859903335571289
kernel size:  3
Offset:  1
Channel:  0 Top bound height:  3 Top bound width:  3
Channel:  0 Col:  1 Row:  1 Column out:  0 Row out:  0
Final matrix row idx:  0
Final matrix row idx:  1
Final matrix row idx:  2
Final matrix row idx:  4
Final matrix row idx:  5
Final matrix row idx:  6
Final matrix row idx:  8
Final matrix row idx:  9
Final matrix row idx:  10
Channel:  0 Col:  1 Row:  2 Column out:  0 Row out:  1
Final matrix row idx:  1
Final matrix row idx:  2
Final matrix row idx:  3
Final matrix row idx:  5
Final matrix row idx:  6
Final matrix row idx:  7
Final matrix row idx:  9
Final matrix row idx:  10
Final matrix row idx:  11
Channel:  0 Col:  2 Row:  1 Column out:  1 Row out:  0
Final matrix row idx:  4
Final matrix row idx:  5
Final matrix row idx:  6
Final matrix row idx:  8
Final matrix row idx:  9
Final matrix row idx:  10
Final matrix row idx:  12
Final matrix row idx:  13
Final matrix row idx:  14
Channel:  0 Col:  2 Row:

True

In [95]:
lc.size()

torch.Size([16, 18])

In [96]:
final_matrix.size()

torch.Size([16, 18])

In [103]:
final_matrix[3]

tensor([ 0.0771,  0.0000,  0.0000,  0.1532,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000, -0.3893,  0.0000,  0.0000,  0.2367,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000])

In [104]:
lc[3]

tensor([ 0.0771,  0.0000,  0.0000,  0.1532,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000, -0.3893,  0.0000,  0.0000,  0.2367,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000])

In [46]:
weights.flatten()

tensor([ 0.2414,  0.0328,  0.0840,  0.1218, -0.0303, -0.1360, -0.0209,  0.0347,
        -0.0776,  0.1350,  0.0608,  0.2054,  0.2260, -0.0920,  0.1001,  0.0320,
         0.1944, -0.2222, -0.1359, -0.0866,  0.0855, -0.1493,  0.1561,  0.1355,
        -0.2109,  0.1817, -0.1664, -0.0834,  0.1701,  0.1444,  0.0270,  0.1022])

In [303]:
torch.round(weights.flatten() * 1000) / 1000

tensor([-0.1110,  0.1750, -0.0280, -0.1190,  0.1040,  0.2290, -0.1070,  0.0580,
        -0.1630,  0.0610, -0.1210,  0.2310, -0.0050,  0.1360,  0.0930,  0.1820,
        -0.0690, -0.0980,  0.1880,  0.1880,  0.0360,  0.1740, -0.0710,  0.1190,
         0.2290, -0.1090, -0.0620, -0.0430,  0.1360, -0.1790, -0.0440, -0.0660,
        -0.1650, -0.0590,  0.0040,  0.1840,  0.1790, -0.0890, -0.0310,  0.1200,
        -0.1570, -0.2080, -0.1040, -0.1940, -0.1690,  0.1290, -0.0350, -0.0760,
         0.1770, -0.2070, -0.1460, -0.1580,  0.1880, -0.1750])

In [222]:
final_matrix_weight_indices = torch.tensor([-0.0289, -0.1273,  0.0912,  0.1214, -0.1635,  0.0009])
torch.arange(len(final_matrix_weight_indices  == 1)) % in_channels == 0

tensor([ True, False,  True, False,  True, False])