In [15]:
import os
import tempfile
from typing import Dict, Optional, Any
from glob import glob

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader

from torchgeo.datasets import EuroSAT
from torchgeo.datamodules import NonGeoDataModule
from torchgeo.transforms import AugmentationSequential, indices
from torchgeo.trainers import ClassificationTask
from torchgeo.models import ResNet18_Weights, ResNet50_Weights

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger

seed_everything(543)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

Global seed set to 543


cuda


Train on all bands - experiment with model params, pretrained weights etc

In [16]:
## torchgeo implementation with mean and std nullified

class EuroSATDataModule(NonGeoDataModule):
    """LightningDataModule implementation for the EuroSAT dataset.

    Uses the train/val/test splits from the dataset.

    .. versionadded:: 0.2
    """

    mean = torch.zeros(13)
    std = torch.ones(13)

    def __init__(
        self, batch_size: int = 64, num_workers: int = 0, **kwargs: Any
    ) -> None:
        """Initialize a new EuroSATDataModule instance.

        Args:
            batch_size: Size of each mini-batch.
            num_workers: Number of workers for parallel data loading.
            **kwargs: Additional keyword arguments passed to
                :class:`~torchgeo.datasets.EuroSAT`.
        """
        super().__init__(EuroSAT, batch_size, num_workers, **kwargs)

In [17]:
if device == "cuda":
    batch_size = 128*5
    num_workers = 8
elif device ==  "cpu":
    batch_size = 64
    num_workers = 0
else:
    print("unknown device!")

datamodule = EuroSATDataModule(
    batch_size=batch_size, 
    root="data", 
    num_workers=num_workers, 
    # download=True,
)

## Experiment
Experiment with the model and pretrained weights -> https://torchgeo.readthedocs.io/en/stable/tutorials/pretrained_weights.html

In [18]:
task = ClassificationTask(
    model="resnet50",
    # weights=True, # standard Imagenet
    # weights=ResNet18_Weights.SENTINEL2_ALL_MOCO, # or try sentinel 2 all bands
    # weights=ResNet18_Weights.SENTINEL2_RGB_MOCO, # or try sentinel 2 rgb bands
    weights=ResNet50_Weights.SENTINEL2_ALL_MOCO, # or try sentinel 2 all bands
    num_classes=10,
    in_channels=13,
    loss="ce", 
    patience=6
)

# tb_logger = TensorBoardLogger("tensorboard_logs", name="eurosat")
wandb_logger = WandbLogger(
    project="eurosat", 
    name="resnet50_SENTINEL2_ALL_MOCO", 
    log_model='all' , # or True
    save_dir = "wandb_logs"
)

# checkpoint_callback = ModelCheckpoint(
#     monitor="val_loss", save_top_k=-1 #  dirpath=default_root_dir, , save_last=True
# )

early_stopping_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=6)

trainer = Trainer(
    logger=wandb_logger,
    callbacks=[early_stopping_callback], # checkpoint_callback
    min_epochs=5,
    max_epochs=10,
)

Downloading: "https://huggingface.co/torchgeo/resnet50_sentinel2_all_moco/resolve/main/resnet50_sentinel2_all_moco-df8b932e.pth" to /teamspace/studios/this_studio/.cache/torch/hub/checkpoints/resnet50_sentinel2_all_moco-df8b932e.pth
100%|██████████| 90.1M/90.1M [00:01<00:00, 51.9MB/s]


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [19]:
trainer.fit(model=task, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | criterion     | CrossEntropyLoss | 0     
1 | train_metrics | MetricCollection | 0     
2 | val_metrics   | MetricCollection | 0     
3 | test_metrics  | MetricCollection | 0     
4 | model         | ResNet           | 23.6 M
---------------------------------------------------
23.6 M    Trainable params
0         Non-trainable params
23.6 M    Total params
94.240    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
trainer.test(model=task, datamodule=datamodule)

In [None]:
wandb_logger.experiment.finish()

In [None]:
wandb_logger.experiment