# Batch Whitening Layer
The purpose of this notebook is to implement the batch whitening layer.   
The implementation is inspired by the implementation of BatchNorm layer from [this reference](https://d2l.ai/chapter_convolutional-modern/batch-norm.html)


In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print("Device to use:", device)


Device to use: cuda:0


## Batch Normalization  
Lets start with implementing BatchNorm from scratch and test it on a simple dataset.


### Option 1

In [None]:
def batch_norm(X, gamma, beta, running_mean, running_var, eps, momentum):
    # Use is_grad_enabled to determine whether we are in training mode
    assert len(X.shape) in (2, 4)
    if len(X.shape) == 2:
        shape = (1, X.shape[1])
    else:
        shape = (1, X.shape[1], 1, 1)

    if not torch.is_grad_enabled():
        # In prediction mode, use mean and variance obtained by moving average
        X_hat = (X - running_mean) / torch.sqrt(running_var + eps)
    else:
        if len(X.shape) == 2:
            # When using a fully connected layer, calculate the mean and
            # variance on the feature dimension
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # When using a two-dimensional convolutional layer, calculate the
            # mean and variance on the channel dimension (axis=1). Here we
            # need to maintain the shape of X, so that the broadcasting
            # operation can be carried out later
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # In training mode, the current mean and variance are used
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # Update the mean and variance using moving average
        running_mean = (1.0 - momentum) * running_mean + momentum * mean
        running_var = (1.0 - momentum) * running_var + momentum * var
    Y = gamma.view(shape) * X_hat + beta.view(shape)  # Scale and shift
    return Y, running_mean.data, running_var.data


class BatchNorm(nn.Module):
    # num_features: the number of outputs for a fully connected layer or the
    # number of output channels for a convolutional layer. num_dims: 2 for a
    # fully connected layer and 4 for a convolutional layer
    def __init__(self, num_features, num_dims):
        super().__init__()
        # if num_dims == 2:
        #     shape = (1, num_features)
        # else:
        #     shape = (1, num_features, 1, 1)
        shape = num_features
        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))

        # The variables that are not model parameters are initialized to 0 and 1
        # self.running_mean = torch.zeros(shape)
        # self.running_var = torch.ones(shape)
        self.register_buffer('running_mean', torch.zeros(shape))
        self.register_buffer('running_var', torch.ones(shape))

    def forward(self, X):
        # If X is not on the main memory, copy moving_mean and moving_var to
        # the device where X is located
        if self.running_mean.device != X.device:
            self.running_mean = self.running_mean.to(X.device)
            self.running_var = self.running_var.to(X.device)
        # Save the updated running_mean and moving_var
        Y, self.running_mean, self.running_var = batch_norm(
            X, self.gamma, self.beta, self.running_mean,
            self.running_var, eps=1e-5, momentum=0.1)
        return Y


### Option 2

In [None]:
class BatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(BatchNorm2d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        # Learnable parameters
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        # Running mean and variance
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        if self.training:
            # Calculate mean and variance for the batch
            mean = x.mean([0, 2, 3], keepdim=True)
            var = x.var([0, 2, 3], keepdim=True, unbiased=False)
            # Update running mean and variance
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var
        
        # Normalize
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)
        # Scale and shift
        out = self.gamma.view(1, self.num_features, 1, 1) * x_normalized + self.beta.view(1, self.num_features, 1, 1)
        
        return out

### Validation


In [None]:
# Create a batch of 2D images (batch size, channels, height, width)
x = torch.randn(20, 10, 50, 50)

# Our custom batch normalization layer
custom_bn = BatchNorm(num_features=10, num_dims=4)
# custom_bn = BatchNorm2d(num_features=10)

# PyTorch's built-in batch normalization layer
torch_bn = nn.BatchNorm2d(num_features=10)


# Copy the parameters from our custom layer to the built-in layer for a fair comparison
torch_bn.weight.data = custom_bn.gamma.data.clone()
torch_bn.bias.data = custom_bn.beta.data.clone()
torch_bn.running_mean = custom_bn.running_mean.clone()
torch_bn.running_var = custom_bn.running_var.clone()

# Forward pass
custom_bn_output = custom_bn(x)
torch_bn_output = torch_bn(x)

# Check if the outputs are close
assert torch.allclose(custom_bn_output, torch_bn_output, atol=1e-5), "The outputs are not close enough!"

print("Functional validation passed!")

In [None]:
x.shape

In [None]:
x.var([0, 2, 3], keepdim=True, unbiased=False).shape

In [None]:
custom_bn.training

# Batch Whitening

### Choleski

In [2]:
def batch_orthonorm_obsolete(X, gamma, beta, running_mean=None, running_cov=None, eps=1e-5, momentum=0.1):
    # Use is_grad_enabled to determine whether we are in training mode
    assert len(X.shape) in (2, 4)
    n_features = X.shape[1]

    if len(X.shape) == 2:
        # When using a fully connected layer, calculate the mean and
        # variance on the feature dimension
        shape = (1, n_features)
        mean = X.mean(dim=0)
        cov = torch.cov(X.T,correction=0)        
        # var = ((X - mean) ** 2).mean(dim=0)
    else:
        # When using a two-dimensional convolutional layer, calculate the
        # mean and covariance on the channel dimension (axis=1). Here we
        # need to maintain the shape of X, so that the broadcasting
        # operation can be carried out later
        shape = (1, n_features, 1, 1)
        mean = X.mean(dim=(0, 2, 3))
        Xtmp = X.view(X.shape[0],X.shape[1],-1)
        Xtmp = Xtmp.permute(1,0,2).reshape(X.shape[1],-1)
        cov = torch.cov(Xtmp,correction=0) 
        # var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
    # In training mode, the current mean and variance are used
    # Update the mean and variance using moving average
    running_mean = (1.0 - momentum) * running_mean + momentum * mean
    running_cov = (1.0 - momentum) * running_cov + momentum * cov
    L = torch.linalg.cholesky(running_cov + eps*torch.eye(n_features))
    if len(X.shape) == 2:
        X_hat = (X-running_mean.view(1,n_features)).T
        Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T
    else:
        X_hat = X-running_mean.view(1,n_features,1,1)
        X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
        Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)
    # Y = gamma.view(shape) * Y + beta.view(shape)  # Scale and shift
    return Y, running_mean.data, running_cov.data

def batch_orthonorm(X, gamma, beta, running_mean=None, running_cov=None, eps=1e-5, momentum=0.1):
    # Use is_grad_enabled to determine whether we are in training mode
    assert len(X.shape) in (2, 4)
    n_features = X.shape[1]

    if len(X.shape) == 2:
        # When using a fully connected layer, calculate the mean and
        # variance on the feature dimension
        shape = (1, n_features)
        mean = X.mean(dim=0)
        cov = torch.cov(X.T,correction=0)        
        # var = ((X - mean) ** 2).mean(dim=0)
    else:
        # When using a two-dimensional convolutional layer, calculate the
        # mean and covariance on the channel dimension (axis=1). Here we
        # need to maintain the shape of X, so that the broadcasting
        # operation can be carried out later
        shape = (1, n_features, 1, 1)
        mean = X.mean(dim=(0, 2, 3))
        Xtmp = X.view(X.shape[0],X.shape[1],-1)
        Xtmp = Xtmp.permute(1,0,2).reshape(X.shape[1],-1)
        cov = torch.cov(Xtmp,correction=0) 
        # var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
    # In training mode, the current mean and variance are used
    # Update the mean and variance using moving average
    if torch.is_grad_enabled():
        running_mean = (1.0 - momentum) * running_mean + momentum * mean
        running_cov = (1.0 - momentum) * running_cov + momentum * cov
        L = torch.linalg.cholesky(cov + eps*torch.eye(n_features))
        if len(X.shape) == 2:
            X_hat = (X-mean.view(1,n_features)).T
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T
        else:
            X_hat = X-mean.view(1,n_features,1,1)
            X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)
    else:
        L = torch.linalg.cholesky(running_cov + eps*torch.eye(n_features))
        if len(X.shape) == 2:
            X_hat = (X-running_mean.view(1,n_features)).T
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T
        else:
            X_hat = X-running_mean.view(1,n_features,1,1)
            X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)
    # Y = gamma.view(shape) * Y + beta.view(shape)  # Scale and shift
    return Y, running_mean.data, running_cov.data

class BatchWhitening(nn.Module):
    # num_features: the number of outputs for a fully connected layer or the
    # number of output channels for a convolutional layer. num_dims: 2 for a
    # fully connected layer and 4 for a convolutional layer
    def __init__(self, num_features,momentum=0.1):
        super().__init__()
        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        self.momentum = momentum
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        # The variables that are not model parameters are initialized to 0 and 1
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_cov', torch.eye(num_features))

    def forward(self, X):
        # If X is not on the main memory, copy moving_mean and moving_var to
        # the device where X is located
        if self.running_mean.device != X.device:
            self.running_mean = self.running_mean.to(X.device)
            self.running_cov = self.running_cov.to(X.device)
        # Save the updated running_mean and moving_var
        # Y, self.running_mean, self.running_var = batch_orthonorm(
        #     X, self.gamma, self.beta, self.running_mean,
        #     self.running_cov, eps=1e-5, momentum=0.1)
        Y, self.running_mean, self.running_cov = batch_orthonorm(
            X, self.gamma, self.beta, self.running_mean,
            self.running_cov, eps=1e-5, momentum=self.momentum)

        return Y


### IterNorm

In [3]:
from torch.nn import Parameter

class IterNorm(nn.Module):
    def __init__(self, num_features, num_groups=1, num_channels=-1, T=5, dim=4, eps=1e-5, momentum=0.1, affine=True, *args, **kwargs):
        super(IterNorm, self).__init__()
        # assert dim == 4, 'IterNorm is not support 2D'
        self.T = T
        self.eps = eps
        self.momentum = momentum
        self.num_features = num_features
        self.affine = affine
        self.dim = dim

        if num_channels == -1:
            num_channels = (num_features - 1) // num_groups + 1
        num_groups = num_features // num_channels
        while num_features % num_channels != 0:
            num_channels //= 2
            num_groups = num_features // num_channels
        assert num_groups > 0 and num_features % num_groups == 0, "num features={}, num groups={}".format(num_features,
            num_groups)
        self.num_groups = num_groups
        self.num_channels = num_channels
        shape = [1] * dim
        shape[1] = self.num_features
        if self.affine:
            self.weight = Parameter(torch.Tensor(*shape))
            self.bias = Parameter(torch.Tensor(*shape))

        if not self.affine:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.register_buffer('running_mean', torch.zeros(num_groups, num_channels, 1))
        # running whiten matrix
        self.register_buffer('running_wm', torch.eye(num_channels).expand(num_groups, num_channels, num_channels))
        self.reset_parameters()

    def reset_parameters(self):
        # self.reset_running_stats()
        if self.affine:
            self.weight.data.fill_(1.0)
            self.bias.data.fill_(0.0)

    def forward(self, X: torch.Tensor):
        eps = 1e-5
        momentum = self.momentum

        nc = self.num_channels
        T = self.T
        g = X.size(1) // nc
        x = X.transpose(0, 1).contiguous().view(g, nc, -1)
        _, d, m = x.size()
        saved = []
        if self.training:
            # calculate centered activation by subtracted mini-batch mean
            mean = x.mean(-1, keepdim=True)
            xc = x - mean
            # calculate covariance matrix
            P = [None] * (T + 1)
            P[0] = torch.eye(d).to(X).expand(g, d, d)
            # Sigma = torch.baddbmm(eps, P[0], 1. / m, xc, xc.transpose(1, 2))
            Sigma = torch.baddbmm(P[0], xc, xc.transpose(1, 2), beta=eps, alpha=1. / m)
            # reciprocal of trace of Sigma: shape [g, 1, 1]
            rTr = (Sigma * P[0]).sum(1, keepdim=True).sum(2, keepdim=True).reciprocal_()
            Sigma_N = Sigma * rTr
            for k in range(T):
                # P[k + 1] = torch.baddbmm(1.5, P[k], -0.5, P[k].bmm(P[k]).bmm(P[k]), Sigma_N)
                P[k + 1] = torch.baddbmm(P[k], P[k].bmm(P[k]).bmm(P[k]), Sigma_N, beta=1.5, alpha=-0.5)
            wm = P[T].mul_(rTr.sqrt())  # whiten matrix: the matrix inverse of Sigma, i.e., Sigma^{-1/2}
            # self.running_mean += momentum * ( mean.detach() - self.running_mean)
            # self.running_wm += momentum * ( wm.detach() - self.running_wm)
            self.running_mean = (1-momentum)*self.running_mean + momentum * mean.detach()
            self.running_wm = (1-momentum)*self.running_wm + momentum * wm.detach() 
        else:
            xc = x - self.running_mean
            wm = self.running_wm
        xn = wm.matmul(xc)
        X_hat = xn.view(X.size(1), X.size(0), *X.size()[2:]).transpose(0, 1).contiguous()
        
        # affine
        if self.affine:
            X_hat = X_hat * self.weight
            X_hat = X_hat + self.bias
        return X_hat

    def extra_repr(self):
        return '{num_features}, num_channels={num_channels}, T={T}, eps={eps}, dim={dim}, ' \
               'momentum={momentum}, affine={affine}'.format(**self.__dict__)


## Validation

from [wikipedia](https://en.wikipedia.org/wiki/Covariance_matrix):  
The covariance matrix is given by:  
![cov](cov.png)

and the correlation matrix is given by:  
![corr](corr_mat.png)

In [4]:
def cov_to_corr(cov_matrix):
    # Compute the standard deviations
    std = torch.sqrt(torch.diag(cov_matrix))
    
    # Compute the correlation matrix
    corr_matrix = cov_matrix / torch.outer(std, std)
    
    # Extract upper triangular part (excluding diagonal)
    upper_tri = torch.triu(corr_matrix, diagonal=1)
    
    # Compute average of cross-correlation coefficients
    avg_corr = upper_tri.sum() / (upper_tri.numel() - upper_tri.diag().numel())
    
    return corr_matrix, avg_corr


In [30]:
def rank_and_avg_corr(x):
    # flatten x from [B,C,H,W] to [C,B*H*W]
    x_f= x.permute(1,0,2,3).reshape(x.shape[1],-1)
    # compute corr matrix
    corr_matrix = torch.corrcoef(x_f)
    # Extract upper triangular part (excluding diagonal)
    upper_tri = torch.triu(corr_matrix, diagonal=1)
    # Compute average of cross-correlation coefficients
    avg_corr = upper_tri.sum() / (upper_tri.numel() - upper_tri.diag().numel())
    rank = torch.linalg.matrix_rank(x_f)/x.shape[1]
    return rank,avg_corr


#### 4D tensor

In [5]:
# Create a batch of 2D images (batch size, channels, height, width)
num_features = 3
m = torch.randint(10,100,(1,num_features)).float()
c = torch.randint(1,10,(1,num_features))
cov = c.T@c + 0.1*torch.eye(num_features)

print(f'generating a tensor of shape (B, {num_features},32,32) with mean {m} and covariance \n {cov}:')
x = torch.randn(20, num_features, 32, 32).permute(1,0,2,3).reshape(num_features,-1).T
L = torch.linalg.cholesky(cov.float())
xc= x@L.T + m
print(f'actual mean and cov: {xc.mean(0)},\n {torch.cov(xc.T,correction=0)}')
x_c = xc.permute(1,0).reshape(num_features,20,32,32).permute(1,0,2,3)



generating a tensor of shape (B, 3,32,32) with mean tensor([[56., 57., 25.]]) and covariance 
 tensor([[ 1.1000,  5.0000,  6.0000],
        [ 5.0000, 25.1000, 30.0000],
        [ 6.0000, 30.0000, 36.1000]]):
actual mean and cov: tensor([55.9966, 56.9789, 24.9788]),
 tensor([[ 1.0954,  4.9635,  5.9524],
        [ 4.9635, 24.8639, 29.7060],
        [ 5.9524, 29.7060, 35.7314]])


In [6]:
# compute the rank
torch.linalg.matrix_rank(xc.T)/num_features

tensor(1.)

In [22]:
xc.shape

torch.Size([20480, 3])

In [28]:
# Our custom batch normalization layer
bw_layer = BatchWhitening(num_features,momentum=1)

# Forward pass
x_w = bw_layer(x_c)

x_w_cov = x_w.permute(1,0,2,3).reshape(x.shape[1],-1).cov()
print(x_w_cov)
# Check if the outputs are indeed orthonormal
assert torch.allclose(x_w_cov, torch.eye(num_features), atol=1e-2), "The outputs are not close enough!"

print("Functional validation passed!")

tensor([[ 1.0000e+00,  2.0510e-05,  4.4876e-05],
        [ 2.0510e-05,  1.0000e+00, -2.5317e-04],
        [ 4.4876e-05, -2.5317e-04,  1.0011e+00]], grad_fn=<SqueezeBackward0>)
Functional validation passed!


In [8]:
x_c_cov = x_c.permute(1,0,2,3).reshape(x.shape[1],-1).cov()
# we can also compute x_c_cov as follows
# torch.cov(x_c.permute(1,0,2,3).reshape(x.shape[1],-1))
print(x_c_cov)
cov_to_corr(x_c_cov)

tensor([[ 1.0954,  4.9638,  5.9527],
        [ 4.9638, 24.8651, 29.7074],
        [ 5.9527, 29.7074, 35.7332]])


(tensor([[1.0000, 0.9511, 0.9515],
         [0.9511, 1.0000, 0.9966],
         [0.9515, 0.9966, 1.0000]]),
 tensor(0.4832))

In [25]:
# compute the rank
torch.linalg.matrix_rank(x_c.permute(1,0,2,3).reshape(x.shape[1],-1))/x_c.shape[1]

tensor(1.)

In [21]:
# direct computation of corr matrix 
torch.corrcoef(x_c.permute(1,0,2,3).reshape(x_c.shape[1],-1))

tensor([[1.0000, 0.9511, 0.9515],
        [0.9511, 1.0000, 0.9966],
        [0.9515, 0.9966, 1.0000]], grad_fn=<ClampBackward1>)

In [27]:
# should be close to eye matrix
cov_to_corr(x_w_cov)

(tensor([[1.0000, 0.9511, 0.9515],
         [0.9511, 1.0000, 0.9966],
         [0.9515, 0.9966, 1.0000]], grad_fn=<DivBackward0>),
 tensor(0.4832, grad_fn=<DivBackward0>))

In [29]:
torch.corrcoef(x_w.permute(1,0,2,3).reshape(x_w.shape[1],-1))

tensor([[ 1.0000e+00,  2.0509e-05,  4.4851e-05],
        [ 2.0509e-05,  1.0000e+00, -2.5303e-04],
        [ 4.4851e-05, -2.5303e-04,  1.0000e+00]], grad_fn=<ClampBackward1>)

In [15]:
# TODO: understand how the algo works so that it whiten the matrix
ItN = IterNorm(3, num_groups=3, T=10, momentum=1, affine=False)
print(ItN)
ItN.train()
x_c.requires_grad_()
x_w = ItN(x_c)

x_w_cov = x_w.permute(1,0,2,3).reshape(x_w.shape[1],-1).cov()
print(x_w_cov)


IterNorm(3, num_channels=1, T=10, eps=1e-05, dim=4, momentum=1, affine=False)
tensor([[1.0000, 0.9511, 0.9515],
        [0.9511, 1.0000, 0.9967],
        [0.9515, 0.9967, 1.0000]], grad_fn=<SqueezeBackward0>)


In [14]:
cov_to_corr(x_w_cov)

(tensor([[1.0000, 0.9511, 0.9515],
         [0.9511, 1.0000, 0.9966],
         [0.9515, 0.9966, 1.0000]], grad_fn=<DivBackward0>),
 tensor(0.4832, grad_fn=<DivBackward0>))

In [32]:
rank_and_avg_corr(x_c)

(tensor(1.), tensor(0.4832, grad_fn=<DivBackward0>))

In [None]:
x_c.shape

#### 2D tensor

In [None]:
# Create a batch of vectors (B, num_features)
num_features = 10
m = torch.randint(10,100,(1,num_features)).float()
c = torch.randint(1,10,(1,num_features))
cov = c.T@c + 0.1*torch.eye(num_features)

print(f'generating a tensor of shape (B, {num_features}) with mean {m} and covariance \n {cov}:')
x = torch.randn(20000, num_features)
L = torch.linalg.cholesky(cov.float())
xc= x@L.T + m 

print(f'actual mean and cov: {xc.mean(0)},\n {torch.cov(xc.T,correction=0)}')


In [None]:
# Our custom batch normalization layer
bw_layer = BatchWhitening(num_features,momentum=1)

# Forward pass
x_w = bw_layer(xc)

print(torch.cov(x_w.T,correction=0))

# Check if the outputs are indeed orthonormal
assert torch.allclose(x_w.T.cov(), torch.eye(num_features), atol=1e-2), "The outputs are not close enough!"

# print("Functional validation passed!")

In [None]:
custom_bn_output.permute(1,0,2,3).reshape(x.shape[1],-1).cov()

In [None]:
c=torch.randint(1,10,(3,1))
c.shape

In [None]:
c

In [None]:
cov = c@c.T
cov

In [None]:
X = torch.randn(20, 10, 50, 50)

mean = X.mean(dim=(0, 2, 3))
Xtmp = X.view(X.shape[0],X.shape[1],-1)
Xtmp = Xtmp.permute(1,0,2).reshape(X.shape[1],-1)
cov = torch.cov(Xtmp,correction=0) 

assert torch.allclose(cov, torch.eye(10), atol=1e-5), "The outputs are not close enough!"
print("Functional validation passed!")

### Debug

In [None]:
# 4D input tensor
n_features = 10
X = torch.randn(20, n_features, 50, 50)

shape = n_features
gamma = torch.ones(shape)
beta = torch.zeros(shape)

running_mean = torch.zeros(shape)
running_cov = torch.eye(shape)

mean = X.mean(dim=(0, 2, 3))
Xtmp = X.view(X.shape[0],X.shape[1],-1)
Xtmp = Xtmp.permute(1,0,2).reshape(X.shape[1],-1)
cov = torch.cov(Xtmp,correction=0) 


momentum = 0.1 
eps = 1e-5
running_mean = (1.0 - momentum) * running_mean + momentum * mean
running_cov = (1.0 - momentum) * running_cov + momentum * cov

L = torch.linalg.cholesky(running_cov + eps*torch.eye(n_features))
X_hat = X-running_mean.view(1,n_features,1,1)
X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)


# assert torch.allclose(cov, torch.eye(10), atol=1e-5), "The outputs are not close enough!"
# print("Functional validation passed!")

In [None]:
# 2D input tensor
n_features = 10
X = torch.randn(20, n_features)

shape = n_features
gamma = torch.ones(shape)
beta = torch.zeros(shape)

running_mean = torch.zeros(shape)
running_cov = torch.eye(shape)

mean = X.mean(dim=0)
cov = torch.cov(X.T,correction=0)

momentum = 0.1 
eps = 1e-5
running_mean = (1.0 - momentum) * running_mean + momentum * mean
running_cov = (1.0 - momentum) * running_cov + momentum * cov

L = torch.linalg.cholesky(running_cov + eps*torch.eye(n_features))
X_hat = X-running_mean.view(1,n_features)
X_hat = X_hat.T
Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T


In [None]:
def batch_orthonorm(X, gamma, beta, running_mean, running_cov, eps, momentum):
    # Use is_grad_enabled to determine whether we are in training mode
    assert len(X.shape) in (2, 4)
    n_features = X.shape[1]

    if len(X.shape) == 2:
        # When using a fully connected layer, calculate the mean and
        # variance on the feature dimension
        shape = (1, n_features)
        mean = X.mean(dim=0)
        cov = torch.cov(X.T,correction=0)        
        # var = ((X - mean) ** 2).mean(dim=0)
    else:
        # When using a two-dimensional convolutional layer, calculate the
        # mean and covariance on the channel dimension (axis=1). Here we
        # need to maintain the shape of X, so that the broadcasting
        # operation can be carried out later
        shape = (1, n_features, 1, 1)
        mean = X.mean(dim=(0, 2, 3), keepdim=True)
        Xtmp = X.view(X.shape[0],X.shape[1],-1)
        Xtmp = Xtmp.permute(1,0,2).reshape(X.shape[1],-1)
        cov = torch.cov(Xtmp,correction=0) 
        # var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
    # In training mode, the current mean and variance are used
    # Update the mean and variance using moving average
    running_mean = (1.0 - momentum) * running_mean + momentum * mean
    running_cov = (1.0 - momentum) * running_cov + momentum * cov
    L = torch.linalg.cholesky(running_cov + eps*torch.eye(n_features))
    if len(X.shape) == 2:
        X_hat = (X-running_mean.view(1,n_features)).T
        Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T
    else:
        X_hat = X-running_mean.view(1,n_features,1,1)
        X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
        Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)
    # Y = gamma.view(shape) * Y + beta.view(shape)  # Scale and shift
    return Y, running_mean.data, running_cov.data


In [None]:
import torch

def cov_to_corr(cov_matrix):
    # Compute the standard deviations
    std = torch.sqrt(torch.diag(cov_matrix))
    
    # Compute the correlation matrix
    corr_matrix = cov_matrix / torch.outer(std, std)
    
    # Extract upper triangular part (excluding diagonal)
    upper_tri = torch.triu(corr_matrix, diagonal=1)
    
    # Compute average of cross-correlation coefficients
    avg_corr = upper_tri.sum() / (upper_tri.numel() - upper_tri.diag().numel())
    
    return corr_matrix, avg_corr

# Example usage
D = 5  # Dimension of the covariance matrix
cov_matrix = torch.randn(D, D)
cov_matrix = torch.mm(cov_matrix, cov_matrix.t())  # Ensure positive semi-definite

corr_matrix, avg_cross_corr = cov_to_corr(cov_matrix)

print("Correlation Matrix:")
print(corr_matrix)
print("\nAverage Cross-Correlation Coefficient:", avg_cross_corr.item())

## IterNorm

In [None]:
ItN = IterNorm(64, num_groups=8, T=10, momentum=1, affine=False)
print(ItN)
ItN.train()
x = torch.randn(32, 64, 14, 14)
# x = torch.randn(128, 64)
x.requires_grad_()


In [None]:
x.shape

In [None]:
y = ItN(x)
z = y.transpose(0, 1).contiguous().view(x.size(1), -1)
print(z.matmul(z.t()) / z.size(1))



In [None]:
z.shape

In [None]:
print(torch.cov(z,correction=0))

In [None]:
y.sum().backward()
print('x grad', x.grad.size())



In [None]:
ItN.eval()
y = ItN(x)
z = y.transpose(0, 1).contiguous().view(x.size(1), -1)
print(z.matmul(z.t()) / z.size(1))
