# Chip Classification using EuroSAT - Training

This notebook demonstrates training a chip classifier on a Sentinel 2 dataset called [EuroSAT](https://github.com/phelber/EuroSAT). Experiment with the choice of model, hyperparameters and pretrained weights to achieve the best performance you can. Note that using the [wandb logger](https://wandb.ai/) only requires a free account

## Environment Setup 

Refer to README.md for environment setup. 

### Import and Init Env

In [None]:
import os

# If using LightningAI, change the current working directory to the directory containing this notebook. 
REPO_DIR = "/teamspace/studios/this_studio/eda-bids-hackathon-prep/"  # Adjust as appropriate
if os.path.exists(REPO_DIR):
    os.chdir(os.path.join(REPO_DIR, "sentinel2-modelling"))

# If you encounter a warning regarding gdal mising GDAL_DATA, run the following 
if os.environ.get('CONDA_PREFIX') is not None: 
    if os.environ.get('GDAL_DATA') is None: 
        os.environ['GDAL_DATA'] = os.environ["CONDA_PREFIX"] + r"\Library\share\gdal"
    if os.environ.get('PROJ_LIB') is None: 
        os.environ['PROJ_LIB'] = os.environ["CONDA_PREFIX"] + r"\Library\share\proj"

In [None]:
import os
import tempfile
from typing import Dict, Optional, Any
from glob import glob

import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader

from torchgeo.datasets import EuroSAT
from torchgeo.datamodules import EuroSATDataModule
from torchgeo.transforms import AugmentationSequential, indices
from torchgeo.trainers import ClassificationTask
from torchgeo.models import ResNet18_Weights, ResNet50_Weights

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger

import lightning
print(lightning.__version__)

seed_everything(543)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

# Load EDS credentials from .env file
from dotenv import load_dotenv
load_dotenv()

Experiment with the bands used to train a classifier on Sentinel 2 imagery

In [None]:
EuroSAT.all_band_names

In [None]:
EuroSAT.rgb_bands

In [None]:
BANDS = ('B04', 'B03', 'B02', 'B8A')

Experiment with bands, model parameters, pretrained weights etc

In [None]:
if device == "cuda":
    batch_size = 128*5 # vary for your GPU
    num_workers = 8
elif device ==  "cpu":
    batch_size = 64
    num_workers = 0

## Dataset Download
This is a large dataset to download - download on CPU before switching to GPU

In [None]:

datamodule = EuroSATDataModule(
    batch_size=batch_size, 
    root="data", 
    num_workers=num_workers, 
    bands=BANDS,
    download=True,
)

## Experiment
Experiment with the model and pretrained weights -> https://torchgeo.readthedocs.io/en/stable/tutorials/pretrained_weights.html

In [None]:
task = ClassificationTask(
    model="resnet18",
    weights=True, # standard Imagenet
    # weights=ResNet18_Weights.SENTINEL2_ALL_MOCO, # or try sentinel 2 all bands
    # weights=ResNet18_Weights.SENTINEL2_RGB_MOCO, # or try sentinel 2 rgb bands
    num_classes=10,
    in_channels=len(BANDS),
    loss="ce", 
    patience=6
)

# tb_logger = TensorBoardLogger("tensorboard_logs", name="eurosat") # if you prefer tensorboard
wandb_logger = WandbLogger(
    project="eurosat", 
    name="resnet18_imagenet", 
    log_model=True, # or 'all' 
    save_dir = "wandb_logs"
)

trainer = Trainer(
    logger=wandb_logger,
    min_epochs=5,
    max_epochs=25,
    enable_model_summary=False, # https://github.com/Lightning-AI/lightning/issues/12233
)

Note training on CPU on Github, the cell below takes 66m

In [None]:
trainer.fit(model=task, datamodule=datamodule)

Note the cell below raises - ReferenceError: weakly-referenced object no longer exists

In [None]:
trainer.test(model=task, datamodule=datamodule)

Can you beat:
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric        ┃       DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│   test_AverageAccuracy    │     0.950072169303894     │
│       test_F1Score        │    0.9520370364189148     │
│     test_JaccardIndex     │    0.9078730940818787     │
│   test_OverallAccuracy    │    0.9520370364189148     │
│         test_loss         │    0.16335640847682953
```

In [None]:
wandb_logger.experiment.finish()