In [2]:
from models import UNet3D
from torchsummary import summary
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.autograd import Variable
import glob
import nibabel as nib
import os
import matplotlib.pyplot as plt

from monai.utils import first

In [3]:
# CUDA_LAUNCH_BLOCKING=1
device = 'cuda'
model = UNet3D()
model = model.to(device)
summary(model, input_size = (1, 64, 64, 64), batch_size = 1, device=device)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1        [1, 16, 64, 64, 64]             448
              ReLU-2        [1, 16, 64, 64, 64]               0
         Dropout3d-3        [1, 16, 64, 64, 64]               0
            Conv3d-4        [1, 16, 64, 64, 64]           6,928
              ReLU-5        [1, 16, 64, 64, 64]               0
         MaxPool3d-6        [1, 16, 32, 32, 32]               0
            Conv3d-7        [1, 32, 32, 32, 32]          13,856
              ReLU-8        [1, 32, 32, 32, 32]               0
         Dropout3d-9        [1, 32, 32, 32, 32]               0
           Conv3d-10        [1, 32, 32, 32, 32]          27,680
             ReLU-11        [1, 32, 32, 32, 32]               0
        MaxPool3d-12        [1, 32, 16, 16, 16]               0
           Conv3d-13        [1, 64, 16, 16, 16]          55,360
             ReLU-14        [1, 64, 16,

In [4]:
data_dir = './splitted/'
print(data_dir)

saveFile ='checkpoint'

checkpoint_path = saveFile + '.pth'

./splitted/


In [5]:
# Print the number of training images in the specified directory
print('Total train image Samples=' + str(len(glob.glob(data_dir+"train/images/*.nii"))))
print('Total train image Samples=' + str(len(glob.glob(data_dir+"train/masks/*.nii"))))

# Print the number of validation images in the specified directory
print('Total val image Samples=' + str(len(glob.glob(data_dir+"test/images/*.nii"))))
print('Total val mask Samples=' + str(len(glob.glob(data_dir+"test/masks/*.nii"))))

Total train image Samples=172
Total train image Samples=172
Total val image Samples=20
Total val mask Samples=20


In [6]:
# Function to convert integer labels to one-hot encoding
def make_one_hot(labels, device, C=2):
    '''
    Converts integer labels to one-hot encoding for semantic segmentation tasks.

    Parameters
    ----------
    labels : torch.autograd.Variable of torch.cuda.LongTensor
        Shape: N x 1 x H x W, where N is the batch size.
        Each value is an integer representing correct classification.
    C : integer
        Number of classes in labels.

    Returns
    -------
    target : torch.autograd.Variable of torch.cuda.FloatTensor
        Shape: N x C x H x W, where C is the class number. One-hot encoded.
    '''
    # Ensure labels are of type LongTensor
    labels = labels.long()

    # Number of classes (including the background class '0')
    C = C+1 # add extra 1 for background class (this will be removed later)

    # Create a zero-initialized one-hot tensor with the appropriate dimensions
    one_hot = torch.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3), labels.size(4)).zero_().to(device)

    # Use scatter_ to set the corresponding class index to 1 for each pixel
    target = one_hot.scatter_(1, labels.data, 1)

    # Convert the result to a torch.autograd.Variable
    target = Variable(target)

    return target

In [7]:
# Define a function to read NIfTI image from a given path
def readNifti_img(path):
    # Load the NIfTI image and normalize pixel values
    # print(path)
    img_ = nib.load(path).get_fdata()
    img_ = img_ / img_.max()
    # Convert to torch tensor and add channel dimension
    img_ = torch.tensor(img_, dtype=torch.float32).unsqueeze(0).to(device)
    
    # print(img_.shape)
    
    return img_

# Define a function to read NIfTI mask from a given path
def readNifti_mask(path):
    # Load the NIfTI mask and convert to torch tensor with channel dimension
    # print(path)
    mask_ = nib.load(path).get_fdata()
    # mask_ = mask_ / mask_.max()
    mask_ = torch.tensor(mask_, dtype=torch.float32).unsqueeze(0).to(device)
    return mask_

# Define a custom dataset class
class NiftiDataset():
    def __init__(self, image_paths, mask_paths, transform=None):
        # Initialize dataset with image and mask paths, fixed, moving, and transform
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform


    def __len__(self):
        # Return the total number of samples in the dataset
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Get the head and tail of the image path
        head, tail = os.path.split(self.image_paths[idx])
        # print(head)
        # print(tail)

        # Load fixed image using the readNifti_img function
        img = readNifti_img(head + '/' + tail[0:-4] + '.nii')

        # Get the head and tail of the mask path
        head, tail = os.path.split(self.mask_paths[idx])
        # print(head)
        # print(tail)

        # Load fixed mask using the readNifti_mask function
        mask = readNifti_mask(head + '/' + tail[0:-4] + '.nii')

        # Create a dictionary containing the fixed and moving images and masks
        subject = {'img': img,
                   'mask': mask}

        # Apply transformations if provided
        if self.transform:
            subject = self.transform(subject)

        # Return the subject dictionary
        return subject


In [8]:
# Create a DataLoader for the training dataset
train_loader = DataLoader(
    NiftiDataset(
        sorted(glob.glob(data_dir+"train/images/*.nii")),
        sorted(glob.glob(data_dir+"train/masks/*.nii")),
        transform=None
    ),
    batch_size=6,
    shuffle=True,
    num_workers=0
)

# Create a DataLoader for the testing/validation dataset
val_loader = DataLoader(
    NiftiDataset(
        sorted(glob.glob(data_dir+"test/images/*.nii")),
        sorted(glob.glob(data_dir+"test/masks/*.nii")),
        transform=None
    ),
    batch_size=6,
    shuffle=True,
    num_workers=0
)

print()

# Retrieve and print a sample from the training dataset
train_sample = first(train_loader)
print(f"size of mask before one hot encoding {train_sample['mask'].shape}")
print(f"size of image {train_sample['img'].shape}")
print(f"classes of mask {torch.unique(train_sample['mask'])}")
train_sample['mask'] = make_one_hot(train_sample['mask'], device, C=5)
print(f"size of mask after one hot encoding {train_sample['mask'].shape}")
print(f"classes of mask after one hot encoding{torch.unique(train_sample['mask'])}")
print()

# Retrieve and print a sample from the testing dataset
test_sample = first(val_loader)
print(f"size of mask before one hot encoding {test_sample['mask'].shape}")
print(f"size of image {test_sample['img'].shape}")
print(f"classes of mask {torch.unique(test_sample['mask'])}")
test_sample['mask'] = make_one_hot(test_sample['mask'], device, C=5)
print(f"size of mask after one hot encoding {test_sample['mask'].shape}")
print(f"classes of mask after one hot encoding{torch.unique(test_sample['mask'])}")
print()


size of mask before one hot encoding torch.Size([6, 1, 64, 64, 64])
size of image torch.Size([6, 1, 64, 64, 64])
classes of mask tensor([0., 2., 3., 4., 5.], device='cuda:0')
size of mask after one hot encoding torch.Size([6, 6, 64, 64, 64])
classes of mask after one hot encodingtensor([0., 1.], device='cuda:0')

size of mask before one hot encoding torch.Size([6, 1, 64, 64, 64])
size of image torch.Size([6, 1, 64, 64, 64])
classes of mask tensor([0., 1., 2., 3., 4., 5.], device='cuda:0')
size of mask after one hot encoding torch.Size([6, 6, 64, 64, 64])
classes of mask after one hot encodingtensor([0., 1.], device='cuda:0')



In [9]:
torch.unique(test_sample['mask'][:, 1, :, :, :])

tensor([0., 1.], device='cuda:0')

In [None]:
loss_fn = nn.BCELoss() # binary cross-entropy
optimizer = optim.Adam(model.parameters(), lr=0.001)

n_epochs = 100
batch_size = 10
for epoch in range(n_epochs):
    for i in range(0, len(X), batch_size):
        Xbatch = X[i:i+batch_size]
        y_pred = model(Xbatch)
        ybatch = y[i:i+batch_size]
        loss = loss_fn(y_pred, ybatch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Finished epoch {epoch}, latest loss {loss}')