In [2]:
import torch
import torchvision
import lightning
from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.datasets import Cityscapes
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
DATA = '/home/gtangg12/auto-augment/data/cityscapes'

In [4]:
class MaskRCNNModule(lightning.LightningModule):
    """
    """
    def __init__(self, num_classes):
        super().__init__()
        self.model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

        # Replace the classifier and mask predictor
        in_features      = self.model.roi_heads.box_predictor.cls_score  .in_features
        in_features_mask = self.model.roi_heads.mask_predictor.conv5_mask.in_channels
        self.model.roi_heads.box_predictor  = FastRCNNPredictor(in_features          , num_classes)
        self.model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, 256, num_classes)

    def forward(self, images, targets=None):
        return self.model(images, targets)

    def training_step(self, batch, batch_idx):
        images, targets = batch
        loss_dict = self(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        return loss

    def validation_step(self, batch, batch_idx):
        images, targets = batch
        loss_dict = self(images, targets)
        loss = sum(loss for loss in loss_dict.values())
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0005)
        return optimizer

In [7]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std =[0.229, 0.224, 0.225])
])
transform_target = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])

train_dataset, val_dataset = \
    Cityscapes(DATA, split='train', mode='fine', target_type='semantic', transform=transform, target_transform=transform_target), \
    Cityscapes(DATA, split='val'  , mode='fine', target_type='semantic', transform=transform, target_transform=transform_target)

In [10]:
for x, y in train_dataset:
    print(torch.unique(y), y.shape)
    break

tensor([0.0039, 0.0078, 0.0118, 0.0157, 0.0196, 0.0275, 0.0314, 0.0431, 0.0667,
        0.0706, 0.0784, 0.0824, 0.0902, 0.0941, 0.1020])


In [15]:
train_loader, val_loader = \
    DataLoader(train_dataset, batch_size=8, shuffle=True , num_workers=4), \
    DataLoader(val_dataset  , batch_size=8, shuffle=False, num_workers=4)

In [16]:
model = MaskRCNNModule(num_classes=21)
trainer = lightning.Trainer(accelerator='gpu', max_epochs=30)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type     | Params
-----------------------------------
0 | model | MaskRCNN | 44.0 M
-----------------------------------
43.8 M    Trainable params
222 K     Non-trainable params
44.0 M    Total params
176.099   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/gtangg12/mambaforge/envs/auto-augment/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/gtangg12/mambaforge/envs/auto-augment/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/home/gtangg12/mambaforge/envs/auto-augment/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/home/gtangg12/mambaforge/envs/auto-augment/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 142, in collate
    return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
  File "/home/gtangg12/mambaforge/envs/auto-augment/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 142, in <listcomp>
    return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
  File "/home/gtangg12/mambaforge/envs/auto-augment/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 150, in collate
    raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.PngImagePlugin.PngImageFile'>
