In [1]:
# Functions for processing STL and CT files and saving cross-sectional images of GTVs for clinician review

In [1]:
# Import Statements
import os
import sys
import numpy as np
import matplotlib.pyplot as plt, mpld3
from scipy import ndimage
import SimpleITK as sitk
import trimesh
import math
import meshcut
from shapely.geometry import Polygon
from shapely.ops import transform
import random

In [2]:
def get_all_CT_paths():    
    # Path to patient folders
    path = ""

    # List of pts 
    pt_list = next(os.walk(path))[1]
    
    all_ct_paths = []
    no_GTVs = []
    mhd_issue = []
    rawgz_issue = []
    lung_issue = []
    files_num_mismatch = []
    num_GTVs = 0
    
    for pt in pt_list:

        pt_path = path + "/" + pt

        # List of cts for each patient
        ct_list = next(os.walk(pt_path))[1]
        for ct in ct_list:

            ct_path = pt_path + "/" + ct
            all_ct_paths.append(ct_path)

            # Items in each CT folder 
            item_list = next(os.walk(ct_path))[2]
            
            # Check that there is at least one GTV file in the item list 
            GTVs = [s for s in [x.lower() for x in item_list] if "gtv" in s]
            num_GTVs += len(GTVs)
            if len(GTVs) < 1:
                no_GTVs.append(ct_path)

            # Check that there is one mhd file and one raw.gz file 
            mhd = [s for s in [x.lower() for x in item_list] if "mhd" in s]
            rawgz = [s for s in [x.lower() for x in item_list] if "raw.gz" in s]

            if len(mhd) != 1:
                mhd_issue.append(ct_path)

            if len(rawgz) != 1:
                rawgz_issue.append(ct_path)

            # Check if there are lung contours
            lungs = [s for s in [x.lower() for x in item_list] if "lung" in s]
            lungs_gtv = [s for s in [x.lower() for x in lungs] if "gtv" in s]

            if len(lungs) < 2:
                lung_issue.append(ct_path)

            if len(lungs) > 2:
                if len(lungs) - len(lungs_gtv) > 2:
                    lung_issue.append(ct_path)

            # Check that lung contour + GTV files + mhd + raw.gz files add to len(item list)
            if len(lungs) - len(lungs_gtv) + len(GTVs) + len(mhd) + len(rawgz) != len(item_list):
                files_num_mismatch.append(ct_path)

    
    print(len(all_ct_paths))

    print(len(no_GTVs)) #CT folders without any GTVs
    for folder in no_GTVs:
        print(folder)
        print(next(os.walk(folder))[2])

    print(len(mhd_issue)) #CT folders without any mhd file
    for folder in mhd_issue:
        print(folder)
        print(next(os.walk(folder))[2])

    print(len(rawgz_issue)) #CT folders without any raw.gz file
    for folder in rawgz_issue:
        print(folder)
        print(next(os.walk(folder))[2])

    print(len(lung_issue)) #CT folders without lung contours or with extra lung contours
    for folder in lung_issue:
        print(folder)
        print(next(os.walk(folder))[2])

    print(len(files_num_mismatch)) # inconsistent number of files 
    for folder in files_num_mismatch:
        print(folder)
        print(next(os.walk(folder))[2])
    

    # From all_ct_paths, remove ct folders with no GTV, no mhd, or no rawgz siles -- will only use remaining folders for subsequent work
    # Leaving folder paths with no/too many lung contours, since this is not important for GTV_screening
    remove_paths = no_GTVs + mhd_issue + rawgz_issue
    remove_paths = list(set(remove_paths))

    processed_ct_paths = [x for x in all_ct_paths if x not in remove_paths]
    
    return processed_ct_paths

In [4]:
def readCT(image_file):
    # Read CT image and attributes
    ct = sitk.ReadImage(image_file, sitk.sitkFloat32)
    ct_array = sitk.GetArrayFromImage(ct)
    
    # Read CT information
    ct_dimensions = ct.GetSize()
    ct_spacings = ct.GetSpacing()
    ct_offset = ct.GetOrigin()
    
    x_offset = ct_offset[0]
    y_offset = ct_offset[1]
    z_offset = ct_offset[2]
    
    x_spacing = ct_spacings[0]
    y_spacing = ct_spacings[1]
    z_spacing = ct_spacings[2]
    
    # Prepare axial, coronal, sagittal aspect ratios
    z_aspect = x_spacing/y_spacing
    x_aspect = x_spacing/z_spacing
    y_aspect = y_spacing/z_spacing
    
    return ct_array, ct_dimensions, [x_offset, y_offset, z_offset], [x_spacing, y_spacing, z_spacing], [x_aspect, y_aspect, z_aspect]

In [5]:
def readGTVstl(contour_file):
    GTV = trimesh.load(contour_file)
    
    # If multiple bodies exist in mesh file, split mesh
    GTV_bodies = GTV.split()
    #print(GTV_bodies)
    
    # Check volume of mesh bodies
    # Determine if volume of body is < 65 mm^3 (screen out contours around lesions with diameter of less than 5mm)
    for body in GTV_bodies:
        if body.volume < 65:
            GTV_bodies = np.delete(GTV_bodies, np.argwhere(GTV_bodies==body))
    
            
    # Get centroids of each body and corresponding cross sections in x, y, z plane
    GTV_cross_sections = []
    for body in GTV_bodies:
        centroid = body.centroid
        #print(centroid)
        
        z_cut = np.array(meshcut.cross_section(body.vertices, body.faces, (0,0,centroid[2]), (0,0,1)))
        x_cut = np.array(meshcut.cross_section(body.vertices, body.faces, (centroid[0],0,0), (1,0,0)))
        y_cut = np.array(meshcut.cross_section(body.vertices, body.faces, (0,centroid[1],0), (0,1,0)))
        
        GTV_cross_sections.append([centroid, z_cut, x_cut, y_cut])
    
    # Sanity check:
    if len(GTV_bodies) != len(GTV_cross_sections):
        raise Exception("More cross sections that there are bodies")
    
    # This just converts GTV_cross_sections into a numpy array that's easier to work with
    GTV_cross_sections = np.array([np.array(GTV_cross_sectionsi) for GTV_cross_sectionsi in GTV_cross_sections])
    
    return GTV_cross_sections

In [18]:
def dispCentroid(GTV_cross_section, ct_array, ct_dimensions, offsets, spacings, aspects, crop = False):
    # CT information
    x_offset = offsets[0]
    y_offset = offsets[1]
    z_offset = offsets[2]
    
    x_spacing = spacings[0]
    y_spacing = spacings[1]
    z_spacing = spacings[2]
    
    x_aspect = aspects[0]
    y_aspect = aspects[1]
    z_aspect = aspects[2]
    
    # Centroid
    centroid = GTV_cross_section[0]
    
    # Cross sections of GTV
    z_cut = GTV_cross_section[1]
    x_cut = GTV_cross_section[2]
    y_cut = GTV_cross_section[3]
    
    # Indices for CT image views
    index_x = int(np.round((-centroid[0]-x_offset)/x_spacing))
    index_y = int(np.round((-centroid[1]-y_offset)/y_spacing))
    index_z = int(np.round((centroid[2]-z_offset)/z_spacing))
        
    # Crop lengths
    if crop:
        crop_cm_x = int(100/x_spacing)
        crop_cm_y = int(100/y_spacing)
        crop_cm_z = int(100/z_spacing)
        
        if index_x < crop_cm_x:
            crop_cm_x = index_x
        if index_y < crop_cm_y:
            crop_cm_y = index_y
        if index_z < crop_cm_z:
            crop_cm_z = index_z
        
    fig, ax = plt.subplots(2, 3, figsize=(15,10))
    
    # Axial image
    ### Extends axial image to make it same relative length and width as sag/coronal images
    extend = int(((1/((z_spacing*ct_dimensions[2])/(y_spacing*ct_dimensions[1]))*ct_dimensions[0])-ct_dimensions[0])/2)
    if extend < 0:
        extend = 0 
    
    
    axial_extended = np.concatenate([np.ones((ct_array[index_z,:,:].shape[0],extend),dtype=ct_array[index_z,:,:].dtype), ct_array[index_z,:,:], np.ones((ct_array[index_z,:,:].shape[0],extend),dtype=ct_array[index_z,:,:].dtype)], axis=1)
    ct_x_center = axial_extended.shape[1]/2
    
    if crop:
        axial_extended = axial_extended[index_y-crop_cm_y:index_y+crop_cm_y, index_x-crop_cm_x+extend:index_x+crop_cm_x+extend]
    
    ax[0,0].imshow(axial_extended, cmap="binary_r", interpolation="bilinear", vmin=-1100, vmax=0.75*1900)
    ax[1,0].imshow(axial_extended, cmap="binary_r", interpolation="bilinear", vmin=700, vmax=1400)
    
    
    for polygon in z_cut:
        polygon[:,0] = (-polygon[:,0]-x_offset)/x_spacing
        polygon[:,1] = (-polygon[:,1]-y_offset)/y_spacing
        polygon[:,2] = (polygon[:,2]-z_offset)/z_spacing
        
        closed_polygon = Polygon(zip(polygon[:,0], polygon[:,1]))
        x,y = closed_polygon.exterior.xy
        
        if crop:
            ax[0,0].plot(np.array(x)-(index_x-crop_cm_x),np.array(y)-(index_y-crop_cm_y), color="red", linewidth=0.75)
            ax[1,0].plot(np.array(x)-(index_x-crop_cm_x),np.array(y)-(index_y-crop_cm_y), color="red", linewidth=0.75)
        else:
            ax[0,0].plot(np.array(x)+extend,np.array(y), color="red", linewidth=0.75)
            ax[1,0].plot(np.array(x)+extend,np.array(y), color="red", linewidth=0.75)
    
    # Coronal image
    coronal = ndimage.rotate(ct_array[:,index_y,:],0)
    
    if crop:
        coronal = coronal[index_z-crop_cm_z:index_z+crop_cm_z,index_x-crop_cm_x:index_x+crop_cm_x]
        
    ax[0,1].imshow(coronal, cmap="binary_r", interpolation="bilinear", vmin=-1100, vmax=0.75*1900)
    ax[1,1].imshow(coronal, cmap="binary_r", interpolation="bilinear", vmin=700, vmax=1400)
    
    for polygon in y_cut:
        polygon[:,0] = (-polygon[:,0]-x_offset)/x_spacing
        polygon[:,1] = (-polygon[:,1]-y_offset)/y_spacing
        polygon[:,2] = (polygon[:,2]-z_offset)/z_spacing
        
        closed_polygon = Polygon(zip(polygon[:,0], polygon[:,2]))
        x,z = closed_polygon.exterior.xy
        
        if crop:
            ax[0,1].plot(np.array(x)-(index_x-crop_cm_x), np.array(z)-(index_z-crop_cm_z), color = "red", linewidth=0.75)
            ax[1,1].plot(np.array(x)-(index_x-crop_cm_x), np.array(z)-(index_z-crop_cm_z), color = "red", linewidth=0.75)
        else:
            ax[0,1].plot(np.array(x), np.array(z), color = "red", linewidth=0.75)
            ax[1,1].plot(np.array(x), np.array(z), color = "red", linewidth=0.75)

    ax[0,1].invert_yaxis()
    ax[0,1].set_aspect(1/x_aspect)
    
    ax[1,1].invert_yaxis()
    ax[1,1].set_aspect(1/x_aspect)
    
    # Sagittal image
    sagittal = ndimage.rotate(ct_array[:,:,index_x],0)
    
    if crop:
        sagittal = sagittal[index_z-crop_cm_z:index_z+crop_cm_z,index_y-crop_cm_y:index_y+crop_cm_y]
    
    ax[0,2].imshow(sagittal, cmap="binary_r", interpolation="bilinear", vmin=-1100, vmax=0.75*1900)
    ax[1,2].imshow(sagittal, cmap="binary_r", interpolation="bilinear", vmin=700, vmax=1400)
    
    for polygon in x_cut:
        polygon[:,0] = (-polygon[:,0]-x_offset)/x_spacing
        polygon[:,1] = (-polygon[:,1]-y_offset)/y_spacing
        polygon[:,2] = (polygon[:,2]-z_offset)/z_spacing
        
        closed_polygon = Polygon(zip(polygon[:,1], polygon[:,2]))
        y,z = closed_polygon.exterior.xy
        
        if crop:
            ax[0,2].plot(np.array(y)-(index_y-crop_cm_y), np.array(z)-(index_z-crop_cm_z), color = "red", linewidth=0.75)
            ax[1,2].plot(np.array(y)-(index_y-crop_cm_y), np.array(z)-(index_z-crop_cm_z), color = "red", linewidth=0.75)
        else:
            ax[0,2].plot(np.array(y), np.array(z), color = "red", linewidth=0.75)
            ax[1,2].plot(np.array(y), np.array(z), color = "red", linewidth=0.75)
    
    ax[0,2].invert_yaxis()
    ax[0,2].set_aspect(1/y_aspect)
    
    ax[1,2].invert_yaxis()
    ax[1,2].set_aspect(1/y_aspect)
    
    for axis, ylabel in zip(ax[:,0],['Lung Window','Mediastinal Window']):
        axis.set_ylabel(ylabel)
    
    return fig, ax, centroid

In [19]:
def saveGTV(CTs, crop = False):
    
    total = 0
    savedGTVs = []
  
    many_cross_sec = []
    cross_sec_fail = []
    
    for ct in CTs:
        
        # All files in one CT path
        item_list = next(os.walk(ct))[2]
        
        # MHD file in CT path
        mhd = [s for s in item_list if "mhd" in s.lower()]
        
        # Sanity check to make sure there is only on mhd file
        if len(mhd) > 1:
            raise Exception("More than one mhd file")
        
        # Read CT and get GTV names
        mhd_path = ct + "/" + str(mhd[0])
        ct_array, ct_dimensions, offsets, spacings, aspects = readCT(mhd_path)
        GTVs = [s for s in item_list if "gtv" in s.lower()]
        
        
        for gtv in GTVs:
            total += 1
            gtv_path = ct + "/" + str(gtv)
            
            path_id = gtv_path.replace("/home/kbush/ImageStructureExtract/", "")
            path_array = path_id.split("/")
            pt_id = path_array[0]
            ct_id = path_array[1]
            gtv_id = path_array[2]

            try:
                GTV_cross_sections = readGTVstl(gtv_path)
                
                # When there are more than 9 volumes in a GTV, display these together in clusters of 5
                if len(GTV_cross_sections) > 9:

                    many_cross_sec.append(gtv_path) ##Decided to exclude these contours            

                # If less than 9 volumes in a GTV, display them separately
                else:        
                    for cross_id, GTV_cross_section in enumerate(GTV_cross_sections):

                        fig, ax, centroid = dispCentroid(GTV_cross_section, ct_array, ct_dimensions, offsets, spacings, aspects, crop)
                        
                        # Save GTV screening image
                        save_name = pt_id + "__" + ct_id + "__" + gtv_id.replace(".stl", "") + "__" + "vol" + str(cross_id+1) +".png"
                        print(save_name)
                        plt.savefig("" + save_name) #path to folder where GTV cross sectional images will be saved
                        plt.close(fig)
                        savedGTVs.append([gtv_path, pt_id, ct_id, gtv_id, cross_id+1, centroid, save_name])
                                            
            except:
                cross_sec_fail.append(gtv_path)
                print("COULDN'T PROCESS GTV!")
                print(sys.exc_info()[0])
                pass
    
    print(total)
    return savedGTVs, many_cross_sec, cross_sec_fail

In [None]:
# Save info about all GTVs for which images were generated for clinician review
CTs = get_all_CT_paths()
crop = True
savedGTVs, many_cross_sec, cross_sec_fail = saveGTV(CTs, crop)
print(many_cross_sec)
print(cross_sec_fail)

In [10]:
# Save output
with open('savedGTVs.txt', 'w') as filehandle:
    for listitem in savedGTVs:
        filehandle.write('%s\n' % listitem)

with open('many_cross_sec.txt', 'w') as filehandle:
    for listitem in many_cross_sec:
        filehandle.write('%s\n' % listitem)
        
with open('cross_sec_fail.txt', 'w') as filehandle:
    for listitem in cross_sec_fail:
        filehandle.write('%s\n' % listitem)