In [None]:
!pip install pytorch_lightning

In [None]:
!pip install torchtext==0.6

In [None]:
pip install torchmetrics


In [None]:
pip install celluloid

In [None]:
%matplotlib notebook
from pathlib import Path
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from celluloid import Camera
from tqdm.notebook import tqdm
import cv2

In [None]:
from google.colab import drive
drive.mount("/content/drive",force_remount=True)

In [None]:
def change_img_to_label_path(path):
    """
    Replaces imagesTr with labelsTr
    """
    parts = list(path.parts)  # get all directories whithin the path
    parts[parts.index("imagesTr")] = "labelsTr"  # Replace imagesTr with labelsTr
    return Path(*parts)  # Combine list back into a Path object


In [None]:
# Inspect some sample data
root = Path("/content/drive/MyDrive/ML/Task06_Lung/Task06_Lung/imagesTr")
label = Path("/content/drive/MyDrive/ML/Task06_Lung/Task06_Lung/labelsTr/")

lung_paths = list(root.glob("lung*"))
print("Number of lung paths:", len(lung_paths))


sample_path = list(root.glob("lung*"))[9]  # Choose a subject
sample_path_label = change_img_to_label_path(sample_path)

print(sample_path)
print(sample_path_label)

In [None]:

# Load NIfTI and extract image data
data = nib.load(sample_path)
label = nib.load(sample_path_label)

ct = data.get_fdata()
mask = label.get_fdata()

print(mask.shape)

In [None]:

# Find out the orientation
nib.aff2axcodes(data.affine)

In [None]:
root = Path("/content/drive/MyDrive/ML/Task06_Lung/Task06_Lung/imagesTr")
label = Path("/content/drive/MyDrive/ML/Task06_Lung/Task06_Lung/labelsTr/")

all_files = list(root.glob("lung_*"))  # Get all subjects
all_files.sort()

print(all_files[3])
print(len(all_files))

In [None]:

# Create train directories for saving images and masks

save_root = Path("/content/drive/MyDrive/Preprocessed")

train_slice_path = save_root/"train"/"data"
train_mask_path = save_root/"train"/"masks"

train_slice_path.mkdir(parents=True, exist_ok=True)
train_mask_path.mkdir(parents=True, exist_ok=True)

In [None]:
all_data = []
all_lables = []

counter = 0 # for naming files

for _ , path_to_ct_data in enumerate(tqdm(all_files)):

    path_to_label = change_img_to_label_path(path_to_ct_data)  # Get path to ground truth

    # Load and extract corresponding data
    ct_data = nib.load(path_to_ct_data).get_fdata()
    label_data = nib.load(path_to_label).get_fdata()

    # Crop volume and label. Remove the first 30 slices
    ct_data = ct_data[:,:,30:] / 3071
    new_label_data = label_data[:,:,30:]

    # Loop over the slices in the full volume and store the data and labels in the data/masks directory
    # Save all filenames in all_data and whether it has tumor or not in all_lables
    for i in range(ct_data.shape[-1]):
        slice = ct_data[:,:,i]
        mask = new_label_data[:,:,i]

        # Resize slice and label to common resolution to reduce training time
        slice = cv2.resize(slice, (256, 256))
        mask = cv2.resize(mask, (256, 256), interpolation=cv2.INTER_NEAREST)

        # Generate name of slice and its corresponding mask
        slice_name = f'slice_000000{100000+counter}'
        mask_name = f'mask_000000{100000+counter}'

        # Save filenames and classification label of slice in for train_test spliting later
        all_data.append(np.array([slice_name,mask_name]))
        all_lables.append(mask.any())

        counter += 1
        np.save(train_slice_path/slice_name, slice, allow_pickle=True)
        np.save(train_mask_path/mask_name, mask, allow_pickle=True)

In [None]:
all_data_np = np.array(all_data)
all_lables_np = np.array(all_lables)

print(all_data_np.shape)
print(all_lables_np.shape)

# This shows distribution of classification label (number of non-tumor slices vs. tumor slices).
print(np.unique(all_lables_np, return_counts=True))

In [None]:
from sklearn.model_selection import train_test_split

# Splitting data to train and test
# If we activate stratify as input, the train and test datasets will have same
# distribution for tumor slices.
X_train, X_test, y_train, y_test = train_test_split(
                all_data_np,
                all_lables_np,
                test_size=0.2,
                random_state=13,
                stratify = all_lables_np,
                )

print("Shape of training set:", X_train.shape)
print("Shape of test set:", X_test.shape)

print(np.unique(y_train, return_counts=True))
print(np.unique(y_test, return_counts=True))

In [None]:

# Now that we know which slices belong to validation we can move them in their own directory

val_slice_path = save_root/"val"/"data"
val_mask_path = save_root/"val"/"masks"

val_slice_path.mkdir(parents=True, exist_ok=True)
val_mask_path.mkdir(parents=True, exist_ok=True)

In [None]:

for test in X_test:
    Path(rf'{train_slice_path}/{test[0]}.npy').replace(rf'{val_slice_path}/{test[0]}.npy')
    Path(rf'{train_mask_path}/{test[1]}.npy').replace(rf'{val_mask_path}/{test[1]}.npy')

In [None]:
tumor_slice_idxs = []
for i in range (y_test.shape[0]):
    if y_test[i]:
        tumor_slice_idxs.append(i)

print(tumor_slice_idxs)

In [None]:
test = X_test[2683]

slice_path = Path(val_slice_path/test[0])
mask_path = Path(val_mask_path/test[1])

# Choose a file and load slice + mask
slice = np.load(str(slice_path) + '.npy')
mask = np.load(str(mask_path) + '.npy')

print(slice.shape)
print(slice.min(), slice.max())

In [None]:
%matplotlib inline
fig, axis = plt.subplots(1, 2, figsize=(8, 8))
axis[0].imshow(slice, cmap="bone")
mask_ = np.ma.masked_where(mask==0, mask)
axis[1].imshow(slice, cmap="bone")
axis[1].imshow(mask_, cmap="autumn")

In [None]:

from pathlib import Path

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import imgaug
import imgaug.augmenters as iaa
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from celluloid import Camera
from imgaug.augmentables.segmaps import SegmentationMapsOnImage

In [None]:
class LungDataset(torch.utils.data.Dataset):
    def __init__(self, root, augment_params):
        self.all_files = self.extract_files(root)
        self.augment_params = augment_params

    @staticmethod
    def extract_files(root):
        """
        Extract the paths to all slices given the root path (ends with train or val)
        """
        files = []
        slice_path = root/"data"  # Get the slices for current subject
        for slice in slice_path.glob("*"):
            files.append(slice)
        return files

    @staticmethod
    def change_img_to_label_path(path):
        """
        Replace data with mask to get the masks
        """
        parts = list(path.parts)
        parts[-2] = "masks"
        parts[-1] = parts[-1].replace('slice','mask')
        return Path(*parts)

    def augment(self, slice, mask):
        """
        Augments slice and segmentation mask in the exact same way
        Note the manual seed initialization
        """
        ###################IMPORTANT###################
        random_seed = torch.randint(0, 1000000, (1,))[0].item()
        imgaug.seed(random_seed)
        #####################################################

        new_mask = mask.astype(bool)

        mask = SegmentationMapsOnImage(new_mask, new_mask.shape)
        slice_aug, mask_aug = self.augment_params(image=slice, segmentation_maps=mask)
        mask_aug = mask_aug.get_arr()
        return slice_aug, mask_aug

    def __len__(self):
        """
        Return the length of the dataset (length of all files)
        """
        return len(self.all_files)

    def __getitem__(self, idx):
        """
        Given an index return the (augmented) slice and corresponding mask
        Add another dimension for pytorch
        """
        file_path = self.all_files[idx]
        mask_path = self.change_img_to_label_path(file_path)
        slice = np.load(file_path)
        mask = np.load(mask_path)

        if self.augment_params:
            slice, mask = self.augment(slice, mask)
        return np.expand_dims(slice, 0), np.expand_dims(mask, 0)


In [None]:
import imgaug.augmenters as iaa
seq = iaa.Sequential([
    iaa.Affine(translate_percent=(0.15),
               scale=(0.85, 1.15), # zoom in or out
               rotate=(-45, 45)#
               ),  # rotate up to 45 degrees
    iaa.ElasticTransformation()  # Elastic Transformations
                ])

In [None]:
# Create the dataset objects
train_path = Path("/content/drive/MyDrive/Preprocessed/train")
val_path = Path("/content/drive/MyDrive/Preprocessed/val")

train_dataset = LungDataset(train_path, seq)
val_dataset = LungDataset(val_path, None)

print(f"There are {len(train_dataset)} train images and {len(val_dataset)} val images")


In [None]:
target_list = []

for _, label in tqdm(train_dataset):
    # Check if mask contains a tumorous pixel:
    if np.any(label):
        target_list.append(1)
    else:
        target_list.append(0)

In [None]:
# Calculate the weight for each class
uniques = np.unique(target_list, return_counts=True)
print("Class distribution:", uniques)

# Calculate the fraction (class imbalance ratio)
fraction = uniques[1][0] / uniques[1][1]
print("Class imbalance ratio:", fraction)

# Calculate weights
weights = [1.0, fraction]

print("Class weights:", weights)


In [None]:
# Create a list of weights based on the class labels
weight_list = [1.0 if target == 0 else fraction for target in target_list]

# Print the first 50 weights as an example
print("Example of weights:", weight_list[50:])


In [None]:
from torch.utils.data import DataLoader
# Select only the first 2000 samples
selected_indices = range(2000)

# Create a list of weights based on the class labels for the selected samples
selected_weight_list = [weight_list[i] for i in selected_indices]

# Create the sampler for the selected samples
sampler = torch.utils.data.sampler.WeightedRandomSampler(selected_weight_list, len(selected_weight_list))

# Assuming your DataLoader is named 'train_loader'
train_loader = DataLoader(dataset=train_dataset, batch_size=8, sampler=sampler, shuffle=False)


In [None]:
from torch.utils.data import Subset
batch_size = 8
num_workers = 2


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                           num_workers=num_workers, sampler=sampler)
subset_indices = range(300)  # Adjust this range based on your requirement

subset_val_dataset = Subset(val_dataset, subset_indices)
val_loader = torch.utils.data.DataLoader(subset_val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)

In [None]:
# We can verify that our sampler works by taking a batch from the train loader and count how many labels are larger than zero
verify_batch = next(iter(train_loader))  # Take one batch

# Assuming your labels are in the second element of the batch (modify if needed)
labels_in_batch = verify_batch[1]

# Count how many labels are larger than zero
count_positive_labels = (labels_in_batch > 0).sum().item()

print(f"Number of labels larger than zero: {count_positive_labels}")


In [None]:
# Print the shape of the labels in the verification batch
print(verify_sampler[1].shape)

# Check if any tumorous pixel is present in each label
verify_labels = np.any(np.array(verify_sampler[1]), axis=(1, 2, 3))
print(verify_labels)

# Extract and print the shape of a slice and its corresponding mask from the batch
slice = verify_sampler[0][1].squeeze()
mask = verify_sampler[1][1].squeeze()
print(slice.shape)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        return x

class UNet_3PlusModified(nn.Module):
    def __init__(self, in_channels=1, n_classes=1, feature_scale=4, is_deconv=True, is_batchnorm=True):
        super(UNet_3PlusModified, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        filters = [64, 128, 256, 512, 1024]

        # Encoder
        self.conv1 = DoubleConv(in_channels, filters[0])
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2 = DoubleConv(filters[0], filters[1])
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.conv3 = DoubleConv(filters[1], filters[2])
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)

        self.conv4 = DoubleConv(filters[2], filters[3])
        self.maxpool4 = nn.MaxPool2d(kernel_size=2)

        self.conv5 = DoubleConv(filters[3], filters[4])

        # Decoder
        self.CatChannels = filters[0]
        self.CatBlocks = 5
        self.UpChannels = self.CatChannels * self.CatBlocks

        self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear')
        self.hd5_UT_hd4_conv = DoubleConv(filters[4], self.CatChannels)

        self.conv4d_1 = DoubleConv(self.UpChannels, self.UpChannels)

        self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear')
        self.hd5_UT_hd3_conv = DoubleConv(filters[4], self.CatChannels)

        self.conv3d_1 = DoubleConv(self.UpChannels, self.UpChannels)

        self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear')
        self.hd5_UT_hd2_conv = DoubleConv(filters[4], self.CatChannels)

        self.conv2d_1 = DoubleConv(self.UpChannels, self.UpChannels)

        self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear')
        self.hd5_UT_hd1_conv = DoubleConv(filters[4], self.CatChannels)

        self.conv1d_1 = DoubleConv(self.UpChannels, self.UpChannels)

        # Output
        self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, kernel_size=3, padding=1)

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm2d):
                init_weights(m, init_type='kaiming')

    def forward(self, inputs):
        inputs = inputs.float()

        # Encoder
        h1 = self.conv1(inputs)
        h2 = self.maxpool1(h1)
        h2 = self.conv2(h2)
        h3 = self.maxpool2(h2)
        h3 = self.conv3(h3)
        h4 = self.maxpool3(h3)
        h4 = self.conv4(h4)
        h5 = self.maxpool4(h4)
        hd5 = self.conv5(h5)

        # Decoder
        h1_PT_hd4 = self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))
        h2_PT_hd4 = self.hd5_UT_hd4_conv(self.maxpool1(h1))
        h3_PT_hd4 = self.hd5_UT_hd4_conv(self.maxpool2(h2))
        h4_Cat_hd4 = self.hd5_UT_hd4_conv(self.maxpool3(h3))
        h5_Cat_hd4 = self.hd5_UT_hd4_conv(h4)

        hd4 = self.conv4d_1(torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, h5_Cat_hd4), 1))

        h1_PT_hd3 = self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))
        h2_PT_hd3 = self.hd5_UT_hd3_conv(self.maxpool1(h1))
        h3_PT_hd3 = self.hd5_UT_hd3_conv(self.maxpool2(h2))
        h4_PT_hd3 = self.hd5_UT_hd3_conv(self.maxpool3(h3))
        h5_Cat_hd3 = self.hd5_UT_hd3_conv(h4)

        hd3 = self.conv3d_1(torch.cat((h1_PT_hd3, h2_PT_hd3, h3_PT_hd3, h4_PT_hd3, h5_Cat_hd3), 1))

        h1_PT_hd2 = self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))
        h2_PT_hd2 = self.hd5_UT_hd2_conv(self.maxpool1(h1))
        h3_PT_hd2 = self.hd5_UT_hd2_conv(self.maxpool2(h2))
        h4_PT_hd2 = self.hd5_UT_hd2_conv(self.maxpool3(h3))
        h5_PT_hd2 = self.hd5_UT_hd2_conv(self.maxpool4(h4))

        hd2 = self.conv2d_1(torch.cat((h1_PT_hd2, h2_PT_hd2, h3_PT_hd2, h4_PT_hd2, h5_PT_hd2), 1))

        h1_PT_hd1 = self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))
        h2_PT_hd1 = self.hd5_UT_hd1_conv(self.maxpool1(h1))
        h3_PT_hd1 = self.hd5_UT_hd1_conv(self.maxpool2(h2))
        h4_PT_hd1 = self.hd5_UT_hd1_conv(self.maxpool3(h3))
        h5_PT_hd1 = self.hd5_UT_hd1_conv(self.maxpool4(h4))

        hd1 = self.conv1d_1(torch.cat((h1_PT_hd1, h2_PT_hd1, h3_PT_hd1, h4_PT_hd1, h5_PT_hd1), 1))

        # Output
        out = self.outconv1(hd1)

        return out

def init_weights(m, init_type='kaiming'):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        if init_type == 'kaiming':
            init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        elif init_type == 'xavier':
            init.xavier_normal_(m.weight)
        elif init_type == 'orthogonal':
            init.orthogonal_(m.weight)
        else:
            raise NotImplementedError(f'Initialization method {init_type} is not implemented')

# Instantiate the model
model = UNet_3PlusModified()

# Print the model architecture
print(model)


In [None]:
# Full Segmentation Model
class TumorSegmentation(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.model = UNet_3PlusModified()

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
        self.loss_fn = torch.nn.BCEWithLogitsLoss()

    def forward(self, data):
        pred = self.model(data)
        return pred

    def training_step(self, batch, batch_idx):
        ct, mask = batch
        mask = mask.float()
        ct = ct.float()

        pred = self(ct)
        loss = self.loss_fn(pred, mask)

        # Logs
        self.log("Train Dice", loss)
        if batch_idx % 50 == 0:
            self.log_images(ct.cpu(), pred.cpu(), mask.cpu(), "Train")
        return loss


    def validation_step(self, batch, batch_idx):
        ct, mask = batch
        mask = mask.float()
        ct = ct.float()

        pred = self(ct)
        loss = self.loss_fn(pred, mask)

        # Logs
        self.log("Val Dice", loss)
        if batch_idx % 50 == 0:
            self.log_images(ct.cpu(), pred.cpu(), mask.cpu(), "Val")

        return loss


    def log_images(self, ct, pred, mask, name):

        results = []

        pred = pred > 0.5 # As we use the sigomid activation function, we threshold at 0.5


        fig, axis = plt.subplots(1, 2)
        axis[0].imshow(ct[0][0], cmap="bone")
        mask_ = np.ma.masked_where(mask[0][0]==0, mask[0][0])
        axis[0].imshow(mask_, alpha=0.6)
        axis[0].set_title("Ground Truth")

        axis[1].imshow(ct[0][0], cmap="bone")
        mask_ = np.ma.masked_where(pred[0][0]==0, pred[0][0])
        axis[1].imshow(mask_, alpha=0.6, cmap="autumn")
        axis[1].set_title("Pred")

        self.logger.experiment.add_figure(f"{name} Prediction vs Label", fig, self.global_step)



    def configure_optimizers(self):
        #We always need to return a list here (just pack our optimizer into one :))
        return [self.optimizer]



In [None]:
# Instanciate the model
model = TumorSegmentation()

# Create the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='Val Dice',
    save_top_k=30,
    mode='min')

# Create the trainer
trainer = pl.Trainer(accelerator="cuda",
                     logger=TensorBoardLogger(save_dir="/content/drive/MyDrive/output"),
                     log_every_n_steps=1,
                     callbacks=checkpoint_callback,
                     max_epochs=1)


In [None]:

trainer.fit(model, train_loader, val_loader,
            # ckpt_path = "/content/drive/MyDrive/output/lightning_logs/version_1/checkpoints/epoch=20-step=33117.ckpt"
            )


In [None]:
class DiceScore(torch.nn.Module):
    """
    class to compute the Dice Loss
    """
    def __init__(self):
        super().__init__()

    def forward(self, pred, mask):

        #flatten label and prediction tensors
        pred = torch.flatten(pred)
        mask = torch.flatten(mask)

        counter = (pred * mask).sum()  # Counter
        denum = pred.sum() + mask.sum()  # denominator
        dice = (2*counter)/denum

        return dice

In [None]:
model = TumorSegmentation.load_from_checkpoint("/content/drive/MyDrive/output/lightning_logs/version_0/checkpoints/epoch=0-step=404.ckpt",
                                               map_location=torch.device('cpu'))
model.eval();
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device);

In [None]:

preds = []
labels = []

for slice, label in tqdm(val_dataset):
    slice = torch.tensor(slice).float().to(device).unsqueeze(0)
    with torch.no_grad():
        pred = torch.sigmoid(model(slice))
    preds.append(pred.cpu().numpy())
    labels.append(label)

preds = np.array(preds)
labels = np.array(labels)

In [None]:
dice_score = DiceScore()(torch.from_numpy(preds), torch.from_numpy(labels).unsqueeze(0).float())
print(f"The Val Dice Score is: {dice_score}")