In [22]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.transforms import ToTensor, Resize, Compose
from torchvision.datasets import VOCSegmentation

import pytorch_lightning as pl
from pl_bolts.models.vision.unet import UNet

import matplotlib.pyplot as plt

from argparse import ArgumentParser
import pdb

In [8]:
torch.cuda.is_available()

True

In [9]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [35]:
class VOCDataModule(pl.LightningDataModule):
    def __init__(self, data_dir="./dataset", batch_size=4) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.voc_test = VOCSegmentation(self.data_dir, year="2007", image_set="test", download=True, transform=Compose([ToTensor(), Resize((520, 520))]), target_transform=Compose([ToTensor(), Resize((520, 520))]))
        self.voc_val = VOCSegmentation(self.data_dir, year="2007", image_set="val", download=True, transform=Compose([ToTensor(), Resize((520, 520))]), target_transform=Compose([ToTensor(), Resize((520, 520))]))
        self.voc_train = VOCSegmentation(self.data_dir, year="2007", image_set="train", download=True, transform=Compose([ToTensor(), Resize((520, 520))]), target_transform=Compose([ToTensor(), Resize((520, 520))]))

    def train_dataloader(self):
        return DataLoader(self.voc_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.voc_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.voc_test, batch_size=self.batch_size)

In [11]:
# voc = VOCDataModule()

In [12]:
# for a, b in voc.train_dataloader():
#    plt.imshow(a[2].transpose(0, 2).transpose(0, 1))
#    plt.imshow(b[2].transpose(0, 2).transpose(0, 1), alpha=0.5)
#    break

In [40]:
class SemSegment(pl.LightningModule):
    def __init__(
        self,
        lr: float = 0.01,
        num_classes: int = 19,
        num_layers: int = 4,
        features_start: int = 64,
        bilinear: bool = False,
    ):
        """Basic model for semantic segmentation. Uses UNet architecture by default.

        The default parameters in this model are for the KITTI dataset. Note, if you'd like to use this model as is,
        you will first need to download the KITTI dataset yourself. You can download the dataset `here.
        <http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015>`_

        Implemented by:

            - `Annika Brundyn <https://github.com/annikabrundyn>`_

        Args:
            num_layers: number of layers in each side of U-net (default 5)
            features_start: number of features in first layer (default 64)
            bilinear: whether to use bilinear interpolation (True) or transposed convolutions (default) for upsampling.
            lr: learning (default 0.01)
        """
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = num_layers
        self.features_start = features_start
        self.bilinear = bilinear
        self.lr = lr

        self.net = UNet(
            num_classes=num_classes,
            num_layers=self.num_layers,
            features_start=self.features_start,
            bilinear=self.bilinear,
        )

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_nb):
        img, mask = batch
        img = img.float()
        mask = mask.long()
        print(img.shape)
        print(mask.shape)
        out = self(img)
        loss_val = F.cross_entropy(out, mask, ignore_index=250)
        log_dict = {"train_loss": loss_val}
        return {"loss": loss_val, "log": log_dict, "progress_bar": log_dict}

    def validation_step(self, batch, batch_idx):
        img, mask = batch
        img = img.float()
        mask = mask.long()
        mask = mask.squeeze(1)
        out = self(img)
        loss_val = F.cross_entropy(out, mask, ignore_index=250)
        return {"val_loss": loss_val}

    def validation_epoch_end(self, outputs):
        loss_val = torch.stack([x["val_loss"] for x in outputs]).mean()
        log_dict = {"val_loss": loss_val}
        return {"log": log_dict, "val_loss": log_dict["val_loss"], "progress_bar": log_dict}

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.net.parameters(), lr=self.lr)
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
        return [opt], [sch]

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument("--lr", type=float, default=0.01, help="adam: learning rate")
        parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net")
        parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer")
        parser.add_argument(
            "--bilinear", action="store_true", default=False, help="whether to use bilinear interpolation or transposed"
        )

        return parser


In [41]:
voc = VOCDataModule()

In [42]:
model = SemSegment(num_classes=20)

In [44]:
trainer = pl.Trainer(max_epochs=10, accelerator="cpu", devices=1)
trainer.fit(model, voc)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Missing logger folder: /home/lijm1358/lightning_logs


Using downloaded and verified file: ./dataset/VOCtest_06-Nov-2007.tar
Extracting ./dataset/VOCtest_06-Nov-2007.tar to ./dataset
Using downloaded and verified file: ./dataset/VOCtrainval_06-Nov-2007.tar
Extracting ./dataset/VOCtrainval_06-Nov-2007.tar to ./dataset
Using downloaded and verified file: ./dataset/VOCtrainval_06-Nov-2007.tar
Extracting ./dataset/VOCtrainval_06-Nov-2007.tar to ./dataset



  | Name | Type | Params
------------------------------
0 | net  | UNet | 7.7 M 
------------------------------
7.7 M     Trainable params
0         Non-trainable params
7.7 M     Total params
30.817    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

torch.Size([4, 3, 520, 520])
torch.Size([4, 1, 520, 520])


RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4