In [1]:
import os, torch, random
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'

CFG = {"SEED": 42,
       "TEST_PORTION": 0.3,
       "EPOCHS": 30,
       "BATCH_SIZE": 16,
       "LR": 1e-4}

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED']=str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic=True
    torch.backends.cudnn.benchmark=True
    return

In [2]:
root = "D:/Datasets/Liver/"

def path_dict(root):
    total_dict = {}
    for ID in os.listdir(root):
        date_dict = {}
        for date in os.listdir(f"{root}{ID}/"):
            path_list = []
            for item in os.listdir(f"{root}{ID}/{date}/"):
                if item == "Seg":
                    path_list.append(f"{root}{ID}/{date}/{item}/liver.nii.gz")
                elif ".nii" in item:
                    path_list.append(f"{root}{ID}/{date}/{item}")
            date_dict[date] = path_list
        total_dict[ID] = date_dict
    return total_dict

In [4]:
path_dict(root)

{'1011830': {'20180723': ['D:/Datasets/Liver/1011830/20180723/A.nii.gz',
   'D:/Datasets/Liver/1011830/20180723/D.nii.gz',
   'D:/Datasets/Liver/1011830/20180723/label.nii.gz',
   'D:/Datasets/Liver/1011830/20180723/P.nii.gz',
   'D:/Datasets/Liver/1011830/20180723/Seg/liver.nii.gz'],
  '20181128': ['D:/Datasets/Liver/1011830/20181128/A.nii.gz',
   'D:/Datasets/Liver/1011830/20181128/D.nii.gz',
   'D:/Datasets/Liver/1011830/20181128/label.nii.gz',
   'D:/Datasets/Liver/1011830/20181128/P.nii.gz',
   'D:/Datasets/Liver/1011830/20181128/Seg/liver.nii.gz']},
 '1070799': {'20200513': ['D:/Datasets/Liver/1070799/20200513/A.nii.gz',
   'D:/Datasets/Liver/1070799/20200513/D.nii.gz',
   'D:/Datasets/Liver/1070799/20200513/label.nii.gz',
   'D:/Datasets/Liver/1070799/20200513/P.nii.gz',
   'D:/Datasets/Liver/1070799/20200513/Seg/liver.nii.gz'],
  '20200812': ['D:/Datasets/Liver/1070799/20200812/A.nii.gz',
   'D:/Datasets/Liver/1070799/20200812/D.nii.gz',
   'D:/Datasets/Liver/1070799/20200812/l

In [None]:
seed_everything(CFG["SEED"])

In [10]:
import torch, torchvision
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
import nibabel as nib
from basic_setup import seed_everything, CFG
from readpath import path_dict

seed_everything(CFG["SEED"])

class LiverDataset(Dataset):
    def __init__(self, dict, transform=None):
        self.dict = dict
        self.transform = transform
    
    def __len__(self):
        return len(self.dict)
    
    def __get_voxel__(self, path):  # (384, 384, 96)
        voxel = nib.load(path).get_fdata()
        return voxel
    
    def __get_date__(self, ID):
        return list(self.dict[ID].keys())
    
    def __merge_APD__(self, A, P, D):  # Needs improvement
        return (A + P + D) / 3
    
    def __normalize__(self, voxel):
        minimum = np.min(voxel)
        maximum = np.max(voxel)
        return ( (voxel - minimum) / (maximum - minimum) * 255 ).astype(np.uint8)

    def __getitem__(self, index):
        ID = list(self.dict.keys())[index]
        date_list = self.__get_date__(ID)
        path_list = self.dict[ID][date_list[0]]  # Use data from the 1st date ONLY
        A = self.__get_voxel__(path_list[0])
        P = self.__get_voxel__(path_list[3])
        D = self.__get_voxel__(path_list[1])
        label = self.__get_voxel__(path_list[2])
        merged = self.__merge_APD__(A, P, D)

        if self.transform is not None:
            merged = self.transform(merged)

        return merged.unsqueeze(0), torchvision.transforms.ToTensor()(label).unsqueeze(0)

train_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

test_ratio = CFG["TEST_PORTION"]
data_dict = path_dict("D:/Datasets/Liver/")
total_dataset = LiverDataset(data_dict, train_transform)
train_set, test_set = random_split(total_dataset, [len(total_dataset)-int(len(total_dataset)*test_ratio), int(len(total_dataset)*test_ratio)])
train_loader = DataLoader(train_set, batch_size=CFG["BATCH_SIZE"], shuffle=True)
test_loader = DataLoader(test_set, batch_size=CFG["BATCH_SIZE"], shuffle=True)

In [18]:
import torch
import torch.nn as nn
from basic_setup import CFG, device

class LiTSModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.InitialBlock = nn.Sequential(
            nn.Conv3d(in_channels=1, out_channels=2, kernel_size=1, stride=1),  # Channel -> DOUBLE (1x1 Conv)
            nn.BatchNorm3d(num_features=2),
            nn.ReLU(),
            nn.MaxPool3d(2, 2)  # Img_size -> HALF
        )

        self.ResBlock1 = nn.Sequential(
            nn.Conv3d(2, 2, 1, 1),  # K=1, S=1, P=0 (O = I) -> just activation
            nn.BatchNorm3d(2),
            nn.ReLU(),
            nn.Conv3d(2, 4, 7, 2, 3),  # K=7, S=2, P=3 (O = I/2)
            nn.BatchNorm3d(4),
            nn.ReLU(),
            nn.Conv3d(4, 4, 1, 1),  # K=1, S=1, P=0 (O = I) -> just activation
            nn.BatchNorm3d(4),
            nn.ReLU(),
            nn.Dropout(p=0.5)
        )

        self.ResBlock2 = nn.Sequential(
            nn.Conv3d(4, 4, 1, 1),  # K=1, S=1, P=0 (O = I) -> just activation
            nn.BatchNorm3d(4),
            nn.ReLU(),
            nn.Conv3d(4, 8, 7, 2, 3),  # K=7, S=2, P=3 (O = I/2)
            nn.BatchNorm3d(8),
            nn.ReLU(),
            nn.Conv3d(8, 8, 1, 1),  # K=1, S=1, P=0 (O = I) -> just activation
            nn.BatchNorm3d(8),
            nn.ReLU(),
            nn.Dropout(p=0.5)
        )

        self.ResBlock3 = nn.Sequential(
            nn.Conv3d(8, 8, 1, 1),  # K=1, S=1, P=0 (O = I) -> just activation
            nn.BatchNorm3d(8),
            nn.ReLU(),
            nn.Conv3d(8, 16, 7, 2, 3),  # K=7, S=2, P=3 (O = I/2)
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.Conv3d(16, 16, 1, 1),  # K=1, S=1, P=0 (O = I) -> just activation
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.Dropout(p=0.5)
        )

        self.SkipConnection1 = nn.Sequential(
            nn.Conv3d(in_channels=2, out_channels=4, kernel_size=1, stride=1),
            nn.BatchNorm3d(4),
            nn.ReLU(),
            nn.MaxPool3d(2, 2),
            nn.Dropout(p=0.5)
        )

        self.SkipConnection2 = nn.Sequential(
            nn.Conv3d(in_channels=4, out_channels=8, kernel_size=1, stride=1),
            nn.BatchNorm3d(8),
            nn.ReLU(),
            nn.MaxPool3d(2, 2),
            nn.Dropout(p=0.5)
        )

        self.SkipConnection3 = nn.Sequential(
            nn.Conv3d(in_channels=8, out_channels=16, kernel_size=1, stride=1),
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.MaxPool3d(2, 2),
            nn.Dropout(p=0.5)
        )
        
        self.activation = nn.ReLU()
        self.maxpool = nn.MaxPool3d(2, 2)
        self.fc = nn.Linear(in_features=16*4*4*2, out_features=2)
        self.softmax = nn.Softmax()
    
    def forward(self, x):
        x = self.InitialBlock(x)
        x = self.ResBlock1(x) + self.SkipConnection1(x)
        x = self.ResBlock2(x) + self.SkipConnection2(x)
        x = self.ResBlock3(x) + self.SkipConnection3(x)

        x = self.maxpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.softmax(x)

        return x

In [None]:
def calculate_iou(pred, gt):
    return

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def train_plot(loss_hist, iou_hist):
    plt.figure(figsize=(20,10))
    plt.subplot(121), plt.plot(loss_hist, label='train_loss')
    plt.title('Train Loss')

    plt.subplot(122), plt.plot(iou_hist, label='train_iou')
    plt.title('Train IOU')
    plt.savefig("./Figures/Train.jpg")
    return

def test_plot(pred, gt):
    from utils import calculate_iou
    iou_scores = calculate_iou(pred, gt)
    plt.figure(figsize=(20,10))
    plt.plot(iou_scores, label='test_iou')
    plt.title('Train Loss')
    plt.savefig("./Figures/Test.jpg")
    return

In [None]:
import os, time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from basic_setup import CFG, device
from construct_dataset import train_loader
from construct_model import LiTSModel
from plot_results import train_plot

def train_loop(dataloader, model, optimizer, loss_fn, model_save_path):
    model.train()

    size = len(dataloader)
    datasize = len(dataloader.dataset)

    loss_hist=[]
    acc_hist=[]

    for epoch in range(CFG["EPOCHS"]):
        epoch_start = time.time()

        loss_item=0
        correct=0
        print(f"Start epoch : {epoch+1}")
        for batch, (X,y) in enumerate(dataloader):
            X = X.to(device).float()
            y = y.to(device).float()

            output = model(X)

            loss = loss_fn(output, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_item += loss.item()

            correct+=(output.argmax(1)==y.argmax(1)).detach().cpu().sum().item()

            if batch % 20 == 0:
                print(f"Batch loss : {(loss):>.5f} {batch}/{size}")

        scheduler.step()
        
        loss_hist.append(loss_item/size)
        acc_hist.append(correct/datasize*100)

        print(f"Loss : {(loss_item/size):>.5f} ACC : {(correct/datasize*100):>.2f}%")

        epoch_end = time.time()
        print(f"End epoch : {epoch+1}")
        print(f"Epoch time : {(epoch_end-epoch_start)//60} min {(epoch_end-epoch_start)%60} sec")
        print()

    torch.save(model.state_dict(), model_save_path)

    return loss_hist, acc_hist

model = LiTSModel().to(device)
loss = nn.BCELoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=CFG["LR"], weight_decay=0.1)  # Weight decay (L2 regularization)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

loss_hist, acc_hist = train_loop(train_loader, model, optimizer, loss, CFG["model_save_path"])

train_plot(loss_hist, acc_hist)

In [None]:
import torch
from basic_setup import CFG, device
from construct_dataset import test_loader
from construct_model import LiTSModel
from plot_results import test_plot

def test_loop(dataloader, model, model_path):
    model.load_state_dict(torch.load(model_path))
    model.eval()

    pred=[]
    target=[]

    for (X,y) in dataloader:
        for t in y:
            target.append(t[1].detach().tolist())

        X = X.to(device).float()
        y = y.to(device).float()

        output = model(X)

        for o in output:
            print(f"Output: {o}")
            pred.append(o[1].detach().cpu().tolist())

    return target, pred

model = LiTSModel().to(device)

target, pred = test_loop(test_loader, model, CFG["model_save_path"])

test_plot(target, pred)