In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
#selecting gpu instead of cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Setting random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
Using device: cpu

In [None]:
class ESM_Tracker:
    #track Estimation Shift Magnitude during training
    def __init__(self):
        self.expected_stats = {}
        self.estimated_stats = {}
        self.esm_history = {}

    def update_expected(self, layer_name, input_data):
        """Calculate expected statistics on current batch"""
        if layer_name not in self.expected_stats:
            self.expected_stats[layer_name] = {'mean': [], 'var': []}

        #calculate the statistics for current batch (what we "expect")
        if len(input_data.shape) == 4:  # CNN case
            mean = input_data.mean(dim=(0, 2, 3))
            var = input_data.var(dim=(0, 2, 3))
        else:  # MLP case
            mean = input_data.mean(dim=0)
            var = input_data.var(dim=0)

        self.expected_stats[layer_name]['mean'].append(mean.detach())
        self.expected_stats[layer_name]['var'].append(var.detach())

    def update_estimated(self, layer_name, bn_layer):
        """Get estimated statistics from BN running averages"""
        if layer_name not in self.estimated_stats:
            self.estimated_stats[layer_name] = {'mean': [], 'var': []}

        self.estimated_stats[layer_name]['mean'].append(bn_layer.running_mean.detach().clone())
        self.estimated_stats[layer_name]['var'].append(bn_layer.running_var.detach().clone())

    def calculate_esm(self, layer_name):
        """Calculate Estimation Shift Magnitude"""
        if layer_name not in self.expected_stats or layer_name not in self.estimated_stats:
            return 0, 0

        #use the most recent statistics
        exp_mean = self.expected_stats[layer_name]['mean'][-1]
        exp_var = self.expected_stats[layer_name]['var'][-1]
        est_mean = self.estimated_stats[layer_name]['mean'][-1]
        est_var = self.estimated_stats[layer_name]['var'][-1]

        esm_mean = torch.norm(exp_mean - est_mean, p=2).item()
        esm_var = torch.norm(torch.sqrt(exp_var + 1e-5) - torch.sqrt(est_var + 1e-5), p=2).item()

        #store history
        if layer_name not in self.esm_history:
            self.esm_history[layer_name] = {'mean': [], 'var': []}

        self.esm_history[layer_name]['mean'].append(esm_mean)
        self.esm_history[layer_name]['var'].append(esm_var)

        return esm_mean, esm_var

class XBNBlock(nn.Module):
    """XBNBlock with BFN at position 2 (P2)"""
    def __init__(self, in_channels, out_channels, stride=1, bfn_type='GN'):
        super(XBNBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels//4, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels//4)

        self.conv2 = nn.Conv2d(out_channels//4, out_channels//4, 3, stride=stride, padding=1, bias=False)
        #this is where we put BFN instead of BN (XBNBlock-P2, meaning Position 2 placement)
        if bfn_type == 'GN':
            self.norm2 = nn.GroupNorm(32, out_channels//4)  #GroupNorm as BFN
        else:  #InstanceNorm
            self.norm2 = nn.InstanceNorm2d(out_channels//4)

        self.conv3 = nn.Conv2d(out_channels//4, out_channels, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.norm2(self.conv2(out)))  #BFN here
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out