Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Merge pull request #194 from concept-to-clinic/prediction-fixes
Browse files Browse the repository at this point in the history
# Bug Fixes

Refactor and fix test skipping
Fix default model_paths

# Stylistic Fixes

Fix spacing and typos
Refactor dictionaries
Put dictionaries onto one line
Refactor/rearrange imports
Fix Config paths
Refactor list comprehensions
Use single quotes and shorten line length

# Refactoring

Refactor fixtures
Refactor tests
Refactor Config path usage
Consolidate tests
Rename tests
  • Loading branch information
reubano committed Nov 1, 2017
2 parents bb54879 + 1845d5f commit 03a6250
Show file tree
Hide file tree
Showing 35 changed files with 613 additions and 573 deletions.
18 changes: 12 additions & 6 deletions prediction/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@
"""
import os

from os import path

LIDC_WILDCARD = ['LIDC-IDRI-*', '**', '**']


class Config(object):
PROD_SERVER = os.getenv('PRODUCTION', False)
DEBUG = False
# The following paths are expanded at runtime
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
SEGMENT_ASSETS_DIR = os.path.abspath(os.path.join(CURRENT_DIR, 'src', 'algorithms', 'segment', 'assets'))
DICOM_PATHS_DOCKER_WILDCARD = os.path.join('/images_full', 'LIDC-IDRI-*', '**', '**')
DICOM_PATHS_LOCAL_WILDCARD = os.path.join(CURRENT_DIR, 'src', 'tests', 'assets',
'test_image_data', 'full', 'LIDC-IDRI-*', '**', '**')
CURRENT_DIR = path.dirname(path.realpath(__file__))
PARENT_DIR = path.dirname(CURRENT_DIR)
ALGOS_DIR = path.abspath(path.join(CURRENT_DIR, 'src', 'algorithms'))
SEGMENT_ASSETS_DIR = path.abspath(path.join(ALGOS_DIR, 'segment', 'assets'))
FULL_DICOM_PATHS = path.join(PARENT_DIR, 'images_full')
SMALL_DICOM_PATHS = path.join(PARENT_DIR, 'images')
FULL_DICOM_PATHS_WILDCARD = path.join(FULL_DICOM_PATHS, *LIDC_WILDCARD)
SMALL_DICOM_PATHS_WILDCARD = path.join(FULL_DICOM_PATHS, *LIDC_WILDCARD)


class Production(Config):
Expand Down
42 changes: 29 additions & 13 deletions prediction/src/algorithms/classify/src/gtr123_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
from os import path

import numpy as np
import torch
from src.preprocess import load_ct, preprocess_ct, crop_patches

from torch import nn
from torch.autograd import Variable

from config import Config
from src.preprocess.crop_patches import patches_from_ct
from src.preprocess.load_ct import load_ct
from src.preprocess.preprocess_ct import PreprocessCT


""""
Classification model from team gtr123
Code adapted from https://github.com/lfz/DSB2017
Expand Down Expand Up @@ -48,16 +56,16 @@ def __init__(self, n_in, n_out, stride=1):
self.shortcut = None

def forward(self, x):

residual = x

if self.shortcut is not None:
residual = self.shortcut(x)

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)

out += residual
out = self.relu(out)
return out
Expand All @@ -79,7 +87,7 @@ def __init__(self):
nn.BatchNorm3d(24),
nn.ReLU(inplace=True))

# 3 poolings, each pooling downsamples the feature map by a factor 2.
# 3 poolings, each pooling down-samples the feature map by a factor 2.
# 3 groups of blocks. The first block of each group has one pooling.
num_blocks_forw = [2, 2, 3, 3]
num_blocks_back = [3, 3]
Expand Down Expand Up @@ -206,11 +214,12 @@ def forward(self, xlist, coordlist):

noduleFeat, nodulePred = self.NoduleNet(xlist, coordlist)
nodulePred = nodulePred.contiguous().view(corrdsize[0], corrdsize[1], -1)

featshape = noduleFeat.size() # nk x 128 x 24 x 24 x24

centerFeat = self.pool(noduleFeat[:, :, featshape[2] // 2 - 1:featshape[2] // 2 + 1,
featshape[3] // 2 - 1:featshape[3] // 2 + 1,
featshape[4] // 2 - 1:featshape[4] // 2 + 1])

centerFeat = centerFeat[:, :, 0, 0, 0]
out = self.dropout(centerFeat)
out = self.Relu(self.fc1(out))
Expand All @@ -221,7 +230,7 @@ def forward(self, xlist, coordlist):
return nodulePred, casePred, out


def predict(ct_path, nodule_list, model_path="src/algorithms/classify/assets/gtr123_model.ckpt"):
def predict(ct_path, nodule_list, model_path=None):
"""
Args:
Expand All @@ -233,24 +242,31 @@ def predict(ct_path, nodule_list, model_path="src/algorithms/classify/assets/gtr
List of nodules, and probabilities
"""
if not model_path:
CLASSIFY_DIR = path.join(Config.ALGOS_DIR, 'classify')
model_path = path.join(CLASSIFY_DIR, 'assets', 'gtr123_model.ckpt')

if not nodule_list:
return []
casenet = CaseNet()

casenet = CaseNet()
casenet.load_state_dict(torch.load(model_path))
casenet.eval()

if torch.cuda.is_available():
casenet = torch.nn.DataParallel(casenet).cuda()
# else:
# casenet = torch.nn.parallel.DistributedDataParallel(casenet)
# casenet = torch.nn.parallel.DistributedDataParallel(casenet)

preprocess = PreprocessCT(clip_lower=-1200., clip_upper=600., spacing=1., order=1,
min_max_normalize=True, scale=255, dtype='uint8')

ct_array, meta = preprocess(*load_ct(ct_path))
patches = patches_from_ct(ct_array, meta, config['crop_size'], nodule_list,
stride=config['stride'], pad_value=config['filling_value'])

preprocess = preprocess_ct.PreprocessCT(clip_lower=-1200., clip_upper=600., spacing=1., order=1,
min_max_normalize=True, scale=255, dtype='uint8')
ct_array, meta = preprocess(*load_ct.load_ct(ct_path))
patches = crop_patches.patches_from_ct(ct_array, meta, config['crop_size'], nodule_list,
stride=config['stride'], pad_value=config['filling_value'])
results = []

for nodule, (cropped_image, coords) in zip(nodule_list, patches):
cropped_image = Variable(torch.from_numpy(cropped_image[np.newaxis, np.newaxis]).float())
cropped_image.volatile = True
Expand Down
6 changes: 6 additions & 0 deletions prediction/src/algorithms/identify/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ def load_patient_images(patient_id, base_dir=EXTRACTED_IMAGE_DIR, wildcard="*.*"
exclude_wildcards = exclude_wildcards or []
src_dir = os.path.join(os.getcwd(), base_dir, patient_id)
src_img_paths = glob.glob(src_dir + wildcard)

for exclude_wildcard in exclude_wildcards:
exclude_img_paths = glob.glob(src_dir + exclude_wildcard)
src_img_paths = [im for im in src_img_paths if im not in exclude_img_paths]

src_img_paths.sort()
images = [cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) for img_path in src_img_paths]
images = [im.reshape((1,) + im.shape) for im in images]
Expand All @@ -55,6 +57,7 @@ def prepare_image_for_net3D(img):
def filter_patient_nodules_predictions(df_nodule_predictions: pandas.DataFrame, patient_id, view_size):
patient_mask = load_patient_images(patient_id, wildcard="*_m.png")
delete_indices = []

for index, row in df_nodule_predictions.iterrows():
z_perc = row["coord_z"]
y_perc = row["coord_y"]
Expand All @@ -66,6 +69,7 @@ def filter_patient_nodules_predictions(df_nodule_predictions: pandas.DataFrame,
start_y = center_y - view_size / 2
start_x = center_x - view_size / 2
nodule_in_mask = False

for z_index in [-1, 0, 1]:
img = patient_mask[z_index + center_z]
start_x = int(start_x)
Expand All @@ -83,8 +87,10 @@ def filter_patient_nodules_predictions(df_nodule_predictions: pandas.DataFrame,
else:
if center_z < 30:
logging.info("Z < 30: ", patient_id, " center z:", center_z, " y_perc: ", y_perc)

if mal_score > 0:
mal_score *= -1

df_nodule_predictions.loc[index, "diameter_mm"] = mal_score

if (z_perc > 0.75 or z_perc < 0.25) and y_perc > 0.85:
Expand Down
45 changes: 37 additions & 8 deletions prediction/src/algorithms/identify/src/gtr123_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from os import path

import numpy as np
import torch

from scipy.special import expit
from src.preprocess import preprocess_ct, load_ct
from src.preprocess.extract_lungs import extract_lungs
from torch import nn
from torch.autograd import Variable

from config import Config
from src.preprocess import preprocess_ct, load_ct
from src.preprocess.extract_lungs import extract_lungs

""""
Detector model from team gtr123
Code adapted from https://github.com/lfz/DSB2017
Expand Down Expand Up @@ -96,6 +101,7 @@ def __init__(self):
num_blocks_back = [3, 3]
self.featureNum_forw = [24, 32, 64, 64, 64]
self.featureNum_back = [128, 64, 64]

for i in range(len(num_blocks_forw)):
blocks = []
for j in range(num_blocks_forw[i]):
Expand Down Expand Up @@ -130,10 +136,12 @@ def __init__(self):
nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True))

self.path2 = nn.Sequential(
nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True))

self.drop = nn.Dropout3d(p=0.2, inplace=False)
self.output = nn.Sequential(nn.Conv3d(self.featureNum_back[0], 64, kernel_size=1),
nn.ReLU(),
Expand Down Expand Up @@ -200,8 +208,8 @@ def __call__(self, output, thresh=-3, ismask=False):
output[:, :, :, :, 4] = np.exp(output[:, :, :, :, 4]) * anchors.reshape((1, 1, 1, -1))
mask = output[..., 0] > thresh
xx, yy, zz, aa = np.where(mask)

output = output[xx, yy, zz, aa]

if ismask:
return output, [xx, yy, zz, aa]
else:
Expand Down Expand Up @@ -232,8 +240,10 @@ def split(self, data, side_len=None, max_stride=None, margin=None):
"""
if side_len is None:
side_len = self.side_len

if max_stride is None:
max_stride = self.max_stride

if margin is None:
margin = self.margin

Expand All @@ -247,14 +257,14 @@ def split(self, data, side_len=None, max_stride=None, margin=None):
nz = int(np.ceil(float(z) / side_len))
nh = int(np.ceil(float(h) / side_len))
nw = int(np.ceil(float(w) / side_len))

nzhw = [nz, nh, nw]
self.nzhw = nzhw

pad = [[0, 0],
[margin, nz * side_len - z + margin],
[margin, nh * side_len - h + margin],
[margin, nw * side_len - w + margin]]

data = np.pad(data, pad, 'edge')

for iz in range(nz):
Expand Down Expand Up @@ -289,22 +299,27 @@ def combine(self, output, nzhw=None, side_len=None, stride=None, margin=None):

if side_len is None:
side_len = self.side_len

if stride is None:
stride = self.stride

if margin is None:
margin = self.margin

if nzhw is None:
nz = self.nz
nh = self.nh
nw = self.nw
else:
nz, nh, nw = nzhw

assert (side_len % stride == 0)
assert (margin % stride == 0)
side_len //= stride
margin //= stride

splits = []

for i in range(len(output)):
splits.append(output[i])

Expand All @@ -316,6 +331,7 @@ def combine(self, output, nzhw=None, side_len=None, stride=None, margin=None):
splits[0].shape[4]), np.float32)

idx = 0

for iz in range(nz):
for ih in range(nh):
for iw in range(nw):
Expand Down Expand Up @@ -349,18 +365,22 @@ def split_data(imgs, split_comber, stride=4):
pz = int(np.ceil(float(nz) / stride)) * stride
ph = int(np.ceil(float(nh) / stride)) * stride
pw = int(np.ceil(float(nw) / stride)) * stride

imgs = np.pad(imgs, [[0, 0], [0, pz - nz], [0, ph - nh], [0, pw - nw]], 'constant',
constant_values=split_comber.pad_value)

xx, yy, zz = np.meshgrid(np.linspace(-0.5, 0.5, imgs.shape[1] // stride),
np.linspace(-0.5, 0.5, imgs.shape[2] // stride),
np.linspace(-0.5, 0.5, imgs.shape[3] // stride), indexing='ij')

coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32')
imgs, nzhw = split_comber.split(imgs)

coord2, nzhw2 = split_comber.split(coord,
side_len=split_comber.side_len // stride,
max_stride=split_comber.max_stride // stride,
margin=int(split_comber.margin // stride))

assert np.all(nzhw == nzhw2)
imgs = (imgs.astype(np.float32) - 128) / 128
return torch.from_numpy(imgs), torch.from_numpy(coord2), np.array(nzhw)
Expand All @@ -387,6 +407,7 @@ def iou(box0, box1):
e1 = box1[:3] + r1

overlap = []

for i in range(len(s0)):
overlap.append(max(0, min(e0[i], e1[i]) - max(s0[i], s1[i])))

Expand All @@ -410,13 +431,16 @@ def nms(predictions, nms_th=0.05):

predictions = predictions[np.argsort(-predictions[:, 0])]
bboxes = [predictions[0]]

for i in np.arange(1, len(predictions)):
bbox = predictions[i]
flag = 1

for j in range(len(bboxes)):
if iou(bbox[1:5], bboxes[j][1:5]) >= nms_th:
flag = -1
break

if flag == 1:
bboxes.append(bbox)

Expand All @@ -439,14 +463,12 @@ def filter_lungs(image, spacing=(1, 1, 1), fill_value=170):
"""

mask = extract_lungs(image, spacing)

extracted = np.array(image)
extracted[np.logical_not(mask)] = fill_value

return extracted, mask


def predict(ct_path, model_path="src/algorithms/identify/assets/dsb2017_detector.ckpt"):
def predict(ct_path, model_path=None):
"""
Args:
Expand All @@ -458,13 +480,19 @@ def predict(ct_path, model_path="src/algorithms/identify/assets/dsb2017_detector
List of Nodule locations and probabilities
"""
if not model_path:
INDENTIFY_DIR = path.join(Config.ALGOS_DIR, 'identify')
model_path = path.join(INDENTIFY_DIR, 'assets', 'dsb2017_detector.ckpt')

ct_array, meta = load_ct.load_ct(ct_path)
meta = load_ct.MetaImage(meta)
spacing = np.array(meta.spacing)
masked_image, mask = filter_lungs(ct_array)

# masked_image = image
net = Net()
net.load_state_dict(torch.load(model_path)["state_dict"])

if torch.cuda.is_available():
net = torch.nn.DataParallel(net).cuda()

Expand All @@ -475,11 +503,13 @@ def predict(ct_path, model_path="src/algorithms/identify/assets/dsb2017_detector
# Transform image to the 0-255 range and resample to 1x1x1mm
preprocess = preprocess_ct.PreprocessCT(clip_lower=-1200., clip_upper=600., spacing=1., order=1,
min_max_normalize=True, scale=255, dtype='uint8')

ct_array, meta = preprocess(ct_array, meta)
ct_array = ct_array[np.newaxis, ...]

imgT, coords, nzhw = split_data(ct_array, split_comber=split_comber)
results = []

# Loop over the image chunks
for img, coord in zip(imgT, coords):
var = Variable(img[np.newaxis])
Expand Down Expand Up @@ -510,5 +540,4 @@ def predict(ct_path, model_path="src/algorithms/identify/assets/dsb2017_detector

# Rescale back to image space coordinates
proposals[:, 1:4] /= spacing[np.newaxis]

return [{"x": int(p[3]), "y": int(p[2]), "z": int(p[1]), "p_nodule": float(p[0])} for p in proposals]
Loading

0 comments on commit 03a6250

Please sign in to comment.