# Task 2: CNN Training for pixel-wise classification

In this task you will be provided a model that was pretrained on [BigEarthNet v2](https://arxiv.org/abs/1902.06148) for pixel-wise classification (i.e. semantic segmentation). We will provide you with a checkpoint, as well as the model definition and your task is to load that model using these weights and finetune it on our target domain (forest segmentation) in our target location (Amazon Rainforest) with pytorch lightning. For that we will provide you with a finetuning dataset.

<img src="../../../data/Example_finetune.png" alt="Example from Dataset" width="600"/>


The goals of this task are as follows:
1. Load a pretrained pixelwise segmentation model
2. Adapt and finetune the model on a new domain (forest segmentation) and location (Amazon Rain Forest)

## Imports

These are all imports we used when solving the task. Please leave them as is even though you might not need all of them.

In [None]:
import os
import rootutils
root = rootutils.setup_root(os.path.abspath(''), dotenv=True, pythonpath=True, cwd=False)

data_path = root / "data"
data_path.mkdir(exist_ok=True)
output_dir = root / "output"
output_dir.mkdir(exist_ok=True)


In [None]:
import torch
from types import SimpleNamespace
from huggingface_hub import PyTorchModelHubMixin
import lightning as L
from configilm import ConfigILM # see https://lhackel-tub.github.io/ConfigILM/ for more information
from torchinfo import summary

from torchmetrics.segmentation import MeanIoU
import torch.nn as nn
import torch.nn.functional as F

import lmdb
from torch.utils.data import Dataset, DataLoader, random_split
from safetensors.numpy import load as load_np_safetensor
import torchvision.transforms as transforms

from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch
import random

## 2.1 Dataset + DataModule definition

Before we can use our data we need to wrap it in a pytorch dataset and thereafter in a lightning DataModule so we can use it for model training. 

### Dataset

For efficient data loading we have put the images in the file `images.lmdb` and the segmentation masks (forest/ no forest) in the file `mask.lmdb`. [LMDB](http://www.lmdb.tech/doc/) is a key-value in-memory database. For the images the key is the image name (1.tif, 2.tif,...) and the values are the image pixels as safetensor (Tip: use `load_np_safetensor` to read it). For the masks the key is the image name followed by _mask (1_mask.tif, 2_mask.tif, ...) the value again is the pixels as safetensor (1 for forest, 0 for no forest). We provided the helper function `_open_lmdb` which opens a connection to the lmdb for images or masks if it does not exist yet. You can read data from the lmdb through `with self.env_images.begin() as txn: txn.get()`. Feel free to add additional functions and adapt the already existing ones. Please open the lmdb only in the `__getitem__` method, due to multi processing.
Use preprocessing and data augmentation where applicable.

In [None]:
mean =  [438.37207031, 614.05566406, 588.40960693, 2193.29199219, 942.84332275, 1769.93164062, 2049.55151367, 1568.22680664, 997.73248291, 2235.55664062]
std = [607.02685547, 603.29681396, 684.56884766, 1369.3717041, 738.43267822, 1100.45605469, 1275.80541992, 1070.16125488, 813.52764893, 1356.54406738]


class FinetuneDataset(Dataset):
    def __init__(self, images_lmdb_path=data_path / "images.lmdb", masks_lmdb_path=data_path / "mask.lmdb", transform=None):
        self.images_lmdb_path = images_lmdb_path
        self.masks_lmdb_path = masks_lmdb_path

        self.env_images = None
        self.env_masks = None
        self.transform = transform


    def _open_lmdb(self, env, path):
        # If the environment is already opened, simply return it
        if env is not None:
            return env
    
        # The path must be a nonempty string
        if not path:
            raise ValueError("The LMDB path is not set")
    
        # Attempt to open the environment; if it fails, rewrap the exception
        try:
            return lmdb.open(path, readonly=True, lock=False)
        except lmdb.Error as e:
            raise RuntimeError(f"Failed to open LMDB at {path!r}") from e


    def __len__(self):
        pass

    def __getitem__(self, idx):
        # should return image, mask
        pass
        


### DataModule

Your DataModule needs to return a valid dataloader for training, validation and testing. Implement the [pytorch lighting](https://lightning.ai/docs/pytorch/stable/) training procedure.

In [None]:
class FinetuneDataModule(L.LightningDataModule):
    def __init__(self, images_lmdb_path=data_path / "images.lmdb", masks_lmdb_path=data_path / "mask.lmdb", batch_size=16, num_workers=0):
        super().__init__()
        self.images_lmdb_path = images_lmdb_path
        self.masks_lmdb_path = masks_lmdb_path
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        pass

    def train_dataloader(self):
        pass

    def val_dataloader(self):
        pass

    def test_dataloader(self):
        pass

## 2.2 Model Definition

In the following we provide you with the definition for a pretrained Resnet18 (pretrained on BigEarthNet). After we have given you an adaptation of the architecture to be used for semantic segmentation. You need to complete the rest of the required model setup.

### BEN pretrained Resnet18

Here we provide you with the definition of a Resnet18 model pretrained on BEN.

In [None]:
class Resnet(L.LightningModule, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()
        self.config = SimpleNamespace(**config)
        self.model = ConfigILM.ConfigILM(self.config)

    def forward(self, x):
        return self.model(x)

### Fully convolutional adaptation

We have only defined the bare minimum (architecture + forward pass). You need to fill in the rest and add functions were appropriate so the model can be used for training later on. As evaluation metric you can use mean Intersection over Union (mIoU). Have a look at [mIoU](https://lightning.ai/docs/torchmetrics/stable/segmentation/mean_iou.html) imported above. Implement the [pytorch lighting](https://lightning.ai/docs/pytorch/stable/) training steps.

In [None]:
pretrained_model  = Resnet.from_pretrained("BIFOLD-BigEarthNetv2-0/resnet18-s2-v0.2.0").model.vision_encoder
backbone = nn.Sequential(*list(pretrained_model.children())[:-2])

class FCNResnet(L.LightningModule):
    def __init__(self, num_classes=19, learning_rate=1e-4):
        super().__init__()
        self.learning_rate = learning_rate
        self.num_classes = num_classes
        self.mIoU = MeanIoU(num_classes=num_classes)
        self.val_outputs = []
        self.test_outputs = []

        self.backbone = backbone

        # Upsample the encoded input to the size of the image.
        self.decoder = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),

            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),

            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),

            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(size=(120,120), mode='bilinear', align_corners=False),

            nn.Conv2d(32, num_classes, kernel_size=1)
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.decoder(x)

        return x

    def training_step(self, batch, batch_idx):
        pass

    def validation_step(self, batch, batch_idx):
        pass

    def test_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):
        pass

## 2.3 Finetuning

Please write the logic required for finetuning the model using the DataModule you have defined above. The checkpoint is the one provided by us finetuned on segmentation for BigEarthNet. Adapt the model if necessary. Briefly describe the results.


In [None]:
ckpt_path = data_path / "pretrained_model.ckpt"
model = None

trainer = None

**TODO:** Describe the results in a few paragraphs.

## 2.4 Training Visualization + Evaluation

It is always good to visualize your training and some qualitative examples on top of the quantitative results obtained above. In this task you should:
1. Visualize model performance over the training epochs
2. Visualize some examples.

Briefly describe the results.

### 2.4.1 Training Visualization

Please visualize validation loss as well as validation performance over the epochs of your training. We recommend using the lighting `CSVLogger`. Plot the results.

In [None]:
# Plot mIoU and loss over training epochs

**TODO:** Describe the results in a few paragraphs.

### 2.4.2 Qualitative Evaluation

Please visualize a few (at least 2) example outputs in the form: 1: Input Image 2: Reference Mask 3: Predicted Mask.


In [None]:
# Plot some (at least 2) example images
# Plot: Input Image - Reference Mask - Predicted Mask

**TODO:** Describe the results in a few paragraphs.