In [None]:
# !pip install -r /teamspace/studios/this_studio/eda-bids-hackathon-prep/sentinel2-modelling/requirements.txt

In [None]:
import os
import tempfile
from typing import Dict, Optional, Any
from glob import glob

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader

from torchgeo.datasets import EuroSAT
from torchgeo.datamodules import EuroSATDataModule
from torchgeo.transforms import AugmentationSequential, indices
from torchgeo.trainers import ClassificationTask
from torchgeo.models import ResNet18_Weights, ResNet50_Weights

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger

import lightning
print(lightning.__version__)

seed_everything(543)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

Experiment with the bands used to train a classifier on Sentinel 2 imagery, also different spectral

In [None]:
EuroSAT.all_band_names

In [None]:
EuroSAT.rgb_bands

In [None]:
BANDS = ('B04', 'B03', 'B02', 'B8A')

Experiment with bands, model parameters, pretrained weights etc

In [None]:
if device == "cuda":
    batch_size = 128*5 # vary for your GPU
    num_workers = 8
elif device ==  "cpu":
    batch_size = 64
    num_workers = 0

datamodule = EuroSATDataModule(
    batch_size=batch_size, 
    root="data", 
    num_workers=num_workers, 
    bands=BANDS,
    download=True,
)

## Experiment
Experiment with the model and pretrained weights -> https://torchgeo.readthedocs.io/en/stable/tutorials/pretrained_weights.html

In [None]:
task = ClassificationTask(
    model="resnet18",
    weights=True, # standard Imagenet
    # weights=ResNet18_Weights.SENTINEL2_ALL_MOCO, # or try sentinel 2 all bands
    # weights=ResNet18_Weights.SENTINEL2_RGB_MOCO, # or try sentinel 2 rgb bands
    num_classes=10,
    in_channels=len(BANDS),
    loss="ce", 
    patience=6
)

# tb_logger = TensorBoardLogger("tensorboard_logs", name="eurosat") # if you prefer tensorboard
wandb_logger = WandbLogger(
    project="eurosat", 
    name="resnet18_imagenet", 
    log_model=True, # or 'all' 
    save_dir = "wandb_logs"
)

trainer = Trainer(
    logger=wandb_logger,
    min_epochs=5,
    max_epochs=25,
    enable_model_summary=False, # https://github.com/Lightning-AI/lightning/issues/12233
)

In [None]:
trainer.fit(model=task, datamodule=datamodule)

In [None]:
trainer.test(model=task, datamodule=datamodule)

In [None]:
wandb_logger.experiment.finish()