In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
class TestLayerNorm(nn.Module):
    def __init__(self, channels):
        super(TestLayerNorm, self).__init__()
        self.conv2d_1 = nn.Conv2d(3,channels,3)
        self.ln = nn.LayerNorm(23, elementwise_affine=False)
        self.conv2d_2 = nn.Conv2d(channels,1,3)
        
    def forward(self, x):
        tmp = self.conv2d_1(x)
        tmp = self.ln(tmp)
        return self.conv2d_2(tmp)

In [3]:
model = TestLayerNorm(6)

In [4]:
model

TestLayerNorm(
  (conv2d_1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
  (ln): LayerNorm((23,), eps=1e-05, elementwise_affine=False)
  (conv2d_2): Conv2d(6, 1, kernel_size=(3, 3), stride=(1, 1))
)

In [5]:
x = torch.rand(2,3,25,25)

In [26]:
torch.onnx.export(model,
                 (x),
                  "layer_norm.onnx",
                 export_params = True,
                 input_names = ["x"],
                  output_names = ["output"],
                  opset_version = 9,
                  verbose = True
                 )

graph(%x : Float(2, 3, 25, 25),
      %conv2d_1.weight : Float(6, 3, 3, 3),
      %conv2d_1.bias : Float(6),
      %conv2d_2.weight : Float(1, 6, 3, 3),
      %conv2d_2.bias : Float(1)):
  %5 : Float(2, 6, 23, 23) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[1, 1]](%x, %conv2d_1.weight, %conv2d_1.bias) # /opt/anaconda3/envs/tf2/lib/python3.8/site-packages/torch/nn/modules/conv.py:341:0
  %6 : Tensor = onnx::ReduceMean[axes=[-1]](%5)
  %7 : FloatTensor = onnx::Sub(%5, %6)
  %8 : Float() = onnx::Constant[value={2}]()
  %9 : FloatTensor = onnx::Pow(%7, %8)
  %10 : Tensor = onnx::ReduceMean[axes=[-1]](%9)
  %11 : Float() = onnx::Constant[value={1e-05}]()
  %12 : FloatTensor = onnx::Add(%10, %11)
  %13 : Tensor = onnx::Sqrt(%12)
  %14 : Float(2, 6, 23, 23) = onnx::Div(%7, %13) # /opt/anaconda3/envs/tf2/lib/python3.8/site-packages/torch/nn/functional.py:1695:0
  %output : Float(2, 1, 21, 21) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 

In [8]:
x.permute((0,2,3,1)).detach().flatten().numpy().astype(np.float32).tofile("./ln_input.bin")

In [34]:
out = model(x)

In [37]:
out.shape

torch.Size([2, 1, 21, 21])

In [36]:
out.permute((0,2,3,1)).flatten()

tensor([ 7.4653e-01, -7.0797e-02, -1.0554e+00,  6.6920e-01,  1.1648e-03,
         1.7048e-01,  4.0424e-01, -1.9442e-01, -1.2650e+00, -2.6460e-01,
        -1.4857e+00, -4.7139e-01, -4.8694e-01,  7.1930e-01, -3.5905e-01,
        -7.1231e-01,  4.5180e-01,  3.6331e-01, -3.6636e-01,  6.0494e-01,
        -1.4274e-01,  2.8320e-01,  7.5432e-01,  2.7374e-02, -9.7898e-01,
         5.9263e-01,  3.0665e-01,  1.9617e-01, -4.1622e-02, -3.4469e-01,
        -5.8227e-01,  8.2753e-02, -6.0747e-01, -1.2255e+00,  9.6025e-01,
        -3.2731e-01,  1.4876e-01, -1.2712e-02,  7.1086e-01, -4.9523e-01,
         1.7897e-01, -7.3997e-01,  2.8574e-01,  9.2537e-01, -3.6571e-01,
        -1.3347e+00,  5.7401e-01, -2.2403e-01, -5.1377e-01, -1.1340e-01,
         1.6171e-01, -5.1462e-01,  1.9847e-01,  5.7926e-01, -3.7613e-01,
        -6.2534e-01, -2.6250e-01, -2.0063e-01,  6.6625e-01, -9.3325e-01,
        -1.0057e+00,  1.3716e+00, -1.4293e+00,  9.1001e-01, -4.6655e-01,
         7.2493e-03,  4.7748e-01, -5.8601e-01,  2.2

In [39]:
bert_input = torch.rand((1,1,1035, 80 ),dtype=torch.float32)

In [44]:
bert_input.permute((0,2,3,1)).numpy().tofile("bert_input.bin")

In [6]:
dethwise_conv = nn.Conv2d(3,3,3,groups=3)

#         self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)


In [7]:
torch.onnx.export(dethwise_conv,
                 (x),
                  "dw3x3.onnx",
                 export_params = True,
                 input_names = ["x"],
                  output_names = ["output"],
                  opset_version = 9,
                  verbose = True
                 )

graph(%x : Float(2, 3, 25, 25),
      %weight : Float(3, 1, 3, 3),
      %bias : Float(3)):
  %output : Float(2, 3, 23, 23) = onnx::Conv[dilations=[1, 1], group=3, kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[1, 1]](%x, %weight, %bias) # /opt/anaconda3/envs/tf2/lib/python3.8/site-packages/torch/nn/modules/conv.py:341:0
  return (%output)



In [12]:
dethwise_conv(x).permute((0,2,3,1)).detach().flatten().numpy()[:10]

array([-0.27840602, -0.8446287 ,  0.40537226, -0.1711365 , -1.0745807 ,
        0.2673544 , -0.477559  , -0.9469062 ,  0.09525683, -0.38811117],
      dtype=float32)

In [1]:
import torch

In [9]:
x = torch.arange(2*2*3*2).view(2, 2,3,2)

In [10]:
x.flatten()

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23])

In [11]:
x.permute((3,0,1,2))

tensor([[[[ 0,  2,  4],
          [ 6,  8, 10]],

         [[12, 14, 16],
          [18, 20, 22]]],


        [[[ 1,  3,  5],
          [ 7,  9, 11]],

         [[13, 15, 17],
          [19, 21, 23]]]])