In [2]:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.nn.functional import  binary_cross_entropy_with_logits, sigmoid
import pytorch_lightning as pl
import torchmetrics
from torchmetrics import Metric
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import segmentation_models_pytorch as smp

  from .autonotebook import tqdm as notebook_tqdm


# Import from self-written code

In [3]:
from tif_processor import SatelliteDataset
from metrics import MIoU
from loss import dice_bce_loss_with_logits, dice_loss_with_logits

ValueError: No subfolders found in the specified input folder.

# Data Path

In [3]:
# Training dataset
feature_dir_train = "../data/CN/feature_trimmed.tif"
label_dir_train = "../data/CN/label_trimmed.tif"
feature_tiles_train = "../data/CN/tiles/features"
label_tiles_train = "../data/CN/tiles/labels"
feature_tiles_mergeback_train = "../data/CN/tiles/merge/merged_feature.tif"
label_tiles_mergeback_train = "../data/CN/tiles/merge/merged_label.tif"

# Test dataset
feature_dir_test = "../data/BZ/feature.tif"
label_dir_test = "../data/BZ/label.tif"
feature_tiles_test = "../data/BZ/tiles/features"
label_tiles_test = "../data/BZ/tiles/labels"
feature_tiles_mergeback_test = "../data/BZ/tiles/merge/merged_feature.tif"
label_tiles_mergeback_test = "../data/BZ/tiles/merge/merged_label.tif"

# HyperParameters

In [4]:
batch_size=4
shuffle=True
EPOCHS=5

# Dataset & DataLoader

In [5]:
dataset= SatelliteDataset(
feature_dir=feature_tiles_train,
label_dir=label_tiles_train,
weight_dir=None,
tiles=range(0, 320),
mu=None,
sigma=None,
sample=None
)

In [6]:
# set up train_val ratio
train_ratio = 0.8
train_size = int(len(dataset) * train_ratio)
val_size = len(dataset) - train_size
print(train_size)
print(val_size)

256
64


In [7]:
from torch.utils.data import random_split, DataLoader
# split train_val
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# create DataLoader
train_loader = DataLoader(train_dataset, batch_size, shuffle)
val_loader = DataLoader(val_dataset, batch_size, shuffle)
# len(train_loader), len(val_loader)

In [8]:
test_dataset = SatelliteDataset(
feature_dir=feature_tiles_test,
label_dir=label_tiles_test,
weight_dir=None,
tiles=range(0, 10),
mu=None,
sigma=None,
sample=None
)
test_loader=DataLoader(test_dataset, batch_size, shuffle)
# print(len(test_loader))

# Experiment

In [9]:
def save_mask(segmentation_mask, filename):
    segmentation_mask_np = segmentation_mask.detach().numpy()[0,:,:]
    segmentation_mask_np_uint8 = (segmentation_mask_np * 255).astype(np.uint8)
    segmentation_mask_pil = Image.fromarray(segmentation_mask_np_uint8)
    segmentation_mask_pil.save(filename)

import os
def mkpath(path: str) -> None:
    if not os.path.exists(path):
        os.makedirs(path)

In [10]:
class Experiment(pl.LightningModule):
    def __init__(self, arch="UNet", encoder_name="resnet34", encoder_weights="imagenet", in_channels=4, out_classes=1, experiment_name="Experiment1", loss="bce"):
        super().__init__()
        self.save_hyperparameters()
        self.experiment_name=experiment_name
        # Create Model
        self.model=smp.create_model(
            arch,
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=out_classes
        )

        self.loss=self._get_loss(loss)

        # Metrics
        self.val_miou = MIoU(2)
        self.val_acc = torchmetrics.Accuracy(task="binary")
        self.val_precision = torchmetrics.Precision(task="binary")
        self.val_recall = torchmetrics.Recall(task="binary")

        # test_Metrics
        self.test_miou=MIoU(2)
        self.test_acc=torchmetrics.Accuracy(task="binary")
        self.test_precision=torchmetrics.Precision(task="binary")
        self.test_recall=torchmetrics.Recall(task="binary")
    
    def configure_optimizers(self):
        optimizer=torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        self.log('train_loss', loss.detach(), prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # Forward Pass
        x, y = batch
        y_hat_loss = self(x)
        y_hat = torch.sigmoid(y_hat_loss)
        loss = self.loss(y_hat_loss, y)

        # Log Loss and Accuracy
        self.val_acc(y_hat, y)
        self.val_miou(y_hat, y)
        self.val_precision(y_hat, y)
        self.val_recall(y_hat, y)
        self.log('val_loss', loss, prog_bar=True, logger=True)
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_precision', self.val_precision, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_recall', self.val_recall, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_miou', self.val_miou, on_step=False, on_epoch=True, prog_bar=True, logger=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = torch.sigmoid(self(x))

        # Log Loss and Accuracy
        self.test_acc(y_hat, y)
        self.test_miou(y_hat, y)
        self.test_precision(y_hat, y)
        self.test_recall(y_hat, y)
        self.log('test_acc', self.test_acc, logger=True)
        self.log('test_precision', self.test_precision, logger=True)
        self.log('test_recall', self.test_recall, logger=True)
        self.log('test_miou', self.test_miou, logger=True)

        # Save Prediction and Label Masks
        y = y[0,:,:,:].cpu()
        out = y_hat[0,:,:,:].cpu()
        mkpath(f"predictions/{self.experiment_name}/masks")
        mkpath(f"predictions/{self.experiment_name}/preds")
        save_mask(y.cpu(), f"predictions/{self.experiment_name}/masks/mask_{batch_idx}.png")
        save_mask(out.round(), f"predictions/{self.experiment_name}/preds/pred_{batch_idx}.png")


    @staticmethod
    def _get_loss(loss):
        if loss == "dice":
            return dice_loss_with_logits
        elif loss == "dice_bce":
            return dice_bce_loss_with_logits
        return binary_cross_entropy_with_logits

# Training_Process

In [11]:
model=Experiment(experiment_name="Experiment0")

In [12]:
trainer=pl.Trainer(max_epochs=EPOCHS, log_every_n_steps=1)
trainer.fit(
    model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
    )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4050 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type            | Params | Mode 
-----------------------------------------------------------
0 | model          | Unet            | 24.4 M | train
1 | val_miou       | MIoU            | 0      | train
2 | val_acc        | BinaryAccuracy  | 0      | train
3 | val_precision  | BinaryPrecision | 0      | train
4 | val_recall     | BinaryRecall    | 0      | train
5 | test_miou      | MIoU            | 0      | train
6 | test_acc       | BinaryA

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

e:\Program Files\anaconda\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:475: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
e:\Program Files\anaconda\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
e:\Program Files\anaconda\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


# Validation and test metrics

In [13]:
# run validation dataset
valid_metrics = trainer.validate(model, dataloaders=val_loader, verbose=False)
print(valid_metrics)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.04803231358528137, 'val_acc': 0.9844863414764404, 'val_precision': 0.8115824460983276, 'val_recall': 0.6527876257896423, 'val_miou': 0.7755206823348999}]


In [14]:
import numpy as np
from PIL import Image
# run test dataset
test_metrics = trainer.test(model, dataloaders=test_loader, verbose=False)
print(test_metrics)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
e:\Program Files\anaconda\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:475: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.
e:\Program Files\anaconda\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_acc': 0.9868423342704773, 'test_precision': 0.37917208671569824, 'test_recall': 0.40324345231056213, 'test_miou': 0.6148345470428467}]
