In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [1]:
%pip install lightning-flash

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Note: you may need to restart the kernel to use updated packages.


In [2]:
from flash.image import SemanticSegmentationData

dm = SemanticSegmentationData.from_folders(
    train_folder="img/google/img",
    train_target_folder="img/google/mask",
    val_split=0.1,
    image_size=(256, 256),
    num_classes=1,
)

ValueError: Found inconsistent files in input folder: img/google/img and mask folder: img/google/mask. All input files must have a corresponding mask file with the same name.

In [None]:
from flash.image import SemanticSegmentation

print(SemanticSegmentation.available_heads())
# ['deeplabv3', 'deeplabv3plus', 'fpn', ..., 'unetplusplus']

print(SemanticSegmentation.available_backbones("fpn"))
# ['densenet121', ..., 'xception'] # + 113 models

print(SemanticSegmentation.available_pretrained_weights("efficientnet-b0"))
# ['imagenet', 'advprop']

model = SemanticSegmentation(head="fpn", backbone="efficientnet-b0", pretrained="advprop", num_classes=dm.num_classes)

In [None]:
from flash import Trainer

trainer = Trainer(max_epochs=3)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
trainer.save_checkpoint("semantic_segmentation_model.pt")

In [None]:
import json
import torch
class NiiDataset(Dataset):
    def __init__(self, train: bool, data_dir: str):
        """
        Args:
            train: whether to use the training or the validation split
            data_dir: directory containing the data
        """
        with open(os.path.join(data_dir, 'dataset.json')) as f:
            content = json.load(f)['training']

            num_train_samples = int(len(content) * 0.9)

            # Split train data into training and validation,
            # since test data contains no ground truth
            if train:
                data = content[:num_train_samples]
            else:
                data = content[num_train_samples:]

            self.data = data
            self.data_dir = data_dir

    def __getitem__(self, item: int) -> dict:
        """
        Loads and Returns a single sample

        Args:
            item: index specifying which item to load

        Returns:
            dict: the loaded sample
        """
        sample = self.data[item]
        img = sitk.GetArrayFromImage(
            sitk.ReadImage(os.path.join(self.data_dir, sample['image'])))

        # add channel dim if necesary
        if img.ndim == 3:
            img = img[None]

        label = sitk.GetArrayFromImage(
            sitk.ReadImage(os.path.join(self.data_dir, sample['label'])))

        # convert multiclass to binary task by combining all positives
        label = label > 0

        # add channel dim if necessary
        if label.ndim == 3:
            label = label[None]
        return {'data': torch.from_numpy(img).float(),
                'label': torch.from_numpy(label).float()}

    def __len__(self) -> int:
        """
        Adds a length to the dataset

        Returns:
            int: dataset's length
        """
        return len(self.data)