### Create a training pipeline WITH image preprocessing

In [1]:
## import necessary libraries

import numpy as np 
import pandas as pd
from pydicom import dcmread
import os
import scipy.ndimage
import matplotlib.pyplot as plt

from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch

from CT_Seg_Network.utils.dataset import *
from CT_Seg_Network.models.Unet import UNet
from CT_Seg_Network.utils.iou import IoU
from CT_Seg_Network.utils.Loss import *
from CT_Seg_Network.models.ResUnet import ResUNet
from tqdm import tqdm

### Process the dataset images and save to another folder
Processing:
- conversion into Hounsfield Units
- resampling into new spacing (optional)
- normalization (i.e., remove unwanted labels such as bones (+700HU))

In [2]:
## load necessary functions

## loads the scans per patient
def load_scan(path, patient):
    #slices = [dcmread(path + '/' + s) for s in os.listdir(path)]
    
    scanID, slices = [], []
    patient_path = path + patient
    for s in os.listdir(patient_path):
        scanID.append(s[:-4])
        slices.append(dcmread(patient_path + '/' + s))

    
    slices.sort(key = lambda x: float(x.ImagePositionPatient[2]))
    try:
        slice_thickness = np.abs(slices[0].ImagePositionPatient[2] - slices[1].ImagePositionPatient[2])
    except:
        slice_thickness = np.abs(slices[0].SliceLocation - slices[1].SliceLocation)
        
    for s in slices:
        s.SliceThickness = slice_thickness
        
    return slices, scanID


## converts the value into HU (Hounsfield Unit) and cleans out of bounds pixels
def get_pixels_hu(slices):
    image = np.stack([s.pixel_array for s in slices])
    # Convert to int16 (from sometimes int16), 
    # should be possible as values should always be low enough (<32k)
    image = image.astype(np.int16)

    # Set outside-of-scan pixels to 0
    # The intercept is usually -1024, so air is approximately 0
    image[image == -2000] = 0
    
    # Convert to Hounsfield units (HU)
    for slice_number in range(len(slices)):
        
        intercept = slices[slice_number].RescaleIntercept
        slope = slices[slice_number].RescaleSlope
        
        if slope != 1:

            image[slice_number] = slope * image[slice_number].astype(np.float64)
            image[slice_number] = image[slice_number].astype(np.int16)
            
        image[slice_number] += np.int16(intercept)
    
    return np.array(image, dtype=np.int16)

## resample to isotropic resolution to help convnets
def resample(image, scan, new_spacing=[1,1,1]):
    # Determine current pixel spacing
    spacing = np.array([scan[0].SliceThickness] + list(scan[0].PixelSpacing), dtype=np.float32)

    resize_factor = spacing / new_spacing
    new_real_shape = image.shape * resize_factor
    new_shape = np.round(new_real_shape)
    real_resize_factor = new_shape / image.shape
    new_spacing = spacing / real_resize_factor
    
    image = scipy.ndimage.interpolation.zoom(image, real_resize_factor, mode='nearest')
    
    return image, new_spacing


## remove values from unwanted body parts (e.g., bones +700HU)
MIN_BOUND = -1000.0
MAX_BOUND = 400.0
    
def normalize(image):
    image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND)
    image[image>1] = 1.
    image[image<0] = 0.
    return image

#### Load each image and process as described above

In [19]:
IMAGE_FOLDER = 'coding_test_files/dicom_series/'
PROCESSED_FOLDER = 'coding_test_files/processed_dicom/'
patients = os.listdir(IMAGE_FOLDER)
patients.sort()

for pt in patients:
    patient_scan, scanID = load_scan(IMAGE_FOLDER, pt)
    patient_scan_pixels = get_pixels_hu(patient_scan)
    #pix_resampled, spacing = resample(patient_scan_pixels, patient_scan, [1,1,1])
    #patient_scan_pixels = pix_resampled

    if not os.path.exists(PROCESSED_FOLDER + pt): 
        os.mkdir(PROCESSED_FOLDER + pt)

    ## save the new images into processed folder
    for i in range(patient_scan_pixels.shape[0]):
        image = normalize(patient_scan_pixels[i,:,:])
        np.savez_compressed(PROCESSED_FOLDER + pt + '/' + scanID[i], image)

### Train the Network

In [20]:
## prepare the dataset by dividing the dataset into
IMAGE_FOLDER = 'coding_test_files/processed_dicom/'
SEG_FOLDER ='coding_test_files/segmentation_data/'
patients = os.listdir(IMAGE_FOLDER)
patients.sort()

## split percentage
percentage = .8
train, val = np.split(patients, [int(len(patients)*percentage)])

def load_list(folder, list_patients):
    folder_list = []
    for i in list_patients:
        for j in os.listdir(folder + i):
            folder_list.append(folder + i + '/' + j)
    return folder_list

train_list_im = load_list(IMAGE_FOLDER, train)
val_list_im = load_list(IMAGE_FOLDER, val)

In [28]:
## create a dataset loader that will handle loading the images from processed dicom files directly
## and labels from segmentation data

class CSTO_CTDataset(Dataset):
    def __init__(self, image_addr_list, seg_folder=None, transforms=None):

        ## takes the address of the images
        self.images = image_addr_list
        if seg_folder == None:
            self.seg_folder = 'coding_test_files/segmentation_data/'
        else:
            self.seg_folder = seg_folder

        self.transform = transforms

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image_addr = self.images[idx]

        img = np.load(image_addr)['arr_0']
        img = np.expand_dims(img, axis=0).astype(np.float32)
        img = torch.from_numpy(img).float()

        ## extract the sample name from the image_name
        last_slash = image_addr.find('processed_dicom/') + 16
        sample = image_addr[last_slash:-4]
        label_addr = self.seg_folder + sample +'.npz'
        
        if os.path.isfile(label_addr):
            label = torch.from_numpy(np.transpose(np.load(label_addr)['arr_0'], (-1, 0, 1)))#.to(torch.float32)
        else:
            label = torch.zeros([4, img.shape[1], img.shape[2]])#.to(torch.float32)
    
        
        if self.transform:
            img, label = self.transform((img, label))

        sample = {'img': img,
                  'label': label}
        
        return sample 

In [29]:
## load custom transforms that manipulates both image and label

transform_train = transforms.Compose([
            RandomFlip(),
            Resize(128),
            Normalize(mean=[0.5],
                      std=[0.5])
        ])

transform_val = transforms.Compose([
            Resize(128),
            Normalize(mean=[0.5],
                      std=[0.5])
        ])


## create sets for train and validation
trainset = CSTO_CTDataset(train_list_im, transforms=transform_train)
valset = CSTO_CTDataset(val_list_im, transforms=transform_val)

## create the dataloader for train and validation
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=16)

In [30]:
## create functions for the training and evaluation proper

def train(model,train_loader,optimizer,LOSS_FUNC,EPOCH,PRINT_INTERVAL, epoch, device):
    losses = []
    for i, batch in enumerate(tqdm(train_loader)):
        img, label = batch['img'].to(device), batch['label'].to(device)
        output = model(img)
        optimizer.zero_grad()

        loss = LOSS_FUNC(output, label)

        loss.backward()
        losses.append(loss.item())
        optimizer.step()
        if (i + 1) % PRINT_INTERVAL == 0:
            tqdm.write('Epoch [%d/%d], Iter [%d/%d], Loss: %.4f'
                       % (epoch + 1, EPOCH, i + 1, len(train_loader), loss.item()))
    return np.mean(losses)

def eval(model,val_loader,LOSS_FUNC, device):
    losses = []
    for i, batch in enumerate(val_loader):
        img, label = batch['img'].to(device), batch['label'].to(device)
        output = model(img)
        loss = LOSS_FUNC(output, label)
        losses.append(loss.item())
    return np.mean(losses)

In [32]:
batch_size = 16
print("Train set {}\nValidation set {}".format(len(trainset),len(valset)))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## initialize model type
model = ResUNet(out_classes=4).to(device)


if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
    model = nn.DataParallel(model).to(device)

## init optimizer, scheduler, and loss function
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
lr_sheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
LOSS_FUNC = DiceCELoss().to(device)
PRINT_INTERVAL = 5
EPOCH= 50



## start training
val_loss_epoch = []
for epoch in range(EPOCH):

    model.train()
    train_loss = train(model, trainloader, optimizer, LOSS_FUNC, EPOCH, PRINT_INTERVAL, epoch, device)
    val_loss = eval(model, valloader, LOSS_FUNC, device)
    val_loss_epoch.append(val_loss)
    lr_sheduler.step()
    tqdm.write('Epoch [%d/%d], Average Train Loss: %.4f, Average Validation Loss: %.4f'
                % (epoch + 1, EPOCH, train_loss, val_loss))


    ## save the model with the best val result
    if val_loss == np.min(val_loss_epoch):
        print('Model saved')
        state = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        checkpoint_addr = 'CT_Seg_Network/weights/'
        torch.save(state, os.path.join(checkpoint_addr,'best.pth'))

Train set 3300
Validation set 1103


 10%|▉         | 5/52 [00:13<02:02,  2.60s/it]

Epoch [1/50], Iter [5/52], Loss: 1.0399


 19%|█▉        | 10/52 [00:26<01:51,  2.66s/it]

Epoch [1/50], Iter [10/52], Loss: 1.0474


 29%|██▉       | 15/52 [00:40<01:40,  2.71s/it]

Epoch [1/50], Iter [15/52], Loss: 1.0328


 38%|███▊      | 20/52 [00:53<01:27,  2.72s/it]

Epoch [1/50], Iter [20/52], Loss: 1.0182


 48%|████▊     | 25/52 [01:07<01:13,  2.71s/it]

Epoch [1/50], Iter [25/52], Loss: 1.0340


 58%|█████▊    | 30/52 [01:21<01:00,  2.77s/it]

Epoch [1/50], Iter [30/52], Loss: 1.0108


 67%|██████▋   | 35/52 [01:34<00:46,  2.74s/it]

Epoch [1/50], Iter [35/52], Loss: 1.0213


 77%|███████▋  | 40/52 [01:48<00:33,  2.81s/it]

Epoch [1/50], Iter [40/52], Loss: 1.0108


 87%|████████▋ | 45/52 [02:02<00:19,  2.75s/it]

Epoch [1/50], Iter [45/52], Loss: 1.0124


 88%|████████▊ | 46/52 [02:05<00:16,  2.78s/it]