# Every Model Learned by Gradient Descent Is Approximately a Kernel Machine
https://arxiv.org/pdf/2012.00152

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [4]:
class SimpleNet(nn.Module):
    def __init__(self, input_size=784, hidden_size=128, output_size=10):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [8]:
class PathKernelTracker:
    """
    Tracks gradients during training to compute path kernels
    This demonstrates the paper's main theorem
    """

    def __init__(self, model):
        self.model = model
        self.gradient_history = []
        self.training_examples = []
        self.training_labels = []

    def store_gradients_after_backward(self):
        """Store gradients after backward pass (called from main training loop)"""
        # Collect gradients as a flat vector
        gradients = []
        for param in self.model.parameters():
            if param.grad is not None:
                gradients.append(param.grad.view(-1).clone())

        if gradients:
            gradient_vector = torch.cat(gradients)
        else:
            gradient_vector = torch.zeros(
                sum(p.numel() for p in self.model.parameters())
            )

        self.gradient_history.append(gradient_vector)

    def store_training_example(self, x, y):
        """Store training examples"""
        self.training_examples.append(x.clone())
        self.training_labels.append(y.clone())

    def compute_path_kernel(self, x_query, x_train_idx):
        """
        Compute path kernel between query point and a training example
        K(x, x') = ∫ ∇y(x) · ∇y(x') dt over the path

        We approximate this as: Σ_t ∇y(x)|_t · ∇y(x')|_t
        """
        if x_train_idx >= len(self.gradient_history):
            return 0.0

        # For simplicity, we'll use the stored gradient from training
        # In practice, we'd need to compute gradients for the query at each training step
        train_grad = self.gradient_history[x_train_idx]

        # Compute gradients for query point (using current model state)
        self.model.eval()
        with torch.enable_grad():
            # Create a copy of query that requires grad
            x_query_copy = x_query.clone().detach()
            # We'll approximate by using parameter gradients at current state
            temp_model = type(self.model)()
            temp_model.load_state_dict(self.model.state_dict())

            temp_model.zero_grad()
            output = temp_model(x_query_copy.unsqueeze(0))
            # Use mean of output for gradient computation
            dummy_loss = output.mean()
            dummy_loss.backward()

            query_grad = []
            for param in temp_model.parameters():
                if param.grad is not None:
                    query_grad.append(param.grad.view(-1))

            if query_grad:
                query_grad = torch.cat(query_grad)
                kernel_value = torch.dot(query_grad, train_grad).item()
                return kernel_value
            else:
                return 0.0

    def predict_as_kernel_machine(self, x_query):
        """
        Make prediction using kernel machine formulation:
        y = Σ_i a_i * K(x, x_i) + b
        """
        if len(self.training_examples) == 0:
            return torch.zeros(10)  # Return zero prediction

        prediction = torch.zeros(10)

        for i, (x_train, y_train) in enumerate(
            zip(self.training_examples, self.training_labels)
        ):
            # Compute path kernel
            kernel_value = self.compute_path_kernel(x_query, i)

            # Create one-hot encoding for y_train
            y_onehot = torch.zeros(10)
            y_onehot[y_train] = 1.0

            # Weight by kernel value (simplified, in practice a_i depends on loss derivatives)
            weight = kernel_value / len(self.training_examples)  # Normalize
            prediction += weight * y_onehot

        return prediction

In [9]:
def demonstrate_path_kernel():
    # Load a small subset of MNIST
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )

    full_dataset = datasets.MNIST(
        "./data", train=True, download=True, transform=transform
    )

    # Use only first 100 examples for demonstration
    small_dataset = Subset(full_dataset, range(100))
    train_loader = DataLoader(small_dataset, batch_size=1, shuffle=False)

    # Test set
    test_dataset = datasets.MNIST("./data", train=False, transform=transform)
    test_subset = Subset(test_dataset, range(10))  # Just 10 test examples
    test_loader = DataLoader(test_subset, batch_size=1, shuffle=False)

    # Initialize model and tracker
    model = SimpleNet()
    tracker = PathKernelTracker(model)
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    print("Training neural network and tracking path kernels...")

    # Training loop with path kernel tracking
    model.train()
    for epoch in range(3):  # Just a few epochs for demo
        for batch_idx, (data, target) in enumerate(train_loader):
            # Store training example
            tracker.store_training_example(data, target)

            # Regular training step
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()

            # Store gradients after backward pass
            tracker.store_gradients_after_backward()

            optimizer.step()

            if batch_idx % 20 == 0:
                print(f"Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}")

    print("\nComparing neural network predictions vs kernel machine predictions...")

    # Test both approaches
    model.eval()
    correct_nn = 0
    correct_km = 0

    with torch.no_grad():
        for i, (data, target) in enumerate(test_loader):
            # Neural network prediction
            nn_output = model(data)
            nn_pred = nn_output.argmax(dim=1)

            # Kernel machine prediction using path kernels
            km_output = tracker.predict_as_kernel_machine(data.squeeze(0))
            km_pred = km_output.argmax()

            correct_nn += (nn_pred == target).sum().item()
            correct_km += (km_pred == target).sum().item()

            print(
                f"Example {i}: True={target.item()}, NN_pred={nn_pred.item()}, KM_pred={km_pred.item()}"
            )
            print(
                f"  NN confidence: {F.softmax(nn_output, dim=1)[0, nn_pred].item():.3f}"
            )
            print(
                f"  KM confidence: {F.softmax(km_output.unsqueeze(0), dim=1)[0, km_pred].item():.3f}"
            )

            if i >= 5:  # Limit output for readability
                break

    print(f"\nAccuracy comparison (limited test set):")
    print(
        f"Neural Network: {correct_nn}/{min(6, len(test_loader))} = {correct_nn / min(6, len(test_loader)):.3f}"
    )
    print(
        f"Kernel Machine: {correct_km}/{min(6, len(test_loader))} = {correct_km / min(6, len(test_loader)):.3f}"
    )


In [10]:
demonstrate_path_kernel()

Training neural network and tracking path kernels...
Epoch: 0, Batch: 0, Loss: 2.3234
Epoch: 0, Batch: 20, Loss: 2.0697
Epoch: 0, Batch: 40, Loss: 1.3557
Epoch: 0, Batch: 60, Loss: 1.5063
Epoch: 0, Batch: 80, Loss: 3.3740
Epoch: 1, Batch: 0, Loss: 3.7886
Epoch: 1, Batch: 20, Loss: 0.1847
Epoch: 1, Batch: 40, Loss: 0.2206
Epoch: 1, Batch: 60, Loss: 0.3959
Epoch: 1, Batch: 80, Loss: 2.3313
Epoch: 2, Batch: 0, Loss: 2.6912
Epoch: 2, Batch: 20, Loss: 0.0251
Epoch: 2, Batch: 40, Loss: 0.0434
Epoch: 2, Batch: 60, Loss: 0.1141
Epoch: 2, Batch: 80, Loss: 0.6814

Comparing neural network predictions vs kernel machine predictions...
Example 0: True=7, NN_pred=7, KM_pred=7
  NN confidence: 0.967
  KM confidence: 0.114
Example 1: True=2, NN_pred=2, KM_pred=2
  NN confidence: 0.425
  KM confidence: 0.106
Example 2: True=1, NN_pred=1, KM_pred=1
  NN confidence: 0.964
  KM confidence: 0.111
Example 3: True=0, NN_pred=0, KM_pred=7
  NN confidence: 0.842
  KM confidence: 0.114
Example 4: True=4, NN_pre

In [11]:
def visualize_gradient_similarity():
    """
    Visualize how similar gradients lead to similar predictions
    """
    print("\n" + "=" * 50)
    print("VISUALIZING GRADIENT SIMILARITY")
    print("=" * 50)

    # Simple 2D example for visualization
    torch.manual_seed(42)

    # Create simple 2D data
    X = torch.randn(20, 2)
    y = (X[:, 0] + X[:, 1] > 0).long()  # Simple linear boundary

    # Simple model
    model = nn.Linear(2, 2)
    optimizer = optim.SGD(model.parameters(), lr=0.1)

    # Store gradients during training
    gradients_history = []

    print("Training on 2D data and collecting gradients...")

    for epoch in range(10):
        for i in range(len(X)):
            optimizer.zero_grad()
            output = model(X[i : i + 1])
            loss = F.cross_entropy(output, y[i : i + 1])
            loss.backward()

            # Store gradients
            grads = []
            for param in model.parameters():
                grads.append(param.grad.clone().flatten())
            gradients_history.append(torch.cat(grads))

            optimizer.step()

    # Compute gradient similarities (simplified path kernel)
    print("\nGradient similarity matrix (shows path kernel values):")
    n_samples = min(5, len(X))

    for i in range(n_samples):
        for j in range(n_samples):
            # Average gradient similarity over training
            similarities = []
            for step in range(0, len(gradients_history), len(X)):
                if step + i < len(gradients_history) and step + j < len(
                    gradients_history
                ):
                    grad_i = gradients_history[step + i]
                    grad_j = gradients_history[step + j]
                    sim = torch.dot(grad_i, grad_j).item()
                    similarities.append(sim)

            avg_sim = np.mean(similarities) if similarities else 0
            print(f"K(x_{i}, x_{j}) = {avg_sim:.3f}", end="  ")
        print()

    print(f"\nData points and labels:")
    for i in range(n_samples):
        print(
            f"x_{i}: [{X[i, 0].item():.2f}, {X[i, 1].item():.2f}] -> label: {y[i].item()}"
        )


In [12]:
visualize_gradient_similarity()


VISUALIZING GRADIENT SIMILARITY
Training on 2D data and collecting gradients...

Gradient similarity matrix (shows path kernel values):
K(x_0, x_0) = 0.590  K(x_0, x_1) = 0.037  K(x_0, x_2) = -0.030  K(x_0, x_3) = 0.062  K(x_0, x_4) = 0.182  
K(x_1, x_0) = 0.037  K(x_1, x_1) = 0.869  K(x_1, x_2) = 0.565  K(x_1, x_3) = 0.257  K(x_1, x_4) = 0.483  
K(x_2, x_0) = -0.030  K(x_2, x_1) = 0.565  K(x_2, x_2) = 0.512  K(x_2, x_3) = 0.157  K(x_2, x_4) = 0.258  
K(x_3, x_0) = 0.062  K(x_3, x_1) = 0.257  K(x_3, x_2) = 0.157  K(x_3, x_3) = 0.092  K(x_3, x_4) = 0.103  
K(x_4, x_0) = 0.182  K(x_4, x_1) = 0.483  K(x_4, x_2) = 0.258  K(x_4, x_3) = 0.103  K(x_4, x_4) = 0.762  

Data points and labels:
x_0: [1.93, 1.49] -> label: 1
x_1: [0.90, -2.11] -> label: 0
x_2: [0.68, -1.23] -> label: 0
x_3: [-0.04, -1.60] -> label: 0
x_4: [-0.75, 1.65] -> label: 1
