In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim

# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device {device}")

# Load the MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

  warn(


Using device cuda:0


In [2]:
# Define the neural network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool = nn.AvgPool2d(kernel_size=2)
        self.fc1 = nn.Linear(64 * 5 * 5, 64 * 5 * 5)
        self.fc2 = nn.Linear(64 * 5 * 5, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(torch.tanh(self.conv1(x)))
        x = self.pool(torch.tanh(self.conv2(x)))
        x = x.view(-1, 64 * 5 * 5)
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        x = torch.softmax(self.fc3(x), dim=1)
        return x

In [3]:
# Initialize the network
net = Net().to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters())

# model parameters
sum(p.numel() for p in net.parameters())

2786634

In [4]:
# Train the network for 3 epochs
for epoch in range(3):
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = net(inputs)
        loss = criterion(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print statistics
        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                epoch + 1, 3, i + 1, len(train_loader), loss.item()))

Epoch [1/3], Step [100/938], Loss: 1.5934
Epoch [1/3], Step [200/938], Loss: 1.5985
Epoch [1/3], Step [300/938], Loss: 1.5435
Epoch [1/3], Step [400/938], Loss: 1.4922
Epoch [1/3], Step [500/938], Loss: 1.5134
Epoch [1/3], Step [600/938], Loss: 1.5082
Epoch [1/3], Step [700/938], Loss: 1.5027
Epoch [1/3], Step [800/938], Loss: 1.5760
Epoch [1/3], Step [900/938], Loss: 1.5309
Epoch [2/3], Step [100/938], Loss: 1.4932
Epoch [2/3], Step [200/938], Loss: 1.5181
Epoch [2/3], Step [300/938], Loss: 1.4820
Epoch [2/3], Step [400/938], Loss: 1.5169
Epoch [2/3], Step [500/938], Loss: 1.4741
Epoch [2/3], Step [600/938], Loss: 1.5075
Epoch [2/3], Step [700/938], Loss: 1.4953
Epoch [2/3], Step [800/938], Loss: 1.4702
Epoch [2/3], Step [900/938], Loss: 1.4894
Epoch [3/3], Step [100/938], Loss: 1.4758
Epoch [3/3], Step [200/938], Loss: 1.4915
Epoch [3/3], Step [300/938], Loss: 1.4864
Epoch [3/3], Step [400/938], Loss: 1.4841
Epoch [3/3], Step [500/938], Loss: 1.4663
Epoch [3/3], Step [600/938], Loss:

In [5]:
# Evaluate the network on the test set
net.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Test Accuracy: {}%'.format((correct / total) * 100))

# Save the model
torch.save(net.state_dict(), './mnist_net.pth')

# Load the model
model = Net().to(device)
model.load_state_dict(torch.load('./mnist_net.pth'))

Test Accuracy: 97.58%


<All keys matched successfully>

In [6]:
from laplace import Laplace

# Examples of different ways to specify the subnetwork
# via indices of the vectorized model parameters
#
# Example 1: select the 128 parameters with the largest magnitude
from laplace.utils import LargestMagnitudeSubnetMask
subnetwork_mask = LargestMagnitudeSubnetMask(model, n_params_subnet=128)
subnetwork_indices = subnetwork_mask.select()

In [7]:
subnetwork_indices = subnetwork_indices.type(torch.LongTensor)

In [8]:
subnetwork_indices

tensor([      6,       7,      11,      15,      16,      22,      23,      28,
             35,      41,      43,      47,      49,      51,      54,      58,
             64,      65,      74,      78,      79,      80,      81,      85,
             86,      87,      91,      93,      95,      99,     103,     107,
            114,     115,     118,     121,     124,     125,     130,     132,
            134,     137,     141,     142,     143,     144,     148,     150,
            155,     158,     160,     161,     164,     165,     167,     169,
            175,     177,     178,     184,     186,     187,     190,     191,
            197,     198,     201,     207,     212,     214,     217,     218,
            226,     227,     229,     232,     238,     241,     248,     259,
            263,     269,     274,     279,     280,     286,     287,    2622,
           8454,   12414,   12800,   12989,   14498, 1182795, 1366099, 1591699,
        1596399, 2017373, 2502173, 25494

In [9]:
# Define and fit subnetwork LA using the specified subnetwork indices
la = Laplace(model, 'classification',
             subset_of_weights='subnetwork',
             hessian_structure='full',
             subnetwork_indices=subnetwork_indices)
la.fit(train_loader)



In [12]:
la.H.shape

torch.Size([128, 128])