In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt

In [2]:
torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [3]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64*7*7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [4]:
class RandomizedSmoothing:
    def __init__(self, base_classifier, sigma=0.25, num_classes=10):
        self.base_classifier, self.sigma, self.num_classes = base_classifier, sigma, num_classes

    def predict(self, x, n_samples=1000, alpha=0.001, batch_size=100):
        self.base_classifier.eval()
        votes = torch.zeros(x.size(0), self.num_classes).to(device)
        with torch.no_grad():
            for i in range(0, n_samples, batch_size):
                current_batch_size = min(batch_size, n_samples - i)
                x_batch = x.repeat(current_batch_size, 1, 1, 1)
                noise = torch.randn_like(x_batch) * self.sigma
                x_noisy = x_batch + noise
                output = self.base_classifier(x_noisy)
                preds = torch.argmax(output, dim=1)
                for j in range(preds.size(0)):
                    sample_idx = j % x.size(0)
                    votes[sample_idx, preds[j]] += 1

        top_class = torch.argmax(votes, dim=1)
        top_count = votes[torch.arange(x.size(0)), top_class]
        p_lower = self._clopper_pearson_lower_bound(top_count, n_samples, alpha)
        certified_radius = self.sigma * stats.norm.ppf(p_lower.cpu().numpy())
        certified_radius = np.maximum(certified_radius, 0)
        return top_class, certified_radius

    def _clopper_pearson_lower_bound(self, counts, n, alpha):
        p_lower = torch.zeros_like(counts, dtype=torch.float32)
        for i in range(len(counts)):
            if counts[i] == n:
                p_lower[i] = 1.0
            elif counts[i] == 0:
                p_lower[i] = 0.0
            else: 
                p_lower[i] = stats.beta.ppf(alpha, counts[i].cpu(), n - counts[i].cpu() + 1)
        return p_lower
    
def train_model(model, train_loader, test_loader, epochs=5):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    train_losses, test_accuracies = [], []
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            if epoch > 0:
                noise = 0.1 * torch.randn_like(data)
                data += noise

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        test_acc = evaluate_model(model, test_loader)
        test_accuracies.append(test_acc)
        train_losses.append(running_loss / len(train_loader))
        print(f"Epoch: {epoch + 1}/{epochs}: Loss {np.round(running_loss / len(train_loader), 3)} | Test accuracy = {np.round(test_acc, 3)}")
    return train_losses, test_accuracies

def evaluate_model(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct/total

def evaluate_certified_accuracy(smoothed_classifier, test_loader, radii, n_samples=1000):
    certified_accuracies = {radius: 0 for radius in radii}
    total = 0
    smoothed_classifier.base_classifier.eval()
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            for i in range(data.size(0)):
                x = data[i:i+1]
                target = targets[i:i+1]
                pred, cert_radius = smoothed_classifier.predict(x, n_samples=n_samples)
                correct = (pred == target).item()
                if correct:
                    for radius in radii:
                        if cert_radius >= radius:
                            certified_accuracies[radius] += 1
                total += 1

    for radius in radii:
        certified_accuracies[radius] = certified_accuracies[radius] / total * 100
    return certified_accuracies

if __name__ == "__main__":
    
    transform = transforms.Compose([transforms.ToTensor()])
    trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
    train_loader = DataLoader(trainset, batch_size=128, shuffle=True)
    test_loader = DataLoader(testset, batch_size=100, shuffle=False)
    
    print("Training base classifier (f)")
    model = SimpleCNN().to(device)
    train_losses, test_accuracies = train_model(model, train_loader, test_loader, epochs=3)
    
    sigma = 0.25
    smoothed_classifier = RandomizedSmoothing(model, sigma=sigma)
    standard_accuracy = evaluate_model(model, test_loader)
    print(f"Base classifier accuracy: {np.round(standard_accuracy, 3)}")
    
    print("\nEvaluating the smoothed classifier (g)")
    smoothed_correct, total = 0, 0
    n_samples = 100
    
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            preds, _ = smoothed_classifier.predict(data, n_samples=100, batch_size=50)
            smoothed_correct += (preds == targets).sum().item()
            total += targets.size(0)
    
    smoothed_accuracy = smoothed_correct / total
    print(f"Assuming {n_samples} majority votes per prediction.")
    print(f"Smoothed classifier accuracy: {np.round(smoothed_accuracy, 3)}")

    radii = [0.0, 0.5, 1.0]
    certified_accuracies = evaluate_certified_accuracy(smoothed_classifier, DataLoader(testset, batch_size=1, shuffle=False), radii, n_samples=100)
    
    print("\nFinal results\n")
    print(f"Base classifier accuracy: {np.round(standard_accuracy, 3)}")
    print(f"Smoothed classifier accuracy: {np.round(smoothed_accuracy, 3)}")
    print("Certified Accuracies:")
    for radius in radii:
        print(f"Radius {radius}: {np.round(certified_accuracies[radius], 3)}")

Training base classifier (f)
Epoch: 1/3: Loss 0.255 | Test accuracy = 0.976
Epoch: 2/3: Loss 0.063 | Test accuracy = 0.982
Epoch: 3/3: Loss 0.045 | Test accuracy = 0.989
Base classifier accuracy: 0.989

Evaluating the smoothed classifier (g)
Assuming 100 majority votes per prediction.
Smoothed classifier accuracy: 0.976

Final results
Base classifier accuracy: 0.989
Smoothed classifier accuracy: 0.976
Certified Accuracies:
Radius 0.0: 97.54
Radius 0.5: 80.3
Radius 1.0: 80.3
