In [2]:
import numpy as np
# implementation of norms using elementwise operations
def group_norm(x,group_size=2):
    eps = 1e-9
    N, H, W, C = x.shape

    G = group_size
    x = np.reshape(x, (N, H, W,G, C // G ))
    mean = np.mean(x,axis = (1, 2, 4), keepdims=True)
    var = np.var(x,axis = (1, 2, 4), keepdims=True)

    x = (x - mean) / np.sqrt(var + eps)
   
    normalized_x = np.reshape(x, (N, C*H*W))

    return normalized_x

def batch_norm(x):
    eps = 1e-9
    N, H, W, C = x.shape
    mean = np.mean(x,axis = (0, 1, 2), keepdims=True)
    var = np.var(x,axis = (0, 1, 2), keepdims=True)
    x = (x - mean) / np.sqrt(var + eps)
    normalized_x = np.reshape(x, (N, C*H*W))

    return normalized_x

def z_norm_my_implementation(x):
    eps = 1e-9
    N, H, W, C = x.shape
    x = np.reshape(x, (N, C*H*W))
    mean = np.mean(x,axis = 0, keepdims=True)
    var = np.var(x,axis = 0, keepdims=True)
    x = (x - mean) / np.sqrt(var + eps)
    normalized_x = np.reshape(x, (N, C*H*W))

    return normalized_x

In [5]:
# implementation of norms using matrix multiplications derived in Duality Diagram framework
def group_norm_matrix_form(x,group_size=2):
    eps = 1e-9
    N, H, W, C = x.shape
    G = group_size
    x = np.reshape(x, (N, H, W,G, C // G ))
    x = np.swapaxes(x,0,4)
    x = np.reshape(x,((C // G)*H*W,-1))
    ones = np.ones(((C // G)*H*W,(C // G)*H*W))/((C // G)*H*W)
    identity = np.identity((C // G)*H*W)
    D = identity -ones    
    centered_x_mat = np.matmul(D,x)

    var = np.var(x,axis = 0)
    Q = np.diag(1 / np.sqrt(var + eps))
    normalized_x =  np.matmul(centered_x_mat,Q)
     
    normalized_x = np.reshape(normalized_x,((C // G),H,W,G,N))
    normalized_x = np.swapaxes(normalized_x,0,4)
    normalized_x = np.reshape(normalized_x, (N, C*H*W))

    return normalized_x

def batch_norm_matrix_form(x):
    eps = 1e-9
    N, H, W, C = x.shape
    x = np.reshape(x, (N*H*W,C))
    
    ones = np.ones((N*H*W,N*H*W))/(N*H*W)
    identity = np.identity((N*H*W))
    D = identity - ones 
    centered_x_mat = np.matmul(D,x)
    var = np.var(x,axis = 0)
    Q = np.diag(1 / np.sqrt(var + eps))
    normalized_x =  np.matmul(centered_x_mat,Q)
    normalized_x = np.reshape(normalized_x, (N, C*H*W))

    return normalized_x

def z_norm_matrix_form(x):
    eps = 1e-9
    N, H, W, C = x.shape
    x = np.reshape(x, (N, C*H*W))
    ones = np.ones((N,N))/(N)
    identity = np.identity((N))
    D = identity - ones 
    centered_x_mat = np.matmul(D,x)
    var = np.var(x,axis = 0)
    Q = np.diag(1 / np.sqrt(var + eps))
    normalized_x =  np.matmul(centered_x_mat,Q)

    return normalized_x

In [8]:
# Testing if the norms computing using elementwise and Duality Diagram framework results in the same transformation

N = 50
H, W = 16
C = 8
x = np.random.rand(N,16,16,8) # random feature vector 
start = time.time()
g_nx = group_norm(x,group_size=2)
end = time.time()
print("Time for group norm element wise", end - start)

start = time.time()
g_nmx = group_norm_matrix_form(x,group_size=2)
end = time.time()
print("Time for group norm matrix form", end - start)

print("Are both equal? ", (g_nx.round(3)==g_nmx.round(3)).all())


start = time.time()
b_nx = batch_norm(x)
end = time.time()
print("Time for batch norm element wise", end - start)
start = time.time()
b_nmx = batch_norm_matrix_form(x)
end = time.time()
print("Time for batch norm matrix form", end - start)

print("Are both equal? ", (b_nx.round(3)==b_nmx.round(3)).all())


start = time.time()
z_nx = z_norm_my_implementation(x)
end = time.time()
print("Time for z norm element wise", end - start)
start = time.time()
z_nmx = z_norm_matrix_form(x)
end = time.time()
print("Time for z norm matrix form", end - start)
print("Are both equal? ",(z_nx.round(3)==z_nmx.round(3)).all())


Time for group norm element wise 0.005983591079711914
Time for group norm matrix form 0.01898193359375
Are both equal?  True
Time for batch norm element wise 0.0019948482513427734
Time for batch norm matrix form 2.2768137454986572
Are both equal?  True
Time for z norm element wise 0.0019617080688476562
Time for z norm matrix form 0.015958070755004883
Are both equal?  True
