# Batch Whitening Layer
The purpose of this notebook is to implement the batch whitening layer.   
The implementation is inspired by the implementation of BatchNorm layer from [this reference](https://d2l.ai/chapter_convolutional-modern/batch-norm.html)


In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 
import numpy as np 

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print("Device to use:", device)


Device to use: cuda:0


## Batch Normalization  
Lets start with implementing BatchNorm from scratch and test it on a simple dataset.


### Option 1

In [None]:
def batch_norm(X, gamma, beta, running_mean, running_var, eps, momentum):
    # Use is_grad_enabled to determine whether we are in training mode
    assert len(X.shape) in (2, 4)
    if len(X.shape) == 2:
        shape = (1, X.shape[1])
    else:
        shape = (1, X.shape[1], 1, 1)

    if not torch.is_grad_enabled():
        # In prediction mode, use mean and variance obtained by moving average
        X_hat = (X - running_mean) / torch.sqrt(running_var + eps)
    else:
        if len(X.shape) == 2:
            # When using a fully connected layer, calculate the mean and
            # variance on the feature dimension
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # When using a two-dimensional convolutional layer, calculate the
            # mean and variance on the channel dimension (axis=1). Here we
            # need to maintain the shape of X, so that the broadcasting
            # operation can be carried out later
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # In training mode, the current mean and variance are used
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # Update the mean and variance using moving average
        running_mean = (1.0 - momentum) * running_mean + momentum * mean
        running_var = (1.0 - momentum) * running_var + momentum * var
    Y = gamma.view(shape) * X_hat + beta.view(shape)  # Scale and shift
    return Y, running_mean.data, running_var.data


class BatchNorm(nn.Module):
    # num_features: the number of outputs for a fully connected layer or the
    # number of output channels for a convolutional layer. num_dims: 2 for a
    # fully connected layer and 4 for a convolutional layer
    def __init__(self, num_features, num_dims):
        super().__init__()
        # if num_dims == 2:
        #     shape = (1, num_features)
        # else:
        #     shape = (1, num_features, 1, 1)
        shape = num_features
        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))

        # The variables that are not model parameters are initialized to 0 and 1
        # self.running_mean = torch.zeros(shape)
        # self.running_var = torch.ones(shape)
        self.register_buffer('running_mean', torch.zeros(shape))
        self.register_buffer('running_var', torch.ones(shape))

    def forward(self, X):
        # If X is not on the main memory, copy moving_mean and moving_var to
        # the device where X is located
        if self.running_mean.device != X.device:
            self.running_mean = self.running_mean.to(X.device)
            self.running_var = self.running_var.to(X.device)
        # Save the updated running_mean and moving_var
        Y, self.running_mean, self.running_var = batch_norm(
            X, self.gamma, self.beta, self.running_mean,
            self.running_var, eps=1e-5, momentum=0.1)
        return Y


### Option 2

In [None]:
class BatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(BatchNorm2d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        # Learnable parameters
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        # Running mean and variance
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        if self.training:
            # Calculate mean and variance for the batch
            mean = x.mean([0, 2, 3], keepdim=True)
            var = x.var([0, 2, 3], keepdim=True, unbiased=False)
            # Update running mean and variance
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var
        
        # Normalize
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)
        # Scale and shift
        out = self.gamma.view(1, self.num_features, 1, 1) * x_normalized + self.beta.view(1, self.num_features, 1, 1)
        
        return out

### Validation


In [None]:
# Create a batch of 2D images (batch size, channels, height, width)
x = torch.randn(20, 10, 50, 50)

# Our custom batch normalization layer
custom_bn = BatchNorm(num_features=10, num_dims=4)
# custom_bn = BatchNorm2d(num_features=10)

# PyTorch's built-in batch normalization layer
torch_bn = nn.BatchNorm2d(num_features=10)


# Copy the parameters from our custom layer to the built-in layer for a fair comparison
torch_bn.weight.data = custom_bn.gamma.data.clone()
torch_bn.bias.data = custom_bn.beta.data.clone()
torch_bn.running_mean = custom_bn.running_mean.clone()
torch_bn.running_var = custom_bn.running_var.clone()

# Forward pass
custom_bn_output = custom_bn(x)
torch_bn_output = torch_bn(x)

# Check if the outputs are close
assert torch.allclose(custom_bn_output, torch_bn_output, atol=1e-5), "The outputs are not close enough!"

print("Functional validation passed!")

In [None]:
x.shape

In [None]:
x.var([0, 2, 3], keepdim=True, unbiased=False).shape

In [None]:
custom_bn.training

# Batch Whitening

## implementation

### Choleski

In [2]:
def batch_orthonorm_obsolete(X, gamma, beta, running_mean=None, running_cov=None, eps=1e-5, momentum=0.1):
    # Use is_grad_enabled to determine whether we are in training mode
    assert len(X.shape) in (2, 4)
    n_features = X.shape[1]

    if len(X.shape) == 2:
        # When using a fully connected layer, calculate the mean and
        # variance on the feature dimension
        shape = (1, n_features)
        mean = X.mean(dim=0)
        cov = torch.cov(X.T,correction=0)        
        # var = ((X - mean) ** 2).mean(dim=0)
    else:
        # When using a two-dimensional convolutional layer, calculate the
        # mean and covariance on the channel dimension (axis=1). Here we
        # need to maintain the shape of X, so that the broadcasting
        # operation can be carried out later
        shape = (1, n_features, 1, 1)
        mean = X.mean(dim=(0, 2, 3))
        Xtmp = X.view(X.shape[0],X.shape[1],-1)
        Xtmp = Xtmp.permute(1,0,2).reshape(X.shape[1],-1)
        cov = torch.cov(Xtmp,correction=0) 
        # var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
    # In training mode, the current mean and variance are used
    # Update the mean and variance using moving average
    running_mean = (1.0 - momentum) * running_mean + momentum * mean
    running_cov = (1.0 - momentum) * running_cov + momentum * cov
    L = torch.linalg.cholesky(running_cov + eps*torch.eye(n_features))
    if len(X.shape) == 2:
        X_hat = (X-running_mean.view(1,n_features)).T
        Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T
    else:
        X_hat = X-running_mean.view(1,n_features,1,1)
        X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
        Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)
    # Y = gamma.view(shape) * Y + beta.view(shape)  # Scale and shift
    return Y, running_mean.data, running_cov.data

def batch_orthonorm(X, gamma, beta, running_mean=None, running_cov=None, eps=1e-5, momentum=0.1):
    # Use is_grad_enabled to determine whether we are in training mode
    assert len(X.shape) in (2, 4)
    n_features = X.shape[1]

    if len(X.shape) == 2:
        # When using a fully connected layer, calculate the mean and
        # variance on the feature dimension
        shape = (1, n_features)
        mean = X.mean(dim=0)
        cov = torch.cov(X.T,correction=0)        
        # var = ((X - mean) ** 2).mean(dim=0)
    else:
        # When using a two-dimensional convolutional layer, calculate the
        # mean and covariance on the channel dimension (axis=1). Here we
        # need to maintain the shape of X, so that the broadcasting
        # operation can be carried out later
        shape = (1, n_features, 1, 1)
        mean = X.mean(dim=(0, 2, 3))
        Xtmp = X.view(X.shape[0],X.shape[1],-1)
        Xtmp = Xtmp.permute(1,0,2).reshape(X.shape[1],-1)
        cov = torch.cov(Xtmp,correction=0) 
        # var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
    # In training mode, the current mean and variance are used
    # Update the mean and variance using moving average
    if torch.is_grad_enabled():
        running_mean = (1.0 - momentum) * running_mean + momentum * mean
        running_cov = (1.0 - momentum) * running_cov + momentum * cov
        L = torch.linalg.cholesky(cov + eps*torch.eye(n_features))
        if len(X.shape) == 2:
            X_hat = (X-mean.view(1,n_features)).T
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T
        else:
            X_hat = X-mean.view(1,n_features,1,1)
            X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)
    else:
        L = torch.linalg.cholesky(running_cov + eps*torch.eye(n_features))
        if len(X.shape) == 2:
            X_hat = (X-running_mean.view(1,n_features)).T
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T
        else:
            X_hat = X-running_mean.view(1,n_features,1,1)
            X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)
    # Y = gamma.view(shape) * Y + beta.view(shape)  # Scale and shift
    return Y, running_mean.data, running_cov.data

class BatchWhitening(nn.Module):
    # num_features: the number of outputs for a fully connected layer or the
    # number of output channels for a convolutional layer. num_dims: 2 for a
    # fully connected layer and 4 for a convolutional layer
    def __init__(self, num_features,momentum=0.1):
        super().__init__()
        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        self.momentum = momentum
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        # The variables that are not model parameters are initialized to 0 and 1
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_cov', torch.eye(num_features))

    def forward(self, X):
        # If X is not on the main memory, copy moving_mean and moving_var to
        # the device where X is located
        if self.running_mean.device != X.device:
            self.running_mean = self.running_mean.to(X.device)
            self.running_cov = self.running_cov.to(X.device)
        # Save the updated running_mean and moving_var
        # Y, self.running_mean, self.running_var = batch_orthonorm(
        #     X, self.gamma, self.beta, self.running_mean,
        #     self.running_cov, eps=1e-5, momentum=0.1)
        Y, self.running_mean, self.running_cov = batch_orthonorm(
            X, self.gamma, self.beta, self.running_mean,
            self.running_cov, eps=1e-5, momentum=self.momentum)

        return Y


### IterNorm

In [130]:
from torch.nn import Parameter

def fix_cov(c):
    a=0.9+0.1*torch.exp(-(c/0.9)**10)
    torch.diagonal(a).fill_(1.0)
    return a*c

class IterNorm(nn.Module):
    def __init__(self, num_features, num_groups=1, num_channels=-1, T=5, dim=4, eps=1e-5, momentum=0.1, affine=True, *args, **kwargs):
        super(IterNorm, self).__init__()
        # assert dim == 4, 'IterNorm is not support 2D'
        self.T = T
        self.eps = eps
        self.momentum = momentum
        self.num_features = num_features
        self.affine = affine
        self.dim = dim

        if num_channels == -1:
            num_channels = (num_features - 1) // num_groups + 1
        num_groups = num_features // num_channels
        while num_features % num_channels != 0:
            num_channels //= 2
            num_groups = num_features // num_channels
        assert num_groups > 0 and num_features % num_groups == 0, "num features={}, num groups={}".format(num_features,
            num_groups)
        self.num_groups = num_groups
        self.num_channels = num_channels
        shape = [1] * dim
        shape[1] = self.num_features
        if self.affine:
            self.weight = Parameter(torch.Tensor(*shape))
            self.bias = Parameter(torch.Tensor(*shape))

        if not self.affine:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.register_buffer('running_mean', torch.zeros(num_groups, num_channels, 1))
        # running whiten matrix
        self.register_buffer('running_wm', torch.eye(num_channels).expand(num_groups, num_channels, num_channels))
        self.reset_parameters()

    def reset_parameters(self):
        # self.reset_running_stats()
        if self.affine:
            self.weight.data.fill_(1.0)
            self.bias.data.fill_(0.0)

    def forward(self, X: torch.Tensor):
        eps = 1e-5
        momentum = self.momentum

        nc = self.num_channels
        T = self.T
        g = X.size(1) // nc
        x = X.transpose(0, 1).contiguous().view(g, nc, -1)
        _, d, m = x.size()
        if self.training:
            # calculate centered activation by subtracted mini-batch mean
            mean = x.mean(-1, keepdim=True)
            xc = x - mean
            # calculate covariance matrix
            P = [None] * (T + 1)
            P[0] = torch.eye(d).to(X).expand(g, d, d)
            # Sigma = torch.baddbmm(eps, P[0], 1. / m, xc, xc.transpose(1, 2))
            Sigma = torch.baddbmm(P[0], xc, xc.transpose(1, 2), beta=eps, alpha=1. / m)  # =torch.cov(xc,correction=0)
            Sigma[0]=fix_cov(Sigma[0])
            # reciprocal of trace of Sigma: shape [g, 1, 1]
            rTr = (Sigma * P[0]).sum(1, keepdim=True).sum(2, keepdim=True).reciprocal_()
            Sigma_N = Sigma * rTr
            for k in range(T):
                # P[k + 1] = torch.baddbmm(1.5, P[k], -0.5, P[k].bmm(P[k]).bmm(P[k]), Sigma_N)
                P[k + 1] = torch.baddbmm(P[k], P[k].bmm(P[k]).bmm(P[k]), Sigma_N, beta=1.5, alpha=-0.5)
            wm = P[T].mul_(rTr.sqrt())  # whiten matrix: the matrix inverse of Sigma, i.e., Sigma^{-1/2}
            # self.running_mean += momentum * ( mean.detach() - self.running_mean)
            # self.running_wm += momentum * ( wm.detach() - self.running_wm)
            self.running_mean = (1-momentum)*self.running_mean + momentum * mean.detach()
            self.running_wm = (1-momentum)*self.running_wm + momentum * wm.detach() 
        else:
            xc = x - self.running_mean
            wm = self.running_wm
        xn = wm.matmul(xc)
        X_hat = xn.view(X.size(1), X.size(0), *X.size()[2:]).transpose(0, 1).contiguous()
        
        # affine
        if self.affine:
            X_hat = X_hat * self.weight
            X_hat = X_hat + self.bias
        return X_hat

    def extra_repr(self):
        return '{num_features}, num_channels={num_channels}, T={T}, eps={eps}, dim={dim}, ' \
               'momentum={momentum}, affine={affine}'.format(**self.__dict__)


## Validation

from [wikipedia](https://en.wikipedia.org/wiki/Covariance_matrix):  
The covariance matrix is given by:  
![cov](cov.png)

and the correlation matrix is given by:  
![corr](corr_mat.png)

so if we have a matrix of `N` samples of a random vector $x \in \mathbb{R}^C$ , we can compute the covariance of the random vector as follows:
`corr_matrix = torch.corrcoef(x_f)`  
and we can also compute the covariance matrix with: `cov_matrix = x.cov()`   

we can also get `corr_matrix` from `cov_matrix` using the following function:

In [4]:
import numpy as np

def generate_well_conditioned_covariance(n, condition_number=2):
    # Step 1: Generate a random matrix
    # L = np.random.rand(n, n)
    L = torch.rand(n, n)
    
    # Step 2: Make it lower triangular with positive diagonal
    L = torch.tril(L) + n * torch.eye(n)
    
    # Step 3: Construct the symmetric positive definite matrix
    A = L@L.T
    
    # Step 4: Adjust the condition number
    # Scale the matrix to have a desired condition number
    u, s, vh = torch.linalg.svd(A)
    s = torch.linspace(s[0], s[0] / condition_number, len(s))
    A = (u * s)@vh
    
    return A


In [7]:
cov=generate_well_conditioned_covariance(3,10)
print(cov)
torch.linalg.eig(cov)[0]

tensor([[11.6849,  4.2934,  2.5641],
        [ 4.2934,  4.7783,  4.6226],
        [ 2.5641,  4.6226, 14.7474]])


tensor([18.9155+0.j, 10.4035+0.j,  1.8916+0.j])

In [8]:
torch.linalg.cond(cov)

tensor(10.0000)

In [9]:
def cov_to_corr(cov_matrix):
    # Compute the standard deviations
    std = torch.sqrt(torch.diag(cov_matrix))
    
    # Compute the correlation matrix
    corr_matrix = cov_matrix / torch.outer(std, std)
    
    # Extract upper triangular part (excluding diagonal)
    upper_tri = torch.triu(corr_matrix, diagonal=1)
    
    # Compute average of cross-correlation coefficients
    avg_corr = upper_tri.sum() / (upper_tri.numel() - upper_tri.diag().numel())
    
    return corr_matrix, avg_corr


def corr_to_cov(corr_matrix,std):
    # Compute the standard deviations
    D = torch.diag(std)
    # D=torch.outer(std,std)
    # cov=corr_matrix * torch.outer(std,std)
    return D@corr_matrix@D


In [4]:
# the following compute the rank of the matrix and the average of the correlation coefficients in the correlation matrix:
# x is expected to have the following shape: [B,C,H,W] if it has 4 dimensions, [B,C] if it has 2 dimensions 
def rank_and_avg_corr(x):
    if len(x.shape)==4:  # conv
        # flatten x from [B,C,H,W] to [C,B*H*W]
        x_f= x.permute(1,0,2,3).reshape(x.shape[1],-1)
    else:
        # change from [B,C] to [C,B]
        x_f=x.T
    # compute corr matrix
    corr_matrix = torch.corrcoef(x_f)
    # Extract upper triangular part (excluding diagonal)
    upper_tri = torch.triu(corr_matrix, diagonal=1)
    # Compute average of cross-correlation coefficients
    avg_corr = upper_tri.sum() / (upper_tri.numel() - upper_tri.diag().numel())
    rank = torch.linalg.matrix_rank(x_f)/x.shape[1]
    return rank,avg_corr


#### 4D tensor

In [138]:
# Create a batch of 2D images (batch size, channels, height, width)
num_features = 3
n_samples=100
m = torch.randint(10,100,(1,num_features)).float()
c = torch.randint(1,10,(1,num_features))
# cov = c.T@c + 0.1*torch.eye(num_features)       # ill conditioned 
cov = generate_well_conditioned_covariance(num_features,10000)  # well conditioned

print(f'generating a tensor of shape ({n_samples}, {num_features},32,32) with mean {m} and covariance \n {cov}:')
x = torch.randn(n_samples, num_features, 32, 32).permute(1,0,2,3).reshape(num_features,-1).T
L = torch.linalg.cholesky(cov.float())
xc= x@L.T + m
# print(f'actual mean and cov: {xc.mean(0)},\n {torch.cov(xc.T,correction=0)}')
print(f'actual mean and cov: {xc.mean(0)},\n {torch.cov(xc.T)}')
x_c = xc.permute(1,0).reshape(num_features,n_samples,32,32).permute(1,0,2,3)
# flatten x_c to [num_featurs,num_samples]:
x_c_f= x_c.permute(1,0,2,3).reshape(x_c.shape[1],-1)



generating a tensor of shape (100, 3,32,32) with mean tensor([[19., 47., 18.]]) and covariance 
 tensor([[ 8.6111,  4.5469, -0.9660],
        [ 4.5469, 11.1936,  7.9774],
        [-0.9660,  7.9774,  8.3056]]):
actual mean and cov: tensor([18.9974, 47.0011, 18.0025]),
 tensor([[ 8.6051,  4.5244, -0.9843],
        [ 4.5244, 11.1625,  7.9613],
        [-0.9843,  7.9613,  8.3015]])


In [139]:
# compute the cov from x_c:
x_c_cov = x_c_f.cov()
print(x_c_cov)
print(f"condition number: {torch.linalg.cond(x_c_cov)}")
print("\n \n correlation matrix and average cross correlation:") 
cov_to_corr(x_c_cov)

tensor([[ 8.6051,  4.5244, -0.9843],
        [ 4.5244, 11.1625,  7.9613],
        [-0.9843,  7.9613,  8.3015]])
condition number: 9925.451171875

 
 correlation matrix and average cross correlation:


(tensor([[ 1.0000,  0.4616, -0.1165],
         [ 0.4616,  1.0000,  0.8270],
         [-0.1165,  0.8270,  1.0000]]),
 tensor(0.1954))

In [140]:
# compute the correlation coefficient matrix
print(torch.corrcoef(x_c_f))
# test the conversion function: should be the same output
cov_to_corr(x_c_cov)[0]

tensor([[ 1.0000,  0.4616, -0.1165],
        [ 0.4616,  1.0000,  0.8270],
        [-0.1165,  0.8270,  1.0000]])


tensor([[ 1.0000,  0.4616, -0.1165],
        [ 0.4616,  1.0000,  0.8270],
        [-0.1165,  0.8270,  1.0000]])

In [14]:
print(xc.shape)
print(x_c.shape)
print(x_c_f.shape)

torch.Size([102400, 3])
torch.Size([100, 3, 32, 32])
torch.Size([3, 102400])


##### Validating Batch Whitening layer

In [30]:
# Our custom batch normalization layer
bw_layer = BatchWhitening(num_features,momentum=1)

# Forward pass
x_w = bw_layer(x_c)   # expecting x_c.shape=[N,C,H,W] 

x_w_cov = x_w.permute(1,0,2,3).reshape(x.shape[1],-1).cov()
print(x_w_cov)
# Check if the outputs are indeed orthonormal
assert torch.allclose(x_w_cov, torch.eye(num_features), atol=1e-2), "The outputs are not close enough!"

print("Functional validation passed!")
cov_to_corr(x_w_cov)

tensor([[ 1.0000e+00,  3.2923e-05, -1.8035e-06],
        [ 3.2923e-05,  9.9937e-01,  3.9545e-05],
        [-1.8035e-06,  3.9545e-05,  1.0000e+00]], grad_fn=<SqueezeBackward0>)
Functional validation passed!


(tensor([[ 1.0000e+00,  3.2933e-05, -1.8035e-06],
         [ 3.2933e-05,  1.0000e+00,  3.9557e-05],
         [-1.8035e-06,  3.9557e-05,  1.0000e+00]], grad_fn=<DivBackward0>),
 tensor(1.1781e-05, grad_fn=<DivBackward0>))

In [28]:
# since the layer also normalized the input, the correlation coefficient matrix should be nearly the same (up to approx errors) as the covariance matrix. lets check it:
print(cov_to_corr(x_w_cov))
# alternatice way to compute: - should produce the same output
print(torch.corrcoef(x_w.permute(1,0,2,3).reshape(x.shape[1],-1)))

(tensor([[ 1.0000e+00,  3.2933e-05, -1.8035e-06],
        [ 3.2933e-05,  1.0000e+00,  3.9557e-05],
        [-1.8035e-06,  3.9557e-05,  1.0000e+00]], grad_fn=<DivBackward0>), tensor(1.1781e-05, grad_fn=<DivBackward0>))
tensor([[ 1.0000e+00,  3.2934e-05, -1.8035e-06],
        [ 3.2934e-05,  1.0000e+00,  3.9557e-05],
        [-1.8035e-06,  3.9557e-05,  1.0000e+00]], grad_fn=<ClampBackward1>)


In [29]:
x_c_cov = x_c.permute(1,0,2,3).reshape(x.shape[1],-1).cov()
# we can also compute x_c_cov as follows
# torch.cov(x_c.permute(1,0,2,3).reshape(x.shape[1],-1))
print(x_c_cov)
cov_to_corr(x_c_cov)

tensor([[ 3.3637,  4.0622,  1.1951],
        [ 4.0622,  4.9443,  1.4957],
        [ 1.1951,  1.4957, 15.0755]], grad_fn=<SqueezeBackward0>)


(tensor([[1.0000, 0.9961, 0.1678],
         [0.9961, 1.0000, 0.1732],
         [0.1678, 0.1732, 1.0000]], grad_fn=<DivBackward0>),
 tensor(0.2229, grad_fn=<DivBackward0>))

In [None]:
# compute the rank
torch.linalg.matrix_rank(x_c.permute(1,0,2,3).reshape(x.shape[1],-1))/x_c.shape[1]

##### Validating iter norm 

In [141]:
# TODO: understand how the algo works so that it whiten the matrix
ItN = IterNorm(3, num_groups=1, T=11, momentum=1, affine=False)
print(ItN)
ItN.train()
x_c.requires_grad_()
x_w = ItN(x_c)
print(x_w.shape)
x_w_cov = x_w.permute(1,0,2,3).reshape(x_w.shape[1],-1).cov()
cov_to_corr(x_w_cov)
# print(x_w_cov)


IterNorm(3, num_channels=3, T=11, eps=1e-05, dim=4, momentum=1, affine=False)
torch.Size([100, 3, 32, 32])


(tensor([[ 1.0000, -0.0227, -0.9968],
         [-0.0227,  1.0000,  0.0201],
         [-0.9968,  0.0201,  1.0000]], grad_fn=<DivBackward0>),
 tensor(-0.1666, grad_fn=<DivBackward0>))

#### 2D tensor

In [53]:
# Create a batch of vectors (B, num_features)
num_features = 4
# num_features = 64
m = torch.randint(10,100,(1,num_features)).float()
c = torch.randint(1,10,(1,num_features))
# cov = c.T@c + 0.1*torch.eye(num_features)
cov = generate_well_conditioned_covariance(4,100)


print(f'generating a tensor of shape (B, {num_features}) with mean {m} and covariance \n {cov}:')
x = torch.randn(1000, num_features)
# x = torch.randn(128, num_features)
L = torch.linalg.cholesky(cov.float())
xc= x@L.T + m 
print(f'tensor xc.shape={xc.shape}')
print(f'actual mean and cov: {xc.mean(0)},\n {torch.cov(xc.T)}')

# print(f'rank and average correlation of x: {rank_and_avg_corr(x)}')
# print(f'rank and average correlation of xc: {rank_and_avg_corr(xc)}')


generating a tensor of shape (B, 4) with mean tensor([[44., 16., 80., 31.]]) and covariance 
 tensor([[ 8.7027,  1.4486,  4.8319,  2.8655],
        [ 1.4486, 22.2658,  9.1006,  0.7148],
        [ 4.8319,  9.1006,  6.3153,  0.2747],
        [ 2.8655,  0.7148,  0.2747, 17.6371]]):
tensor xc.shape=torch.Size([1000, 4])
actual mean and cov: tensor([43.9412, 16.1113, 79.9863, 31.0422]),
 tensor([[ 8.1434,  1.3770,  4.4945,  2.1811],
        [ 1.3770, 23.0514,  9.5046, -0.1960],
        [ 4.4945,  9.5046,  6.3290, -0.4049],
        [ 2.1811, -0.1960, -0.4049, 16.2121]])


In [54]:
x_c_cov=torch.cov(xc.T)
print(f"condition number: {torch.linalg.cond(x_c_cov)}")
print("\n \n correlation matrix and average cross correlation:") 
cov_to_corr(x_c_cov)

condition number: 106.389892578125

 
 correlation matrix and average cross correlation:


(tensor([[ 1.0000,  0.1005,  0.6261,  0.1898],
         [ 0.1005,  1.0000,  0.7869, -0.0101],
         [ 0.6261,  0.7869,  1.0000, -0.0400],
         [ 0.1898, -0.0101, -0.0400,  1.0000]]),
 tensor(0.1378))

In [None]:
print(f'eigen values of x.cov: {torch.linalg.eig(torch.cov(x.T))[0]}')
print(f'eigen values of xc.cov: {torch.linalg.eig(torch.cov(xc.T))[0]}')


In [72]:
C=torch.corrcoef(xc.T).detach().numpy()

array([[ 1.        ,  0.10050666,  0.6260576 ,  0.18982305],
       [ 0.10050666,  1.        ,  0.78689915, -0.01013907],
       [ 0.6260577 ,  0.7868992 ,  1.        , -0.03996942],
       [ 0.18982306, -0.01013907, -0.03996942,  1.        ]],
      dtype=float32)

##### Batch whitening (Cholesky)

In [36]:
# Our custom batch normalization layer
bw_layer = BatchWhitening(num_features,momentum=1)

print('cov matrix of xc.T')
print(torch.cov(xc.T))

# Forward pass - expecting xc.shape=[N,D]
x_w = bw_layer(xc)

print('cov matrix of x_w.T')
print(torch.cov(x_w.T))

# Check if the outputs are indeed orthonormal
assert torch.allclose(x_w.T.cov(), torch.eye(num_features), atol=1e-2), "The outputs are not close enough!"

print("Functional validation passed!")

cov matrix of xc.T
tensor([[19.5924,  2.7735,  5.2357, -0.5478],
        [ 2.7735, 20.4336,  3.9708,  6.3781],
        [ 5.2357,  3.9708, 11.4728,  6.4239],
        [-0.5478,  6.3781,  6.4239, 10.1352]])
cov matrix of x_w.T
tensor([[ 1.0010e+00,  5.7994e-08,  4.1194e-08, -6.6347e-08],
        [ 5.7994e-08,  1.0010e+00,  1.4367e-07,  3.1407e-07],
        [ 4.1194e-08,  1.4367e-07,  1.0010e+00,  9.2122e-07],
        [-6.6347e-08,  3.1407e-07,  9.2122e-07,  1.0010e+00]])
Functional validation passed!


In [37]:
# print the correlation matrix of the whitened signal and the average cross correlation
cov_to_corr(x_w.T.cov())

(tensor([[ 1.0000e+00,  5.7936e-08,  4.1152e-08, -6.6280e-08],
         [ 5.7936e-08,  1.0000e+00,  1.4353e-07,  3.1376e-07],
         [ 4.1152e-08,  1.4353e-07,  1.0000e+00,  9.2030e-07],
         [-6.6280e-08,  3.1376e-07,  9.2030e-07,  1.0000e+00]]),
 tensor(1.1753e-07))

In [38]:
# should be the same correlation matrix as previous cell
torch.corrcoef(x_w.T)


tensor([[ 1.0000e+00,  5.7936e-08,  4.1152e-08, -6.6280e-08],
        [ 5.7936e-08,  1.0000e+00,  1.4353e-07,  3.1376e-07],
        [ 4.1152e-08,  1.4353e-07,  1.0000e+00,  9.2030e-07],
        [-6.6280e-08,  3.1376e-07,  9.2030e-07,  1.0000e+00]])

##### Iternorm

using the same signal as above

In [None]:
# using the same signal 
print(f' xc.shape={xc.shape}')

In [81]:
ItN = IterNorm(num_features, num_groups=1, T=10,  momentum=1, affine=False)
print(ItN)
ItN.train()
xc.requires_grad_()
y = ItN(xc)  # y.shape=xc.shape=[N,D]
# the following is an alternative computation of y's cov matrix (implements what cov is doing behind the scene with correction=0)
z = y.transpose(0, 1).contiguous().view(x.size(1), -1) # z.shape=[D,N]
y_cov = z.matmul(z.t()) / z.size(1)
print("y_cov:")
print(y_cov)  # the outcome is the cov matrix at shape [D,D]

cov_to_corr(y_cov)

IterNorm(4, num_channels=4, T=10, eps=1e-05, dim=4, momentum=1, affine=False)
y_cov:
tensor([[ 4180.2402,  3087.0859, -7855.9429,  -733.3395],
        [ 3087.0859,  2280.7432, -5803.5220,  -541.7139],
        [-7855.9429, -5803.5220, 14767.7178,  1378.4817],
        [ -733.3395,  -541.7139,  1378.4817,   129.6800]],
       grad_fn=<DivBackward0>)


(tensor([[ 1.0000,  0.9998, -0.9999, -0.9960],
         [ 0.9998,  1.0000, -1.0000, -0.9961],
         [-0.9999, -1.0000,  1.0000,  0.9961],
         [-0.9960, -0.9961,  0.9961,  1.0000]], grad_fn=<DivBackward0>),
 tensor(-0.1663, grad_fn=<DivBackward0>))

In [56]:
# alternative computation 
x_w = ItN(xc)
print(x_w.shape)
x_w_cov = x_w.T.cov(correction=0)
print(x_w_cov)

# torch.corrcoef(x_w.T)
cov_to_corr(x_w_cov)

torch.Size([1000, 4])
tensor([[ 4180.2388,  3087.0857, -7855.9424,  -733.3392],
        [ 3087.0857,  2280.7434, -5803.5215,  -541.7140],
        [-7855.9424, -5803.5215, 14767.7139,  1378.4822],
        [ -733.3392,  -541.7140,  1378.4822,   129.6800]],
       grad_fn=<SqueezeBackward0>)


(tensor([[ 1.0000,  0.9998, -0.9999, -0.9960],
         [ 0.9998,  1.0000, -1.0000, -0.9961],
         [-0.9999, -1.0000,  1.0000,  0.9961],
         [-0.9960, -0.9961,  0.9961,  1.0000]], grad_fn=<DivBackward0>),
 tensor(-0.1663, grad_fn=<DivBackward0>))

In [95]:
xc.T.cov()

tensor([[ 8.1434,  1.3770,  4.4945,  2.1811],
        [ 1.3770, 23.0514,  9.5046, -0.1960],
        [ 4.4945,  9.5046,  6.3290, -0.4049],
        [ 2.1811, -0.1960, -0.4049, 16.2121]], grad_fn=<SqueezeBackward0>)

In [117]:
xc_corr=cov_to_corr(xc.T.cov())[0]
xc_corr

tensor([[ 1.0000,  0.1005,  0.6261,  0.1898],
        [ 0.1005,  1.0000,  0.7869, -0.0101],
        [ 0.6261,  0.7869,  1.0000, -0.0400],
        [ 0.1898, -0.0101, -0.0400,  1.0000]], grad_fn=<DivBackward0>)

In [118]:
c2c=corr_to_cov(xc_corr,xc.std(axis=0))
c2c

tensor([[ 8.1434,  1.3770,  4.4945,  2.1811],
        [ 1.3770, 23.0514,  9.5046, -0.1960],
        [ 4.4945,  9.5046,  6.3290, -0.4049],
        [ 2.1811, -0.1960, -0.4049, 16.2121]], grad_fn=<MmBackward0>)

In [87]:
# torch.std(xc,dim=0)
xc.std(axis=0)

tensor([2.8537, 4.8012, 2.5157, 4.0264], grad_fn=<StdBackward0>)

### Debug

In [19]:
M = torch.randn(10, 3, 5)
A = torch.zeros_like(M)
batch1 = torch.randn(10, 3, 4)
batch2 = torch.randn(10, 4, 5)
B=torch.baddbmm(M, batch1, batch2,out=A)

In [77]:
def get_a(c):
    a=0.9+0.1*torch.exp(-(c/0.9)**10)
    torch.diagonal(a).fill_(1.0)
    return a


In [127]:
# C=torch.corrcoef(xc.T).detach().numpy()
C=torch.corrcoef(xc.T)
print(C)

tensor([[1.0000, 0.9795, 0.3806],
        [0.9795, 1.0000, 0.2026],
        [0.3806, 0.2026, 1.0000]])


In [129]:
C*get_a(C)

tensor([[1.0000, 0.8911, 0.3805],
        [0.8911, 1.0000, 0.2026],
        [0.3805, 0.2026, 1.0000]])

In [79]:
C*get_a(C)

array([[ 1.        ,  0.10050666,  0.6244186 ,  0.18982303],
       [ 0.10050666,  1.        ,  0.768818  , -0.01013907],
       [ 0.6244187 ,  0.768818  ,  1.        , -0.03996942],
       [ 0.18982305, -0.01013907, -0.03996942,  1.        ]],
      dtype=float32)