# Change Detection using OSCD 

This notebook demonstrates training a change detection model using the OSCD dataset 

## Environment Setup 

Refer to README.md for environment setup. 

### Import and Init Env

In [None]:
import os

# If using LightningAI, change the current working directory to the directory containing this notebook. 
REPO_DIR = "/teamspace/studios/this_studio/eda-bids-hackathon-prep/"  # Adjust as appropriate
if os.path.exists(REPO_DIR):
    os.chdir(os.path.join(REPO_DIR, "sentinel2-modelling"))

# If you encounter a warning regarding gdal mising GDAL_DATA, run the following 
if os.environ.get('CONDA_PREFIX') is not None: 
    if os.environ.get('GDAL_DATA') is None: os.environ["CONDA_PREFIX"] + r"\Library\share\gdal"
    if os.environ.get('PROJ_LIB') is None: os.environ["CONDA_PREFIX"] + r"\Library\share\proj"

In [None]:
import os
import tempfile
from typing import Dict, Optional
from glob import glob

import torch
import torch.nn as nn
import torchvision.transforms as T

from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader

import torchmetrics as tm
from torchmetrics import Metric
from torchmetrics.classification import BinaryJaccardIndex, BinaryF1Score

from torchgeo.datasets import OSCD
from torchgeo.datamodules import OSCDDataModule
from torchgeo.trainers import SemanticSegmentationTask
from torchgeo.datasets.utils import unbind_samples

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger

import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision.transforms import Compose
from tqdm import tqdm

from sklearn.metrics import precision_score, recall_score

import lightning
print(lightning.__version__)

seed_everything(543)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

# Load EDS credentials from .env file
from dotenv import load_dotenv
load_dotenv()

In [None]:
if device == "cuda":
    batch_size = 8
    num_workers = 8
elif device ==  "cpu":
    batch_size = 4
    num_workers = 0

## Dataset Download
This is a large dataset to download - download on CPU before switching to GPU

In [None]:
print(OSCD.all_bands)

In [None]:
print(OSCD.rgb_bands)

Select the bands to experiment with

In [None]:
BANDS = ('B04', 'B03', 'B02', 'B8A')

In [None]:
datamodule = OSCDDataModule(
    batch_size=batch_size,
    num_workers=num_workers,
    download=True,
    bands=BANDS,
    patch_size=256
)

Viz a sample, remembering they are patchified on training

In [None]:
datamodule.setup(stage="fit")
fig = datamodule.train_dataset.dataset.plot(datamodule.train_dataset[0])

## Training
We will use a custom semantic segmentation model for change detection

In [None]:
class CustomSemanticSegmentationTask(SemanticSegmentationTask):
    def __init__(self, **kwargs):
        super().__init__()
        self.train_f1 = BinaryF1Score()
        self.val_f1 = BinaryF1Score()
        self.test_f1 = BinaryF1Score()
        self.train_iou = BinaryJaccardIndex()
        self.val_iou = BinaryJaccardIndex()
        self.test_iou = BinaryJaccardIndex()

    def plot(self, sample):
        image1 = sample["image1"]
        image2 = sample["image2"]
        mask = sample["mask"]
        prediction = sample["prediction"]

        fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(4 * 5, 5))
        axs[0].imshow(image1.permute(1, 2, 0))
        axs[0].axis("off")
        axs[1].imshow(image2.permute(1, 2, 0))
        axs[1].axis("off")
        axs[2].imshow(mask)
        axs[2].axis("off")
        axs[3].imshow(prediction)
        axs[3].axis("off")

        axs[0].set_title("Image 1")
        axs[1].set_title("Image 2")
        axs[2].set_title("Mask")
        axs[3].set_title("Prediction")

        plt.tight_layout()
        
        return fig

    def training_step(self, *args, **kwargs):
        batch = args[0]
        batch_idx = args[1]
        
        image1 = batch["image1"].float()
        image2 = batch["image2"].float()
        x = torch.cat([image1, image2], dim=1)
        y = batch["mask"].long()
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.log("train_loss", loss, on_step=True, on_epoch=False)
        self.train_metrics(y_hat, y)

        y_hat_hard = y_hat.argmax(dim=1) # convert to hard predictions, i.e. 0 or 1
        self.train_f1.update(y_hat_hard, y)
        self.train_iou.update(y_hat_hard, y)
        return loss
    
    def validation_step(self, *args, **kwargs):
        batch = args[0]
        batch_idx = args[1]
        image1 = batch["image1"]
        image2 = batch["image2"]
        x = torch.cat([image1, image2], dim=1)
        y = batch["mask"]
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.log("val_loss", loss, on_step=False, on_epoch=True)
        self.val_metrics(y_hat, y)

        y_hat_hard = y_hat.argmax(dim=1)
        self.val_f1.update(y_hat_hard, y)
        self.val_iou.update(y_hat_hard, y)
        return loss

    def test_step(self, *args, **kwargs):
        batch = args[0]
        batch_idx = args[1]
        image1 = batch["image1"]
        image2 = batch["image2"]
        x = torch.cat([image1, image2], dim=1)
        y = batch["mask"]
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.log("test_loss", loss, on_step=False, on_epoch=True)
        self.test_metrics(y_hat, y)

        y_hat_hard = y_hat.argmax(dim=1)
        self.test_f1.update(y_hat_hard, y)
        self.test_iou.update(y_hat_hard, y)
        return loss

    def on_train_epoch_end(self):
        self.log("train_f1", self.train_f1.compute())
        self.train_f1.reset()
        self.log("train_iou", self.train_iou.compute())
        self.train_iou.reset()

    def on_validation_epoch_end(self):
        self.log("val_f1", self.val_f1.compute())
        self.val_f1.reset()
        self.log("val_iou", self.val_iou.compute())
        self.val_iou.reset()

    def on_test_epoch_end(self):
        self.log("test_f1", self.test_f1.compute())
        self.log("test_iou", self.test_iou.compute())

In [None]:
task = CustomSemanticSegmentationTask(
    model="unet",
    weights=True,
    num_classes=2,
    in_channels=len(BANDS)*2,
    loss="ce", 
)

In [None]:
wandb_logger = WandbLogger(
    project="oscd",  
    log_model=True, # or 'all' 
    save_dir = "wandb_logs"
)

In [None]:
trainer = Trainer(
    logger=[wandb_logger],
    min_epochs=20,
    max_epochs=25,
)

In [None]:
trainer.fit(model=task, datamodule=datamodule)

Note the test cell raises an error - ReferenceError: weakly-referenced object no longer exists

Can you beat:
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric        ┃       DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│          test_f1          │    0.33471593260765076    │
│         test_iou          │    0.2009963095188141     │
│         test_loss         │    0.14401671290397644    │
```

In [None]:
trainer.test(model=task, datamodule=datamodule)

In [None]:
# Log experiement to WandB
wandb_logger.experiment.finish()