# Batch Whitening Layer
The purpose of this notebook is to implement the batch whitening layer.   
The implementation is inspired by the implementation of BatchNorm layer from [this reference](https://d2l.ai/chapter_convolutional-modern/batch-norm.html)


In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 
import numpy as np 

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print("Device to use:", device)

%load_ext autoreload
%autoreload 2

Device to use: cuda:0


## Batch Normalization  
Lets start with implementing BatchNorm from scratch and test it on a simple dataset.


### Option 1

In [None]:
def batch_norm(X, gamma, beta, running_mean, running_var, eps, momentum):
    # Use is_grad_enabled to determine whether we are in training mode
    assert len(X.shape) in (2, 4)
    if len(X.shape) == 2:
        shape = (1, X.shape[1])
    else:
        shape = (1, X.shape[1], 1, 1)

    if not torch.is_grad_enabled():
        # In prediction mode, use mean and variance obtained by moving average
        X_hat = (X - running_mean) / torch.sqrt(running_var + eps)
    else:
        if len(X.shape) == 2:
            # When using a fully connected layer, calculate the mean and
            # variance on the feature dimension
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # When using a two-dimensional convolutional layer, calculate the
            # mean and variance on the channel dimension (axis=1). Here we
            # need to maintain the shape of X, so that the broadcasting
            # operation can be carried out later
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # In training mode, the current mean and variance are used
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # Update the mean and variance using moving average
        running_mean = (1.0 - momentum) * running_mean + momentum * mean
        running_var = (1.0 - momentum) * running_var + momentum * var
    Y = gamma.view(shape) * X_hat + beta.view(shape)  # Scale and shift
    return Y, running_mean.data, running_var.data


class BatchNorm(nn.Module):
    # num_features: the number of outputs for a fully connected layer or the
    # number of output channels for a convolutional layer. num_dims: 2 for a
    # fully connected layer and 4 for a convolutional layer
    def __init__(self, num_features, num_dims):
        super().__init__()
        # if num_dims == 2:
        #     shape = (1, num_features)
        # else:
        #     shape = (1, num_features, 1, 1)
        shape = num_features
        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))

        # The variables that are not model parameters are initialized to 0 and 1
        # self.running_mean = torch.zeros(shape)
        # self.running_var = torch.ones(shape)
        self.register_buffer('running_mean', torch.zeros(shape))
        self.register_buffer('running_var', torch.ones(shape))

    def forward(self, X):
        # If X is not on the main memory, copy moving_mean and moving_var to
        # the device where X is located
        if self.running_mean.device != X.device:
            self.running_mean = self.running_mean.to(X.device)
            self.running_var = self.running_var.to(X.device)
        # Save the updated running_mean and moving_var
        Y, self.running_mean, self.running_var = batch_norm(
            X, self.gamma, self.beta, self.running_mean,
            self.running_var, eps=1e-5, momentum=0.1)
        return Y


### Option 2

In [None]:
class BatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(BatchNorm2d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        # Learnable parameters
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        # Running mean and variance
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        if self.training:
            # Calculate mean and variance for the batch
            mean = x.mean([0, 2, 3], keepdim=True)
            var = x.var([0, 2, 3], keepdim=True, unbiased=False)
            # Update running mean and variance
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var
        
        # Normalize
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)
        # Scale and shift
        out = self.gamma.view(1, self.num_features, 1, 1) * x_normalized + self.beta.view(1, self.num_features, 1, 1)
        
        return out

### Validation


In [None]:
# Create a batch of 2D images (batch size, channels, height, width)
x = torch.randn(20, 10, 50, 50)

# Our custom batch normalization layer
custom_bn = BatchNorm(num_features=10, num_dims=4)
# custom_bn = BatchNorm2d(num_features=10)

# PyTorch's built-in batch normalization layer
torch_bn = nn.BatchNorm2d(num_features=10)


# Copy the parameters from our custom layer to the built-in layer for a fair comparison
torch_bn.weight.data = custom_bn.gamma.data.clone()
torch_bn.bias.data = custom_bn.beta.data.clone()
torch_bn.running_mean = custom_bn.running_mean.clone()
torch_bn.running_var = custom_bn.running_var.clone()

# Forward pass
custom_bn_output = custom_bn(x)
torch_bn_output = torch_bn(x)

# Check if the outputs are close
assert torch.allclose(custom_bn_output, torch_bn_output, atol=1e-5), "The outputs are not close enough!"

print("Functional validation passed!")

In [None]:
x.shape

In [None]:
x.var([0, 2, 3], keepdim=True, unbiased=False).shape

In [None]:
custom_bn.training

# Batch Whitening

## implementation

In [2]:
def cov_to_corr(cov_matrix):
    # Compute the standard deviations
    std = torch.sqrt(torch.diag(cov_matrix))
    
    # Compute the correlation matrix
    corr_matrix = cov_matrix / torch.outer(std, std)
    
    # Extract upper triangular part (excluding diagonal)
    upper_tri = torch.triu(corr_matrix, diagonal=1)
    
    # Compute average of cross-correlation coefficients
    avg_corr = upper_tri.abs().sum() / ((upper_tri.numel() - upper_tri.diag().numel())/2)
    
    return corr_matrix, avg_corr


def corr_to_cov(corr_matrix,std):
    # Compute the standard deviations
    D = torch.diag(std)
    # D=torch.outer(std,std)
    # cov=corr_matrix * torch.outer(std,std)
    return D@corr_matrix@D


### Choleski

In [3]:
def batch_orthonorm_obsolete(X, gamma, beta, running_mean=None, running_cov=None, eps=1e-5, momentum=0.1):
    # Use is_grad_enabled to determine whether we are in training mode
    assert len(X.shape) in (2, 4)
    n_features = X.shape[1]

    if len(X.shape) == 2:
        # When using a fully connected layer, calculate the mean and
        # variance on the feature dimension
        shape = (1, n_features)
        mean = X.mean(dim=0)
        cov = torch.cov(X.T,correction=0)        
        # var = ((X - mean) ** 2).mean(dim=0)
    else:
        # When using a two-dimensional convolutional layer, calculate the
        # mean and covariance on the channel dimension (axis=1). Here we
        # need to maintain the shape of X, so that the broadcasting
        # operation can be carried out later
        shape = (1, n_features, 1, 1)
        mean = X.mean(dim=(0, 2, 3))
        Xtmp = X.view(X.shape[0],X.shape[1],-1)
        Xtmp = Xtmp.permute(1,0,2).reshape(X.shape[1],-1)
        cov = torch.cov(Xtmp,correction=0) 
        # var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
    # In training mode, the current mean and variance are used
    # Update the mean and variance using moving average
    running_mean = (1.0 - momentum) * running_mean + momentum * mean
    running_cov = (1.0 - momentum) * running_cov + momentum * cov
    L = torch.linalg.cholesky(running_cov + eps*torch.eye(n_features))
    if len(X.shape) == 2:
        X_hat = (X-running_mean.view(1,n_features)).T
        Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T
    else:
        X_hat = X-running_mean.view(1,n_features,1,1)
        X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
        Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)
    # Y = gamma.view(shape) * Y + beta.view(shape)  # Scale and shift
    return Y, running_mean.data, running_cov.data

def batch_orthonorm(X, gamma, beta, running_mean=None, running_cov=None, eps=1e-5, momentum=0.1):
    # Use is_grad_enabled to determine whether we are in training mode
    assert len(X.shape) in (2, 4)
    n_features = X.shape[1]

    if len(X.shape) == 2:
        # When using a fully connected layer, calculate the mean and
        # variance on the feature dimension
        shape = (1, n_features)
        mean = X.mean(dim=0)
        cov = torch.cov(X.T,correction=0)        
        # var = ((X - mean) ** 2).mean(dim=0)
    else:
        # When using a two-dimensional convolutional layer, calculate the
        # mean and covariance on the channel dimension (axis=1). Here we
        # need to maintain the shape of X, so that the broadcasting
        # operation can be carried out later
        shape = (1, n_features, 1, 1)
        mean = X.mean(dim=(0, 2, 3))
        Xtmp = X.view(X.shape[0],X.shape[1],-1)
        Xtmp = Xtmp.permute(1,0,2).reshape(X.shape[1],-1)
        cov = torch.cov(Xtmp,correction=0) 
        # var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
    # In training mode, the current mean and variance are used
    # Update the mean and variance using moving average
    if torch.is_grad_enabled():
        running_mean = (1.0 - momentum) * running_mean + momentum * mean
        running_cov = (1.0 - momentum) * running_cov + momentum * cov
        L = torch.linalg.cholesky(cov + eps*torch.eye(n_features))
        if len(X.shape) == 2:
            X_hat = (X-mean.view(1,n_features)).T
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T
        else:
            X_hat = X-mean.view(1,n_features,1,1)
            X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)
    else:
        L = torch.linalg.cholesky(running_cov + eps*torch.eye(n_features))
        if len(X.shape) == 2:
            X_hat = (X-running_mean.view(1,n_features)).T
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T
        else:
            X_hat = X-running_mean.view(1,n_features,1,1)
            X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)
    # Y = gamma.view(shape) * Y + beta.view(shape)  # Scale and shift
    return Y, running_mean.data, running_cov.data

class BatchWhitening(nn.Module):
    # num_features: the number of outputs for a fully connected layer or the
    # number of output channels for a convolutional layer. num_dims: 2 for a
    # fully connected layer and 4 for a convolutional layer
    def __init__(self, num_features,momentum=0.1):
        super().__init__()
        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        self.momentum = momentum
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        # The variables that are not model parameters are initialized to 0 and 1
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_cov', torch.eye(num_features))

    def forward(self, X):
        # If X is not on the main memory, copy moving_mean and moving_var to
        # the device where X is located
        if self.running_mean.device != X.device:
            self.running_mean = self.running_mean.to(X.device)
            self.running_cov = self.running_cov.to(X.device)
        # Save the updated running_mean and moving_var
        # Y, self.running_mean, self.running_var = batch_orthonorm(
        #     X, self.gamma, self.beta, self.running_mean,
        #     self.running_cov, eps=1e-5, momentum=0.1)
        Y, self.running_mean, self.running_cov = batch_orthonorm(
            X, self.gamma, self.beta, self.running_mean,
            self.running_cov, eps=1e-5, momentum=self.momentum)

        return Y


### IterNorm

The algorithm :  (taken from [the paper](https://arxiv.org/abs/1904.03441) )  
![iternorm](./iternorm_algo.png)

In [None]:
from torch.nn import Parameter

def fix_corr(corr):
    a=0.9+0.1*torch.exp(-(abs(corr)/0.9)**10)
    a=a.clone()  # so not to lose the gradients in backprop
    torch.diagonal(a).fill_(1.0)
    return a*corr

class IterNorm(nn.Module):
    def __init__(self, num_features, num_groups=1, num_channels=-1, T=5, dim=4, eps=1e-5, momentum=0.1, affine=True, *args, **kwargs):
        super(IterNorm, self).__init__()
        # assert dim == 4, 'IterNorm is not support 2D'
        self.T = T
        self.eps = eps
        self.momentum = momentum
        self.num_features = num_features
        self.affine = affine
        self.dim = dim
        self.fix_cov=kwargs.get('fix_cov',False)

        if num_channels == -1:
            num_channels = (num_features - 1) // num_groups + 1
        num_groups = num_features // num_channels
        while num_features % num_channels != 0:
            num_channels //= 2
            num_groups = num_features // num_channels
        assert num_groups > 0 and num_features % num_groups == 0, "num features={}, num groups={}".format(num_features,
            num_groups)
        self.num_groups = num_groups
        self.num_channels = num_channels
        shape = [1] * dim
        shape[1] = self.num_features
        if self.affine:
            self.weight = Parameter(torch.Tensor(*shape))
            self.bias = Parameter(torch.Tensor(*shape))

        if not self.affine:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.register_buffer('running_mean', torch.zeros(num_groups, num_channels, 1))
        # running whiten matrix
        self.register_buffer('running_wm', torch.eye(num_channels).expand(num_groups, num_channels, num_channels))
        self.reset_parameters()

    def reset_parameters(self):
        # self.reset_running_stats()
        if self.affine:
            self.weight.data.fill_(1.0)
            self.bias.data.fill_(0.0)

    def forward(self, X: torch.Tensor):
        eps = 1e-5
        momentum = self.momentum

        nc = self.num_channels
        T = self.T
        g = X.size(1) // nc
        x = X.transpose(0, 1).contiguous().view(g, nc, -1)
        _, d, m = x.size()
        if self.training:
            # calculate centered activation by subtracted mini-batch mean
            mean = x.mean(-1, keepdim=True)
            xc = x - mean
            # calculate covariance matrix
            P = [None] * (T + 1)
            P[0] = torch.eye(d).to(X).expand(g, d, d)
            # Sigma = torch.baddbmm(eps, P[0], 1. / m, xc, xc.transpose(1, 2))
            Sigma = torch.baddbmm(P[0], xc, xc.transpose(1, 2), beta=eps, alpha=1. / m)  # =torch.cov(xc,correction=0)
            if self.fix_cov:
                Sigma[0]=fix_corr(Sigma[0])
            # reciprocal of trace of Sigma: shape [g, 1, 1]
            rTr = (Sigma * P[0]).sum(1, keepdim=True).sum(2, keepdim=True).reciprocal_()
            Sigma_N = Sigma * rTr
            for k in range(T):
                # P[k + 1] = torch.baddbmm(1.5, P[k], -0.5, P[k].bmm(P[k]).bmm(P[k]), Sigma_N)
                P[k + 1] = torch.baddbmm(P[k], P[k].bmm(P[k]).bmm(P[k]), Sigma_N, beta=1.5, alpha=-0.5)
            wm = P[T].mul_(rTr.sqrt())  # whiten matrix: the matrix inverse of Sigma, i.e., Sigma^{-1/2}
            # self.running_mean += momentum * ( mean.detach() - self.running_mean)
            # self.running_wm += momentum * ( wm.detach() - self.running_wm)
            self.running_mean = (1-momentum)*self.running_mean + momentum * mean.detach()
            self.running_wm = (1-momentum)*self.running_wm + momentum * wm.detach() 
        else:
            xc = x - self.running_mean
            wm = self.running_wm
        xn = wm.matmul(xc)
        X_hat = xn.view(X.size(1), X.size(0), *X.size()[2:]).transpose(0, 1).contiguous()
        
        # affine
        if self.affine:
            X_hat = X_hat * self.weight
            X_hat = X_hat + self.bias
        return X_hat

    def extra_repr(self):
        return '{num_features}, num_channels={num_channels}, T={T}, eps={eps}, dim={dim}, ' \
               'momentum={momentum}, affine={affine}'.format(**self.__dict__)

from batch_whitening import IterNormMod,iter_norm_batch



## Validation

from [wikipedia](https://en.wikipedia.org/wiki/Covariance_matrix):  
The covariance matrix is given by:  
![cov](cov.png)

and the correlation matrix is given by:  
![corr](corr_mat.png)

so if we have a matrix of `N` samples of a random vector $x \in \mathbb{R}^C$ , we can compute the covariance of the random vector as follows:
`corr_matrix = torch.corrcoef(x_f)`  
and we can also compute the covariance matrix with: `cov_matrix = x.cov()`   

we can also get `corr_matrix` from `cov_matrix` using the following function:

In [5]:
import numpy as np

def generate_well_conditioned_covariance(n, condition_number=2):
    # Step 1: Generate a random matrix
    # L = np.random.rand(n, n)
    L = torch.rand(n, n)
    
    # Step 2: Make it lower triangular with positive diagonal
    L = torch.tril(L) + n * torch.eye(n)
    
    # Step 3: Construct the symmetric positive definite matrix
    A = L@L.T
    
    # Step 4: Adjust the condition number
    # Scale the matrix to have a desired condition number
    u, s, vh = torch.linalg.svd(A)
    s = torch.linspace(s[0], s[0] / condition_number, len(s))
    A = (u * s)@vh
    
    return A


In [6]:
cov=generate_well_conditioned_covariance(3,10)
print(cov)
print(f"condition number: {torch.linalg.cond(cov)}")
torch.linalg.eig(cov)[0]

tensor([[ 3.4189,  3.6185,  1.6481],
        [ 3.6185,  9.7029,  2.8032],
        [ 1.6481,  2.8032, 15.9187]])
condition number: 10.000001907348633


tensor([ 1.7600+0.j, 17.6003+0.j,  9.6802+0.j])

In [None]:
# the following compute the rank of the matrix and the average of the correlation coefficients in the correlation matrix:
# x is expected to have the following shape: [B,C,H,W] if it has 4 dimensions, [B,C] if it has 2 dimensions 
def rank_and_avg_corr(x):
    if len(x.shape)==4:  # conv
        # flatten x from [B,C,H,W] to [C,B*H*W]
        x_f= x.permute(1,0,2,3).reshape(x.shape[1],-1)
    else:
        # change from [B,C] to [C,B]
        x_f=x.T
    # compute corr matrix
    corr_matrix = torch.corrcoef(x_f)
    # Extract upper triangular part (excluding diagonal)
    upper_tri = torch.triu(corr_matrix, diagonal=1)
    # Compute average of cross-correlation coefficients
    avg_corr = upper_tri.sum() / (upper_tri.numel() - upper_tri.diag().numel())
    rank = torch.linalg.matrix_rank(x_f)/x.shape[1]
    return rank,avg_corr


#### 4D tensor

In [7]:
# Create a batch of 2D images (batch size, channels, height, width)
num_features = 3
n_samples=100
m = torch.randint(10,100,(1,num_features)).float()
c = torch.randint(1,10,(1,num_features))
# cov = c.T@c + 0.1*torch.eye(num_features)       # ill conditioned 
cov = generate_well_conditioned_covariance(num_features,10000)  # well conditioned

print(f'generating a tensor of shape ({n_samples}, {num_features},32,32) with mean {m} and covariance \n {cov}:')
x = torch.randn(n_samples, num_features, 32, 32).permute(1,0,2,3).reshape(num_features,-1).T
L = torch.linalg.cholesky(cov.float())
xc= x@L.T + m
# xc[:,-1]=0.5*xc[:,-2]+0.5*xc[:,-3]
# print(f'actual mean and cov: {xc.mean(0)},\n {torch.cov(xc.T,correction=0)}')
print(f'actual mean and cov: {xc.mean(0)},\n {torch.cov(xc.T)}')
x_c = xc.permute(1,0).reshape(num_features,n_samples,32,32).permute(1,0,2,3)
# flatten x_c to [num_featurs,num_samples]:
x_c_f= x_c.permute(1,0,2,3).reshape(x_c.shape[1],-1)



generating a tensor of shape (100, 3,32,32) with mean tensor([[74., 33., 79.]]) and covariance 
 tensor([[13.8802,  1.7485,  4.8531],
        [ 1.7484,  6.5989,  6.8115],
        [ 4.8531,  6.8115,  7.7272]]):
actual mean and cov: tensor([74.0076, 32.9945, 78.9963]),
 tensor([[13.8793,  1.7194,  4.8242],
        [ 1.7194,  6.5937,  6.7992],
        [ 4.8242,  6.7992,  7.7081]])


In [8]:
# compute the cov from x_c:
x_c_cov = x_c_f.cov()
print(x_c_cov)
print(f"condition number: {torch.linalg.cond(x_c_cov)}")
print("\n \n correlation matrix and average cross correlation:") 
cov_to_corr(x_c_cov)

tensor([[13.8793,  1.7194,  4.8242],
        [ 1.7194,  6.5937,  6.7992],
        [ 4.8242,  6.7992,  7.7081]])
condition number: 9929.55859375

 
 correlation matrix and average cross correlation:


(tensor([[1.0000, 0.1797, 0.4664],
         [0.1797, 1.0000, 0.9537],
         [0.4664, 0.9537, 1.0000]]),
 tensor(0.5333))

In [9]:
# compute the correlation coefficient matrix
print(torch.corrcoef(x_c_f))
# test the conversion function: should be the same output
x_c_corr=cov_to_corr(x_c_cov)[0]
x_c_corr

tensor([[1.0000, 0.1797, 0.4664],
        [0.1797, 1.0000, 0.9537],
        [0.4664, 0.9537, 1.0000]])


tensor([[1.0000, 0.1797, 0.4664],
        [0.1797, 1.0000, 0.9537],
        [0.4664, 0.9537, 1.0000]])

In [10]:
corr_to_cov(x_c_corr,x_c_f.std(axis=-1))

tensor([[13.8793,  1.7194,  4.8242],
        [ 1.7194,  6.5937,  6.7992],
        [ 4.8242,  6.7992,  7.7081]])

In [None]:
print(xc.shape)
print(x_c.shape)
print(x_c_f.shape)

##### Validating Batch Whitening layer

In [17]:
# Our custom batch normalization layer
bw_layer = BatchWhitening(num_features,momentum=1)

# Forward pass
x_w = bw_layer(x_c)   # expecting x_c.shape=[N,C,H,W] 

x_w_cov = x_w.permute(1,0,2,3).reshape(x.shape[1],-1).cov()
print(x_w_cov)
# Check if the outputs are indeed orthonormal
assert torch.allclose(x_w_cov, torch.eye(num_features), atol=1e-2), "The outputs are not close enough!"

print("Functional validation passed!")
cov_to_corr(x_w_cov)

tensor([[1.0000e+00, 1.2562e-07, 1.0667e-05],
        [1.2562e-07, 1.0000e+00, 5.4605e-05],
        [1.0667e-05, 5.4605e-05, 9.9532e-01]], grad_fn=<SqueezeBackward0>)
Functional validation passed!


(tensor([[1.0000e+00, 1.2562e-07, 1.0692e-05],
         [1.2562e-07, 1.0000e+00, 5.4733e-05],
         [1.0692e-05, 5.4733e-05, 1.0000e+00]], grad_fn=<DivBackward0>),
 tensor(2.1850e-05, grad_fn=<DivBackward0>))

In [18]:
torch.linalg.cond(x_w_cov)

tensor(1.0047, grad_fn=<SqueezeBackward1>)

In [None]:
# since the layer also normalized the input, the correlation coefficient matrix should be nearly the same (up to approx errors) as the covariance matrix. lets check it:
print(cov_to_corr(x_w_cov))
# alternatice way to compute: - should produce the same output
print(torch.corrcoef(x_w.permute(1,0,2,3).reshape(x.shape[1],-1)))

In [None]:
x_c_cov = x_c.permute(1,0,2,3).reshape(x.shape[1],-1).cov()
# we can also compute x_c_cov as follows
# torch.cov(x_c.permute(1,0,2,3).reshape(x.shape[1],-1))
print(x_c_cov)
cov_to_corr(x_c_cov)

In [None]:
# compute the rank
torch.linalg.matrix_rank(x_c.permute(1,0,2,3).reshape(x.shape[1],-1))/x_c.shape[1]

##### Validating iter norm 

In [15]:
ItN = IterNormMod(3, num_groups=1, T=11, momentum=1, affine=False,fix_cov=False)
print(ItN)
ItN.train()
x_c.requires_grad_()
x_w = ItN(x_c)
print(x_w.shape)
x_w_cov = x_w.permute(1,0,2,3).reshape(x_w.shape[1],-1).cov()
cov_to_corr(x_w_cov)


IterNormMod(3, num_channels=3, T=11, eps=1e-05, dim=4, momentum=1, affine=False)
torch.Size([100, 3, 32, 32])


(tensor([[ 1.0000,  1.0000, -1.0000],
         [ 1.0000,  1.0000, -1.0000],
         [-1.0000, -1.0000,  1.0000]], grad_fn=<DivBackward0>),
 tensor(1.0000, grad_fn=<DivBackward0>))

In [16]:
torch.linalg.cond(x_w_cov)

tensor(48156036., grad_fn=<SqueezeBackward1>)

In [None]:
# torch.corrcoef(x_w.permute(1,0,2,3).reshape(x_w.shape[1],-1))
cov_to_corr(x_w_cov)

#### 2D tensor

In [None]:
# Create a batch of vectors (B, num_features)
num_features = 4
# num_features = 64
m = torch.randint(10,100,(1,num_features)).float()
c = torch.randint(1,10,(1,num_features))
# cov = c.T@c + 0.1*torch.eye(num_features)
cov = generate_well_conditioned_covariance(4,100)


print(f'generating a tensor of shape (B, {num_features}) with mean {m} and covariance \n {cov}:')
x = torch.randn(1000, num_features)
# x = torch.randn(128, num_features)
L = torch.linalg.cholesky(cov.float())
xc= x@L.T + m 
print(f'tensor xc.shape={xc.shape}')
print(f'actual mean and cov: {xc.mean(0)},\n {torch.cov(xc.T)}')

# print(f'rank and average correlation of x: {rank_and_avg_corr(x)}')
# print(f'rank and average correlation of xc: {rank_and_avg_corr(xc)}')


In [None]:
x_c_cov=torch.cov(xc.T)
print(f"condition number: {torch.linalg.cond(x_c_cov)}")
print("\n \n correlation matrix and average cross correlation:") 
cov_to_corr(x_c_cov)

In [None]:
print(f'eigen values of x.cov: {torch.linalg.eig(torch.cov(x.T))[0]}')
print(f'eigen values of xc.cov: {torch.linalg.eig(torch.cov(xc.T))[0]}')


In [None]:
C=torch.corrcoef(xc.T).detach().numpy()

##### Batch whitening (Cholesky)

In [None]:
# Our custom batch normalization layer
bw_layer = BatchWhitening(num_features,momentum=1)

print('cov matrix of xc.T')
print(torch.cov(xc.T))

# Forward pass - expecting xc.shape=[N,D]
x_w = bw_layer(xc)

print('cov matrix of x_w.T')
print(torch.cov(x_w.T))

# Check if the outputs are indeed orthonormal
assert torch.allclose(x_w.T.cov(), torch.eye(num_features), atol=1e-2), "The outputs are not close enough!"

print("Functional validation passed!")

In [None]:
# print the correlation matrix of the whitened signal and the average cross correlation
cov_to_corr(x_w.T.cov())

In [None]:
# should be the same correlation matrix as previous cell
torch.corrcoef(x_w.T)


##### Iternorm

using the same signal as above

In [None]:
# using the same signal 
print(f' xc.shape={xc.shape}')

In [None]:
ItN = IterNorm(num_features, num_groups=1, T=10,  momentum=1, affine=False)
print(ItN)
ItN.train()
xc.requires_grad_()
y = ItN(xc)  # y.shape=xc.shape=[N,D]
# the following is an alternative computation of y's cov matrix (implements what cov is doing behind the scene with correction=0)
z = y.transpose(0, 1).contiguous().view(x.size(1), -1) # z.shape=[D,N]
y_cov = z.matmul(z.t()) / z.size(1)
print("y_cov:")
print(y_cov)  # the outcome is the cov matrix at shape [D,D]

cov_to_corr(y_cov)

In [None]:
# alternative computation 
x_w = ItN(xc)
print(x_w.shape)
x_w_cov = x_w.T.cov(correction=0)
print(x_w_cov)

# torch.corrcoef(x_w.T)
cov_to_corr(x_w_cov)

In [None]:
xc.T.cov()

In [None]:
xc_corr=cov_to_corr(xc.T.cov())[0]
xc_corr

In [None]:
c2c=corr_to_cov(xc_corr,xc.std(axis=0))
c2c

In [None]:
# torch.std(xc,dim=0)
xc.std(axis=0)

### Debug

In [None]:
M = torch.randn(10, 3, 5)
A = torch.zeros_like(M)
batch1 = torch.randn(10, 3, 4)
batch2 = torch.randn(10, 4, 5)
B=torch.baddbmm(M, batch1, batch2,out=A)

In [None]:
def get_a(c):
    a=0.9+0.1*torch.exp(-(c/0.9)**10)
    torch.diagonal(a).fill_(1.0)
    return a


In [None]:
# C=torch.corrcoef(xc.T).detach().numpy()
C=torch.corrcoef(xc.T)
print(C)

In [None]:
C*get_a(C)

In [None]:
C*get_a(C)