In [3]:
import os
import numpy as np
import torch
from glob import glob
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from PIL import Image
from albumentations import Compose, HorizontalFlip, VerticalFlip
from albumentations.pytorch import ToTensorV2
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor

In [4]:
def splitter(images, random_state=0):
    valid_images = []
    
    ptidx_unique = np.unique(np.array([str(x).split("/")[-2] for x in images]))
    np.random.seed(random_state)
    valid_patients = np.random.choice(ptidx_unique, size=int(len(ptidx_unique)*0.20), replace=False)
    train_patients = np.setdiff1d(ptidx_unique, valid_patients)
    
    for patient in valid_patients:
        for path in images:
            if patient in path:
                valid_images.append(path)
    
    train_images = list(set(images).difference(set(valid_images)))

    return train_images, valid_images

In [2]:
class CBISDDSM_Dataset(Dataset):
    def __init__(self, images, transforms=None):
        self.images = images
        self.transform = transforms

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        mask_path = self.images[idx].replace("mammo", "mask")
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)

        if self.transforms:
            transformed = self.transforms(image=np.array(image), mask=np.array(mask).astype(np.uint8))
            image = transformed['image']
            mask = transformed['mask']

        # Use the feature extractor
        feature_extractor = SegformerImageProcessor()
        image = feature_extractor(images=image, segmentation_maps=mask)

        return {'image': image, 'mask': mask}

In [5]:
class CBISDDSMDataModule(pl.LightningDataModule):
    def __init__(self, image_dir: str, batch_size: int = 8, transform=None, val_transform=None):
        super().__init__()
        self.image_dir = image_dir
        self.batch_size = batch_size
        self.transform = transform
        self.val_transform = val_transform

    def setup(self, stage=None):
        images = glob(f"{image_dir}/**/*_mammo.png", recursive=True)
        train_images, val_images = splitter(images, random_state=123)

        self.train_dataset = CBISDDSM_Dataset(train_images, self.transform)
        self.val_dataset = CBISDDSM_Dataset(val_images, self.val_transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

In [None]:
class SegmentationModel(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=2e-4):
        super().__init__()

        self.segformer = SegformerForSemanticSegmentation.from_pretrained('name-of-the-model')
        self.learning_rate = learning_rate

    def forward(self, x):
        return self.segformer(x)

    def training_step(self, batch, batch_idx):
        x, y = batch['image'], batch['mask']
        logits = self(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch['image'], batch['mask']
        logits = self(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        self.log('val_loss', loss, on_step=True, on_epoch=True, logger=True)
        return {'val_loss': loss}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

# Define transformations
train_transforms = Compose([
    HorizontalFlip(p=0.5),
    VerticalFlip(p=0.5),
    ToTensorV2()
])

val_transforms = Compose([
    ToTensorV2()
])

# Initialize data module and model
data_module = CBISDDSMDataModule(image_dir='path/to/images',
                                 mask_dir='path/to/masks',
                                 batch_size=32,
                                 transform=train_transforms,
                                 val_transform=val_transforms)

model = SegmentationModel(num_classes=3)

# Initialize a trainer
trainer = pl.Trainer(max_epochs=10, gpus=1)

# Train the model
trainer.fit(model, datamodule=data_module)