# Batch Whitening Layer
The purpose of this notebook is to implement the batch whitening layer.   
The implementation is inspired by the implementation of BatchNorm layer from [this reference](https://d2l.ai/chapter_convolutional-modern/batch-norm.html)


In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print("Device to use:", device)


Device to use: cuda:0


# Batch Whitening

In [8]:
def batch_orthonorm(X, gamma, beta, running_mean=None, running_cov=None, eps=1e-5, momentum=0.1):
    # Use is_grad_enabled to determine whether we are in training mode
    assert len(X.shape) in (2, 4)
    n_features = X.shape[1]

    if len(X.shape) == 2:
        # When using a fully connected layer, calculate the mean and
        # variance on the feature dimension
        shape = (1, n_features)
        mean = X.mean(dim=0)
        cov = torch.cov(X.T,correction=0)        
        # var = ((X - mean) ** 2).mean(dim=0)
    else:
        # When using a two-dimensional convolutional layer, calculate the
        # mean and covariance on the channel dimension (axis=1). Here we
        # need to maintain the shape of X, so that the broadcasting
        # operation can be carried out later
        shape = (1, n_features, 1, 1)
        mean = X.mean(dim=(0, 2, 3))
        Xtmp = X.view(X.shape[0],X.shape[1],-1)
        Xtmp = Xtmp.permute(1,0,2).reshape(X.shape[1],-1)
        cov = torch.cov(Xtmp,correction=0) 
        # var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
    # In training mode, the current mean and variance are used
    # Update the mean and variance using moving average
    if torch.is_grad_enabled():
        running_mean = (1.0 - momentum) * running_mean + momentum * mean
        running_cov = (1.0 - momentum) * running_cov + momentum * cov
        L = torch.linalg.cholesky(cov + eps*torch.eye(n_features))
        if len(X.shape) == 2:
            X_hat = (X-mean.view(1,n_features)).T
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T
        else:
            X_hat = X-mean.view(1,n_features,1,1)
            X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)
    else:
        L = torch.linalg.cholesky(running_cov + eps*torch.eye(n_features))
        if len(X.shape) == 2:
            X_hat = (X-running_mean.view(1,n_features)).T
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T
        else:
            X_hat = X-running_mean.view(1,n_features,1,1)
            X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)
    # Y = gamma.view(shape) * Y + beta.view(shape)  # Scale and shift
    return Y, running_mean.data, running_cov.data


class BatchWhitening(nn.Module):
    # num_features: the number of outputs for a fully connected layer or the
    # number of output channels for a convolutional layer. num_dims: 2 for a
    # fully connected layer and 4 for a convolutional layer
    def __init__(self, num_features,momentum=0.1):
        super().__init__()
        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        self.momentum = momentum
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        # The variables that are not model parameters are initialized to 0 and 1
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_cov', torch.eye(num_features))

    def forward(self, X):
        # If X is not on the main memory, copy moving_mean and moving_var to
        # the device where X is located
        if self.running_mean.device != X.device:
            self.running_mean = self.running_mean.to(X.device)
            self.running_cov = self.running_cov.to(X.device)
        # Save the updated running_mean and moving_var
        # Y, self.running_mean, self.running_var = batch_orthonorm(
        #     X, self.gamma, self.beta, self.running_mean,
        #     self.running_cov, eps=1e-5, momentum=0.1)
        Y, self.running_mean, self.running_cov = batch_orthonorm(
            X, self.gamma, self.beta, self.running_mean,
            self.running_cov, eps=1e-5, momentum=self.momentum)

        return Y


### Validation

#### 4D tensor

In [9]:
# Create a batch of 2D images (batch size, channels, height, width)
num_features = 3
m = torch.randint(10,100,(1,num_features)).float()
c = torch.randint(1,10,(1,num_features))
cov = c.T@c + 0.1*torch.eye(num_features)

print(f'generating a tensor of shape (B, {num_features},32,32) with mean {m} and covariance \n {cov}:')
x = torch.randn(20, num_features, 32, 32).permute(1,0,2,3).reshape(num_features,-1).T
L = torch.linalg.cholesky(cov.float())
xc= x@L.T + m
print(f'actual mean and cov: {xc.mean(0)},\n {torch.cov(xc.T,correction=0)}')
xc = xc.permute(1,0).reshape(num_features,20,32,32).permute(1,0,2,3)



generating a tensor of shape (B, 3,32,32) with mean tensor([[59., 78., 62.]]) and covariance 
 tensor([[49.1000, 14.0000, 56.0000],
        [14.0000,  4.1000, 16.0000],
        [56.0000, 16.0000, 64.1000]]):
actual mean and cov: tensor([59.0727, 78.0227, 62.0803]),
 tensor([[49.8842, 14.2209, 56.8721],
        [14.2209,  4.1618, 16.2462],
        [56.8721, 16.2462, 65.0648]])


In [10]:
# Our custom batch normalization layer
bw_layer = BatchWhitening(num_features,momentum=1)

# Forward pass
x_w = bw_layer(xc)

x_w_cov = x_w.permute(1,0,2,3).reshape(x.shape[1],-1).cov()
print(x_w_cov)
# Check if the outputs are indeed orthonormal
assert torch.allclose(x_w_cov, torch.eye(num_features), atol=1e-2), "The outputs are not close enough!"

print("Functional validation passed!")

tensor([[ 1.0001e+00, -1.7687e-05, -4.7574e-05],
        [-1.7687e-05,  1.0002e+00,  2.9081e-04],
        [-4.7574e-05,  2.9081e-04,  1.0020e+00]])
Functional validation passed!


In [11]:
torch.is_grad_enabled()


True

#### 2D tensor

In [5]:
# Create a batch of vectors (B, num_features)
num_features = 10
m = torch.randint(10,100,(1,num_features)).float()
c = torch.randint(1,10,(1,num_features))
cov = c.T@c + 0.1*torch.eye(num_features)

print(f'generating a tensor of shape (B, {num_features}) with mean {m} and covariance \n {cov}:')
x = torch.randn(20000, num_features)
L = torch.linalg.cholesky(cov.float())
xc= x@L.T + m 

print(f'actual mean and cov: {xc.mean(0)},\n {torch.cov(xc.T,correction=0)}')


generating a tensor of shape (B, 10) with mean tensor([[66., 38., 73., 19., 88., 82., 68., 25., 33., 73.]]) and covariance 
 tensor([[25.1000, 10.0000, 25.0000, 10.0000, 40.0000, 10.0000,  5.0000,  5.0000,
         20.0000, 40.0000],
        [10.0000,  4.1000, 10.0000,  4.0000, 16.0000,  4.0000,  2.0000,  2.0000,
          8.0000, 16.0000],
        [25.0000, 10.0000, 25.1000, 10.0000, 40.0000, 10.0000,  5.0000,  5.0000,
         20.0000, 40.0000],
        [10.0000,  4.0000, 10.0000,  4.1000, 16.0000,  4.0000,  2.0000,  2.0000,
          8.0000, 16.0000],
        [40.0000, 16.0000, 40.0000, 16.0000, 64.1000, 16.0000,  8.0000,  8.0000,
         32.0000, 64.0000],
        [10.0000,  4.0000, 10.0000,  4.0000, 16.0000,  4.1000,  2.0000,  2.0000,
          8.0000, 16.0000],
        [ 5.0000,  2.0000,  5.0000,  2.0000,  8.0000,  2.0000,  1.1000,  1.0000,
          4.0000,  8.0000],
        [ 5.0000,  2.0000,  5.0000,  2.0000,  8.0000,  2.0000,  1.0000,  1.1000,
          4.0000,  8.0000],
   

In [6]:
# Our custom batch normalization layer
bw_layer = BatchWhitening(num_features,momentum=1)

# Forward pass
x_w = bw_layer(xc)

print(torch.cov(x_w.T,correction=0))

# Check if the outputs are indeed orthonormal
assert torch.allclose(x_w.T.cov(), torch.eye(num_features), atol=1e-2), "The outputs are not close enough!"

# print("Functional validation passed!")

tensor([[ 1.0000e+00,  5.3830e-06,  4.7240e-05, -7.9326e-06, -5.7434e-05,
          2.1321e-06,  6.7980e-06, -1.5218e-06, -2.1695e-06, -1.5783e-05],
        [ 5.3841e-06,  9.9998e-01,  3.5401e-05, -6.2590e-06,  5.5552e-04,
         -1.7931e-04, -6.6148e-05, -3.5519e-05, -1.2134e-04, -4.0477e-04],
        [ 4.7240e-05,  3.5401e-05,  9.9907e-01,  4.9824e-05, -1.6165e-04,
         -9.6805e-06, -6.3025e-05,  1.2189e-05, -2.2067e-04,  1.8161e-04],
        [-7.9219e-06, -6.2509e-06,  4.9824e-05,  1.0000e+00,  5.2157e-04,
         -1.2211e-04,  1.8204e-05,  4.2289e-05, -1.0851e-04, -1.6295e-04],
        [-5.7434e-05,  5.5552e-04, -1.6165e-04,  5.2157e-04,  1.0017e+00,
         -1.5310e-04, -7.7433e-05, -2.6723e-05, -2.6612e-04, -8.6174e-04],
        [ 2.1360e-06, -1.7930e-04, -9.6805e-06, -1.2211e-04, -1.5310e-04,
          1.0001e+00, -1.9027e-05, -1.3850e-05,  3.7746e-04,  6.4313e-04],
        [ 6.7980e-06, -6.6148e-05, -6.3025e-05,  1.8204e-05, -7.7433e-05,
         -1.9027e-05,  9.9991e-0