In [0]:
%python
%pip install torch
%pip install lightning
%pip install segmentation-models-pytorch
%pip install torchmetrics

In [0]:
import torch
from lightning.pytorch import LightningModule
import segmentation_models_pytorch as smp
import torch.nn.functional as F
from torchmetrics.classification import MulticlassJaccardIndex


class DeepLabV3Lightning(LightningModule):
    def __init__(self,config):
        super().__init__()
        self.save_hyperparameters(config)

        self.model = smp.DeepLabV3(
            encoder_name=config.get("backbone", "resnet50"),
            encoder_weights=config.get("encoder_weights", "imagenet"),
            in_channels=config.get("in_channels", 4),
            classes=config.get("num_classes", 2),
        )

        self.lr= config.get("lr", 1e-3)
        self.loss_fn=smp.losses.DiceLoss(mode='multiclass')
        self.iou_metric = MulticlassJaccardIndex( num_classes=config.get("num_classes", 2))
       


    def forward(self, x):
        if x.dtype == torch.uint8:
            x = x.float() / 255.0
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x,y = batch
        logits = self(x)
        loss= self.loss_fn(logits, y)
        preds = torch.softmax(logits, dim=1)
        iou = self.iou_metric(preds, y)
        self.log("train_loss", loss)
        self.log("train_iou", iou, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x,y = batch
        logits = self(x)
        loss= self.loss_fn(logits, y)
        preds = torch.softmax(logits, dim=1)
        iou = self.iou_metric(preds, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_iou", iou, on_epoch=True, prog_bar=True)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)
    



def get_deeplabv3_lightning(config):
        return DeepLabV3Lightning(config=config)