In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import time

## Standard GRN Experiment

In [2]:
# 2D Dense Tensor Sample
N,H,W,C=10,6,9,512
x=torch.rand(N,H,W,C)
print(x.shape)

torch.Size([10, 6, 9, 512])


In [3]:
class GRN(nn.Module):
    """ GRN (Global Response Normalization) layer
    """
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))

    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True) # torch.Size([10, 1, 1, 512])
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * Nx) + self.beta + x

In [4]:
# 1) Calculate Gx: the L2 norm of x across the height (dim=1) and width (dim=2) axes for each batch and channel.
# output => [batch_size, 1, 1, dim]
gx=torch.norm(x,p=2,dim=(1,2),keepdim=True)
print(gx.shape)

# 2) Take the mean of Gx across the channel dimension (dim=-1)
# output => [batch_size, 1, 1, 1]
print(gx.mean(dim=-1,keepdim=True).shape)

# 3) Compute Nx by dividing Gx by its mean across channels.
# output => [batch_size, 1, 1, dim]
nx=gx/(gx.mean(dim=-1,keepdim=True)+1e-6) 
print(nx.shape) 

# 4) Apply the normalization to x using gamma and beta.
# output => [batch_size, H, W, C]
grn=GRN(C)
output=grn(x)
print(output.shape)

torch.Size([10, 1, 1, 512])
torch.Size([10, 1, 1, 1])
torch.Size([10, 1, 1, 512])
torch.Size([10, 6, 9, 512])


## Existing Sparse GRN Experiment

In [5]:
import MinkowskiEngine as ME

#2D Sparse Tensor sample

batch_size = 10
H = 6
W = 9
C = 512

coords_list = []
feats_list = []

# coord: (batch, h, w)
for b in range(batch_size):
    for x in range(H):
        for y in range(W):
            coords_list.append([b, x, y])
            feats_list.append(torch.randn(C))

coords = torch.IntTensor(coords_list)  # (N, 3),  N = 10*6*9 = 540
feats = torch.stack(feats_list, dim=0) # (N, 512)

sparse_tensor = ME.SparseTensor(
    features=feats,
    coordinates=coords
)
print(f"coords shape = {sparse_tensor.C.shape}")
print(f"feats shape  = {sparse_tensor.F.shape}")


coords shape = torch.Size([540, 3])
feats shape  = torch.Size([540, 512])




In [6]:
class MinkowskiGRN(nn.Module):
    """ GRN layer for sparse tensors.
    """
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, dim))
        self.beta = nn.Parameter(torch.zeros(1, dim))

    def forward(self, x):
        cm = x.coordinate_manager
        in_key = x.coordinate_map_key

        Gx = torch.norm(x.F, p=2, dim=0, keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return ME.SparseTensor(
                self.gamma * (x.F * Nx) + self.beta + x.F,
                coordinate_map_key=in_key,
                coordinate_manager=cm)

In [7]:
# 0) sparse tensor.F is the feature matrix of shape [N, C]
# N is the number of valid points (not the batch size) so N = 10*6*9 = 540
x=sparse_tensor
print(x.F.shape)

# 1) Calculate Gx: the L2 norm of x across the channel axis (dim=0)
# output => [1, C]
Gx = torch.norm(x.F, p=2, dim=0, keepdim=True)
print(Gx.shape)

# 2) Take the mean of Gx across the last dimension (dim=-1)
# output => [1, 1]
print(Gx.mean(dim=-1, keepdim=True).shape)

# 3) Compute Nx by dividing Gx by its mean across channels.
# output => [1, C]
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
print(Nx.shape)

# 4) Apply the normalization to x using gamma and beta.
# output => [N, C]
mgrn = MinkowskiGRN(C)
output = mgrn(x)
print(output.F.shape)

torch.Size([540, 512])
torch.Size([1, 512])
torch.Size([1, 1])
torch.Size([1, 512])
torch.Size([540, 512])


Existing Minkowski GRN != Standard GRN

In [8]:
dim=512
random_init=torch.rand(1,dim)
random_init_bias=torch.rand(1,dim)
stand_random_init=random_init.unsqueeze(1).unsqueeze(2)
stand_random_init_bias=random_init_bias.unsqueeze(1).unsqueeze(2)

gamma = nn.Parameter(random_init)#(torch.zeros(1, dim))
beta = nn.Parameter(random_init_bias) #torch.zeros(1, dim))
stan_gamma = nn.Parameter(stand_random_init)
stan_beta = nn.Parameter(stand_random_init_bias)


In [9]:
# gamma, beta same value setting
class MinkowskiGRN(nn.Module):
    """ GRN layer for sparse tensors.
    """
    def __init__(self, dim):
        super().__init__()
        self.gamma = gamma
        self.beta = beta

    def forward(self, x):
        cm = x.coordinate_manager
        in_key = x.coordinate_map_key

        Gx = torch.norm(x.F, p=2, dim=0, keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return ME.SparseTensor(
                self.gamma * (x.F * Nx) + self.beta + x.F,
                coordinate_map_key=in_key,
                coordinate_manager=cm)
class GRN(nn.Module):
    """ GRN (Global Response Normalization) layer
    """
    def __init__(self, dim):
        super().__init__()
        self.gamma = stan_gamma
        self.beta = stan_beta

    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True) # torch.Size([10, 1, 1, 512])
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * Nx) + self.beta + x

In [10]:
from MinkowskiOps import (
    to_sparse,
)
N,H,W,C=10,6,9,512
sample=torch.rand(N, C, H, W).to(dtype=torch.float64)


Dense_Sample=sample.permute(0, 2, 3, 1)
print("Dense sample shape: ",Dense_Sample.shape)
Sparse_Sample=to_sparse(sample)
print("Sparse sample.F shape: ",Sparse_Sample.F.shape)

standard_grn=GRN(512)
output_dense_=standard_grn(Dense_Sample)
output_dense=output_dense_.permute(0, 3, 1, 2)
print("Dense GRN output shape: ",output_dense.shape)

sparse_grn=MinkowskiGRN(512)
output_sparse_=sparse_grn(Sparse_Sample)
output_sparse = output_sparse_.dense()[0]
print("Sparse GRN output shape: ",output_sparse.shape)

print("result: ")
print("stand vs sparse :", torch.abs(output_dense - output_sparse).mean().item())


Dense sample shape:  torch.Size([10, 6, 9, 512])
Sparse sample.F shape:  torch.Size([540, 512])
Dense GRN output shape:  torch.Size([10, 512, 6, 9])
Sparse GRN output shape:  torch.Size([10, 512, 6, 9])
result: 
stand vs sparse : 0.011198032016820122


Minkowski GRN update

In [11]:
# class MinkowskiGRN_updated(nn.Module):
#     def __init__(self, dim):
#         super().__init__()
#         self.gamma = nn.Parameter(torch.zeros(1, dim))
#         self.beta = nn.Parameter(torch.zeros(1, dim))

#     def forward(self, x):
#         batch_index = x.C[:, 0] 
#         num_batches = batch_index.max().item() + 1
#         num_channels = x.F.shape[1]

#         sum_of_squares = torch.zeros([num_batches, num_channels], device=x.F.device, dtype=x.F.dtype)
#         sum_of_squares.index_add_(dim=0, index=batch_index, source=x.F.square())
#         Gx_per_batch = sum_of_squares.sqrt()

#         Nx_per_batch = Gx_per_batch / (Gx_per_batch.mean(dim=-1, keepdim=True) + 1e-6)
#         out_feat = x.F + self.gamma * (x.F * Nx_per_batch[batch_index]) + self.beta

#         out = ME.SparseTensor(
#             features=out_feat,
#             coordinate_manager=x.coordinate_manager,
#             coordinate_map_key=x.coordinate_map_key
#         )

#         return out

In [12]:
class MinkowskiGRN_updated(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = gamma
        self.beta = beta

    def forward(self, x):
        batch_index = x.C[:, 0] 
        num_batches = batch_index.max().item() + 1
        num_channels = x.F.shape[1]

        sum_of_squares = torch.zeros([num_batches, num_channels], device=x.F.device, dtype=x.F.dtype)
        sum_of_squares.index_add_(dim=0, index=batch_index, source=x.F.square())
        Gx_per_batch = sum_of_squares.sqrt()

        Nx_per_batch = Gx_per_batch / (Gx_per_batch.mean(dim=-1, keepdim=True) + 1e-6)
        out_feat = x.F + self.gamma * (x.F * Nx_per_batch[batch_index,:]) + self.beta

        out = ME.SparseTensor(
            features=out_feat,
            coordinate_manager=x.coordinate_manager,
            coordinate_map_key=x.coordinate_map_key
        )

        return out

In [13]:
from MinkowskiOps import (
    to_sparse,
)
N,H,W,C=10,6,9,512
sample=torch.rand(N, C, H, W).to(dtype=torch.float64)

Dense_Sample=sample.permute(0, 2, 3, 1)
print("Dense sample shape: ",Dense_Sample.shape)
Sparse_Sample=to_sparse(sample)
print("Sparse sample.F shape: ",Sparse_Sample.F.shape)

standard_grn=GRN(512)
output_dense_=standard_grn(Dense_Sample)
output_dense=output_dense_.permute(0, 3, 1, 2)
print("Dense GRN output shape: ",output_dense.shape)

sparse_grn_update=MinkowskiGRN_updated(512)
output_sparse_update=sparse_grn_update(Sparse_Sample)
output_sparse_update = output_sparse_update.dense()[0]
print("Sparse GRN output shape: ",output_sparse_update.shape)

print("result: ")
print("stand vs sparse", torch.abs(output_dense - output_sparse_update).mean().item())

Dense sample shape:  torch.Size([10, 6, 9, 512])
Sparse sample.F shape:  torch.Size([540, 512])
Dense GRN output shape:  torch.Size([10, 512, 6, 9])
Sparse GRN output shape:  torch.Size([10, 512, 6, 9])
result: 
stand vs sparse 5.8333715984234266e-18
