In [24]:
import numpy as np
def group_norm(x,group_size=2):
    # group size (G)  = channel_dim (C) for instance normalization
    # group size (G)  = 1 for layer normalization
    eps = 1e-9
    N, H, W, C = x.shape
    #print("The shape of features are ",N, H, W, C)
    G = group_size
    x = np.reshape(x, (N, H, W,G, C // G ))
    mean = np.mean(x,axis = (1, 2, 4), keepdims=True)
    var = np.var(x,axis = (1, 2, 4), keepdims=True)
    #print("Mean statistics, mean,min, max")
    #print(mean.mean(),mean.min(),mean.max())
    x = (x - mean) / np.sqrt(var + eps)
    #print("VAR statistics, mean,min, max")
    #print(var.mean(),var.min(),var.max())
   
    normalized_x = np.reshape(x, (N, C*H*W))

    return normalized_x

def batch_norm(x):
    eps = 1e-9
    N, H, W, C = x.shape
    #print("The shape of features are ",N, H, W, C)
    mean = np.mean(x,axis = (0, 1, 2), keepdims=True)
    var = np.var(x,axis = (0, 1, 2), keepdims=True)
    x = (x - mean) / np.sqrt(var + eps)
    #print(mean.shape,var.shape)
    normalized_x = np.reshape(x, (N, C*H*W))

    return normalized_x

def z_norm_my_implementation(x):
    eps = 1e-9
    N, H, W, C = x.shape
    x = np.reshape(x, (N, C*H*W))
    mean = np.mean(x,axis = 0, keepdims=True)
    var = np.var(x,axis = 0, keepdims=True)
    x = (x - mean) / np.sqrt(var + eps)
    normalized_x = np.reshape(x, (N, C*H*W))

    return normalized_x

In [25]:
x=np.zeros((200,16,16,8))
group_norm(x)
group_norm(x,group_size=1)
group_norm(x,group_size=8)
batch_norm(x)

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])

In [19]:
def group_norm_matrix_form(x,group_size=2):
    # group size (G)  = channel_dim (C) for instance normalization
    # group size (G)  = 1 for layer normalization
    eps = 1e-9
    N, H, W, C = x.shape
    #print("The shape of features are ",N, H, W, C)
    G = group_size
    x = np.reshape(x, (N, H, W,G, C // G ))
    x = np.swapaxes(x,0,4)
    x = np.reshape(x,((C // G)*H*W,-1))
    ones = np.ones(((C // G)*H*W,(C // G)*H*W))/((C // G)*H*W)
    identity = np.identity((C // G)*H*W)
    D = identity -ones    
    #mean = np.mean(x,axis = 0, keepdims=True)
    #centered_x = x-mean
    centered_x_mat = np.matmul(D,x)
    #print("Is centering working?")
    #print((centered_x_mat==centered_x).all())
    #print(centered_x_mat[0,5],centered_x[0,5])
    
    
    #var_with_axis = np.var(x,axis = 0, keepdims=True)
    var = np.var(x,axis = 0)
    #standard_x = centered_x/np.sqrt(var_with_axis + eps)
    #print("Mean statistics, mean,min, max")
    #print(mean.mean(),mean.min(),mean.max())
    #print("VAR statistics, mean,min, max")
    #print(var.mean(),var.min(),var.max())
    #print(mean.shape,var.shape)
    Q = np.diag(1 / np.sqrt(var + eps))
    #print(D.shape,x.shape,Q.shape)
    normalized_x =  np.matmul(centered_x_mat,Q)
    
    #print("Is division working?")
    #print((standard_x==normalized_x).all())
    #print(standard_x[0,5],normalized_x[0,5])
    
    
    normalized_x = np.reshape(normalized_x,((C // G),H,W,G,N))
    normalized_x = np.swapaxes(normalized_x,0,4)
    normalized_x = np.reshape(normalized_x, (N, C*H*W))

    return normalized_x

def batch_norm_matrix_form(x):
    eps = 1e-9
    N, H, W, C = x.shape
    x = np.reshape(x, (N*H*W,C))
    
    ones = np.ones((N*H*W,N*H*W))/(N*H*W)
    identity = np.identity((N*H*W))
    D = identity - ones 
    centered_x_mat = np.matmul(D,x)
    var = np.var(x,axis = 0)
    Q = np.diag(1 / np.sqrt(var + eps))
    normalized_x =  np.matmul(centered_x_mat,Q)
    normalized_x = np.reshape(normalized_x, (N, C*H*W))

    return normalized_x

def z_norm_matrix_form(x):
    eps = 1e-9
    N, H, W, C = x.shape
    x = np.reshape(x, (N, C*H*W))
    ones = np.ones((N,N))/(N)
    identity = np.identity((N))
    D = identity - ones 
    print(D.shape, x.shape)
    centered_x_mat = np.matmul(D,x)
    var = np.var(x,axis = 0)
    Q = np.diag(1 / np.sqrt(var + eps))
    normalized_x =  np.matmul(centered_x_mat,Q)

    return normalized_x

In [23]:
import time

x = np.random.rand(50,16,16,8)
start = time.time()
g_nx = group_norm(x,group_size=2)
end = time.time()
print("Time for group norm element wise", end - start)

start = time.time()
g_nmx = group_norm_matrix_form(x,group_size=2)
end = time.time()
print("Time for group norm matrix form", end - start)

(g_nx==g_nmx).all()
print(g_nx[0,:4])
print(g_nmx[0,:4])

start = time.time()
g_nx = batch_norm(x)
end = time.time()
print("Time for batch norm element wise", end - start)
start = time.time()
g_nmx = batch_norm_matrix_form(x)
end = time.time()
print("Time for batch norm matrix form", end - start)

(g_nx==g_nmx).all()

print(g_nx[0,:4])
print(g_nmx[0,:4])

start = time.time()
g_nx = z_norm_my_implementation(x)
end = time.time()
print("Time for z norm element wise", end - start)
start = time.time()
g_nmx = z_norm_matrix_form(x)
end = time.time()
print("Time for z norm matrix form", end - start)
(g_nx==g_nmx).all()
print(g_nx[0,:4])
print(g_nmx[0,:4])

Mean statistics, mean,min, max
(0.5009777440107738, 0.47795526534612404, 0.5273885167859876)
VAR statistics, mean,min, max
(0.08365572860105852, 0.07558882979945152, 0.08997844892046611)
('Time for group norm element wise', 0.008234977722167969)
('Time for group norm matrix form', 0.018539905548095703)
[-1.13148856 -1.26799944 -0.39184533 -0.44017047]
[-1.13148856 -1.26799944 -0.39184533 -0.44017047]
((1, 1, 1, 8), (1, 1, 1, 8))
('Time for batch norm element wise', 0.0029439926147460938)
('Time for batch norm matrix form', 1.7661070823669434)
[-1.15867566 -1.29140539 -0.426883   -0.4448229 ]
[-1.15867566 -1.29140539 -0.426883   -0.4448229 ]
('Time for z norm element wise', 0.0009720325469970703)
((50, 50), (50, 2048))
('Time for z norm matrix form', 0.01367497444152832)
[-1.33002925 -1.30859849 -0.43515395 -0.31703515]
[-1.33002925 -1.30859849 -0.43515395 -0.31703515]


In [36]:
N, H, W, C = x.shape
orig_x = np.reshape(x, (N, C*H*W)) 
    #print("The shape of features are ",N, H, W, C)
G = 2
x = np.reshape(x, (N, H, W,G, C // G ))
x = np.swapaxes(x,0,4)
x = np.reshape(x,((C // G)*H*W,-1))
normalized_x = np.reshape(x,((C // G),H,W,G,N))
normalized_x = np.swapaxes(normalized_x,0,4)
normalized_x = np.reshape(normalized_x, (N, C*H*W))
print("Is reshaping and swapaxes working?")
print((orig_x==normalized_x).all())

Is reshaping and swapaxes working?
True
