Skip to content

Commit

Permalink
use imagenet pretrained model & add a baseline unet as reference
Browse files Browse the repository at this point in the history
- add a batch-norm layer for the tinycd model
- add a unet model for reference
- use efficientnet model pretrained on imagenet
- reduce the `pos_weight` for bce loss to 5.0
  • Loading branch information
SRM committed May 31, 2023
1 parent 1b59e8a commit 75c6d28
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 25 deletions.
48 changes: 29 additions & 19 deletions chabud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,17 @@
import torch
import torch.nn.functional as F
import torchmetrics
from torchmetrics.classification import BinaryJaccardIndex
import trimesh.voxel.runlength
from pytorch_toolbelt.losses import BinaryFocalLoss, BinaryLovaszLoss, DiceLoss
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import (
FocalLoss,
LovaszLoss,
DiceLoss,
)

from chabud.tinycd_model import ChangeClassifier
from chabud.unet_model import UnetChangeClassifier


class ChaBuDNet(L.LightningModule):
Expand Down Expand Up @@ -71,32 +78,39 @@ def __init__(

# Loss functions
self.loss_bce = torch.nn.BCEWithLogitsLoss(
pos_weight=torch.tensor(32.0), reduction="mean"
pos_weight=torch.tensor(5.0), reduction="mean"
)
# self.loss_dice = DiceLoss(mode="binary", from_logits=True, smooth=0.1)
# self.loss_focal = BinaryFocalLoss(alpha=0.25, gamma=2.0)
# self.loss_focal = FocalLoss(mode="binary", alpha=0.25, gamma=2.0)

# Evaluation metrics to know how good the segmentation results are
self.iou = torchmetrics.JaccardIndex(
task="binary", threshold=0.5, num_classes=2
)
self.iou = BinaryJaccardIndex(threshold=0.5)

def _init_model(self, name):
if name == "tinycd":
return ChangeClassifier(
bkbn_name="efficientnet_b4",
weights=None, # not using pretrained weights
pretrained=True,
output_layer_bkbn="3",
freeze_backbone=False,
)
elif name == "unet":
return UnetChangeClassifier()
else:
return NotImplementedError(f"model {name} is not available")

def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
"""
Forward pass (Inference/Prediction).
"""
y_hat: torch.Tensor = self.model(x1, x2)
if self.hparams.model_name == "tinycd":
y_hat: torch.Tensor = self.model(x1, x2)
elif self.hparams.model_name == "unet":
y_hat: torch.Tensor = self.model(x1, x2)
else:
raise NotImplementedError(
f"model {self.hparams.model_name} is not available"
)

return y_hat

Expand All @@ -111,14 +125,13 @@ def shared_step(
"""
# dtype = torch.float16 if "16" in self.trainer.precision else torch.float32
pre_img, post_img, mask, metadata = batch
# y_hat is logits
y_hat: torch.Tensor = self(x1=pre_img, x2=post_img).squeeze()
y_pred: torch.Tensor = F.sigmoid(y_hat).detach().byte()
logits: torch.Tensor = self(x1=pre_img, x2=post_img).squeeze()
y_pred: torch.Tensor = F.sigmoid(logits).detach()

# Compute loss and metrics
loss: torch.Tensor = self.loss_bce(input=y_hat, target=mask.float())
metric: torch.Tensor = self.iou(preds=y_pred, target=mask)
loss_and_metric: dict = {f"{phase}/loss_dice": loss, f"{phase}/iou": metric}
loss: torch.Tensor = self.loss_bce(logits, mask.float())
metric: torch.Tensor = self.iou(y_pred, mask)
loss_and_metric: dict = {f"{phase}/loss_bce": loss, f"{phase}/iou": metric}
# Report fit/val losses and Intersection over Union metric to the console
self.log_dict(
dictionary=loss_and_metric, on_step=True, on_epoch=False, prog_bar=True
Expand Down Expand Up @@ -165,14 +178,11 @@ def test_step(
- https://huggingface.co/datasets/chabud-team/chabud-ecml-pkdd2023/blob/main/create_sample_submission.py
- https://trimsh.org/trimesh.voxel.runlength.html#trimesh.voxel.runlength.dense_to_brle
"""
dtype = torch.float16 if "16" in self.trainer.precision else torch.float32
pre_img, post_img, mask, metadata = batch

# Pass the image through neural network model to get predicted images
# y_hat is logits
y_hat: torch.Tensor = self(x1=pre_img, x2=post_img).squeeze()
# y_pred: torch.Tensor = (F.sigmoid(y_hat) > 0.5).detach().byte()
# assert y_hat.shape == mask.shape == (32, 512, 512)
logits: torch.Tensor = self(x1=pre_img, x2=post_img).squeeze()
y_pred: torch.Tensor = F.sigmoid(logits).detach().cpu()

# Format predicted mask as binary run length encoding vector
result: list = []
Expand Down
18 changes: 13 additions & 5 deletions chabud/tinycd_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torchvision
from torch import Tensor
from torch.nn import Module, ModuleList, Sigmoid
from torch.nn import Module, ModuleList, Sigmoid, BatchNorm2d

from chabud.layers import MixingBlock, MixingMaskAttentionBlock, PixelwiseLinear, UpMask

Expand All @@ -22,17 +22,20 @@ class ChangeClassifier(Module):
def __init__(
self,
bkbn_name="efficientnet_b4",
weights=None, # not using pretrained weights
pretrained=True,
output_layer_bkbn="3",
freeze_backbone=False,
):
super().__init__()

# Load the pretrained backbone according to parameters:
self._backbone = _get_backbone(
bkbn_name, weights, output_layer_bkbn, freeze_backbone
bkbn_name, pretrained, output_layer_bkbn, freeze_backbone
)

# Normalize the input:
self._normalize = BatchNorm2d(3) # 3 number of bands

# Initialize mixing blocks:
self._first_mix = MixingMaskAttentionBlock(6, 3, [3, 10, 5], [10, 5, 1])
self._mixing_mask = ModuleList(
Expand All @@ -57,6 +60,7 @@ def __init__(
self._classify = PixelwiseLinear([32, 16, 8], [16, 8, 1])

def forward(self, ref: Tensor, test: Tensor) -> Tensor:
ref, test = self._normalize(ref), self._normalize(test)
features = self._encode(ref, test)
latents = self._decode(features)
return self._classify(latents)
Expand All @@ -76,9 +80,13 @@ def _decode(self, features) -> Tensor:
return upping


def _get_backbone(bkbn_name, weights, output_layer_bkbn, freeze_backbone) -> ModuleList:
def _get_backbone(
bkbn_name, pretrained, output_layer_bkbn, freeze_backbone
) -> ModuleList:
# The whole model:
entire_model = getattr(torchvision.models, bkbn_name)(weights=weights).features
entire_model = getattr(torchvision.models, bkbn_name)(
pretrained=pretrained
).features

# Slicing it:
derived_model = ModuleList([])
Expand Down
21 changes: 21 additions & 0 deletions chabud/unet_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp


class UnetChangeClassifier(nn.Module):
def __init__(self, in_channels=6, out_channels=1):
super().__init__()
self.model = smp.Unet(
encoder_name="timm-efficientnet-b0",
encoder_weights="imagenet",
in_channels=in_channels,
classes=out_channels,
activation=None,
)
self.normalize = nn.BatchNorm2d(in_channels)

def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x = torch.cat([x1, x2], dim=1)
x = self.normalize(x)
return self.model(x)
3 changes: 2 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- xarray-datatree=0.0.12
- pip:
- typeshed-client==2.3.0
- pytorch_toolbelt==0.6.3
- segmentation-models-pytorch
- matplotlib
platforms:
- linux-64

0 comments on commit 75c6d28

Please sign in to comment.