In [1]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader, Subset, random_split

# Define the transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Download the dataset

n_classes = 10
batch_size = 32

trainset = torchvision.datasets.Food101(root='../data', split='train', download=True, transform=transform)
testset = torchvision.datasets.Food101(root='../data', split='test', download=True, transform=transform)

# Select top k most frequent classes
# classes, counts = np.unique(trainset._labels, return_counts=True)
# idx = np.argsort(counts)[-n_classes:]
# target_classes = classes[idx]
# print('target_classes', target_classes)

target_classes = list(range(n_classes))



# Filter the dataset to include only samples from the target classes
def filter_classes(dataset, target_classes):
    targets = np.array(dataset._labels)
    mask = np.isin(targets, target_classes)
    indices = np.where(mask)[0]
    return Subset(dataset, indices)

trainset_filtered = filter_classes(trainset, target_classes)
testset_filtered = filter_classes(testset, target_classes)

# Create a smaller subset for training
# subset_indices = torch.randperm(len(dataset)) # [:100]  # Using 100 samples for training
# trainset_subset = Subset(dataset, subset_indices)

trainloader = DataLoader(trainset_filtered, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = DataLoader(testset_filtered, batch_size=batch_size, shuffle=False, num_workers=2)


print(f'Training Samples - Train {len(trainloader) * batch_size} Test {len(testloader) * batch_size}')
print(f'Batch size: {batch_size}')

Training Samples - Train 7520 Test 2528
Batch size: 32


In [2]:
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18, ResNet18_Weights
import math

class LoRALayer(nn.Module):
    def __init__(self, base_layer, rank=4):
        super(LoRALayer, self).__init__()
        self.base_layer = base_layer
        self.rank = rank
        self.lora_A = nn.Parameter(torch.randn(base_layer.weight.size(0), rank))
        self.lora_B = nn.Parameter(torch.randn(rank, base_layer.weight.size(1)))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.lora_B, a=math.sqrt(5))

    def forward(self, x):
        lora_weight = torch.matmul(self.lora_A, self.lora_B)
        new_weight = self.base_layer.weight + lora_weight
        return nn.functional.linear(x, new_weight, self.base_layer.bias)


def apply_lora(model, rank):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            lora_layer = LoRALayer(module, rank)
            setattr(model, name, lora_layer)
            lora_layer.lora_A.requires_grad = True
            lora_layer.lora_B.requires_grad = True
    return model
    

def print_trainable_parameters(model): 
    # Calculate total number of parameters and trainable parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f'Total number of parameters: {total_params}')
    print(f'Number of trainable parameters: {trainable_params}')


lora_rank = 1  # when 0, only train the last layer 

# net = resnet18(pretrained=True)
net = resnet18(weights=ResNet18_Weights.DEFAULT)


# Freeze all parameters
for param in net.parameters():
    param.requires_grad = False

if lora_rank > 0: 
    net = apply_lora(net, rank=lora_rank)


net.fc = nn.Linear(512, n_classes) 
net.fc.requires_grad = True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device', device)
net.to(device)


print_trainable_parameters(net)

print('done')


device cuda:0
Total number of parameters: 11181642
Number of trainable parameters: 5130
done


In [6]:
import os 
from os.path import join

lr = 0.01 #  0.001 


print('Training ...')

best_valid_loss = 9999
model_dir = './models'
export_path = join(model_dir, f'model_lorarank{lora_rank}.pth')

net.load_state_dict(torch.load(export_path))
net.eval() 

print('Model Loaded')

Training ...
Model Loaded


In [None]:

from laplace import Laplace

# la = Laplace(
#     net,
#     likelihood="classification",
#     subset_of_weights="all",
#     hessian_structure="kron",
# )

la = Laplace(
    net,
    likelihood="classification",
    subset_of_weights="all",
    hessian_structure="full",
)

print('Fitting Laplace ...')
la.fit(trainloader)

print('Optimization of HPs,,,')
la.optimize_prior_precision()

print('done')

Fitting Laplace ...


In [None]:
print('done')

In [None]:
@torch.no_grad()
def predict(dataloader, model, laplace=False):
    py = []

    for x, _ in dataloader:
        if laplace:
            py.append(model(x.cuda()))
        else:
            py.append(torch.softmax(model(x.cuda()), dim=-1))

    return torch.cat(py).cpu()


probs_laplace = predict(testloader, la, laplace=True)

print(probs_laplace)
print('done')

In [None]:
targets = torch.cat([y for x, y in testloader], dim=0).cpu()


# print(targets)

In [None]:
from netcal.metrics import ECE 
import torch.distributions as dists

# TODO also compare with the baseline that was not finetuned?

probs_baseline = predict(testloader, net, laplace=False)


def eval(probs, name): 
    
    acc_laplace = (probs.argmax(-1) == targets).float().mean()
    ece_laplace = ECE(bins=15).measure(probs.numpy(), targets.numpy())
    nll_laplace = -dists.Categorical(probs).log_prob(targets).mean()
    
    print(
        f"[{name}] Acc.: {acc_laplace:.1%}; ECE: {ece_laplace:.1%}; NLL: {nll_laplace:.3}"
    )

eval(probs_baseline, 'Baseline')
eval(probs_laplace, 'Laplace')