# MC-Dropout with EuroSAT100 dataset

Since all implemented UQ-methods are just lightning modules that can take in any underlying model you might have, it is straightforward to apply them to a wide array of use cases. So let's to that for a classification task on the EuroSAT100 dataset, a dataset of Sentinel-2 imagery with 13 spectral bands and 10 classes. We will use the [TorchGeo Library](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#eurosat) to handle the data loading for us.

## Imports

In [55]:
import os
import tempfile

import timm
from functools import partial
import torch
from lightning.pytorch import Trainer

from torchgeo.datamodules import EuroSAT100DataModule
from torchgeo.trainers import ClassificationTask
from torchgeo.models import ResNet18_Weights
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import CSVLogger

from lightning_uq_box.uq_methods import MCDropoutClassification
from lightning_uq_box.viz_utils import plot_training_metrics

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
seed_everything(0)

In [2]:
batch_size = 16
num_workers = 4

## Datmodule

In [3]:
root = os.path.join(tempfile.gettempdir(), "eurosat100")
datamodule = EuroSAT100DataModule(
    root=root, batch_size=batch_size, num_workers=num_workers, download=True
)

## Classification Model

We will use pretrained weights from TorchGeo which can be loaded into a [timm](https://github.com/huggingface/pytorch-image-models) model architecture

In [25]:
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
in_chans = weights.meta["in_chans"]
model = timm.create_model("resnet18", in_chans=in_chans, num_classes=10, drop_rate=0.2)
model.load_state_dict(weights.get_state_dict(progress=True), strict=False)

_IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[])

In [26]:
mc_dropout_model = MCDropoutClassification(
    model=model,
    optimizer=partial(torch.optim.Adam, lr=1e-3),
    loss_fn=torch.nn.CrossEntropyLoss(),
    num_mc_samples=10,
)

Different datasets might have different conventions about what inputs and targets are called, especially across domains and tasks. By default lightning-uq-box expects a batch with dictionary keys "inputs" and "targets", however, these attributes can be changed as follows. The EuroSAT100 dataset has keys "image" and "label"

In [33]:
mc_dropout_model.input_key = "image"
mc_dropout_model.target_key = "label"

## Training

This is just standar training with Lightning, where you can configure the [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html) to your needs.

In [53]:
trainer = Trainer(
    max_epochs=20,
    logger=CSVLogger(root, name="mc_dropout_eurosat"),
    log_every_n_steps=5,
    devices=[0],
    accelerator="gpu",
)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [54]:
trainer.fit(mc_dropout_model, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name          | Type             | Params
---------------------------------------------------
0 | model         | ResNet           | 11.2 M
1 | loss_fn       | CrossEntropyLoss | 0     
2 | train_metrics | MetricCollection | 0     
3 | val_metrics   | MetricCollection | 0     
4 | test_metrics  | MetricCollection | 0     
---------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.852    Total estimated model params size (MB)


                                                                            

  rank_zero_warn(


Epoch 19: 100%|██████████| 4/4 [00:00<00:00,  5.13it/s, v_num=4]   

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 4/4 [00:01<00:00,  3.50it/s, v_num=4]


## Predictions

In [60]:
trainer.test(mc_dropout_model, datamodule=datamodule)

  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Testing DataLoader 0: 100%|██████████| 2/2 [00:00<00:00, 35.83it/s]


[{'testAcc': 0.5, 'testF1Score': 0.5}]

In [68]:
sample = next(iter(datamodule.test_dataloader()))
preds = mc_dropout_model.predict_step(sample["image"])

In [72]:
preds["pred"].shape, preds["pred_uct"].shape

(torch.Size([16, 10]), torch.Size([16]))