# RAPS with EuroSAT

In this notebook, we will demonstrate how you can use a posthoc conformal method like Regularized Adaptive Prediction Sets (RAPS) [Angelopoulos et al. 2021](https://arxiv.org/abs/2009.14193) on an Earth Observation (EO) Classification Task, namely the infamous EuroSAT dataset. For the dataloading we will use the [TorchGeo library](https://torchgeo.readthedocs.io/en/stable/), which you will need to install to run this tutorial (`pip install torchgeo`). We will use the smaller version `EuroSAT100` for demonstration purposes. Additionally, we will show how you can use a pretrained model - specific to EO data and apply RAPS for improved uncertainty quantification (UQ).

## Imports

In [24]:
import os
from torchgeo.datamodules import EuroSAT100DataModule
from torchgeo.models import ResNet18_Weights
from lightning_uq_box.uq_methods import DeterministicClassification, RAPS
import timm
import torch
from lightning import Trainer
import tempfile

## Datamodule

In [8]:
datamodule = EuroSAT100DataModule(root=".", num_workers=4, download=True)
# setup manually so we can access val_loader
datamodule.setup("fit")

## Pretrained Model

We will use pretrained weights for Sentinel 2 data from the SSL4EO paper [Wang et al. 2022](https://arxiv.org/abs/2211.07044) that are accessible through TorchGeo.

In [26]:
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
in_chans = weights.meta["in_chans"]
resnet18 = timm.create_model("resnet18", in_chans=in_chans, num_classes=10)
resnet18.load_state_dict(weights.get_state_dict(progress=True), strict=False)

_IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[])

## Predictions with Pretrained Model

Let's first look at the predictions from the pretrained model so that we can later see the impact of RAPs. We will use a Lightning base class `DeterministicClassification` which will iterate over the dataloader and compute and store the metrics we are interested in.

In [28]:
base_dir = tempfile.mkdtemp()
# implement a torchmetrics empirical coverage metric and use that and accuracy to compare results
base_model = DeterministicClassification(resnet18, loss_fn=torch.nn.CrossEntropyLoss())
base_model.input_key = "image"
base_model.target_key = "label"

base_trainer = Trainer(accelerator="cpu", default_root_dir=base_dir)

base_trainer.validate(base_model, datamodule=datamodule)

/home/nils/.conda/envs/uqboxEnv/lib/python3.9/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/nils/.conda/envs/uqboxEnv/lib/python3.9/site-p ...
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/nils/.conda/envs/uqboxEnv/lib/python3.9/site-packages/lightning/pytorch/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
Missing logger folder: /tmp/tmpm6wkotye/lightning_logs


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 2.304800510406494,
  'valAcc': 0.10000000149011612,
  'valF1Score': 0.10000000149011612}]

## Apply RAPS

In [29]:
raps_dir = tempfile.mkdtemp()
raps = RAPS(model=base_model.model, kreg=7, lamda_param=0)
raps.input_key = "image"
raps.target_key = "label"

raps_trainer = Trainer(accelerator="cpu", default_root_dir=raps_dir, inference_mode=False)

raps_trainer.validate(raps, dataloaders=datamodule.val_dataloader())

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /tmp/tmp0izu38oj/lightning_logs


Validation: |          | 0/? [00:00<?, ?it/s]

[{}]

## Example Visualization

In [31]:
predict_batch = next(iter(datamodule.val_dataloader()))

preds = raps.predict_step(predict_batch["image"])

In [33]:
preds["pred"].shape

torch.Size([20, 10])