In [None]:
from datetime import datetime
import os
from pathlib import Path
import tempfile
from glob import glob

import torch
from torch.utils.data import random_split, DataLoader
import monai
import pandas as pd
import torchio as tio
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import seaborn as sns

monai.config.print_config()

# Configurations and Tensorboard Setup

In [2]:
sns.set()
plt.rcParams["figure.figsize"] = 12, 8
monai.utils.set_determinism()

%load_ext tensorboard

# Setup Data Directory

In [3]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

C:\Users\LESC\AppData\Local\Temp\tmpufciwlvz


# Data

In [None]:
# Funções para ler imagens e adaptar o conjunto de dados Diaretdb1
def pil_loader(image_path,is_mask=False):
    with open(image_path, 'rb') as f:
        img = Image.open(f)
        h, w = img.size
        if not is_mask:
            return img.resize((h//2, w//2)).convert('RGB')
            # return img.convert('RGB')
        else:
            return img.resize((h//2, w//2)).convert('L')
            # return img.convert('L')


def create_dir(path:Path):
    if not path.exists():
        if not path.parent.exists():
            create_dir(path.parent)
        path.mkdir()


def adaptar_dataset(root_dir: Path, dir_fundus_imgs: Path, dir_groundtruths_imgs: Path, annotations_path: Path):
    """Com base nos arquivos de anotações .txt do dataset diaretdb1_v1.1, essa função cria uma divisão melhor das
    imagens em TESTSET e TRAINSET para facilitar futuras utilizações desses dados"""

    path_base = Path(root_dir/str(annotations_path.stem).upper())
    create_dir(path_base/dir_fundus_imgs.name)
    labels = pd.read_csv(annotations_path, header=None).sort_values(by=0, ascending=True)
    print("Nova pasta com fundoscopias criada.")
    for dir_masks in ['hardexudates', 'hemorrhages', 'redsmalldots', 'softexudates']:
        create_dir(path_base/'ddb1_groundtruth'/dir_masks)

        for label in labels[0]:
            # Salvar a imagem correspondente das anotações na pasta de fundoscopias:
            img_fundus = pil_loader(dir_fundus_imgs/label)
            img_fundus.save(path_base/dir_fundus_imgs.name/label)
            # Salvar a mascara:
            mask = pil_loader(dir_groundtruths_imgs/dir_masks/label)
            mask.save(path_base/'ddb1_groundtruth'/dir_masks/label)
        print("Nova pasta com mascaras de lesões criada.")

In [4]:
class MedicalDecathlonDataModule(pl.LightningDataModule):
    def __init__(self, task, batch_size, train_val_ratio):
        super().__init__()
        self.task = task
        self.batch_size = batch_size
        self.base_dir = root_dir
        self.dataset_dir = os.path.join(root_dir, task)
        self.train_val_ratio = train_val_ratio
        self.subjects = None
        self.test_subjects = None
        self.preprocess = None
        self.transform = None
        self.train_set = None
        self.val_set = None
        self.test_set = None

    def download_data(self):
        if not os.path.isdir(self.dataset_dir):
            url = "https://www.it.lut.fi/project/imageret/diaretdb1/diaretdb1_v_1_1.zip"
            monai.apps.download_and_extract(url, output_dir="./datasets")
        
        image_training_paths = sorted(glob(os.path.join(self.dataset_dir, "imagesTr", "*.nii*")))
        label_training_paths = sorted(glob(os.path.join(self.dataset_dir, "labelsTr", "*.nii*")))
        image_test_paths = sorted(glob(os.path.join(self.dataset_dir, "imagesTs", "*.nii*")))
        return image_training_paths, label_training_paths, image_test_paths

    def prepare_data(self):
        image_training_paths, label_training_paths, image_test_paths = self.download_data()

        self.subjects = []
        for image_path, label_path in zip(image_training_paths, label_training_paths):
            # 'image' and 'label' are arbitrary names for the images
            subject = tio.Subject(image=tio.ScalarImage(image_path), label=tio.LabelMap(label_path))
            self.subjects.append(subject)

        self.test_subjects = []
        for image_path in image_test_paths:
            subject = tio.Subject(image=tio.ScalarImage(image_path))
            self.test_subjects.append(subject)

    def get_preprocessing_transform(self):
        preprocess = tio.Compose(
            [
                tio.RescaleIntensity((-1, 1)),
                tio.CropOrPad(self.get_max_shape(self.subjects + self.test_subjects)),
                tio.EnsureShapeMultiple(8),  # for the U-Net
                tio.OneHot(),
            ]
        )
        return preprocess

    def get_augmentation_transform(self):
        augment = tio.Compose(
            [
                tio.RandomAffine(),
                tio.RandomGamma(p=0.5),
                tio.RandomNoise(p=0.5),
                tio.RandomMotion(p=0.1),
                tio.RandomBiasField(p=0.25),
            ]
        )
        return augment

    def setup(self, stage=None):
        num_subjects = len(self.subjects)
        num_train_subjects = int(round(num_subjects * self.train_val_ratio))
        num_val_subjects = num_subjects - num_train_subjects
        splits = num_train_subjects, num_val_subjects
        train_subjects, val_subjects = random_split(self.subjects, splits)

        self.preprocess = self.get_preprocessing_transform()
        augment = self.get_augmentation_transform()
        self.transform = tio.Compose([self.preprocess, augment])

        self.train_set = tio.SubjectsDataset(train_subjects, transform=self.transform)
        self.val_set = tio.SubjectsDataset(val_subjects, transform=self.preprocess)
        self.test_set = tio.SubjectsDataset(self.test_subjects, transform=self.preprocess)

    def train_dataloader(self):
        return DataLoader(self.train_set, self.batch_size, num_workers=2)

    def val_dataloader(self):
        return DataLoader(self.val_set, self.batch_size, num_workers=2)

    def test_dataloader(self):
        return DataLoader(self.test_set, self.batch_size, num_workers=2)

In [5]:
data = MedicalDecathlonDataModule(
    task="Task04_Hippocampus",
    batch_size=16,
    train_val_ratio=0.8,
)

data.prepare_data()
data.setup()
print("Training:  ", len(data.train_set))
print("Validation: ", len(data.val_set))
print("Test:      ", len(data.test_set))

Task04_Hippocampus.tar: 27.1MB [00:15, 1.88MB/s]                                                                                                                                       

2023-03-23 14:09:32,273 - INFO - Downloaded: C:\Users\LESC\AppData\Local\Temp\tmpp7lydt4n\Task04_Hippocampus.tar





2023-03-23 14:09:32,274 - INFO - Expected md5 is None, skip md5 check for file C:\Users\LESC\AppData\Local\Temp\tmpp7lydt4n\Task04_Hippocampus.tar.
2023-03-23 14:09:32,275 - INFO - Writing into directory: C:\Users\LESC\AppData\Local\Temp\tmpufciwlvz.
Training:   208
Validation:  52
Test:       130


# Lightning model

In [6]:
class Model(pl.LightningModule):
    def __init__(self, net, criterion, learning_rate, optimizer_class):
        super().__init__()
        self.lr = learning_rate
        self.net = net
        self.criterion = criterion
        self.optimizer_class = optimizer_class

    def configure_optimizers(self):
        optimizer = self.optimizer_class(self.parameters(), lr=self.lr)
        return optimizer

    def prepare_batch(self, batch):
        return batch["image"][tio.DATA], batch["label"][tio.DATA]

    def infer_batch(self, batch):
        x, y = self.prepare_batch(batch)
        y_hat = self.net(x)
        return y_hat, y

    def training_step(self, batch, batch_idx):
        y_hat, y = self.infer_batch(batch)
        loss = self.criterion(y_hat, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        y_hat, y = self.infer_batch(batch)
        loss = self.criterion(y_hat, y)
        self.log("val_loss", loss)
        return loss

In [7]:
unet = monai.networks.nets.UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(8, 16, 32, 64),
    strides=(2, 2, 2),
)

model = Model(
    net=unet,
    criterion=monai.losses.DiceCELoss(softmax=True),
    learning_rate=1e-2,
    optimizer_class=torch.optim.AdamW,
)
early_stopping = pl.callbacks.early_stopping.EarlyStopping(
    monitor="val_loss",
)
trainer = pl.Trainer(
    gpus=0,
    precision='bf16',
    callbacks=[early_stopping],
)
trainer.logger._default_hp_metric = False

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


# Training

In [None]:
start = datetime.now()
print("Training started at", start)
trainer.fit(model=model, datamodule=data)
print("Training duration:", datetime.now() - start)

In [None]:
%tensorboard --logdir lightning_logs

# Plot validation results

In [13]:
model.to("cpu")
all_dices = []
get_dice = monai.metrics.DiceMetric(include_background=False, reduction="none")
with torch.no_grad():
    for batch in data.val_dataloader():
        inputs, targets = model.prepare_batch(batch)
        logits = model.net(inputs.to(model.device))
        labels = logits.argmax(dim=1)
        labels_one_hot = torch.nn.functional.one_hot(labels).permute(0, 4, 1, 2, 3)
        get_dice(labels_one_hot.to(model.device), targets.to(model.device))
    metric = get_dice.aggregate()
    get_dice.reset()
    all_dices.append(metric)
all_dices = torch.cat(all_dices)

In [None]:
records = []
for ant, post in all_dices:
    records.append({"Dice": ant, "Label": "Anterior"})
    records.append({"Dice": post, "Label": "Posterior"})
df = pd.DataFrame.from_records(records)
ax = sns.stripplot(x="Label", y="Dice", data=df, size=10, alpha=0.5)
ax.set_title("Dice scores")

# Test

In [15]:
with torch.no_grad():
    for batch in data.test_dataloader():
        inputs = batch["image"][tio.DATA].to(model.device)
        labels = model.net(inputs).argmax(dim=1, keepdim=True).cpu()
        break
batch_subjects = tio.utils.get_subjects_from_batch(batch)
tio.utils.add_images_from_batch(batch_subjects, labels, tio.LabelMap)

In [None]:
for subject in batch_subjects:
    subject.plot()