In [71]:
import tensorflow as tf

class EmulateMultiGPUBatchNorm(tf.keras.layers.Layer):

    def __init__(self, num_gpus, axis=1, *args, **kwargs):
        if axis != 1 and axis != 3:
            raise NotImplementedError(
                "Currently, EmulateMultiGPUBatchNorm just supports axis==1 or axis==3")
        # Only can do axis==3 as otherwise will get error:
        # "InternalError: The CPU implementation of FusedBatchNorm only supports
        #  NHWC tensor format for now. [Op:FusedBatchNormV3]"
        super(EmulateMultiGPUBatchNorm, self).__init__()
        self.bn_layer = tf.keras.layers.BatchNormalization(axis=3, *args, **kwargs)
        self.num_gpus = num_gpus
        self.axis = axis

    def call(self, inputs, training=None):
        # Either NHWC (means axis=3) or NCHW (means axis=1)
        # First, for reshaping, we need NCHW:
        if self.axis == 3:
            inputs = tf.transpose(inputs, [0, 3, 1, 2])
        input_shape = tf.keras.backend.int_shape(inputs)
        tensor_input_shape = tf.shape(inputs)
        reshaped_inputs = self._reshape_into_groups(inputs, input_shape, tensor_input_shape)
        normalized_inputs = self.bn_layer(reshaped_inputs, training=training)
        outputs = tf.reshape(normalized_inputs, tensor_input_shape)
        if self.axis == 3:
            outputs = tf.transpose(outputs, [0, 2, 3, 1])
        return outputs

    def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape):
        # N,C,H,W --> N // G, C * G, H, W
        group_shape = [tensor_input_shape[i] for i in range(len(input_shape))]
        group_shape[1] = group_shape[1] * self.num_gpus
        group_shape[0] = group_shape[0] // self.num_gpus
        group_shape = tf.stack(group_shape)
        reshaped_inputs = tf.reshape(inputs, group_shape)
        # Back to NHWC
        reshaped_inputs = tf.transpose(reshaped_inputs, [0, 2, 3, 1])
        return reshaped_inputs
    
    def get_moving_mean_and_var_for_regular_bn(self):
        return [tf.reduce_mean(tf.reshape(v, (2, -1)), 0)
                for v in (self.bn_layer.moving_mean, self.bn_layer.moving_variance)]
            
        

x = tf.random.normal([1, 1, 1, 3])
x = tf.tile(x, [8, 1, 1, 1])
x = tf.concat([x[-4:], x[-4:] + tf.random.normal([4, 1, 1, 3]) * 0.05], 0)
bn = EmulateMultiGPUBatchNorm(2, axis=3, momentum=0.)

In [72]:
tf.reduce_mean(tf.transpose(x, [0, 3, 1, 2]), [0, 2, 3])

<tf.Tensor: id=2932, shape=(3,), dtype=float32, numpy=array([-0.57972205,  0.6254021 ,  1.7347976 ], dtype=float32)>

In [73]:
bn(x, training=True).shape


TensorShape([8, 1, 1, 3])

In [74]:
bn.get_moving_mean_and_var_for_regular_bn()

[<tf.Tensor: id=3027, shape=(3,), dtype=float32, numpy=array([-0.57972205,  0.6254021 ,  1.7347977 ], dtype=float32)>,
 <tf.Tensor: id=3032, shape=(3,), dtype=float32, numpy=array([2.3970008e-04, 7.6025724e-05, 9.2053413e-04], dtype=float32)>]

In [76]:
bn.bn_layer.moving_mean.numpy().reshape(2, -1)

array([[-0.5907143,  0.6301867,  1.7384589],
       [-0.5687299,  0.6206175,  1.7311364]], dtype=float32)

In [None]:
# tf.reduce_mean(tf.reshape(x, [2, 1, 1, 8]), [0, 1, 2])

In [2]:
import tensorflow_addons as tfa

In [4]:
gn_layer = tfa.layers.normalizations.GroupNormalization(groups=2, center=False, scale=False)

In [5]:
gn_layer(x, training=True)

<tf.Tensor: id=41, shape=(32, 16, 16, 64), dtype=float32, numpy=
array([[[[-1.10276651e+00, -7.46356428e-01,  5.14362693e-01, ...,
           6.40418410e-01, -3.66490096e-01, -6.74697518e-01],
         [-4.23000902e-01,  1.82651150e+00,  9.97774661e-01, ...,
          -9.20963585e-01,  3.52181047e-01,  5.80095947e-01],
         [ 5.02809286e-02, -1.18407562e-01, -2.30862290e-01, ...,
          -3.14711705e-02, -3.94361198e-01,  7.43216336e-01],
         ...,
         [ 1.04548037e+00,  8.77989113e-01, -1.35087645e+00, ...,
          -8.11927259e-01,  1.24564394e-01, -1.14744806e+00],
         [ 2.84271657e-01,  3.93960387e-01, -7.70990491e-01, ...,
           6.69490039e-01,  2.22289309e-01,  5.22983313e-01],
         [-3.10084134e-01, -7.16256499e-01, -5.00112474e-01, ...,
           1.17364812e+00,  7.29153097e-01,  9.58405316e-01]],

        [[-6.44800901e-01,  3.40748668e-01, -2.71560997e-01, ...,
           1.00251007e+00,  2.84485966e-01, -5.44930160e-01],
         [ 1.85434341e+

In [26]:
gn_layer.non_trainable_variables

[]

In [30]:
gn_layer.center

False

In [77]:
import torch
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm


def group_norm(input, group, running_mean, running_var, weight=None, bias=None,
                  use_input_stats=True, momentum=0.1, eps=1e-5):
    r"""Applies Group Normalization for channels in the same group in each data sample in a
    batch.
    See :class:`~torch.nn.GroupNorm1d`, :class:`~torch.nn.GroupNorm2d`,
    :class:`~torch.nn.GroupNorm3d` for details.
    """
    if not use_input_stats and (running_mean is None or running_var is None):
        raise ValueError('Expected running_mean and running_var to be not None when use_input_stats=False')

    b, c = input.size(0), input.size(1)
    if weight is not None:
        weight = weight.repeat(b)
    if bias is not None:
        bias = bias.repeat(b)

    def _instance_norm(input, group, running_mean=None, running_var=None, weight=None,
                       bias=None, use_input_stats=None, momentum=None, eps=None):
        # Repeat stored stats and affine transform params if necessary
        if running_mean is not None:
            running_mean_orig = running_mean
            running_mean = running_mean_orig.repeat(b)
        if running_var is not None:
            running_var_orig = running_var
            running_var = running_var_orig.repeat(b)

        #norm_shape = [1, b * c / group, group]
        #print(norm_shape)
        # Apply instance norm
        input_reshaped = input.contiguous().view(1, int(b * c/group), group, *input.size()[2:])

        out = F.batch_norm(
            input_reshaped, running_mean, running_var, weight=weight, bias=bias,
            training=use_input_stats, momentum=momentum, eps=eps)

        # Reshape back
        if running_mean is not None:
            running_mean_orig.copy_(running_mean.view(b, int(c/group)).mean(0, keepdim=False))
        if running_var is not None:
            running_var_orig.copy_(running_var.view(b, int(c/group)).mean(0, keepdim=False))

        return out.view(b, c, *input.size()[2:])
    return _instance_norm(input, group, running_mean=running_mean,
                          running_var=running_var, weight=weight, bias=bias,
                          use_input_stats=use_input_stats, momentum=momentum,
                          eps=eps)


class _GroupNorm(_BatchNorm):
    def __init__(self, num_features, num_groups=1, eps=1e-5, momentum=0.1,
                 affine=False, track_running_stats=False):
        self.num_groups = num_groups
        self.track_running_stats = track_running_stats
        super(_GroupNorm, self).__init__(int(num_features/num_groups), eps,
                                         momentum, affine, track_running_stats)

    def _check_input_dim(self, input):
        return NotImplemented

    def forward(self, input):
        self._check_input_dim(input)

        return group_norm(
            input, self.num_groups, self.running_mean, self.running_var, self.weight, self.bias,
            self.training or not self.track_running_stats, self.momentum, self.eps)


class GroupNorm2d(_GroupNorm):
    r"""Applies Group Normalization over a 4D input (a mini-batch of 2D inputs
    with additional channel dimension) as described in the paper
    https://arxiv.org/pdf/1803.08494.pdf
    `Group Normalization`_ .
    Args:
        num_features: :math:`C` from an expected input of size
            :math:`(N, C, H, W)`
        num_groups:
        eps: a value added to the denominator for numerical stability. Default: 1e-5
        momentum: the value used for the running_mean and running_var computation. Default: 0.1
        affine: a boolean value that when set to ``True``, this module has
            learnable affine parameters. Default: ``True``
        track_running_stats: a boolean value that when set to ``True``, this
            module tracks the running mean and variance, and when set to ``False``,
            this module does not track such statistics and always uses batch
            statistics in both training and eval modes. Default: ``False``
    Shape:
        - Input: :math:`(N, C, H, W)`
        - Output: :math:`(N, C, H, W)` (same shape as input)
    Examples:
        >>> # Without Learnable Parameters
        >>> m = GroupNorm2d(100, 4)
        >>> # With Learnable Parameters
        >>> m = GroupNorm2d(100, 4, affine=True)
        >>> input = torch.randn(20, 100, 35, 45)
        >>> output = m(input)
    """

    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(input.dim()))

In [104]:
import torch

x = torch.randn(4, 8, 2, 2)

In [117]:
gn_layer = GroupNorm2d(8, 1, track_running_stats=True, momentum=1.0)

In [118]:
r = gn_layer.train()(x)

In [119]:
r.view(-1)

tensor([ 0.0942, -1.5481,  0.2085,  1.2454,  0.1638,  0.1814,  1.2203, -1.5655,
        -0.1845, -1.0018, -0.4696,  1.6558, -0.5385, -0.6480, -0.5439,  1.7304,
        -0.2527,  1.6447, -0.3340, -1.0580, -0.5533,  1.7060, -0.8203, -0.3324,
         1.4989,  0.0683, -1.2937, -0.2735, -0.3271,  1.2400, -1.4410,  0.5282,
        -0.9680, -0.3372,  1.6767, -0.3715,  1.1037,  0.8166, -1.3284, -0.5919,
         1.6538, -1.0004, -0.1675, -0.4858, -0.6700, -0.2811, -0.7530,  1.7041,
         0.4193,  1.4106, -0.6320, -1.1979,  1.2421,  0.0547, -1.5466,  0.2498,
         0.0440,  0.1899, -1.5196,  1.2856, -0.4855,  0.9338, -1.4056,  0.9573,
         0.4523,  1.3596, -0.5152, -1.2967,  0.6326, -1.3541, -0.5070,  1.2285,
        -0.8039,  0.2893,  1.5098, -0.9952, -1.3170,  1.3681,  0.4173, -0.4684,
        -0.7910,  1.1991,  0.7586, -1.1667, -1.4167,  1.1831, -0.4152,  0.6488,
         1.4916, -1.3230, -0.0137, -0.1549, -1.2567,  0.8562,  1.0967, -0.6962,
         1.1827,  0.7798, -1.1651, -0.79

In [120]:
x.mean([0,2,3])

tensor([ 0.1027,  0.0362, -0.1015,  0.2301, -0.0703, -0.0454, -0.2904,  0.3127])

In [121]:
gn_layer.running_mean

tensor([ 0.1027,  0.0362, -0.1015,  0.2301, -0.0703, -0.0454, -0.2904,  0.3127])

In [15]:
gn_layer.state_dict()

OrderedDict([('weight',
              tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])),
             ('bias',
              tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))])