In [10]:
from utils import get_weights
import torch
from resnetv2 import ResNetV2

import numpy as np
import math

%matplotlib inline
import matplotlib.pyplot as plt

import pickle
import argparse
import time
import itertools
from copy import deepcopy
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import csv
import sys
#sys.path.append('/content/KD')
# Import the module
import networks
import utils

%load_ext autoreload
%autoreload 2


use_gpu = True    # set use_gpu to True if system has gpu
gpu_id = 0        # id of gpu to be used
cpu_device = torch.device('cpu')
# fast_device is where computation (training, inference) happens
fast_device = torch.device('cpu')
if use_gpu:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2'    # set visible devices depending on system configuration
    fast_device = torch.device('cuda:' + str(gpu_id))

In [11]:
def reproducibilitySeed():
    """
    Ensure reproducibility of results; Seeds to 0
    """
    torch_init_seed = 0
    torch.manual_seed(torch_init_seed)
    numpy_init_seed = 0
    np.random.seed(numpy_init_seed)
    if use_gpu:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

reproducibilitySeed()

In [14]:
import torchvision
import torchvision.transforms as transforms

# Set up transformations for CIFAR-10
transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),  # Augment training data by padding 4 and random cropping
        transforms.RandomHorizontalFlip(),     # Randomly flip images horizontally
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # Normalization for CIFAR-10
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # Normalization for CIFAR-10
    ]
)

val_tx = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Load CIFAR-10 dataset
train_val_dataset = torchvision.datasets.CIFAR10(root='./CIFAR10_dataset/', train=True,
                                            download=True, transform=transform_train)

test_dataset = torchvision.datasets.CIFAR10(root='./CIFAR10_dataset/', train=False,
                                            download=True, transform=val_tx)

# Split the training dataset into training and validation
num_train = int(0.95 * len(train_val_dataset))  # 95% of the dataset for training
num_val = len(train_val_dataset) - num_train  # Remaining 5% for validation
train_dataset, val_dataset = torch.utils.data.random_split(train_val_dataset, [num_train, num_val])

# DataLoader setup
batch_size = 128
train_val_loader = torch.utils.data.DataLoader(train_val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [17]:
model = ResNetV2(ResNetV2.BLOCK_UNITS['r101'], width_factor=1, head_size=10)  # NOTE: No new head.
checkpoint = torch.load('BiT-M-R101x1_step5000.tar')
state_dict = checkpoint['model']
new_state_dict = {}

for key in state_dict.keys():
    new_key = key.replace("module.", "")  # Remove "module." prefix
    new_state_dict[new_key] = state_dict[key]

model.load_state_dict(new_state_dict)

model.to(fast_device)

  checkpoint = torch.load('BiT-M-R101x1_step5000.tar')


ResNetV2(
  (root): Sequential(
    (conv): StdConv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (padp): ConstantPad2d(padding=(1, 1, 1, 1), value=0)
    (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (body): Sequential(
    (block1): Sequential(
      (unit01): PreActBottleneck(
        (gn1): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv1): StdConv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (gn2): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv2): StdConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (gn3): GroupNorm(32, 64, eps=1e-05, affine=True)
        (conv3): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (relu): ReLU(inplace=True)
        (downsample): StdConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (unit02): PreActBottleneck(
        (gn1): GroupNorm(32, 256, eps=1e-05, af

In [None]:
model = torchvision.models.resnet18()
model.fc = nn.Linear(model.fc.in_features, 10)  # CIFAR-10 has 10 classes
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Identity()
model.load_state_dict(torch.load('resnet18_pretrained.bin'))

for name, param in model.named_parameters():
    if not name.startswith("fc."):
        param.requires_grad = False

# Put model in eval mode (we're not actually training, so no dropout/batchnorm updates).
model.eval()

# ------------------------
# 4. GET (RANDOM) LAST-LAYER WEIGHTS & APPROXIMATE HESSIAN
# ------------------------
# Because we replaced the last layer, it's randomly initialized at this point and 
# DOES NOT match the data at all. We'll compute cross-entropy loss and Hessian
# on that mismatch. This is purely to demonstrate the Laplace pipeline.

# We'll do a "dummy" pass with cross-entropy:
criterion_sum = nn.CrossEntropyLoss(reduction='sum')

# Collect the final-layer parameters
fc_params = list(model.fc.parameters())  # [weight, bias]
assert len(fc_params) == 2, "Expected the final layer to have weight & bias."

# (a) Flatten them into a single vector for convenience
with torch.no_grad():
    w_mean_list = [p.flatten() for p in fc_params]
w_mean = torch.cat(w_mean_list)  # This is the 'MAP' weights (randomly init, untrained)

# (b) Hessian diagonal approximation via sum of grad^2
hessian_diag = torch.zeros_like(w_mean)

# We won't do any optimizer steps, just forward/backward to gather Hessian info.
for batch_x, batch_y in train_loader:
    # Forward
    batch_x = batch_x.to(fast_device)
    batch_x.requires_grad = False  # we only want gradient w.r.t. final-layer weights
    logits = model(batch_x)
    loss_sum = criterion_sum(logits, batch_y)  # sum of NLL over the batch

    # Zero grads for last layer
    model.fc.zero_grad()
    loss_sum.backward(retain_graph=True)

    # Gather gradient
    grad_list = [p.grad.detach().flatten() for p in fc_params]
    grad_vec = torch.cat(grad_list)

    # Accumulate squared gradients for diagonal approximation
    hessian_diag += grad_vec**2

# (c) Add prior precision (1 / sigma^2_0)
sigma2_0 = 5.0
hessian_diag += 1.0 / sigma2_0

# (d) The diagonal of the posterior covariance
cov_diag = 1.0 / hessian_diag

print("\nLaplace approximation set up with random final layer.")
print("Final layer param dim:", w_mean.shape[0])

# ------------------------
# 5. DEFINE UTILS FOR SAMPLING & PREDICTION
# ------------------------
def resnet_penultimate(model, x):
    """
    Runs ResNet up to the final average pool & flatten (i.e., penultimate features).
    """
    # replicate the forward logic for ResNet
    x = model.conv1(x)
    x = model.bn1(x)
    x = model.relu(x)
    x = model.maxpool(x)

    x = model.layer1(x)
    x = model.layer2(x)
    x = model.layer3(x)
    x = model.layer4(x)

    x = model.avgpool(x)
    x = torch.flatten(x, 1)
    return x

def predict_laplace(model, x, cov_diag, n_samples=10):
    """
    Make a prediction by averaging softmax outputs over
    multiple samples from N(w_MAP, diag(cov_diag)) for the final layer.
    """
    model.eval()
    # Extract MAP (random) final-layer parameters
    fc_params = list(model.fc.parameters())
    with torch.no_grad():
        w_map_list = [p.detach().flatten() for p in fc_params]
    w_map = torch.cat(w_map_list)

    # Reshape info
    W_shape = model.fc.weight.shape  # [num_classes, 2048]
    b_shape = model.fc.bias.shape
    W_numel = W_shape[0] * W_shape[1]
    b_numel = b_shape[0]

    # Compute penultimate features
    with torch.no_grad():
        feats = resnet_penultimate(model, x)  # shape [batch_size, 2048]

    # Monte Carlo sample from last-layer posterior
    logits_samples = []
    for _ in range(n_samples):
        eps = torch.randn_like(w_map)
        w_samp = w_map + eps * torch.sqrt(cov_diag)

        # Extract W, b
        W_samp = w_samp[:W_numel].reshape(W_shape)
        b_samp = w_samp[W_numel:W_numel + b_numel]

        # logits = feats @ W_samp.T + b_samp
        logits_s = feats @ W_samp.t() + b_samp
        logits_samples.append(logits_s.unsqueeze(0))

    logits_mc = torch.cat(logits_samples, dim=0)  # [n_samples, batch_size, num_classes]
    probs_mc = torch.softmax(logits_mc, dim=-1)   # [n_samples, batch_size, num_classes]
    return probs_mc.mean(dim=0)


batch_x_test, batch_y_test = next(iter(test_loader))  # get first batch

# Standard forward (random final layer)
with torch.no_grad():
    logits_test = model(batch_x_test)
y_pred_test_softmax = torch.softmax(logits_test, dim=-1)

# Laplace-based predictions
y_pred_test_laplace = predict_laplace(model, batch_x_test, cov_diag, n_samples=10)

print("\nRandom final layer => predictions are essentially random.")
print("Softmax preds (no sampling):\n", y_pred_test_softmax)
print("Laplace preds (with sampling):\n", y_pred_test_laplace)

  model.load_state_dict(torch.load('resnet18_pretrained.bin'))


KeyboardInterrupt: 

In [18]:
reproducibilitySeed()
_, test_accuracy = utils.getLossAccuracyOnDataset(model, test_loader, fast_device)
print('test accuracy: ', test_accuracy)

test accuracy:  0.9803
