In [None]:
# Copyright (c) 2023 Allen Institute for AI

# Demo using the torchgeo package to initialize a SatlasPretrain model and finetune
# on the UCMerced dataset.
#
# SETUP - this demo requires a DIFFERENT conda environment than the SatlasPretrain demo
# conda create --name torchgeodemo python=3.12
# conda activate torchgeodemo
# NOTE: Satlas weights will be a part of the 0.6.0 release and the current version is 0.5.1, so install from git for now.
# pip install git+https://github.com/microsoft/torchgeo 

In [1]:
import os
import torch
import tempfile
from typing import Optional
from lightning.pytorch import Trainer

from torchgeo.models import Swin_V2_B_Weights, swin_v2_b
from torchgeo.datamodules import UCMercedDataModule
from torchgeo.trainers import ClassificationTask

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Experiment Arguments
batch_size = 8
num_workers = 2
max_epochs = 10
fast_dev_run = False

In [3]:
# Torchgeo lightning datamodule to initialize dataset
root = os.path.join(tempfile.gettempdir(), "ucm")
datamodule = UCMercedDataModule(
    root=root, batch_size=batch_size, num_workers=num_workers, download=True
)

In [11]:
# Custom ClassificationTask to load in the SatlasPretrain model
class SatlasClassificationTask(ClassificationTask):
    def configure_models(self):
        weights = Swin_V2_B_Weights.SENTINEL2_RGB_SI_SATLAS
        self.model = swin_v2_b(weights)

        # Replace first layer's input channels with the task's number input channels.
        first_layer = self.model.features[0][0]
        self.model.features[0][0] = torch.nn.Conv2d(3,
                                    first_layer.out_channels,
                                    kernel_size=first_layer.kernel_size,
                                    stride=first_layer.stride,
                                    padding=first_layer.padding,
                                    bias=(first_layer.bias is not None))

        # Replace last layer's output features with the number classes.
        self.model.head = torch.nn.Linear(in_features=1024, out_features=self.hparams["num_classes"], bias=True)

    def on_validation_epoch_end(self):
        # Accessing metrics logged during the current validation epoch
        val_loss = self.trainer.callback_metrics.get('val_loss', 'N/A')
        val_acc = self.trainer.callback_metrics.get('val_OverallAccuracy', 'N/A')
        print(f"Epoch {self.current_epoch} Validation - Loss: {val_loss}, Accuracy: {val_acc}")

    def on_validation_epoch_end(self):
        # Accessing metrics logged during the current validation epoch
        val_loss = self.trainer.callback_metrics.get('val_loss', 'N/A')
        val_acc = self.trainer.callback_metrics.get('val_OverallAccuracy', 'N/A')
        print(f"Epoch {self.current_epoch} Validation - Loss: {val_loss}, Accuracy: {val_acc}")


In [12]:
# Initialize the Classifcation Task
task = SatlasClassificationTask(num_classes=21)

In [13]:
# Initialize the training code.
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
default_root_dir = os.path.join(tempfile.gettempdir(), "experiments")

trainer = Trainer(
    accelerator=accelerator,
    default_root_dir=default_root_dir,
    fast_dev_run=fast_dev_run,
    log_every_n_steps=1,
    min_epochs=1,
    max_epochs=max_epochs,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [14]:
# Train
trainer.fit(model=task, datamodule=datamodule)


  | Name          | Type             | Params
---------------------------------------------------
0 | criterion     | CrossEntropyLoss | 0     
1 | train_metrics | MetricCollection | 0     
2 | val_metrics   | MetricCollection | 0     
3 | test_metrics  | MetricCollection | 0     
4 | model         | SwinTransformer  | 86.9 M
---------------------------------------------------
86.9 M    Trainable params
0         Non-trainable params
86.9 M    Total params
347.709   Total estimated model params size (MB)


Sanity Checking DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  0.60it/s]Epoch 0 Validation - Loss: 3.1346988677978516, Accuracy: 0.0
                                                                                                                                                                                                               

/Users/piperw/opt/anaconda3/envs/torchgeotest/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 0:  78%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                | 124/158 [10:06<02:46,  0.20it/s, v_num=1]

/Users/piperw/opt/anaconda3/envs/torchgeotest/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
