https://github.com/leelabcnbc/unsup-pytorch/issues/2

so the computation done is actually a convolution (with one input flipped).

but depending on how you think about it, it may not be flipped.

check the end

In [1]:
import numpy as np
from torch import FloatTensor
from torch.autograd import Variable
from torch.nn.functional import mse_loss, conv2d

In [2]:
from torch.nn import ConvTranspose2d

In [3]:
from scipy.signal import convolve

In [4]:
import torch

In [5]:
_ = torch.manual_seed(0)

In [6]:
num_channel_code = 4
kernel_size = 9
input_this = FloatTensor(1, num_channel_code, 17, 17)
_ = input_this.normal_(0, 1)

In [7]:
def forward():
    code_to_image_layer = ConvTranspose2d(num_channel_code, 1, kernel_size, bias=False)
    input_var = Variable(input_this, requires_grad=True)
    output_now = code_to_image_layer(input_var).data.numpy()[0,0]
    print(output_now.shape)
    
    # try to replicate the result using naive `convolve` in `scipy.signal`
    weight_this = code_to_image_layer.weight.data.numpy()[:,0]
    print(weight_this.shape)
    input_this_np = input_var.data.numpy()[0]
    print(input_this_np.shape)
    
    output_ref = np.zeros_like(output_now, dtype=np.float64)
    assert weight_this.shape[0] == input_this_np.shape[0]
    for weight_this_slice, input_this_np_slice in zip(weight_this, input_this_np):
        # these two are the same.
#         output_ref += convolve(input_this_np_slice, weight_this_slice)
        output_ref += convolve(weight_this_slice, input_this_np_slice)
    assert output_now.shape == output_ref.shape
    print(abs(output_now-output_ref).max())
    assert abs(output_now-output_ref).max()<1e-6

In [8]:
forward()

(25, 25)
(4, 9, 9)
(4, 17, 17)
6.25848770142e-07


In [9]:
# suppose our code is [[1,1], [1,1]]
# and weight is [[1,2,3],
#                [4,5,6],
#                [7,8,9]]
# let's see how each code contributes to the convolution.
# 

def forward_debug():
    
    input_debug_np = np.ones((2,2)).reshape(1,1,2,2).astype(np.float64)
    weight_debug_np = (np.arange(9)+1).reshape(1,1,3,3).astype(np.float64)
    
    code_to_image_layer = ConvTranspose2d(1, 1, 3, bias=False)
    assert code_to_image_layer.weight.size() == (1,1,3,3)
    code_to_image_layer.weight.data[...] = FloatTensor(weight_debug_np)
    
    input_var = Variable(FloatTensor(input_debug_np))
    output_now = code_to_image_layer(input_var).data.numpy()[0,0]
#     print(output_now.shape)
    print(output_now)
    
    
    
    input_this_np = input_debug_np[0,0]
    weight_this = weight_debug_np[0,0]
    print(weight_this)
    
    output_ref = np.zeros_like(output_now, dtype=np.float64)
    output_ref += convolve(weight_this, input_this_np)
    assert output_now.shape == output_ref.shape
    print(abs(output_now-output_ref).max())
    assert abs(output_now-output_ref).max()<1e-6
    
    # then let's compute it another way, decomposing the code.
    output_ref_2 = np.zeros_like(output_now, dtype=np.float64)
    for idx in range(4):
        input_this_np_hole = np.zeros(4, dtype=np.float64)
        input_this_np_hole[idx] = 1
        input_this_np_hole = input_this_np_hole.reshape(2,2)
        output_ref_2_this_section = convolve(weight_this, input_this_np_hole)
        print(idx)
        print(input_this_np_hole)
        print(output_ref_2_this_section)
        output_ref_2 += output_ref_2_this_section
        
    assert output_now.shape == output_ref_2.shape
    print(abs(output_now-output_ref_2).max())
    assert abs(output_now-output_ref_2).max()<1e-6

In [10]:
forward_debug()

[[  1.   3.   5.   3.]
 [  5.  12.  16.   9.]
 [ 11.  24.  28.  15.]
 [  7.  15.  17.   9.]]
[[ 1.  2.  3.]
 [ 4.  5.  6.]
 [ 7.  8.  9.]]
0.0
0
[[ 1.  0.]
 [ 0.  0.]]
[[ 1.  2.  3.  0.]
 [ 4.  5.  6.  0.]
 [ 7.  8.  9.  0.]
 [ 0.  0.  0.  0.]]
1
[[ 0.  1.]
 [ 0.  0.]]
[[ 0.  1.  2.  3.]
 [ 0.  4.  5.  6.]
 [ 0.  7.  8.  9.]
 [ 0.  0.  0.  0.]]
2
[[ 0.  0.]
 [ 1.  0.]]
[[ 0.  0.  0.  0.]
 [ 1.  2.  3.  0.]
 [ 4.  5.  6.  0.]
 [ 7.  8.  9.  0.]]
3
[[ 0.  0.]
 [ 0.  1.]]
[[ 0.  0.  0.  0.]
 [ 0.  1.  2.  3.]
 [ 0.  4.  5.  6.]
 [ 0.  7.  8.  9.]]
0.0
