In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
#selecting gpu instead of cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Setting random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
Using device: cpu

In [None]:
class ESM_Tracker:
    #track Estimation Shift Magnitude during training
    def __init__(self):
        self.expected_stats = {}
        self.estimated_stats = {}
        self.esm_history = {}

    def update_expected(self, layer_name, input_data):
        """Calculate expected statistics on current batch"""
        if layer_name not in self.expected_stats:
            self.expected_stats[layer_name] = {'mean': [], 'var': []}

        #calculate the statistics for current batch (what we "expect")
        if len(input_data.shape) == 4:  # CNN case
            mean = input_data.mean(dim=(0, 2, 3))
            var = input_data.var(dim=(0, 2, 3))
        else:  # MLP case
            mean = input_data.mean(dim=0)
            var = input_data.var(dim=0)

        self.expected_stats[layer_name]['mean'].append(mean.detach())
        self.expected_stats[layer_name]['var'].append(var.detach())

    def update_estimated(self, layer_name, bn_layer):
        """Get estimated statistics from BN running averages"""
        if layer_name not in self.estimated_stats:
            self.estimated_stats[layer_name] = {'mean': [], 'var': []}

        self.estimated_stats[layer_name]['mean'].append(bn_layer.running_mean.detach().clone())
        self.estimated_stats[layer_name]['var'].append(bn_layer.running_var.detach().clone())

    def calculate_esm(self, layer_name):
        """Calculate Estimation Shift Magnitude"""
        if layer_name not in self.expected_stats or layer_name not in self.estimated_stats:
            return 0, 0

        #use the most recent statistics
        exp_mean = self.expected_stats[layer_name]['mean'][-1]
        exp_var = self.expected_stats[layer_name]['var'][-1]
        est_mean = self.estimated_stats[layer_name]['mean'][-1]
        est_var = self.estimated_stats[layer_name]['var'][-1]

        esm_mean = torch.norm(exp_mean - est_mean, p=2).item()
        esm_var = torch.norm(torch.sqrt(exp_var + 1e-5) - torch.sqrt(est_var + 1e-5), p=2).item()

        #store history
        if layer_name not in self.esm_history:
            self.esm_history[layer_name] = {'mean': [], 'var': []}

        self.esm_history[layer_name]['mean'].append(esm_mean)
        self.esm_history[layer_name]['var'].append(esm_var)

        return esm_mean, esm_var

class XBNBlock(nn.Module):
    """XBNBlock with BFN at position 2 (P2)"""
    def __init__(self, in_channels, out_channels, stride=1, bfn_type='GN'):
        super(XBNBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels//4, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels//4)

        self.conv2 = nn.Conv2d(out_channels//4, out_channels//4, 3, stride=stride, padding=1, bias=False)
        #this is where we put BFN instead of BN (XBNBlock-P2, meaning Position 2 placement)
        if bfn_type == 'GN':
            self.norm2 = nn.GroupNorm(32, out_channels//4)  #GroupNorm as BFN
        else:  #InstanceNorm
            self.norm2 = nn.InstanceNorm2d(out_channels//4)

        self.conv3 = nn.Conv2d(out_channels//4, out_channels, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.norm2(self.conv2(out)))  #BFN here
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, use_xbn=False):
        super(SimpleCNN, self).__init__()
        self.use_xbn = use_xbn
        self.esm_tracker = ESM_Tracker()

        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)

        if use_xbn:
            self.block1 = XBNBlock(64, 128)
            self.block2 = XBNBlock(128, 256)
        else:
            #create proper residual blocks for vanilla model
            self.block1 = self._make_vanilla_res_block(64, 128)
            self.block2 = self._make_vanilla_res_block(128, 256)

        self.fc = nn.Linear(256 * 8 * 8, 10)
        self.layers_to_track = ['bn1', 'block1.bn1', 'block1.bn3', 'block2.bn1', 'block2.bn3']

    def _make_vanilla_res_block(self, in_channels, out_channels):
        """Create a residual block similar to XBNBlock but with all BNs"""
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels//4, 1, bias=False),
            nn.BatchNorm2d(out_channels//4),
            nn.ReLU(),
            nn.Conv2d(out_channels//4, out_channels//4, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels//4),  # This will be position P2
            nn.ReLU(),
            nn.Conv2d(out_channels//4, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x, track_esm=False, epoch=0):
        #Layer 1
        out = self.conv1(x)
        if track_esm:
            self.esm_tracker.update_expected('bn1', out)
        out = self.bn1(out)
        if track_esm:
            self.esm_tracker.update_estimated('bn1', self.bn1)
        out = F.relu(out)
        out = F.max_pool2d(out, 2)

        #Blocks
        out = self.block1(out)
        out = self.block2(out)

        out = F.adaptive_avg_pool2d(out, (8, 8))
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [None]:
def train_model(model, train_loader, test_loader, epochs=50, model_name='vanilla'):
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    train_losses = []
    test_accuracies = []
    esm_history = {name: [] for name in model.layers_to_track}

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')

        for batch_idx, (data, target) in enumerate(progress_bar):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data, track_esm=(batch_idx % 10 == 0))  # Track ESM periodically
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})

        # Calculate ESM at end of epoch
        if epoch % 5 == 0:  # Track ESM every 5 epochs
            with torch.no_grad():
        # Simple ESM tracking for just the first BN layer
                if hasattr(model, 'bn1'):
             # Use a dummy forward pass to get expected stats
                    dummy_data = next(iter(train_loader))[0][:1]  # Single sample
                    _ = model(dummy_data, track_esm=True)
                    esm_m, esm_v = model.esm_tracker.calculate_esm('bn1')
                    esm_history['bn1'].append(esm_v)

        # Test accuracy
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                output = model(data)
                _, predicted = torch.max(output.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()

        accuracy = 100 * correct / total
        test_accuracies.append(accuracy)
        train_losses.append(total_loss / len(train_loader))

        print(f'Epoch {epoch+1}: Loss = {train_losses[-1]:.4f}, Test Acc = {accuracy:.2f}%')

    return train_losses, test_accuracies, esm_history

In [None]:
def test_robustness(model, test_loader, noise_levels=[0.0, 0.1, 0.2, 0.3, 0.4]):
    """Test model robustness to noise in BN statistics"""
    original_stats = {}

    # Save original BN statistics
    for name, module in model.named_modules():
        if isinstance(module, nn.BatchNorm2d):
            original_stats[name] = {
                'running_mean': module.running_mean.clone(),
                'running_var': module.running_var.clone()
            }

    accuracies = []
    confidences = []  # We'll track predictive entropy

    for noise_magnitude in noise_levels:
        # Add noise to BN statistics
        for name, module in model.named_modules():
            if isinstance(module, nn.BatchNorm2d):
                module.running_mean = original_stats[name]['running_mean'] * (
                    1 + torch.randn_like(original_stats[name]['running_mean']) * noise_magnitude
                )
                module.running_var = original_stats[name]['running_var'] * (
                    1 + torch.randn_like(original_stats[name]['running_var']) * noise_magnitude
                )

        # Test accuracy with noisy stats
        model.eval()
        correct = 0
        total = 0
        all_entropies = []

        with torch.no_grad():
            for data, target in test_loader:
                output = model(data)
                probabilities = F.softmax(output, dim=1)

                # Calculate predictive entropy
                entropy = -torch.sum(probabilities * torch.log(probabilities + 1e-8), dim=1)
                all_entropies.extend(entropy.cpu().numpy())

                _, predicted = torch.max(output.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()

        accuracy = 100 * correct / total
        avg_entropy = np.mean(all_entropies)

        accuracies.append(accuracy)
        confidences.append(avg_entropy)

        print(f'Noise {noise_magnitude:.1f}: Accuracy = {accuracy:.2f}%, Avg Entropy = {avg_entropy:.4f}')

    # Restore original statistics
    for name, module in model.named_modules():
        if isinstance(module, nn.BatchNorm2d):
            module.running_mean = original_stats[name]['running_mean']
            module.running_var = original_stats[name]['running_var']

    return accuracies, confidences, noise_levels

In [None]:
def main():
    # Data loading
    print("Loading CIFAR-10 dataset...")
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)

    # Train vanilla model (with fewer epochs for testing)
    print("\n=== Training Vanilla CNN ===")
    vanilla_model = SimpleCNN(use_xbn=False)
    vanilla_model = vanilla_model.to(device)
    vanilla_loss, vanilla_acc, vanilla_esm = train_model(vanilla_model, train_loader, test_loader, epochs=10, model_name='vanilla')

    # Train XBNBlock model
    print("\n=== Training XBNBlock CNN ===")
    xbn_model = SimpleCNN(use_xbn=True)
    xbn_model = xbn_model.to(device)
    xbn_loss, xbn_acc, xbn_esm = train_model(xbn_model, train_loader, test_loader, epochs=10, model_name='xbn')

    # Test robustness
    print("\n=== Testing Robustness ===")
    vanilla_accuracies, vanilla_entropies, noise_levels = test_robustness(vanilla_model, test_loader)
    xbn_accuracies, xbn_entropies, _ = test_robustness(xbn_model, test_loader)

    return {
        'vanilla': {'acc': vanilla_acc, 'loss': vanilla_loss, 'esm': vanilla_esm,
                   'robust_acc': vanilla_accuracies, 'robust_entropy': vanilla_entropies},
        'xbn': {'acc': xbn_acc, 'loss': xbn_loss, 'esm': xbn_esm,
               'robust_acc': xbn_accuracies, 'robust_entropy': xbn_entropies},
        'noise_levels': noise_levels
    }

# Run the experiment
results = main()