# CT scan segmetation using multiple slices per image - Liver dataset

As we are using 2D images to describe 3D CT scans, inevitably we are losing some information about original CT scan.
In this part we will try to give our model more information and still keep using 2D aproach. 
In previous aproach, we were slicing 3D scans and providing each of those slices to the model to learn from them but separately. So, our model doesn't know that slices are actually connected and represent a part of some bigger media.
Now, we will merge each slice with its previous and next slice of CT scan, in order to provide more information to out model during training.

In [1]:
%run ../Data-Preprocessing.ipynb
%run ../U-Net.ipynb
%run ../Train-Eval-Utils.ipynb

import torchvision.transforms as transforms

from torch.utils.data import DataLoader, ConcatDataset

Data will be preprocess on the same way as before. So, we still slicing the 3D CT scan into 2D images:

In [2]:
# run if this is the first run of liver segmentation:
# convert_ct_dataset_to_slices('Task03_Liver', 'Liver_Train', 'Liver_Val', 'Liver_Test', val_split=0.1, test_split=0.1, negative_downsampling_rate=5)

In the previous approach, the model received one slice (the one on which we want to do the segmentation), and now, in addition to the current slice, we pass the previous and the next slice of the CT scan to the model. 

As the slices are black and white, we can merge them into one 3-channel image and pass it to the model.

 <img src="../metadata/multipleSlicesForTrainingLiver.png" alt="multiple slices image" width="500" height="600"> 

According to that, we created CTDatasetMultiSlices class which will prepare our data on the we way descrribed above.

In [None]:
transform = transforms.Compose([
    transforms.Resize((128, 128), antialias=False),
    transforms.ConvertImageDtype(torch.float)
])

TRAIN_DIR = 'Liver_Train'

train_dataset = CTDatasetMultiSlices(root_dir=TRAIN_DIR, image_transform=transform, label_transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=os.cpu_count())

VAL_DIR = 'Liver_Val'

val_dataset = CTDatasetMultiSlices(root_dir=VAL_DIR, image_transform=transform, label_transform=transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=os.cpu_count())

TEST_DIR = 'Liver_Test'

test_dataset = CTDatasetMultiSlices(root_dir=TEST_DIR, image_transform=transform, label_transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, num_workers=os.cpu_count())

If we look more closely at any of the training instances, we notice that they are blurry. This is due to the fact that there are actually three images in one instance.

In [None]:
# class demo
image, label = train_dataset.__getitem__(75)

plt.imshow(image.permute(1, 2, 0).numpy())  # No cmap for 3-channel images
plt.axis('off')  # Turn off axis labels
plt.show()

In [None]:
image.shape

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = DiceLoss()

Further, we will train the model with different parameters of u-net network that we created and, based on the validation set, conclude which model we are going to use.

In [None]:
model_3 = UNet(depth=3, in_channel=3)
model_3.to(device)
optimizer = torch.optim.Adam(model_3.parameters(), lr=0.001)
_, best_loss = train_loop_with_validation(model_3, 30, train_loader, val_loader, optimizer, criterion)
print(f'Best loss achieved on the validation set: {best_loss}')

In [None]:
model_4 = UNet(depth=5, in_channel=3)
model_4.to(device)
optimizer = torch.optim.Adam(model_4.parameters(), lr=0.001)
_, best_loss = train_loop_with_validation(model_4, 30, train_loader, val_loader, optimizer, criterion)
print(f'Best loss achieved on the validation set: {best_loss}')

In [None]:
model_5 = UNet(depth=7, in_channel=3)
model_5.to(device)
optimizer = torch.optim.Adam(model_5.parameters(), lr=0.001)
_, best_loss = train_loop_with_validation(model_5, 30, train_loader, val_loader, optimizer, criterion)
print(f'Best loss achieved on the validation set: {best_loss}')

In [None]:
model_lower_lr = UNet(depth=3, in_channel=3)
model_lower_lr.to(device)
optimizer = torch.optim.Adam(model_lower_lr.parameters(), lr=0.0005)
_, best_loss = train_loop_with_validation(model_lower_lr, 60, train_loader, val_loader, optimizer, criterion)
print(f'Best loss achieved on the validation set: {best_loss}')

In [None]:
train_val_concat_dataset = ConcatDataset([train_dataset, val_dataset])
train_val_concat_loader = DataLoader(train_val_concat_dataset, batch_size=32, shuffle=True, num_workers=os.cpu_count())

In [None]:
model_final = UNet(in_channel=3, depth=3)
model_final.to(device)
optimizer = torch.optim.Adam(model_final.parameters(), lr=0.001)
best_model, best_loss = train_loop(model_final, 50, train_val_concat_loader, optimizer, criterion)

In [None]:
torch.save(best_model.state_dict(), '../models/modelLiverMultiSlices.pth')

In [None]:
display_predictions(best_model, train_loader, device, 32)

In [None]:
display_predictions(best_model, test_loader, device, 32)