In [2]:
import numpy as np
import tensorflow as tf
print(tf.__version__)

def arr_diff(x, y):
    assert list(x.shape) == list(y.shape)
    x = np.ravel(x)
    y = np.ravel(y)
    return (x - y) @ (x - y)

2.4.1


# Conv2D

In [3]:
x = np.random.randn(3, 32, 29, 7)
w = np.random.randn(5, 5, 7, 10)

@tf.function
def tf_conv(x, w):
    y = tf.nn.conv2d(input=x, filters=w, strides=[1, 1, 1, 1], padding='SAME')
    return y

fun = tf_conv.get_concrete_function(
    tf.TensorSpec(x.shape, tf.dtypes.float32),
    tf.TensorSpec(w.shape, tf.dtypes.float32),
)
mlir = tf.mlir.experimental.convert_function(fun)
print(mlir)

#y = tf_conv(x, w)
#print(y.shape)



module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 561 : i32}} {
  func @__inference_tf_conv_7(%arg0: tensor<3x32x29x7xf32> {tf._user_specified_name = "x"}, %arg1: tensor<5x5x7x10xf32> {tf._user_specified_name = "w"}) -> tensor<3x32x29x10xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "x,w", outputs = "identity_RetVal"}} {
    %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<3x32x29x7xf32>, tensor<5x5x7x10xf32>) -> tensor<3x32x29x10xf32>
    %1 = "tf.Identity"(%0) {device = ""} : (tensor<3x32x29x10xf32>) -> tensor<3x32x29x10xf32>
    return %1 : tensor<3x32x29x10xf32>
  }
}


In [4]:
class MyConv2DDescriptor:
    
    def get_explicit_padding(self):
        padding = self.padding
        pad_top, pad_bottom, pad_left, pad_right = (None, None, None, None)
        xHeight = self.x[1]
        xWidth = self.x[2]
        kHeight = self.w[0]
        kWidth = self.w[1]
        sh = self.strides[1]
        sw = self.strides[2]
        
        if not isinstance(padding, str):
            #explicit padding
            #[[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
            assert len(padding) == 4
            for x in padding: assert len(x) == 2
            for x in padding[0]: assert x == 0
            for x in padding[3]: assert x == 0
            pad_top = padding[1][0]
            pad_bottom = padding[1][1]
            pad_left = padding[2][0]
            pad_right = padding[2][1]
        
        elif padding == 'SAME':
            if xHeight % sh == 0:
                pad_height = max(kHeight - sh, 0)
            else:
                pad_height = max(kHeight - (xHeight % sh), 0)
            if xWidth % sw == 0:
                pad_width = max(kWidth - sw, 0)
            else:
                pad_width = max(kWidth - (xWidth % sw), 0)
            
            pad_top = pad_height // 2
            pad_bottom = pad_height - pad_top
            pad_left = pad_width // 2
            pad_right = pad_width - pad_left
            
        
        elif padding == 'VALID':
            pad_top = 0
            pad_bottom = 0
            pad_left = 0
            pad_right = 0
            
        else:
            raise Exception('Unknown padding')
        
        self.padding = [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
            
    def get_output_shape(self):
        pad_top, pad_bottom = self.padding[1]
        pad_left, pad_right = self.padding[2]
        xHeight = self.x[1]
        xWidth = self.x[2]
        kHeight = self.w[0]
        kWidth = self.w[1]
        sh = self.strides[1]
        sw = self.strides[2]
        
        assert self.x[3] == self.w[2]
        
        # Padded input dimensions
        xPHeight = xHeight + pad_top + pad_bottom
        xPWidth = xWidth + pad_left + pad_right
        
        outN = self.x[0]
        outHeight = (xPHeight - kHeight) // sh + 1
        outWidth = (xPWidth - kWidth) // sw + 1
        outC = self.w[3]
        
        self.output_shape = [outN, outHeight, outWidth, outC]
    
    def __init__(self, x, w, strides, padding):
        self.x = x
        self.w = w
        self.strides = strides
        self.padding = padding
        self.output_shape = None
                        
        # Check strides
        assert len(strides) == 4
        assert strides[0] == 1
        assert strides[3] == 1
        
        # Compute real paddings
        self.get_explicit_padding()
        print('Paddings:', self.padding)
        
        # Compute output shape
        self.get_output_shape()
        print('Output shape:', self.output_shape)
    

    
    

    
    
    # compute using tf version directly
    def compute_tf(self, x, w):
        y = tf.nn.conv2d(input=x, filters=w, strides=self.strides, padding=[[0,0],[0,0],[0,0],[0,0]])
        return y.numpy()

    # pad input tensor using paddings infos
    def pad_input(self, x):
        return np.pad(x, self.padding)
    
    # 7 for loop using conv2d formula
    def compute_naive(self, x, w):
        sh = self.strides[1]
        sw = self.strides[2]
        kHeight = self.w[0]
        kWidth = self.w[1]
        xChannels = self.w[2]
        res = np.zeros(self.output_shape)
        
        for b in range(self.output_shape[0]):
            for i in range(self.output_shape[1]):
                for j in range(self.output_shape[2]):
                    for k in range(self.output_shape[3]):
                        val = 0
                        for di in range(kHeight):
                            for dj in range(kWidth):
                                for q in range(xChannels):
                                    val += x[b, sh*i + di, sw * j + dj, q] * w[di, dj, q, k]
                        res[b, i, j, k] = val
                                
                                
                        
        return res
    
    # flatten and extend the input to run conv2d as a matmul
    # turned from (xB, xH, xW, xC) into (xB*yH*yW, kH*kW*xC)
    def matmul_flatten_x(self, x):
        sh, sw = self.strides[1:-1]
        xB, xHeight, xWidth, xC = self.x
        kHeight, kWidth = self.w[:2]
        yHeight, yWidth, yC  = self.output_shape[1:]
        
        # divide 1 DIM into 3 to build the mat more easily
        new_x = np.zeros((xB, yHeight, yWidth, kHeight*kWidth*xC))
        
        for b in range(xB):
            for i in range(yHeight):
                for j in range(yWidth):
                    # extract x slice used in the sum of naive implem to compute y[b, i, j, :]
                    new_x[b, i, j] = x[b, sh*i:sh*i+kHeight, sw*j:sw*j+kWidth, :].ravel()
            
        return new_x.reshape(xB * yHeight * yWidth, kHeight * kWidth * xC)
    
    # Compute by turning the conv into a matmul op
    def compute_matmul(self, x, w):
        sh, sw = self.strides[1:-1]
        xB, xHeight, xWidth, xC = self.x
        kHeight, kWidth = self.w[:2]
        yHeight, yWidth, yC  = self.output_shape[1:]
        
        # (xB, xH, xW, xC) -> (xB*yH*yW, kH*kW*xC)
        x = self.matmul_flatten_x(x)
        
        # (kH, kW, xC, yC) -> (kH*kW*xC, yC)
        w = w.reshape(kHeight * kWidth * xC, yC)
        
        # [(xB*yH*yW, kH*kW*xC), (kH*kW*xC, yC)] -> (xB*yH*yW, yC)
        y = np.matmul(x, w)
        
        # (xB*yH*yW, yC) -> (xB, yH, yW, yC)
        y = y.reshape(xB, yHeight, yWidth, yC)
        
        return y
    
    
        
        
    def compute(self, x, w):
        assert list(x.shape) == list(self.x)
        assert list(w.shape) == list(self.w)
        
        x = self.pad_input(x)
        
        #y = self.compute_tf(x, w)
        #y = self.compute_naive(x, w)
        y = self.compute_matmul(x, w)
        
        assert list(y.shape) == list(self.output_shape)
        return y
        
        
def my_conv2d(x, w, strides, padding):
    op = MyConv2DDescriptor(x.shape, w.shape, strides, padding)
    return op.compute(x, w)

x = np.random.randn(3, 32, 29, 7)
w = np.random.randn(5, 5, 7, 10)

y_ref = tf_conv(x, w).numpy()
y = my_conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')

print(arr_diff(y_ref, y))

Paddings: [[0, 0], [2, 2], [2, 2], [0, 0]]
Output shape: [3, 32, 29, 10]
4.535802833650465e-25


# Conv2D backprop data

In [29]:
x = tf.constant(np.random.randn(3, 32, 29, 7).astype(np.float32))
w = tf.constant(np.random.randn(5, 5, 7, 10).astype(np.float32))
yshape = tf_conv(x, w).shape
dy = tf.constant(np.random.randn(*yshape).astype(np.float32))


@tf.function
def tf_conv_dx(x, w, dy):
    y = tf.nn.conv2d(input=x, filters=w, strides=[1, 1, 1, 1], padding='SAME')
    loss = tf.reduce_sum(y*dy)
    return tf.gradients(loss, [x], stop_gradients=[x])[0]


fun = tf_conv_dx.get_concrete_function(
    tf.TensorSpec(x.shape, tf.dtypes.float32),
    tf.TensorSpec(w.shape, tf.dtypes.float32),
    tf.TensorSpec(dy.shape, tf.dtypes.float32),
)
mlir = tf.mlir.experimental.convert_function(fun)
print(mlir)

dx = tf_conv_dx(x, w, dy)
print(x.shape, w.shape, dy.shape)
print(dx.shape)



module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 561 : i32}} {
  func @__inference_tf_conv_dx_575(%arg0: tensor<3x32x29x7xf32> {tf._user_specified_name = "x"}, %arg1: tensor<5x5x7x10xf32> {tf._user_specified_name = "w"}, %arg2: tensor<3x32x29x10xf32> {tf._user_specified_name = "dy"}) -> tensor<3x32x29x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "x,w,dy", outputs = "identity_RetVal"}} {
    %0 = "tf.Const"() {value = dense<[3, 32, 29, 7]> : tensor<4xi32>} : () -> tensor<4xi32>
    %1 = "tf.Conv2DBackpropInput"(%0, %arg1, %arg2) {_class = ["loc:@Conv2D"], data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<4xi32>, tensor<5x5x7x10xf32>, tensor<3x32x29x10xf32>) -> tensor<3x32x29x7xf32>
    %2 = "tf.Identity"(%1) {device = ""} : (tensor<3x32x29x7xf32>) -> tensor<3x32x29x7xf32>
    return %2 : tensor<3x32x29x7xf32>


## Raw call

can call the raw function directly

In [38]:
x = tf.constant(np.random.randn(3, 32, 29, 7).astype(np.float32))
w = tf.constant(np.random.randn(5, 5, 7, 10).astype(np.float32))
yshape = tf_conv(x, w).shape
dy = tf.constant(np.random.randn(*yshape).astype(np.float32))


@tf.function
def tf_conv_backprop_input(x_shape, w, dy):
    dx = tf.raw_ops.Conv2DBackpropInput(
        input_sizes=x_shape, 
        filter=w, 
        out_backprop=dy,
        strides=[1, 1, 1, 1], 
        padding='SAME'
    )
    return dx


fun = tf_conv_backprop_input.get_concrete_function(
    list(x.shape),
    tf.TensorSpec(w.shape, tf.dtypes.float32),
    tf.TensorSpec(dy.shape, tf.dtypes.float32),
)
mlir = tf.mlir.experimental.convert_function(fun)
print(mlir)

dx = tf_conv_backprop_input(list(x.shape), w, dy)
dx_ref = tf_conv_dx(x, w, dy)
print(x.shape, w.shape, dy.shape)
print(dx.shape)
print(arr_diff(dx, dx_ref))



module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 561 : i32}} {
  func @__inference_tf_conv_backprop_input_689(%arg0: tensor<5x5x7x10xf32> {tf._user_specified_name = "w"}, %arg1: tensor<3x32x29x10xf32> {tf._user_specified_name = "dy"}) -> tensor<3x32x29x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "w,dy", outputs = "identity_RetVal"}} {
    %0 = "tf.Const"() {value = dense<[3, 32, 29, 7]> : tensor<4xi32>} : () -> tensor<4xi32>
    %1 = "tf.Conv2DBackpropInput"(%0, %arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<4xi32>, tensor<5x5x7x10xf32>, tensor<3x32x29x10xf32>) -> tensor<3x32x29x7xf32>
    %2 = "tf.Identity"(%1) {device = ""} : (tensor<3x32x29x7xf32>) -> tensor<3x32x29x7xf32>
    return %2 : tensor<3x32x29x7xf32>
  }
}
(3, 32, 29, 7) (5, 5, 7, 10) (3, 32, 29, 10)
(3, 32, 29, 7)
0.0


Output shape is needed because if stride is > 2, there can be multiple possible output values because of the formula to compute conv2D output

```
outHeight = (xPHeight - kHeight) // sh + 1
```

The number is rounded, so it might be more than one possible input.

In [46]:
strides = [1, 1, 1, 1]
padding='SAME'

x = tf.constant(np.random.randn(3, 32, 29, 7).astype(np.float32))
w = tf.constant(np.random.randn(5, 5, 7, 10).astype(np.float32))
yshape = tf.nn.conv2d(input=x, filters=w, strides=strides, padding=padding).shape
dy = tf.constant(np.random.randn(*yshape).astype(np.float32))

def tf_conv_backprop_input(x_shape, w, dy):
    dx = tf.raw_ops.Conv2DBackpropInput(
        input_sizes=x_shape, 
        filter=w, 
        out_backprop=dy,
        strides=strides, 
        padding=padding
    )
    return dx

fake_shape = [x.shape[0], x.shape[1], x.shape[2], x.shape[3]]

# Valid
dx = tf_conv_backprop_input(fake_shape, w, dy)

# Because stride 1 only one valid shape
# Wrong
fake_shape[2] = fake_shape[2] + 1
dx = tf_conv_backprop_input(fake_shape, w, dy)

InvalidArgumentError: Conv2DCustomBackpropInput: Size of out_backprop doesn't match computed: actual = 29, computed = 30 spatial_dim: 2 input: 30 filter: 5 output: 29 stride: 1 dilation: 1 [Op:Conv2DBackpropInput]

In [57]:
strides = [1, 2, 2, 1]
padding='SAME'

x = tf.constant(np.random.randn(3, 32, 29, 7).astype(np.float32))
w = tf.constant(np.random.randn(5, 5, 7, 10).astype(np.float32))
yshape = tf.nn.conv2d(input=x, filters=w, strides=strides, padding=padding).shape
dy = tf.constant(np.random.randn(*yshape).astype(np.float32))

def tf_conv_backprop_input(x_shape, w, dy):
    dx = tf.raw_ops.Conv2DBackpropInput(
        input_sizes=x_shape, 
        filter=w, 
        out_backprop=dy,
        strides=strides, 
        padding=padding
    )
    return dx

fake_shape = [x.shape[0], x.shape[1], x.shape[2], x.shape[3]]

# Valid
dx = tf_conv_backprop_input(fake_shape, w, dy)

# Valid
fake_shape[2] = fake_shape[2] + 1
dx = tf_conv_backprop_input(fake_shape, w, dy)

# Because stride 2 there can be 2 valid shapes (can add +1), but nore more
# Wrong
fake_shape[2] = fake_shape[2] + 2
dx = tf_conv_backprop_input(fake_shape, w, dy)

InvalidArgumentError: Conv2DCustomBackpropInput: Size of out_backprop doesn't match computed: actual = 15, computed = 16 spatial_dim: 2 input: 32 filter: 5 output: 15 stride: 2 dilation: 1 [Op:Conv2DBackpropInput]

In [62]:
strides = [1, 2, 2, 1]
padding='SAME'

x = tf.constant(np.random.randn(1, 5, 5, 1).astype(np.float32))
w = tf.constant(np.random.randn(2, 2, 1, 1).astype(np.float32))
yshape = tf.nn.conv2d(input=x, filters=w, strides=strides, padding=padding).shape
dy = tf.constant(np.random.randn(*yshape).astype(np.float32))

def tf_conv_backprop_input(x_shape, w, dy):
    dx = tf.raw_ops.Conv2DBackpropInput(
        input_sizes=x_shape, 
        filter=w, 
        out_backprop=dy,
        strides=strides, 
        padding=padding
    )
    return dx

fake_shape = [x.shape[0], x.shape[1], x.shape[2], x.shape[3]]

# Valid
dx = tf_conv_backprop_input(fake_shape, w, dy)
print(dx.numpy())
print('\n======\n')

# Valid

fake_shape[2] = fake_shape[2] + 1
dx = tf_conv_backprop_input(fake_shape, w, dy)
print(dx.numpy())
print('\n======\n')

[[[[-0.28893998]
   [ 0.6151089 ]
   [-0.45987356]
   [ 0.97900033]
   [-0.12023436]]

  [[ 2.409006  ]
   [ 0.40683782]
   [ 3.8341465 ]
   [ 0.6475184 ]
   [ 1.0024412 ]]

  [[-0.4638604 ]
   [ 0.9874877 ]
   [-0.57328576]
   [ 1.2204375 ]
   [-0.17393352]]

  [[ 3.8673863 ]
   [ 0.653132  ]
   [ 4.779709  ]
   [ 0.80720687]
   [ 1.4501522 ]]

  [[ 0.1213702 ]
   [-0.25837854]
   [ 0.04046171]
   [-0.08613677]
   [ 0.20346424]]]]


[[[[-0.28893998]
   [ 0.6151089 ]
   [-0.45987356]
   [ 0.97900033]
   [-0.12023436]
   [ 0.25596052]]

  [[ 2.409006  ]
   [ 0.40683782]
   [ 3.8341465 ]
   [ 0.6475184 ]
   [ 1.0024412 ]
   [ 0.16929428]]

  [[-0.4638604 ]
   [ 0.9874877 ]
   [-0.57328576]
   [ 1.2204375 ]
   [-0.17393352]
   [ 0.3702778 ]]

  [[ 3.8673863 ]
   [ 0.653132  ]
   [ 4.779709  ]
   [ 0.80720687]
   [ 1.4501522 ]
   [ 0.24490462]]

  [[ 0.1213702 ]
   [-0.25837854]
   [ 0.04046171]
   [-0.08613677]
   [ 0.20346424]
   [-0.43314418]]]]




In [56]:
class MyConv2DBackpropInputDescriptor:
    
    # Compute the input padding, same function than for Conv2D
    def get_explicit_padding(self):
        padding = self.x_padding
        pad_top, pad_bottom, pad_left, pad_right = (None, None, None, None)
        xHeight = self.x[1]
        xWidth = self.x[2]
        kHeight = self.w[0]
        kWidth = self.w[1]
        sh = self.x_strides[1]
        sw = self.x_strides[2]
        
        if not isinstance(padding, str):
            #explicit padding
            #[[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
            assert len(padding) == 4
            for x in padding: assert len(x) == 2
            for x in padding[0]: assert x == 0
            for x in padding[3]: assert x == 0
            pad_top = padding[1][0]
            pad_bottom = padding[1][1]
            pad_left = padding[2][0]
            pad_right = padding[2][1]
        
        elif padding == 'SAME':
            if xHeight % sh == 0:
                pad_height = max(kHeight - sh, 0)
            else:
                pad_height = max(kHeight - (xHeight % sh), 0)
            if xWidth % sw == 0:
                pad_width = max(kWidth - sw, 0)
            else:
                pad_width = max(kWidth - (xWidth % sw), 0)
            
            pad_top = pad_height // 2
            pad_bottom = pad_height - pad_top
            pad_left = pad_width // 2
            pad_right = pad_width - pad_left
            
        
        elif padding == 'VALID':
            pad_top = 0
            pad_bottom = 0
            pad_left = 0
            pad_right = 0
            
        else:
            raise Exception('Unknown padding')
        
        self.x_padding = [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]
            
    def get_output_shape(self):
        pad_top, pad_bottom = self.padding[1]
        pad_left, pad_right = self.padding[2]
        xHeight = self.x[1]
        xWidth = self.x[2]
        kHeight = self.w[0]
        kWidth = self.w[1]
        sh = self.strides[1]
        sw = self.strides[2]
        
        assert self.x[3] == self.w[2]
        
        # Padded input dimensions
        xPHeight = xHeight + pad_top + pad_bottom
        xPWidth = xWidth + pad_left + pad_right
        
        outN = self.x[0]
        outHeight = (xPHeight - kHeight) // sh + 1
        outWidth = (xPWidth - kWidth) // sw + 1
        outC = self.w[3]
        
        self.output_shape = [outN, outHeight, outWidth, outC]
    
    def __init__(self, input_shape, dy, w, strides, padding):
        self.x = input_shape
        self.w = w
        self.y = dy

        self.x_strides = strides
        self.x_padding = padding
                        
        # Check strides
        assert len(strides) == 4
        assert strides[0] == 1
        assert strides[3] == 1
        
        # Compute real paddings
        self.get_explicit_padding()
        print('Input Paddings:', self.x_padding)
        
        # Compute output shape
        #self.get_output_shape()
        #print('Output shape:', self.output_shape)
        
    
    
    def compute(self, dy, w):
        #TODO
        return np.zeros(self.x)
        
def my_conv2d_backprop_input(input_shape, grad, w, strides, padding):
    op = MyConv2DBackpropInputDescriptor(input_shape, grad.shape, w.shape, strides, padding)
    return op.compute(grad, w)

def ref_conv2d_backprop_input(input_shape, grad, w, strides, padding):
    dx = tf.raw_ops.Conv2DBackpropInput(
        input_sizes=input_shape, 
        filter=w, 
        out_backprop=grad,
        strides=strides, 
        padding=padding
    )
    return dx

strides = [1, 1, 1, 1]
padding = 'SAME'

x = tf.constant(np.random.randn(3, 32, 29, 7).astype(np.float32))
w = tf.constant(np.random.randn(5, 5, 7, 10).astype(np.float32))
yshape = tf.nn.conv2d(input=x, filters=w, strides=strides, padding=padding).shape
dy = tf.constant(np.random.randn(*yshape).astype(np.float32))
input_shape = list(x.shape)

dx_ref = ref_conv2d_backprop_input(input_shape, dy, w, strides, padding)
dx = my_conv2d_backprop_input(input_shape, dy, w, strides, padding)

print(arr_diff(dx_ref, dx))

Input Paddings: [[0, 0], [2, 2], [2, 2], [0, 0]]
4428097.685704555
