In [None]:
from pathlib import Path

import torchio as tio
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import matplotlib.pyplot as plt
import numpy as np

import logging
logging.getLogger().handlers.clear()

from unet import UNet

In [None]:
def change_img_to_label_path(path):
    parts = list(path.parts)
    parts[parts.index("imagesTr")] = "labelsTr"
    return Path(*parts)

In [None]:
path = Path("./Task06_Lung/imagesTr/")
subjects_paths = list(path.glob("Lung*"))

In [None]:
subjects = []

for subject_path in subjects_paths:
    label_path = change_img_to_label_path(subject_path)
    subject = tio.Subject(CT = tio.ScalarImage(subject_path),
                          Label = tio.LabelMap(label_path))
    subjects.append(subject)

In [None]:
for subject in subjects:
    assert subject["CT"].orientation == ("L", "A", "S")

In [None]:
subjects[0]["Label"]["data"].shape

In [None]:
preprocess = tio.Compose([
    tio.ToCanonical(),
    tio.CropOrPad((500, 500, 300)),
    tio.RescaleIntensity((-1,1))
])

augmentation = tio.RandomAffine(scales=(0.9, 1.1), degrees=(-10, 10))

train_transformation = tio.Compose([preprocess, augmentation])
val_transformation = preprocess

In [None]:
train_dataset = tio.SubjectsDataset(subjects[:53], transform=train_transformation)
val_dataset = tio.SubjectsDataset(subjects[53:], transform=val_transformation)

sampler = tio.data.LabelSampler(patch_size=96, label_name="Label", label_probabilities={0:0.00001, 1:0.99999})

In [None]:
train_patches_queue = tio.Queue(
        train_dataset,
        max_length=40,
        samples_per_volume=5,
        sampler=sampler,
        num_workers=8)

val_patches_queue = tio.Queue(
        val_dataset,
        max_length=40,
        samples_per_volume=5,
        sampler=sampler,
        num_workers=8)

In [None]:
batch_size = 12
train_loader = torch.utils.data.DataLoader(train_patches_queue, batch_size=batch_size, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_patches_queue, batch_size=batch_size, num_workers=0)

In [None]:
class DoubleConv(torch.nn.Module):
    """
    Helper Class which implements the intermediate Convolutions
    """
    def __init__(self, in_channels, out_channels):
        
        super().__init__()
        self.step = torch.nn.Sequential(torch.nn.Conv3d(in_channels, out_channels, 3, padding=1),
                                        torch.nn.ReLU(),
                                        torch.nn.Conv3d(out_channels, out_channels, 3, padding=1),
                                        torch.nn.ReLU())
        
    def forward(self, X):
        return self.step(X)

In [None]:
class UNet(torch.nn.Module):
    """
    This class implements a UNet for the Segmentation
    We use 3 down- and 3 UpConvolutions and two Convolutions in each step
    """

    def __init__(self):
        """Sets up the U-Net Structure
        """
        super().__init__()
        
        
        ############# DOWN #####################
        self.layer1 = DoubleConv(1, 32)
        self.layer2 = DoubleConv(32, 64)
        self.layer3 = DoubleConv(64, 128)
        self.layer4 = DoubleConv(128, 256)

        #########################################

        ############## UP #######################
        self.layer5 = DoubleConv(256 + 128, 128)
        self.layer6 = DoubleConv(128+64, 64)
        self.layer7 = DoubleConv(64+32, 32)
        self.layer8 = torch.nn.Conv3d(32, 1, 1)  
        #########################################

        self.maxpool = torch.nn.MaxPool3d(2)

    def forward(self, x):
        
        ####### DownConv 1#########
        x1 = self.layer1(x)
        x1m = self.maxpool(x1)
        ###########################
        
        ####### DownConv 2#########        
        x2 = self.layer2(x1m)
        x2m = self.maxpool(x2)
        ###########################

        ####### DownConv 3#########        
        x3 = self.layer3(x2m)
        x3m = self.maxpool(x3)
        ###########################
        
        ##### Intermediate Layer ## 
        x4 = self.layer4(x3m)
        ###########################

        ####### UpCONV 1#########        
        x5 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x4)  # Upsample with a factor of 2
        x5 = torch.cat([x5, x3], dim=1)  # Skip-Connection
        x5 = self.layer5(x5)
        ###########################

        ####### UpCONV 2#########        
        x6 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x5)        
        x6 = torch.cat([x6, x2], dim=1)  # Skip-Connection    
        x6 = self.layer6(x6)
        ###########################
        
        ####### UpCONV 3#########        
        x7 = torch.nn.Upsample(scale_factor=2, mode="trilinear")(x6)
        x7 = torch.cat([x7, x1], dim=1)       
        x7 = self.layer7(x7)
        ###########################
        
        ####### Predicted segmentation#########        
        ret = self.layer8(x7)
        return ret

In [None]:
class DiceLoss(torch.nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, pred, mask, smooth=1):
        #flatten label and prediction tensors
        pred = pred.view(-1)
        mask = mask.view(-1)
        
        intersection = (pred * mask).sum()                            
        dice = (2.*intersection + smooth)/(pred.sum() + mask.sum() + smooth)  
        
        return 1 - dice

In [None]:
class Segmenter(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.model = UNet()
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr = 1e-6)
        self.loss_fn = DiceLoss()
        
    def forward(self, data):
        return self.model(data)
    
    def training_step(self, batch, batch_idx):
        img = batch["CT"]["data"]
        label = batch["Label"]["data"]
        label = label.long()
        
        pred = self(img)
        
        loss = self.loss_fn(pred, label)
        
        self.log("Train Dice", loss)
        if batch_idx % 25 == 0:
            self.log_images(img.cpu(), pred.cpu(), label.cpu(), "Train")
        return loss
    
    def validation_step(self, batch, batch_idx):
        img = batch["CT"]["data"]
        label = batch["Label"]["data"]
        label = label.long()
        
        pred = self(img)
        
        loss = self.loss_fn(pred, label)
        
        self.log("Val Dice", loss)
        self.log_images(img.cpu(), pred.cpu(), label.cpu(), "Val")
        return loss
    
    def log_images(self, img, pred, label, name):
        pred = pred >0.5
        axial_slice = 50
        fig, axis = plt.subplots(1,2)
        axis[0].imshow(img[0][0][:,:,axial_slice], cmap="bone")
        label_ = np.ma.masked_where(label[0][0][:,:,axial_slice]==0, label[0][0][:,:,axial_slice])
        axis[0].imshow(label_, alpha=0.5)
        
        axis[1].imshow(img[0][0][:,:,axial_slice], cmap="bone")
        label_ = np.ma.masked_where(label[0][0][:,:,axial_slice]==0, label[0][0][:,:,axial_slice])
        axis[1].imshow(label_, alpha=0.5)
        
        self.logger.experiment.add_figure(name, fig, self.global_step)
        
    def configure_optimizers(self):
        return [self.optimizer]

In [None]:
model = Segmenter()

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor="Val Dice",
    save_top_k=1,
    mode="min")

early_stop = EarlyStopping(
    monitor="Val Dice",
    mode="min",
    patience=20,
    min_delta=0.001)

In [None]:
gpus=1
trainer = pl.Trainer(gpus=gpus, logger=TensorBoardLogger(save_dir="./logs"), log_every_n_steps=1,
                     callbacks=[checkpoint_callback, early_stop], max_epochs=300)

In [None]:
trainer.fit(model, train_loader, val_loader)

In [None]:
from IPython.display import HTML
from celluloid import Camera

In [None]:
model2 = Segmenter.load_from_checkpoint(r"./logs/lightning_logs/version_9/checkpoints/epoch=10-step=374.ckpt")
model2 = model2.eval()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model2.to(device);

In [None]:
IDX = 2
mask = val_dataset[IDX]["Label"]["data"]
imgs = val_dataset[IDX]["CT"]["data"]

# GridSampler
grid_sampler = tio.inference.GridSampler(val_dataset[IDX], 96, (8, 8, 8))

In [None]:
aggregator = tio.inference.GridAggregator(grid_sampler)

In [None]:
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=4)

In [None]:
with torch.no_grad():
    for patches_batch in patch_loader:
        input_tensor = patches_batch['CT']["data"].to(device)  # Get batch of patches
        locations = patches_batch[tio.LOCATION]  # Get locations of patches
        pred = model2(input_tensor)  # Compute prediction
        aggregator.add_batch(pred, locations)  # Combine predictions to volume

In [None]:
output_tensor = aggregator.get_output_tensor()  

In [None]:
fig = plt.figure()
camera = Camera(fig)  # create the camera object from celluloid
pred = output_tensor.argmax(0)

for i in range(0, output_tensor.shape[3], 2):  # axial view
    plt.imshow(imgs[0,:,:,i], cmap="bone")
    #mask_ = np.ma.masked_where(pred[:,:,i]==0, pred[:,:,i])
    label_mask = np.ma.masked_where(mask[0,:,:,i]==0, mask[0,:,:,i])
    plt.imshow(pred[:,:,i], alpha=0.5, cmap="autumn")
    #plt.imshow(label_mask, alpha=0.5, cmap="jet")  # Uncomment if you want to see the label

    # plt.axis("off")
    camera.snap()  # Store the current slice
animation = camera.animate()  # create the animation

In [None]:
HTML(animation.to_html5_video()) 