In [1]:
"""
This notebook is intended to show the full pipeline and test that everything works. 

Note that when running for the first time, it might take a while to download the training data. 
"""

import json
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import os 
from os.path import join
import math
import sys
import numpy as np 

sys.path.append('../calibratedvisionlora')

from utils import print_trainable_parameters
from dataloaders import create_food_data_loaders
from training import train_single_model
from calibration import fit_laplace_and_compute_predictions, eval_metrics

config = dict(
    model_name = 'sample_lora_model',
    n_classes = 5,
    batch_size =  32,  # Make it smaller if running out of memory
    lora_rank = 1,                
    lora_layers = [10, 11],
    random_projections=False,
    learning_rate =  0.001,
    epochs =  3,
    train_model = True 
)


In [2]:
"""
Train a LoRA model.
"""

best_valid_loss = train_single_model(config)
print('best_valid_loss', best_valid_loss)

New Model Training ...
{'model_name': 'sample_lora_model', 'n_classes': 5, 'batch_size': 32, 'lora_rank': 1, 'lora_layers': [10, 11], 'random_projections': False, 'learning_rate': 0.001, 'epochs': 3, 'train_model': True}
Loaded pretrained weights.
Total number of parameters: 86100485
Number of trainable parameters: 9989
Trainable parameters per layer:
 9989
lora_vit 9989
lora_vit.transformer 6144
lora_vit.transformer.blocks 6144
lora_vit.transformer.blocks.10 3072
lora_vit.transformer.blocks.10.attn 3072
lora_vit.transformer.blocks.10.attn.proj_q 1536
lora_vit.transformer.blocks.10.attn.proj_q.w_a 768
lora_vit.transformer.blocks.10.attn.proj_q.w_b 768
lora_vit.transformer.blocks.10.attn.proj_v 1536
lora_vit.transformer.blocks.10.attn.proj_v.w_a 768
lora_vit.transformer.blocks.10.attn.proj_v.w_b 768
lora_vit.transformer.blocks.11 3072
lora_vit.transformer.blocks.11.attn 3072
lora_vit.transformer.blocks.11.attn.proj_q 1536
lora_vit.transformer.blocks.11.attn.proj_q.w_a 768
lora_vit.trans

In [4]:
"""
Calibration of the trained LoRA model. 
"""

model_path = '../models/sample_lora_model.pth'

config['batch_size'] = 2 
config['train_size'] = 200

probs_laplace, probs_baseline, targets = \
    fit_laplace_and_compute_predictions(model_path, config, hessian_structure="kron")


eval_metrics(probs_baseline, targets, 'Baseline')
eval_metrics(probs_laplace, targets, 'Laplace')

Loaded pretrained weights.
device cuda:0
Total number of parameters: 86100485
Number of trainable parameters: 9989
Trainable parameters per layer:
 9989
lora_vit 9989
lora_vit.transformer 6144
lora_vit.transformer.blocks 6144
lora_vit.transformer.blocks.10 3072
lora_vit.transformer.blocks.10.attn 3072
lora_vit.transformer.blocks.10.attn.proj_q 1536
lora_vit.transformer.blocks.10.attn.proj_q.w_a 768
lora_vit.transformer.blocks.10.attn.proj_q.w_b 768
lora_vit.transformer.blocks.10.attn.proj_v 1536
lora_vit.transformer.blocks.10.attn.proj_v.w_a 768
lora_vit.transformer.blocks.10.attn.proj_v.w_b 768
lora_vit.transformer.blocks.11 3072
lora_vit.transformer.blocks.11.attn 3072
lora_vit.transformer.blocks.11.attn.proj_q 1536
lora_vit.transformer.blocks.11.attn.proj_q.w_a 768
lora_vit.transformer.blocks.11.attn.proj_q.w_b 768
lora_vit.transformer.blocks.11.attn.proj_v 1536
lora_vit.transformer.blocks.11.attn.proj_v.w_a 768
lora_vit.transformer.blocks.11.attn.proj_v.w_b 768
lora_vit.fc 3845
Tra



Fitting Laplace - Done
[Baseline] Acc.: 50.8%; ECE: 15.9%; NLL: 0.921
[Laplace] Acc.: 51.6%; ECE: 8.3%; NLL: 0.954
