In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Define the complete path to your dataset
DATASET_PATH = '/content/drive/MyDrive/Computer Vision/50.035 CV Team 9'

# Change directory to the dataset location
%cd "/content/drive/MyDrive/Computer Vision/50.035 CV Team 9"

# Verify the path exists (optional check)
import os
assert os.path.exists(DATASET_PATH), "[!] Dataset path does not exist. Please check the path."

Mounted at /content/drive
/content/drive/.shortcut-targets-by-id/1Mdz9CpJD5zYhDk1e3Ch93fV4o95Ud7HJ/50.035 CV Team 9


In [2]:
%%capture
def is_running_in_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

if is_running_in_colab():
  # Normal packages
  %pip install lightning polars segmentation_models_pytorch
  # Dev packages
  %pip install icecream rich tqdm

In [3]:
from pathlib import Path

import polars as pl
import torch
import torch.nn as nn
from torchvision.io import decode_image
from torchvision.transforms import v2
from torchvision.tv_tensors import Image, Mask
import lightning as L
from lightning.pytorch.callbacks import RichProgressBar
from lightning.pytorch.loggers import CSVLogger
import torchmetrics
import torchmetrics.segmentation
import segmentation_models_pytorch as smp

# Dev Imports
from icecream import ic

class SegmentationData(L.LightningDataModule):
    def __init__(self, ws_root: Path = Path("."), num_workers=0):
        super().__init__()
        self.data_path = ws_root / 'segmentation_dataset' / 'data'
        self.image_names = list(f.stem for f in (self.data_path / "masks").iterdir())

        self.dataloader_extras = dict(
            num_workers = num_workers,
            pin_memory = True,
            persistent_workers = num_workers > 0
        )

        self.n_classes = 1

    def setup(self, stage: str):
        train, val, test = torch.utils.data.random_split(self.image_names, [0.8, 0.1, 0.1], generator=torch.Generator().manual_seed(42))
        self.train_ds = ImageDataset(train, self.data_path, training=True)
        self.val_ds = ImageDataset(val, self.data_path, training=True)
        self.test_ds = ImageDataset(test, self.data_path)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_ds, batch_size=16, shuffle=True, **self.dataloader_extras)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_ds, batch_size=64, **self.dataloader_extras)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_ds, batch_size=1, **self.dataloader_extras)

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, image_names, data_path, training=False):
        super().__init__()
        self.mask_path = data_path / "masks"
        self.image_path = data_path / "images"
        self.image_names = image_names

        self.training = training
        self.train_transforms = v2.Compose([
            v2.RandomHorizontalFlip(),
            v2.RandomVerticalFlip(),
        ])
        self.transforms = v2.Compose([
            v2.RandomResizedCrop((256, 256)),
            v2.ToDtype(torch.float32, scale=True),
            v2.ToPureTensor(),
        ])

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image = Image(decode_image(self.image_path / f"{self.image_names[idx]}.jpg", mode="RGB"))
        mask = Mask(decode_image(self.mask_path / f"{self.image_names[idx]}.png", mode="GRAY"))
        if self.training:
            image, mask = self.train_transforms(image, mask)
        image, mask = self.transforms(image, mask)
        mask = (mask > 37).to(torch.long).squeeze()
        return image, mask

class WrappedModel(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.model = smp.DeepLabV3(
            classes=n_classes,
            encoder_name="mobilenet_v2",
            encoder_weights="imagenet"
        )

    def forward(self, x):
        x = self.model(x)
        return x

class LitWrappedModel(L.LightningModule):
    def __init__(self, n_classes):
        super().__init__()
        self.model = WrappedModel(n_classes)
        self.n_classes = n_classes

        self.val_metrics = torchmetrics.MetricCollection(
            {
                "pixel_accuracy": torchmetrics.classification.Accuracy(task="binary", num_classes=n_classes),
                "pixel_f1": torchmetrics.classification.F1Score(task="binary", num_classes=n_classes),
                "DICE": torchmetrics.segmentation.GeneralizedDiceScore(num_classes=2, input_format="index"),
                "IOU": torchmetrics.segmentation.MeanIoU(num_classes=2, input_format="index"),
            },
            prefix="val_",
        )
        self.test_metrics = self.val_metrics.clone(prefix="test_")

        self.losses = [smp.losses.DiceLoss('binary'), smp.losses.SoftBCEWithLogitsLoss()]

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.model(x).squeeze(1)
        loss = sum(loss(y_pred, y.to(torch.float32)) for loss in self.losses)
        self.log("train_loss", loss, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.model(x)
        y_pred = (y_pred > 0).to(torch.long).squeeze(1)
        self.log_dict(self.val_metrics(y_pred, y), prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.model(x)
        y_pred = (y_pred > 0).to(torch.long).squeeze(1)
        self.log_dict(self.test_metrics(y_pred, y), prog_bar=True)

    def on_validation_epoch_end(self):
        L.pytorch.utilities.memory.garbage_collection_cuda()

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Main execution
exp_name = "DeepLabV3_MobileNet"
segmentation_data = SegmentationData(num_workers=15)
lit_model = LitWrappedModel(segmentation_data.n_classes)

trainer = L.Trainer(
    max_epochs=200,
    accelerator='gpu',
    callbacks=[RichProgressBar()],
    logger=CSVLogger("csv_logs/segmentation", name=exp_name, version=0)
)
trainer.fit(model=lit_model, datamodule=segmentation_data)

# Save models
model_save_path = Path("models") / "segmentation" / exp_name
model_save_path.mkdir(exist_ok=True, parents=True)

model = lit_model.model
model = model.eval().cpu()

# Save weights
torch.save(model.state_dict(), model_save_path / f"weights_{exp_name}.pt")

# Save full model
torch.save(model, model_save_path / f"model_{exp_name}.pt")

# Try simpler TorchScript export
try:
    # Trace the model with example input
    example_input = torch.randn(1, 3, 512, 512)
    traced_model = torch.jit.trace(model, example_input)
    torch.jit.save(traced_model, model_save_path / f"traced_{exp_name}.pt")
    print("Successfully exported traced model")
except Exception as e:
    print(f"Tracing failed with error: {e}")

# Test the model
trainer.test(model=lit_model, datamodule=segmentation_data)

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
100%|██████████| 13.6M/13.6M [00:00<00:00, 186MB/s]
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/lightning/fabric/loggers/csv_logs.py:268: Experiment logs directory csv_logs/segmentation/DeepLabV3_MobileNet/version_0 exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!
/usr/local/lib/python3.10/dist-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory csv_logs/segmentation/DeepLabV3_MobileNet/versi

Output()

INFO: `Trainer.fit` stopped: `max_epochs=200` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=200` reached.


  if h % output_stride != 0 or w % output_stride != 0:
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Successfully exported traced model


Output()

[{'test_DICE': 0.755911648273468,
  'test_IOU': 0.7534765005111694,
  'test_pixel_accuracy': 0.9089342355728149,
  'test_pixel_f1': 0.7363795638084412}]