In [1]:
import os
import sys
import tempfile
import shutil
from glob import glob
import logging
import nibabel as nib
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import NiftiDataset, create_test_image_3d
from monai.inferers import sliding_window_inference
from monai.metrics import compute_meandice
from monai.visualize.img2tensorboard import plot_2d_or_3d_image
from monai.transforms import \
    Compose, AddChannel, LoadNifti, \
    ScaleIntensity, RandSpatialCrop, \
    ToTensor, CastToType, SpatialPad

monai.config.print_config()

MONAI version: 0.1.0+460.g82a7933
Python version: 3.7.4 (default, Jul 18 2019, 19:34:02)  [GCC 5.4.0]
Numpy version: 1.18.1
Pytorch version: 1.5.0

Optional dependencies:
Pytorch Ignite version: 0.3.0
Nibabel version: 3.1.0
scikit-image version: 0.14.2
Pillow version: 7.0.0
Tensorboard version: 2.1.0


In [None]:
data_dir = '/home/marafath/scratch/eu_data'
import matplotlib.pyplot as plt
eu_labels = []

for case in os.listdir(data_dir):
    img = nib.load(os.path.join(data_dir,case,'image.nii.gz'))
    img = img.get_fdata()
    seg = nib.load(os.path.join(data_dir,case,'segmentation.nii.gz'))
    seg = seg.get_fdata()
    seg[seg > 6] = 1
    
    if np.max(seg) == 6:
        eu_labels.append(1)
    else:
        eu_labels.append(0)
        
    seg[seg > 0] = 1
    img_masked = np.multiply(img, seg)

    '''
    plt.figure('check', (18, 6))
    plt.subplot(1, 2, 1)
    plt.title('image')
    plt.imshow(img_masked[:, :, 50], cmap='gray')
    plt.subplot(1, 2, 2)
    plt.title('label')
    plt.imshow(seg[:, :, 50])
    plt.show()
    '''
    
    img_masked = nib.Nifti1Image(img_masked, np.eye(4))
    nib.save(img_masked,os.path.join(data_dir,case,'image_masked.nii.gz')) 

labels = np.asarray(eu_labels,np.int64)  
np.save('eu_labels.npy', labels)

In [15]:
print(len(labels))
print(np.sum(labels))

96
52


In [16]:
print(labels)

[1 1 0 1 1 0 1 1 1 0 1 1 1 0 0 0 0 1 1 0 1 0 0 0 0 0 1 1 1 0 1 0 1 1 1 0 1
 0 0 0 1 1 1 1 1 0 1 0 0 0 0 1 0 1 0 0 1 0 1 1 1 1 0 0 1 1 1 0 1 1 0 1 0 0
 1 1 1 0 0 1 1 1 0 0 0 0 1 0 1 1 0 0 1 0 1 1]


In [None]:
data_dir = '/home/marafath/scratch/eu_data'

covid = 0
noncovid = 0
for case in os.listdir(data_dir):
    s = nib.load(os.path.join(data_dir,case,'segmentation.nii.gz'))
    s = s.get_fdata()
    if np.max(s) == 6.0:
        covid += 1
    else:
        noncovid += 1

print('Covid: {}'.format(covid))
print('nonCovid: {}'.format(noncovid))

In [None]:
data_dir = '/home/marafath/scratch/eu_data'
# import matplotlib.pyplot as plt

for case in os.listdir(data_dir):
    seg = nib.load(os.path.join(data_dir,case,'segmentation.nii.gz'))
    seg = seg.get_fdata()
    seg[seg == 6] = 0
    # img = seg
    # img = img - 100.5440 # Subtracting Mean

    '''
    plt.figure('check', (18, 6))
    plt.title('Image')
    plt.imshow(img[:, :, 50], cmap='gray')
    print(np.max(img))
    print(np.min(img))
    plt.show()
    '''
    
    seg = nib.Nifti1Image(seg, np.eye(4))
    nib.save(seg,os.path.join(data_dir,case,'segmentation_no_infection.nii.gz')) 
    

In [None]:
# Supervised learning data for training and validation
data_dir = '/home/marafath/scratch/iran_organized_data/test'

test_images = []
test_labels = []
test_dir = []

for patient in os.listdir(data_dir):
    for series in os.listdir(os.path.join(data_dir,patient)):
        test_images.append(os.path.join(data_dir,patient,series,'image.nii.gz'))
        test_labels.append(os.path.join(data_dir,patient,series,'segmentation_lobes.nii.gz'))
        test_dir.append(os.path.join(data_dir,patient,series))

In [None]:
class CrossentropyND(torch.nn.CrossEntropyLoss):
    """
    Network has to have NO NONLINEARITY!
    """
    def forward(self, inp, target):
        target = target.long()
        num_classes = inp.size()[1]

        i0 = 1
        i1 = 2

        while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
            inp = inp.transpose(i0, i1)
            i0 += 1
            i1 += 1

        inp = inp.contiguous()
        inp = inp.view(-1, num_classes)

        target = target.view(-1,)

        return super(CrossentropyND, self).forward(inp, target)
    
class SoftDiceLoss(nn.Module):
    def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.):
        """
        """
        super(SoftDiceLoss, self).__init__()

        self.do_bg = do_bg
        self.batch_dice = batch_dice
        self.apply_nonlin = apply_nonlin
        self.smooth = smooth

    def forward(self, x, y, loss_mask=None):
        shp_x = x.shape

        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        if self.apply_nonlin is not None:
            x = self.apply_nonlin(x)

        tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False)

        nominator = 2 * tp + self.smooth
        denominator = 2 * tp + fp + fn + self.smooth

        dc = nominator / denominator

        if not self.do_bg:
            if self.batch_dice:
                dc = dc[1:]
            else:
                dc = dc[:, 1:]
        dc = dc.mean()

        return -dc
    
def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False):
    """
    net_output must be (b, c, x, y(, z)))
    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
    if mask is provided it must have shape (b, 1, x, y(, z)))
    :param net_output:
    :param gt:
    :param axes: can be (, ) = no summation
    :param mask: mask must be 1 for valid pixels and 0 for invalid pixels
    :param square: if True then fp, tp and fn will be squared before summation
    :return:
    """
    if axes is None:
        axes = tuple(range(2, len(net_output.size())))

    shp_x = net_output.shape
    shp_y = gt.shape

    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            gt = gt.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = gt
        else:
            gt = gt.long()
            y_onehot = torch.zeros(shp_x)
            if net_output.device.type == "cuda":
                y_onehot = y_onehot.cuda(net_output.device.index)
            y_onehot.scatter_(1, gt, 1)

    tp = net_output * y_onehot
    fp = net_output * (1 - y_onehot)
    fn = (1 - net_output) * y_onehot
    tn = (1 - net_output) * (1 - y_onehot)

    if mask is not None:
        tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
        fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
        fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
        tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1)

    if square:
        tp = tp ** 2
        fp = fp ** 2
        fn = fn ** 2
        tn = tn ** 2

    if len(axes) > 0:
        tp = sum_tensor(tp, axes, keepdim=False)
        fp = sum_tensor(fp, axes, keepdim=False)
        fn = sum_tensor(fn, axes, keepdim=False)
        tn = sum_tensor(tn, axes, keepdim=False)

    return tp, fp, fn, tn

class DC_and_CE_loss(nn.Module):
    def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum", square_dice=False, weight_ce=1, weight_dice=1):
        """
        CAREFUL. Weights for CE and Dice do not need to sum to one. You can set whatever you want.
        :param soft_dice_kwargs:
        :param ce_kwargs:
        :param aggregate:
        :param square_dice:
        :param weight_ce:
        :param weight_dice:
        """
        super(DC_and_CE_loss, self).__init__()
        self.weight_dice = weight_dice
        self.weight_ce = weight_ce
        self.aggregate = aggregate
        self.ce = CrossentropyND(**ce_kwargs)
        self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs)

    def forward(self, net_output, target):
        dc_loss = self.dc(net_output, target) if self.weight_dice != 0 else 0
        ce_loss = self.ce(net_output, target) if self.weight_ce != 0 else 0
        if self.aggregate == "sum":
            result = self.weight_ce * ce_loss + self.weight_dice * dc_loss
        else:
            raise NotImplementedError("nah son") # reserved for other stuff (later)
        return result

def softmax_helper(x):
    rpt = [1 for _ in range(len(x.size()))]
    rpt[1] = x.size(1)
    x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
    e_x = torch.exp(x - x_max)
    return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)

def sum_tensor(inp, axes, keepdim=False):
    axes = np.unique(axes).astype(int)
    if keepdim:
        for ax in axes:
            inp = inp.sum(int(ax), keepdim=True)
    else:
        for ax in sorted(axes, reverse=True):
            inp = inp.sum(int(ax))
    return inp

In [None]:
data_dir = '/home/marafath/scratch/3d_seg_ct'
train_images = []
train_labels = []

val_images = []
val_labels = []

for case in os.listdir(data_dir):
    if case == ".ipynb_checkpoints":
        continue   
    '''
    if int(case[2:4]) < 65:
        train_images.append(os.path.join(data_dir,case,'image.nii.gz'))
        train_labels.append(os.path.join(data_dir,case,'label.nii.gz'))
    else:
        val_images.append(os.path.join(data_dir,case,'image.nii.gz'))
        val_labels.append(os.path.join(data_dir,case,'label.nii.gz')) 
    '''    
    if int(case[2:4]) < 17 or int(case[2:4]) > 32:
        train_images.append(os.path.join(data_dir,case,'image.nii.gz'))
        train_labels.append(os.path.join(data_dir,case,'label.nii.gz'))
    else:
        val_images.append(os.path.join(data_dir,case,'image.nii.gz'))
        val_labels.append(os.path.join(data_dir,case,'label.nii.gz'))

In [None]:
data_dir = '/home/marafath/scratch/eu_data'

images = []
labels = []

for case in os.listdir(data_dir):
    images.append(os.path.join(data_dir,case,'image.nii.gz'))
    labels.append(os.path.join(data_dir,case,'segmentation.nii.gz'))

val_images = images[:24]
val_labels = labels[:24]

train_images = images[24:]
train_labels = labels[24:]

In [None]:
train_imtrans = Compose([
    ScaleIntensity(),
    AddChannel(),
    CastToType(), 
    RandSpatialCrop((96, 96, 96), random_size=False),
    SpatialPad((96, 96, 96), mode='constant'),
    ToTensor()
])
train_segtrans = Compose([
    AddChannel(),
    CastToType(), 
    RandSpatialCrop((96, 96, 96), random_size=False),
    SpatialPad((96, 96, 96), mode='constant'),
    ToTensor()
])
val_imtrans = Compose([
    ScaleIntensity(),
    AddChannel(),
    CastToType(),
    SpatialPad((96, 96, 96), mode='constant'),
    ToTensor()
])
val_segtrans = Compose([
    AddChannel(),
    CastToType(),
    SpatialPad((96, 96, 96), mode='constant'),
    ToTensor()
])

In [None]:
# Defining Transform
test_imtrans = Compose([
    ScaleIntensity(),
    AddChannel(),
    CastToType(),
    SpatialPad((96, 96, 96), mode='constant'),
    ToTensor()
])
test_segtrans = Compose([
    AddChannel(),
    CastToType(),
    SpatialPad((96, 96, 96), mode='constant'),
    ToTensor()
])

# create a validation data loader
test_ds = NiftiDataset(test_images, test_labels, transform=test_imtrans, seg_transform=test_segtrans)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())

In [None]:
import matplotlib.pyplot as plt
ds = NiftiDataset(val_images, val_labels, transform=val_imtrans, seg_transform=val_segtrans)
loader = DataLoader(ds, batch_size=1, num_workers=2, pin_memory=torch.cuda.is_available())
im, seg = monai.utils.misc.first(loader)
im = np.squeeze(im)
seg = np.squeeze(seg)
print('image shape: {}, label shape: {}'.format(im.shape, seg.shape))
sl = 54
plt.figure('check', (12, 6))
plt.subplot(1, 2, 1)
plt.title('image')
plt.imshow(im[:, :, sl], cmap='gray')
plt.subplot(1, 2, 2)
plt.title('segmentation')
plt.imshow(seg[:, :, sl])
plt.show()

In [None]:
# create a training data loader
train_ds = NiftiDataset(train_images, train_labels, transform=train_imtrans, seg_transform=train_segtrans)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())

# create a validation data loader
val_ds = NiftiDataset(val_images, val_labels, transform=val_imtrans, seg_transform=val_segtrans)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())

In [None]:
# Defining model and hyperparameters
device = torch.device('cuda:0')
model = monai.networks.nets.UNet(
    dimensions=3,
    in_channels=1,
    out_channels=7,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2
).to(device)

loss_function = DC_and_CE_loss({'smooth': 1e-5, 'do_bg': False}, {})
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

In [None]:
epc = 15
val_interval = 1
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()
for epoch in range(epc):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{epc}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            metric_sum = 0.0
            metric_count = 0
            val_images = None
            val_labels = None
            val_outputs = None
            for val_data in val_loader:
                val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                roi_size = (160, 160, 160)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=False, 
                                         to_onehot_y=True, mutually_exclusive=True)
                metric_count += len(value)
                metric_sum += value.sum().item()
                print(metric_sum)
            print(metric_count)
            metric = metric_sum / metric_count
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), '/home/marafath/scratch/saved_models/UNet3D_eu_best_metric_model.pth')
                print("saved new best metric model")
            print(
                "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                    epoch + 1, metric, best_metric, best_metric_epoch
                )
            )
            writer.add_scalar("val_mean_dice", metric, epoch + 1)
            # plot the last model output as GIF image in TensorBoard with the corresponding image and label
            plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
            plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
            plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()

In [None]:
import matplotlib.pyplot as plt
i = 0
for val_data in test_loader:
    im = val_data[0]
    seg = val_data[1]

    im = im.cpu().detach().numpy()
    im = np.squeeze(im)

    seg = seg.cpu().detach().numpy()
    seg = np.squeeze(seg)
    
    sl = 50
    plt.figure('check', (18, 6))
    plt.subplot(1, 2, 1)
    plt.title('image ' + str(i))
    plt.imshow(im[:, :, sl], cmap='gray')
    print(im.shape)
    plt.subplot(1, 2, 2)
    plt.title('label ' + str(i))
    plt.imshow(seg[:, :, sl])
    print(seg.shape)
    plt.show()

    i += 1