# Posthoc Laplace Approximation with EuroSAT100 dataset

Since all implemented UQ-methods are just lightning modules that can take in any underlying model you might have, it is straightforward to apply them to a wide array of use cases. So let's to that for a classification task on the EuroSAT100 dataset, a dataset of Sentinel-2 imagery with 13 spectral bands and 10 classes. We will use the [TorchGeo Library](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) to handle the data loading for us. In this case we will use the Laplace Approximation as a post-hoc method to equip a pretrained network with a notion of uncertainty.

## Imports

In [18]:
import os
import tempfile

import timm
from functools import partial
import torch
from lightning.pytorch import Trainer

from torchgeo.datamodules import EuroSAT100DataModule
from torchgeo.trainers import ClassificationTask
from torchgeo.models import ResNet18_Weights
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger

from lightning_uq_box.uq_methods import LaplaceClassification
from laplace import Laplace
from lightning_uq_box.viz_utils import plot_training_metrics

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
seed_everything(0)

Global seed set to 0


0

In [3]:
batch_size = 16
num_workers = 4

## Datmodule

In [4]:
root = os.path.join(tempfile.gettempdir(), "eurosat100")
datamodule = EuroSAT100DataModule(
    root=root, batch_size=batch_size, num_workers=num_workers, download=True
)

## Classification Model and Training

We will use pretrained weights from TorchGeo which can be loaded into a [timm](https://github.com/huggingface/pytorch-image-models) model architecture

In [33]:
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
task = ClassificationTask(
    model="resnet18",
    loss="ce",
    weights=weights,
    in_channels=13,
    num_classes=10,
    learning_rate=0.001,
    learning_rate_schedule_patience=5,
)

In [34]:
trainer = Trainer(
    max_epochs=1,
    logger=CSVLogger(root),
    log_every_n_steps=1,
    devices=[0],
    accelerator="gpu",
)

trainer.fit(task, datamodule=datamodule)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | ResNet           | 11.2 M
1 | loss          | CrossEntropyLoss | 0     
2 | train_metrics | MetricCollection | 0     
3 | val_metrics   | MetricCollection | 0     
4 | test_metrics  | MetricCollection | 0     
---------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.852    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Epoch 0: 100%|██████████| 4/4 [00:00<00:00,  4.59it/s, v_num=2]            

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 4/4 [00:01<00:00,  3.70it/s, v_num=2]


The we will setup a Laplace Approximation model with the [Laplace](https://github.com/AlexImmer/Laplace) package. In this case we will use a last-layer approximation, but check their [documentation](https://aleximmer.github.io/Laplace/) to see the different possible configurations. The lightning-uq-box simply provides a wrapper so you can stay within the lightning modelling framework.

In [35]:
la = Laplace(task.model, likelihood="classification", subset_of_weights="last_layer")

laplace_model = LaplaceClassification(model=la)

Different datasets might have different conventions about what inputs and targets are called, especially across domains and tasks. By default lightning-uq-box expects a batch with dictionary keys "inputs" and "targets", however, these attributes can be changed as follows. The EuroSAT100 dataset has keys "image" and "label"

In [37]:
laplace_model.input_key = "image"
laplace_model.target_key = "label"

## Fit Laplace and make Predictions

In [38]:
trainer = Trainer(
    logger=CSVLogger(root), log_every_n_steps=1, devices=[0], accelerator="gpu"
)
datamodule.setup("fit")  # give access to training dataloader used to fit laplace
trainer.test(laplace_model, datamodule=datamodule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Testing DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

In [68]:
sample = next(iter(datamodule.test_dataloader()))
preds = laplace_model.predict_step(sample["image"])

In [72]:
preds["pred"].shape, preds["pred_uct"].shape

(torch.Size([16, 10]), torch.Size([16]))