# Batch Whitening Layer
The purpose of this notebook is to implement the batch whitening layer.   
The implementation is inspired by the implementation of BatchNorm layer from [this reference](https://d2l.ai/chapter_convolutional-modern/batch-norm.html)


In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print("Device to use:", device)


Device to use: cuda:0


## Batch Normalization  
Lets start with implementing BatchNorm from scratch and test it on a simple dataset.


### Option 1

In [None]:
def batch_norm(X, gamma, beta, running_mean, running_var, eps, momentum):
    # Use is_grad_enabled to determine whether we are in training mode
    assert len(X.shape) in (2, 4)
    if len(X.shape) == 2:
        shape = (1, X.shape[1])
    else:
        shape = (1, X.shape[1], 1, 1)

    if not torch.is_grad_enabled():
        # In prediction mode, use mean and variance obtained by moving average
        X_hat = (X - running_mean) / torch.sqrt(running_var + eps)
    else:
        if len(X.shape) == 2:
            # When using a fully connected layer, calculate the mean and
            # variance on the feature dimension
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # When using a two-dimensional convolutional layer, calculate the
            # mean and variance on the channel dimension (axis=1). Here we
            # need to maintain the shape of X, so that the broadcasting
            # operation can be carried out later
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # In training mode, the current mean and variance are used
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # Update the mean and variance using moving average
        running_mean = (1.0 - momentum) * running_mean + momentum * mean
        running_var = (1.0 - momentum) * running_var + momentum * var
    Y = gamma.view(shape) * X_hat + beta.view(shape)  # Scale and shift
    return Y, running_mean.data, running_var.data


class BatchNorm(nn.Module):
    # num_features: the number of outputs for a fully connected layer or the
    # number of output channels for a convolutional layer. num_dims: 2 for a
    # fully connected layer and 4 for a convolutional layer
    def __init__(self, num_features, num_dims):
        super().__init__()
        # if num_dims == 2:
        #     shape = (1, num_features)
        # else:
        #     shape = (1, num_features, 1, 1)
        shape = num_features
        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))

        # The variables that are not model parameters are initialized to 0 and 1
        # self.running_mean = torch.zeros(shape)
        # self.running_var = torch.ones(shape)
        self.register_buffer('running_mean', torch.zeros(shape))
        self.register_buffer('running_var', torch.ones(shape))

    def forward(self, X):
        # If X is not on the main memory, copy moving_mean and moving_var to
        # the device where X is located
        if self.running_mean.device != X.device:
            self.running_mean = self.running_mean.to(X.device)
            self.running_var = self.running_var.to(X.device)
        # Save the updated running_mean and moving_var
        Y, self.running_mean, self.running_var = batch_norm(
            X, self.gamma, self.beta, self.running_mean,
            self.running_var, eps=1e-5, momentum=0.1)
        return Y


### Option 2

In [None]:
class BatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(BatchNorm2d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        # Learnable parameters
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        # Running mean and variance
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        if self.training:
            # Calculate mean and variance for the batch
            mean = x.mean([0, 2, 3], keepdim=True)
            var = x.var([0, 2, 3], keepdim=True, unbiased=False)
            # Update running mean and variance
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var
        
        # Normalize
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)
        # Scale and shift
        out = self.gamma.view(1, self.num_features, 1, 1) * x_normalized + self.beta.view(1, self.num_features, 1, 1)
        
        return out

### Validation


In [None]:
# Create a batch of 2D images (batch size, channels, height, width)
x = torch.randn(20, 10, 50, 50)

# Our custom batch normalization layer
custom_bn = BatchNorm(num_features=10, num_dims=4)
# custom_bn = BatchNorm2d(num_features=10)

# PyTorch's built-in batch normalization layer
torch_bn = nn.BatchNorm2d(num_features=10)


# Copy the parameters from our custom layer to the built-in layer for a fair comparison
torch_bn.weight.data = custom_bn.gamma.data.clone()
torch_bn.bias.data = custom_bn.beta.data.clone()
torch_bn.running_mean = custom_bn.running_mean.clone()
torch_bn.running_var = custom_bn.running_var.clone()

# Forward pass
custom_bn_output = custom_bn(x)
torch_bn_output = torch_bn(x)

# Check if the outputs are close
assert torch.allclose(custom_bn_output, torch_bn_output, atol=1e-5), "The outputs are not close enough!"

print("Functional validation passed!")

In [None]:
x.shape

In [None]:
x.var([0, 2, 3], keepdim=True, unbiased=False).shape

In [None]:
custom_bn.training

# Batch Whitening

In [2]:
def batch_orthonorm_obsolete(X, gamma, beta, running_mean=None, running_cov=None, eps=1e-5, momentum=0.1):
    # Use is_grad_enabled to determine whether we are in training mode
    assert len(X.shape) in (2, 4)
    n_features = X.shape[1]

    if len(X.shape) == 2:
        # When using a fully connected layer, calculate the mean and
        # variance on the feature dimension
        shape = (1, n_features)
        mean = X.mean(dim=0)
        cov = torch.cov(X.T,correction=0)        
        # var = ((X - mean) ** 2).mean(dim=0)
    else:
        # When using a two-dimensional convolutional layer, calculate the
        # mean and covariance on the channel dimension (axis=1). Here we
        # need to maintain the shape of X, so that the broadcasting
        # operation can be carried out later
        shape = (1, n_features, 1, 1)
        mean = X.mean(dim=(0, 2, 3))
        Xtmp = X.view(X.shape[0],X.shape[1],-1)
        Xtmp = Xtmp.permute(1,0,2).reshape(X.shape[1],-1)
        cov = torch.cov(Xtmp,correction=0) 
        # var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
    # In training mode, the current mean and variance are used
    # Update the mean and variance using moving average
    running_mean = (1.0 - momentum) * running_mean + momentum * mean
    running_cov = (1.0 - momentum) * running_cov + momentum * cov
    L = torch.linalg.cholesky(running_cov + eps*torch.eye(n_features))
    if len(X.shape) == 2:
        X_hat = (X-running_mean.view(1,n_features)).T
        Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T
    else:
        X_hat = X-running_mean.view(1,n_features,1,1)
        X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
        Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)
    # Y = gamma.view(shape) * Y + beta.view(shape)  # Scale and shift
    return Y, running_mean.data, running_cov.data

def batch_orthonorm(X, gamma, beta, running_mean=None, running_cov=None, eps=1e-5, momentum=0.1):
    # Use is_grad_enabled to determine whether we are in training mode
    assert len(X.shape) in (2, 4)
    n_features = X.shape[1]

    if len(X.shape) == 2:
        # When using a fully connected layer, calculate the mean and
        # variance on the feature dimension
        shape = (1, n_features)
        mean = X.mean(dim=0)
        cov = torch.cov(X.T,correction=0)        
        # var = ((X - mean) ** 2).mean(dim=0)
    else:
        # When using a two-dimensional convolutional layer, calculate the
        # mean and covariance on the channel dimension (axis=1). Here we
        # need to maintain the shape of X, so that the broadcasting
        # operation can be carried out later
        shape = (1, n_features, 1, 1)
        mean = X.mean(dim=(0, 2, 3))
        Xtmp = X.view(X.shape[0],X.shape[1],-1)
        Xtmp = Xtmp.permute(1,0,2).reshape(X.shape[1],-1)
        cov = torch.cov(Xtmp,correction=0) 
        # var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
    # In training mode, the current mean and variance are used
    # Update the mean and variance using moving average
    if torch.is_grad_enabled():
        running_mean = (1.0 - momentum) * running_mean + momentum * mean
        running_cov = (1.0 - momentum) * running_cov + momentum * cov
        L = torch.linalg.cholesky(cov + eps*torch.eye(n_features))
        if len(X.shape) == 2:
            X_hat = (X-mean.view(1,n_features)).T
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T
        else:
            X_hat = X-mean.view(1,n_features,1,1)
            X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)
    else:
        L = torch.linalg.cholesky(running_cov + eps*torch.eye(n_features))
        if len(X.shape) == 2:
            X_hat = (X-running_mean.view(1,n_features)).T
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T
        else:
            X_hat = X-running_mean.view(1,n_features,1,1)
            X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
            Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)
    # Y = gamma.view(shape) * Y + beta.view(shape)  # Scale and shift
    return Y, running_mean.data, running_cov.data

class BatchWhitening(nn.Module):
    # num_features: the number of outputs for a fully connected layer or the
    # number of output channels for a convolutional layer. num_dims: 2 for a
    # fully connected layer and 4 for a convolutional layer
    def __init__(self, num_features,momentum=0.1):
        super().__init__()
        # The scale parameter and the shift parameter (model parameters) are
        # initialized to 1 and 0, respectively
        self.momentum = momentum
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        # The variables that are not model parameters are initialized to 0 and 1
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_cov', torch.eye(num_features))

    def forward(self, X):
        # If X is not on the main memory, copy moving_mean and moving_var to
        # the device where X is located
        if self.running_mean.device != X.device:
            self.running_mean = self.running_mean.to(X.device)
            self.running_cov = self.running_cov.to(X.device)
        # Save the updated running_mean and moving_var
        # Y, self.running_mean, self.running_var = batch_orthonorm(
        #     X, self.gamma, self.beta, self.running_mean,
        #     self.running_cov, eps=1e-5, momentum=0.1)
        Y, self.running_mean, self.running_cov = batch_orthonorm(
            X, self.gamma, self.beta, self.running_mean,
            self.running_cov, eps=1e-5, momentum=self.momentum)

        return Y


### Validation

#### 4D tensor

In [5]:
# Create a batch of 2D images (batch size, channels, height, width)
num_features = 3
m = torch.randint(10,100,(1,num_features)).float()
c = torch.randint(1,10,(1,num_features))
cov = c.T@c + 0.1*torch.eye(num_features)

print(f'generating a tensor of shape (B, {num_features},32,32) with mean {m} and covariance \n {cov}:')
x = torch.randn(20, num_features, 32, 32).permute(1,0,2,3).reshape(num_features,-1).T
L = torch.linalg.cholesky(cov.float())
xc= x@L.T + m
print(f'actual mean and cov: {xc.mean(0)},\n {torch.cov(xc.T,correction=0)}')
xc = xc.permute(1,0).reshape(num_features,20,32,32).permute(1,0,2,3)



generating a tensor of shape (B, 3,32,32) with mean tensor([[29., 27., 21.]]) and covariance 
 tensor([[64.1000, 64.0000, 40.0000],
        [64.0000, 64.1000, 40.0000],
        [40.0000, 40.0000, 25.1000]]):
actual mean and cov: tensor([29.0015, 27.0027, 20.9997]),
 tensor([[64.2156, 64.1060, 40.0807],
        [64.1060, 64.1925, 40.0733],
        [40.0807, 40.0733, 25.1551]])


In [6]:
# Our custom batch normalization layer
bw_layer = BatchWhitening(num_features,momentum=1)

# Forward pass
x_w = bw_layer(xc)

x_w_cov = x_w.permute(1,0,2,3).reshape(x.shape[1],-1).cov()
print(x_w_cov)
# Check if the outputs are indeed orthonormal
assert torch.allclose(x_w_cov, torch.eye(num_features), atol=1e-2), "The outputs are not close enough!"

print("Functional validation passed!")

tensor([[ 1.0001e+00, -7.1795e-05, -2.0783e-05],
        [-7.1795e-05,  1.0022e+00, -3.2540e-04],
        [-2.0783e-05, -3.2540e-04,  1.0007e+00]])
Functional validation passed!


#### 2D tensor

In [36]:
# Create a batch of vectors (B, num_features)
num_features = 10
m = torch.randint(10,100,(1,num_features)).float()
c = torch.randint(1,10,(1,num_features))
cov = c.T@c + 0.1*torch.eye(num_features)

print(f'generating a tensor of shape (B, {num_features}) with mean {m} and covariance \n {cov}:')
x = torch.randn(20000, num_features)
L = torch.linalg.cholesky(cov.float())
xc= x@L.T + m 

print(f'actual mean and cov: {xc.mean(0)},\n {torch.cov(xc.T,correction=0)}')


generating a tensor of shape (B, 10) with mean tensor([[25., 62., 16., 28., 47., 59., 94., 44., 80., 54.]]) and covariance 
 tensor([[81.1000, 27.0000, 63.0000, 36.0000, 63.0000, 18.0000, 27.0000, 18.0000,
         81.0000, 63.0000],
        [27.0000,  9.1000, 21.0000, 12.0000, 21.0000,  6.0000,  9.0000,  6.0000,
         27.0000, 21.0000],
        [63.0000, 21.0000, 49.1000, 28.0000, 49.0000, 14.0000, 21.0000, 14.0000,
         63.0000, 49.0000],
        [36.0000, 12.0000, 28.0000, 16.1000, 28.0000,  8.0000, 12.0000,  8.0000,
         36.0000, 28.0000],
        [63.0000, 21.0000, 49.0000, 28.0000, 49.1000, 14.0000, 21.0000, 14.0000,
         63.0000, 49.0000],
        [18.0000,  6.0000, 14.0000,  8.0000, 14.0000,  4.1000,  6.0000,  4.0000,
         18.0000, 14.0000],
        [27.0000,  9.0000, 21.0000, 12.0000, 21.0000,  6.0000,  9.1000,  6.0000,
         27.0000, 21.0000],
        [18.0000,  6.0000, 14.0000,  8.0000, 14.0000,  4.0000,  6.0000,  4.1000,
         18.0000, 14.0000],
   

In [37]:
# Our custom batch normalization layer
bw_layer = BatchWhitening(num_features,momentum=1)

# Forward pass
x_w = bw_layer(xc)

print(torch.cov(x_w.T,correction=0))

# Check if the outputs are indeed orthonormal
assert torch.allclose(x_w.T.cov(), torch.eye(num_features), atol=1e-2), "The outputs are not close enough!"

# print("Functional validation passed!")

tensor([[ 1.0000e+00, -3.1479e-05, -1.2375e-04, -1.7883e-05,  6.2787e-05,
         -1.2165e-05, -2.6292e-05, -6.3144e-06,  6.5099e-05, -7.6969e-05],
        [-3.1467e-05,  1.0001e+00,  1.1378e-03,  3.5838e-05, -8.6736e-04,
          4.6297e-05,  2.0276e-04,  1.8911e-04, -4.8607e-04,  6.3275e-04],
        [-1.2375e-04,  1.1378e-03,  1.0033e+00,  8.6579e-04, -6.9715e-05,
          1.0257e-04,  3.2663e-04,  8.0492e-05, -6.5207e-05,  2.0619e-03],
        [-1.7882e-05,  3.5836e-05,  8.6579e-04,  9.9959e-01, -5.7881e-04,
          1.5406e-04,  4.0285e-04, -4.8886e-05, -4.8925e-04,  2.9243e-04],
        [ 6.2787e-05, -8.6736e-04, -6.9715e-05, -5.7881e-04,  9.9747e-01,
          2.1352e-04,  1.9974e-06, -4.2332e-05, -1.8930e-03,  3.2424e-04],
        [-1.2162e-05,  4.6294e-05,  1.0257e-04,  1.5406e-04,  2.1352e-04,
          1.0000e+00,  2.5533e-04,  1.6552e-04, -5.6447e-04,  5.0011e-04],
        [-2.6292e-05,  2.0276e-04,  3.2663e-04,  4.0285e-04,  1.9974e-06,
          2.5533e-04,  1.0002e+0

In [None]:
custom_bn_output.permute(1,0,2,3).reshape(x.shape[1],-1).cov()

In [None]:
mean = torch.arange(0, 10)
cov = 

In [12]:
c=torch.randint(1,10,(3,1))
c.shape

torch.Size([3, 1])

In [5]:
c

tensor([2, 1, 1])

In [13]:
cov = c@c.T
cov

tensor([[25, 15, 35],
        [15,  9, 21],
        [35, 21, 49]])

In [None]:
X = torch.randn(20, 10, 50, 50)

mean = X.mean(dim=(0, 2, 3))
Xtmp = X.view(X.shape[0],X.shape[1],-1)
Xtmp = Xtmp.permute(1,0,2).reshape(X.shape[1],-1)
cov = torch.cov(Xtmp,correction=0) 

assert torch.allclose(cov, torch.eye(10), atol=1e-5), "The outputs are not close enough!"
print("Functional validation passed!")

### Debug

In [None]:
# 4D input tensor
n_features = 10
X = torch.randn(20, n_features, 50, 50)

shape = n_features
gamma = torch.ones(shape)
beta = torch.zeros(shape)

running_mean = torch.zeros(shape)
running_cov = torch.eye(shape)

mean = X.mean(dim=(0, 2, 3))
Xtmp = X.view(X.shape[0],X.shape[1],-1)
Xtmp = Xtmp.permute(1,0,2).reshape(X.shape[1],-1)
cov = torch.cov(Xtmp,correction=0) 


momentum = 0.1 
eps = 1e-5
running_mean = (1.0 - momentum) * running_mean + momentum * mean
running_cov = (1.0 - momentum) * running_cov + momentum * cov

L = torch.linalg.cholesky(running_cov + eps*torch.eye(n_features))
X_hat = X-running_mean.view(1,n_features,1,1)
X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)


# assert torch.allclose(cov, torch.eye(10), atol=1e-5), "The outputs are not close enough!"
# print("Functional validation passed!")

In [None]:
# 2D input tensor
n_features = 10
X = torch.randn(20, n_features)

shape = n_features
gamma = torch.ones(shape)
beta = torch.zeros(shape)

running_mean = torch.zeros(shape)
running_cov = torch.eye(shape)

mean = X.mean(dim=0)
cov = torch.cov(X.T,correction=0)

momentum = 0.1 
eps = 1e-5
running_mean = (1.0 - momentum) * running_mean + momentum * mean
running_cov = (1.0 - momentum) * running_cov + momentum * cov

L = torch.linalg.cholesky(running_cov + eps*torch.eye(n_features))
X_hat = X-running_mean.view(1,n_features)
X_hat = X_hat.T
Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T


In [None]:
def batch_orthonorm(X, gamma, beta, running_mean, running_cov, eps, momentum):
    # Use is_grad_enabled to determine whether we are in training mode
    assert len(X.shape) in (2, 4)
    n_features = X.shape[1]

    if len(X.shape) == 2:
        # When using a fully connected layer, calculate the mean and
        # variance on the feature dimension
        shape = (1, n_features)
        mean = X.mean(dim=0)
        cov = torch.cov(X.T,correction=0)        
        # var = ((X - mean) ** 2).mean(dim=0)
    else:
        # When using a two-dimensional convolutional layer, calculate the
        # mean and covariance on the channel dimension (axis=1). Here we
        # need to maintain the shape of X, so that the broadcasting
        # operation can be carried out later
        shape = (1, n_features, 1, 1)
        mean = X.mean(dim=(0, 2, 3), keepdim=True)
        Xtmp = X.view(X.shape[0],X.shape[1],-1)
        Xtmp = Xtmp.permute(1,0,2).reshape(X.shape[1],-1)
        cov = torch.cov(Xtmp,correction=0) 
        # var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
    # In training mode, the current mean and variance are used
    # Update the mean and variance using moving average
    running_mean = (1.0 - momentum) * running_mean + momentum * mean
    running_cov = (1.0 - momentum) * running_cov + momentum * cov
    L = torch.linalg.cholesky(running_cov + eps*torch.eye(n_features))
    if len(X.shape) == 2:
        X_hat = (X-running_mean.view(1,n_features)).T
        Y = torch.linalg.solve_triangular(L,X_hat,upper=False).T
    else:
        X_hat = X-running_mean.view(1,n_features,1,1)
        X_hat = X_hat.permute(1,0,2,3).reshape(X.shape[1],-1)
        Y = torch.linalg.solve_triangular(L,X_hat,upper=False).reshape(X.shape[1],X.shape[0],X.shape[2],X.shape[3]).permute(1,0,2,3)
    # Y = gamma.view(shape) * Y + beta.view(shape)  # Scale and shift
    return Y, running_mean.data, running_cov.data


In [None]:
Y.shape