In [1]:
import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, ConcatDataset

import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
import random

In [2]:
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss
from segmentation_models_pytorch.utils.metrics import IoU, Fscore, Accuracy

In [3]:
import matplotlib.pyplot as plt

# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image, 'gray')
    plt.show()

In [4]:
import random

SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

print(np.random.rand(5), torch.randn(5))

[0.5488135  0.71518937 0.60276338 0.54488318 0.4236548 ] tensor([ 1.5410, -0.2934, -2.1788,  0.5684, -1.0845])


# Settings

In [5]:
root = os.getcwd()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

BATCH = 1
SEGMENT = 1

ELECTRON = False
G_COORD = False
L_COORD = False

EPOCH = 200

In [6]:
VIEW_BOUND = (-500, 500)
AIR_BOUND = (-500, -426)
BONE_BOUND = (400, 500)
if ELECTRON:
    VIEW_BOUND = (0.5, 1.5)
    AIR_BOUND = (0.5, 0.5009)
    BONE_BOUND = (1.2, 1.2009)    

# Read Data

In [7]:
from codes.dataset import DicomsSegmentDataset, DicomSegmentDataset
import codes.augmentation as aug

In [8]:
pelvic_train_case_path = 'raw/train/*_*'
pelvic_train_id_case_path = 'raw/train_id/*_*'
pelvic_test_case_path = 'raw/test/*_*'

In [9]:
trainset_pelvic = DicomsSegmentDataset(pelvic_train_case_path, 
                   intensity_aug=None, geometry_aug=aug.get_validation_augmentation(),
                  identity=False, electron=ELECTRON, position="pelvic", segment=SEGMENT, g_coord=G_COORD, l_coord=L_COORD)
trainset_id_pelvic = DicomsSegmentDataset(pelvic_train_id_case_path, 
                   intensity_aug=None, geometry_aug=aug.get_validation_augmentation(),
                  identity=False, electron=ELECTRON, position="pelvic", segment=SEGMENT, g_coord=G_COORD, l_coord=L_COORD)

In [10]:
testset = DicomsSegmentDataset(pelvic_test_case_path, geometry_aug=aug.get_validation_augmentation(), 
                 identity=False, electron=ELECTRON, position="pelvic", segment=SEGMENT, g_coord=G_COORD, l_coord=L_COORD)

In [11]:
len(trainset_pelvic), len(trainset_id_pelvic), len(testset)

(784, 560, 191)

In [12]:
dataset = ConcatDataset([trainset_pelvic, trainset_id_pelvic, testset])

In [13]:
paths = sorted(glob.glob(pelvic_train_case_path))
for i in range(0, len(paths), 2):
    scans = DicomSegmentDataset(cbct_path=paths[i+1], ct_path=paths[i],
                         geometry_aug=None, intensity_aug=None, 
                         identity=False, electron=ELECTRON, position="pelvic", segment=SEGMENT, g_coord=G_COORD, l_coord=L_COORD)
    patient_id = scans.patientID()
    print(patient_id, len(scans))

046 27
047 28
050 27
051 27
052 31
053 26
054 28
055 29
057 29
058 29
059 30
060 29
061 27
062 29
064 28
065 26
066 27
067 26
078 30
080 26
083 29
084 30
089 27
094 29
096 27
099 27
100 29
101 27


In [14]:
paths = sorted(glob.glob(pelvic_train_id_case_path))
for i in range(0, len(paths), 2):
    scans = DicomSegmentDataset(cbct_path=paths[i+1], ct_path=paths[i],
                         geometry_aug=None, intensity_aug=None, 
                         identity=False, electron=ELECTRON, position="pelvic", segment=SEGMENT, g_coord=G_COORD, l_coord=L_COORD)
    patient_id = scans.patientID()
    print(patient_id, len(scans))

004 27
006 28
007 28
012 27
014 28
017 28
018 28
019 32
021 28
025 28
027 26
028 28
029 27
030 28
033 26
039 28
040 32
042 27
044 27
045 29


In [15]:
paths = sorted(glob.glob(pelvic_test_case_path))
for i in range(0, len(paths), 2):
    scans = DicomSegmentDataset(cbct_path=paths[i+1], ct_path=paths[i],
                         geometry_aug=None, intensity_aug=None, 
                         identity=False, electron=ELECTRON, position="pelvic", segment=SEGMENT, g_coord=G_COORD, l_coord=L_COORD)
    patient_id = scans.patientID()
    print(patient_id, len(scans))

068 28
069 27
070 26
071 28
073 27
074 28
076 27


In [22]:
def hu_histogram(dataset, ys, xs):
    for index, data in tqdm(enumerate(dataset)):
        x, y, *_ = data
        x = x.squeeze() * 1000 - 500
        y = y.squeeze() * 1000 - 500
        x_hist = np.histogram(x, bins=1000)[0]
        y_hist = np.histogram(y, bins=1000)[0]
        xs = np.add(x_hist, xs)
        ys = np.add(y_hist, ys)
        
    xs = xs/len(xs)
    ys = ys/len(ys)
    
    bins = np.arange(len(xs)) - 500
    
    plt.style.use("ggplot")
    fig = plt.figure() 
    ax = plt.subplot(111)
#     ax.set_position([0.1,0.1,0.8,0.8])

    ax.plot(bins, xs, label="CBCT", alpha=0.7)
    ax.plot(bins, ys, label="CT", alpha=0.7)
    ax.set_yscale("log")
   
    ax.legend(loc="upper center")
    ax.fill_between(bins, 0, xs, alpha=0.5)
    ax.fill_between(bins, 0, ys, alpha=0.5)
    ax.set_xlabel("HU (CT number)")
    ax.set_ylabel("Intensity")
    fig.suptitle("HU histogram of combined dataset")
#     plt.show()
    
    plt.savefig("Combined hu histogram.png")

In [None]:
hu_histogram(dataset, np.zeros(1000), np.zeros(1000))

396it [00:14, 25.43it/s]