In [1]:
import torch
import torch.nn.functional as F
import math
from torch import nn 

In [2]:
class GCT(nn.Module):
    def __init__(self, num_channels, epsilon=1e-5, mode='l2', after_relu=False) -> None:
        super().__init__()
        
        self.alpha = nn.Parameter(torch.ones(1, num_channels, 1, 1))
        self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
        self.epsilon = epsilon
        self.mode = mode
        self.after_relu = after_relu
        
    def forward(self, x):
        if self.mode == 'l2':
            embedding = (x.pow(2).sum((2, 3), keepdim=True) + self.epsilon).pow(0.5) * self.alpha
            norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5)
        elif self.mode == 'l1':
            if not self.after_relu:
                _x = torch.abs(x)
            else:
                _x = x
            embedding = _x.sum((2, 3), keepdim=True) * self.alpha
            norm = self.gamma / (torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon)
        else:
            print('Unknow mode!')
        
        gate = 1. + torch.tanh(embedding * norm + self.beta)
        
        return x * gate      
        

In [5]:
alpha =  nn.Parameter(torch.ones(1, 12, 1, 1))
alpha.shape

torch.Size([1, 12, 1, 1])

In [10]:
x = torch.randn(4, 12, 4, 4)
x = x.sum((2, 3), keepdim=True)
x.shape

torch.Size([4, 12, 1, 1])

In [11]:
x = (x.pow(2).mean(dim=1, keepdim=True))
x.shape

torch.Size([4, 1, 1, 1])

In [13]:
x = torch.randn(4, 12, 4, 4)
print(f'{x}')
BN = nn.BatchNorm2d(12)
x = BN(x)
x.shape, x

tensor([[[[-1.0907e+00, -7.8579e-01,  1.3389e+00, -1.0994e+00],
          [-6.7382e-01, -8.5815e-02, -1.1039e-01, -2.2817e+00],
          [ 5.8559e-01, -6.0617e-01,  1.8198e-01, -7.0922e-02],
          [-1.2834e-01,  1.0458e+00, -2.3795e+00, -2.9397e-02]],

         [[ 3.3151e-01, -1.2638e+00, -5.2213e-01, -2.7057e+00],
          [-2.1774e-01, -2.3305e-01,  2.6542e-02,  1.4301e+00],
          [ 2.6670e-01,  3.7783e-01,  1.0204e+00, -1.1398e+00],
          [-2.8638e-01,  6.2652e-01,  2.3147e-01, -6.6035e-01]],

         [[ 1.2881e+00,  4.0754e-01, -1.4569e+00,  6.4815e-01],
          [ 4.4932e-01, -1.5754e+00, -1.6371e-01, -1.1207e-01],
          [ 1.3838e+00,  5.5745e-01, -4.6383e-01,  8.7705e-01],
          [ 3.0802e-01, -1.0052e+00, -1.1461e+00, -1.1693e-01]],

         [[ 3.6643e-01,  2.3681e-01,  5.4726e-01,  1.2393e+00],
          [ 1.0019e+00, -7.8929e-01,  4.9933e-01,  3.0524e-01],
          [ 7.8942e-01,  1.7082e+00, -4.1838e-01,  1.6946e+00],
          [ 1.9294e-01, -1.0551e+0

(torch.Size([4, 12, 4, 4]),
 tensor([[[[-0.9027, -0.6147,  1.3924, -0.9109],
           [-0.5089,  0.0465,  0.0233, -2.0277],
           [ 0.6807, -0.4450,  0.2995,  0.0606],
           [ 0.0064,  1.1155, -2.1201,  0.0998]],
 
          [[ 0.5086, -1.2223, -0.4176, -2.7869],
           [-0.0873, -0.1040,  0.1777,  1.7006],
           [ 0.4383,  0.5589,  1.2561, -1.0878],
           [-0.1618,  0.8287,  0.4001, -0.5676]],
 
          [[ 1.2333,  0.4723, -1.1391,  0.6802],
           [ 0.5084, -1.2415, -0.0214,  0.0232],
           [ 1.3160,  0.6018, -0.2808,  0.8780],
           [ 0.3862, -0.7487, -0.8705,  0.0190]],
 
          [[ 0.2369,  0.0982,  0.4304,  1.1710],
           [ 0.9169, -0.9999,  0.3791,  0.1714],
           [ 0.6895,  1.6727, -0.6030,  1.6583],
           [ 0.0512, -1.2844, -0.5234,  1.1015]],
 
          [[-0.9748,  0.0533, -0.1393, -0.5934],
           [-0.8402, -1.2593,  0.1502,  0.9971],
           [ 0.2800, -0.7241,  0.8243, -1.8757],
           [ 1.0069, -0.6624,