# DICOM CT scan plotting with interactive 3D plots

A notebook collecting some functions I used to explore some head CT scan images.

Input data is expected to be in DICOM format. As a series of scans that can be combined into one 3D model.

Most CT scan machines seem to produce such formats as output. Some seem to produce multiple series, which likely are in slightly different formats. Generally one of those series seems to work for this visualization.

This notebook is based on a notebook for lung CT scan image 3D modelling and exploration from a Kaggle notebook. Some modifications and tuning is done for my purposes. Thanks for the initial code though!

I think the original notebook I used as a basis is [here](https://www.kaggle.com/code/aravrs/3d-dicom-visualizations-with-interactive-plots/notebook). At least the competition was that one, the notebooks always spin and branch in copies so sorry if I got the wrong one. But from there I got the basics..

In [None]:
import pandas as pd
import numpy as np
import os
import pydicom

from glob import glob
import scipy
import scipy.ndimage
from skimage import measure
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from collections.abc import Iterable

from plotly.figure_factory import create_trisurf


In [None]:
def load_scan(path, reverse=True):
    slices = [pydicom.read_file(path + "/" + s) for s in os.listdir(path)]
    slices.sort(key=lambda x: int(x.InstanceNumber), reverse=reverse)

    if "SliceThickness" in slices[0]:
        slice_thickness = slices[0].SliceThickness
        #print(f"thickness: {slice_thickness}")
    else:
        print("thickness not found")
        try:
            #print(slices[0])
            slice_thickness = np.abs(
                slices[0].ImagePositionPatient[2] - slices[1].ImagePositionPatient[2]
            )
        except:
            slice_thickness = np.abs(slices[0].SliceLocation - slices[1].SliceLocation)

    for s in slices:
        s.SliceThickness = slice_thickness

    return slices

def get_pixels_hu(scans):
    image = np.stack([s.pixel_array for s in scans])
    image = image.astype(np.int16)
    image[image == -2000] = 0

    intercept = scans[0].RescaleIntercept
    slope = scans[0].RescaleSlope

    if slope != 1:
        image = slope * image.astype(np.float64)
        image = image.astype(np.int16)

    image += np.int16(intercept)
    return np.array(image, dtype=np.int16)


def resample(image, scan, new_spacing=[1, 1, 1]):
    spacing = map(float, ([scan[0].SliceThickness] + list(scan[0].PixelSpacing)))
    spacing = np.array(list(spacing))

    resize_factor = spacing / new_spacing
    new_real_shape = image.shape * resize_factor
    new_shape = np.round(new_real_shape)
    real_resize_factor = new_shape / image.shape
    new_spacing = spacing / real_resize_factor

    image = scipy.ndimage.interpolation.zoom(image, real_resize_factor)
    return image, new_spacing


def make_mesh(image, threshold=-300, step_size=1):
    p = image.transpose(2, 1, 0)
#    verts, faces, _, _ = measure.marching_cubes_lewiner(p, threshold, step_size=step_size, allow_degenerate=True)
    verts, faces, _, _ = measure.marching_cubes(p, threshold, step_size=step_size, allow_degenerate=True)
    return verts, faces


def largest_label_volume(im, bg=-1):
    vals, counts = np.unique(im, return_counts=True)
    counts = counts[vals != bg]
    vals = vals[vals != bg]
    if len(counts) > 0:
        return vals[np.argmax(counts)]
    else:
        return None

#this is called longs because of the original notebook I used as basis and the code there
def segment_lung_mask(image, fill_lung_structures=True):
    binary_image = np.array(image >= -700, dtype=np.int8) + 1
    labels = measure.label(binary_image)
    background_label = labels[0, 0, 0]
    binary_image[background_label == labels] = 2

    if fill_lung_structures:
        for i, axial_slice in enumerate(binary_image):
            axial_slice = axial_slice - 1
            labeling = measure.label(axial_slice)
            l_max = largest_label_volume(labeling, bg=0)

            if l_max is not None:
                binary_image[i][labeling != l_max] = 1
    binary_image -= 1
    binary_image = 1 - binary_image

    labels = measure.label(binary_image, background=0)
    l_max = largest_label_volume(labels, bg=0)
    if l_max is not None:
        binary_image[labels != l_max] = 0

    return binary_image

In [None]:
data_path = "DATADIR/PATIENT1/SERIES1"

#load_scan(data_path)


In [None]:
def print_image_info(slice0):
    address = slice0.InstitutionAddress
    device_manufacturer = slice0.Manufacturer
    device_model = slice0.ManufacturerModelName
    patient_age = slice0.PatientAge
    patient_birthday = slice0.PatientBirthDate
    patient_sotu = slice0.PatientID
    patient_name = slice0.PatientName
    patient_gender = slice0.PatientSex
    protocol = slice0.ProtocolName
    referring_physician = slice0.ReferringPhysicianName
    procedure = slice0.RequestedProcedureDescription
    station = slice0.StationName
    comments = ""
    if "StudyComments" in slice0:
        comments = slice0.StudyComments
    study_date = slice0.StudyDate
    timestamp = slice0.timestamp
    print(f"Address: {address}, {device_manufacturer} {device_model}")
    print(f"Patient: {patient_name}, {patient_gender} {patient_age} {patient_birthday} {patient_sotu}")
    print(f"Referrer: {referring_physician}, Study: {station}, {protocol}, {procedure}")
    print(f"Date: {study_date} {timestamp}, {comments}")
    
    

In [None]:
def plot_3d(data_path, reverse=False):
    print(f"{data_path.split('/')[-3].upper()} - {data_path.split('/')[-2]}")
    g = glob(data_path + "/*.dcm")
    print(f"Total of {len(g)} DICOM images.")

    patient = load_scan(data_path, reverse)
    print(f"Slice Thickness: {patient[0].SliceThickness}")
    #print(f"Pixel Spacing (row, col): ({patient[0].PixelSpacing[0]}, {patient[0].PixelSpacing[1]})")

    imgs = get_pixels_hu(patient)
    print(f"Shape resampling: {imgs.shape}", end="")
    imgs_after_resamp, spacing = resample(imgs, patient, [1, 1, 1])
    print(f" -> {imgs_after_resamp.shape}")

    v1, f1 = make_mesh(imgs_after_resamp, 350, 2)

    segmented_lungs = segment_lung_mask(imgs_after_resamp, fill_lung_structures=False)
    segmented_lungs_fill = segment_lung_mask(imgs_after_resamp, fill_lung_structures=True)
    internal_structures = segmented_lungs_fill - segmented_lungs
    p = internal_structures.transpose(2, 1, 0)
#    v2, f2, _, _ = measure.marching_cubes_lewiner(p)
    v2, f2, _, _ = measure.marching_cubes(p)

    ### PLOTS
    fig = plt.figure(figsize=(20, 10))
    bg = np.array((30, 39, 46))/255.0
    
    # Ext
    print(".", end="")
    x, y, z = zip(*v1)
    ax1 = fig.add_subplot(121, projection="3d")
    mesh = Poly3DCollection(v1[f1], alpha=0.8)
    face_color = (1, 1, 0.9)
    mesh.set_facecolor(face_color)
    ax1.add_collection3d(mesh)
    ax1.set_xlim(0, max(x))
    ax1.set_ylim(0, max(y))
    ax1.set_zlim(0, max(z))
    ax1.w_xaxis.set_pane_color((*bg, 1))
    ax1.w_yaxis.set_pane_color((*bg, 1))
    ax1.w_zaxis.set_pane_color((*bg, 1))

    # Int
    print(".", end="")
    x, y, z = zip(*v2)
    ax2 = fig.add_subplot(122, projection="3d")
    mesh = Poly3DCollection(v2[f2], alpha=0.8)
    face_color = np.array((255, 107, 107))/255.0
    mesh.set_facecolor(face_color)
    ax2.add_collection3d(mesh)
    ax2.set_xlim(0, max(x))
    ax2.set_ylim(0, max(y))
    ax2.set_zlim(0, max(z))
    ax2.w_xaxis.set_pane_color((*bg, 1))
    ax2.w_yaxis.set_pane_color((*bg, 1))
    ax2.w_zaxis.set_pane_color((*bg, 1))

    print(".", end="")
    fig.tight_layout()
    plt.show()

In [None]:
def show_img(img_path, colormap = None, extra_brightness=0):
    ds = pydicom.dcmread(img_path)
    shape = ds.pixel_array.shape
    target = 255

    # Convert to float to avoid overflow or underflow losses.
    image_2d = ds.pixel_array.astype(float)
    img_data = image_2d
    print(f"data min: {img_data.min()}, max: {img_data.max()}")
    print(f"window center: {ds.WindowCenter}, rescale intercept: {ds.RescaleIntercept}")
    multival = isinstance(ds.WindowCenter, Iterable)
    if multival:
        scale_center = -ds.WindowCenter[0]
    else:
        scale_center = -ds.WindowCenter
    intercept = scale_center+ds.RescaleIntercept+extra_brightness
    print(f"final intercept: {intercept}")
    image_2d += intercept
    print(f"after applying intercept, min: {image_2d.min()}, max: {image_2d.max()}")

    # Rescaling grey scale between 0-255
    image_2d_scaled = (np.maximum(image_2d,0) / image_2d.max()) * 255.0
    print(f"after scaling to 0-255, min: {image_2d_scaled.min()}, max: {image_2d_scaled.max()}")

    # Convert to uint
    image_2d_scaled = np.uint8(image_2d_scaled)

    plt.figure(figsize=(12,8))
    plt.imshow(image_2d_scaled, cmap=colormap)
    plt.show()

In [None]:
files = [data_path + "/" + s for s in os.listdir(data_path)]
files.sort()
len(files)

In [None]:
show_img(files[360], colormap=plt.cm.bone, extra_brightness=300)


In [None]:
plot_3d(data_path)

In [None]:
### 3D interactive plotting helper
def plotly_3d(verts, faces, ext=True):
    x, y, z = zip(*verts)

    fig = create_trisurf(
        x=x,
        y=y,
        z=z,
        plot_edges=False,
        show_colorbar=False,
        showbackground=False,
        colormap=["rgb(236, 236, 212)", "rgb(236, 236, 212)"] if ext else ["rgb(255, 107, 107)", "rgb(255, 107, 107)"],
        simplices=faces,
        backgroundcolor="rgb(30, 39, 46)",
        gridcolor="rgb(30, 39, 46)",
        title="<b>Interactive Visualization</b>",
    )
    fig.layout.template = "plotly_dark"  # for dark theme 
    fig.show()

In [None]:
### Plotting functions

def plot3d_interactive_ext(data_path, threshold=350, reverse=False):
    print(f"{data_path.split('/')[-3].upper()} - {data_path.split('/')[-2]}")
    g = glob(data_path + "/*.dcm")
    patient = load_scan(data_path, reverse)
    print_image_info(patient[0])
    imgs = get_pixels_hu(patient)
#    imgs[imgs>200] = 0
    imgs_after_resamp, spacing = resample(imgs, patient, [1, 1, 1])

#    v, f = make_mesh(imgs_after_resamp, 0, 2)
    v, f = make_mesh(imgs_after_resamp, threshold, 2)
    plotly_3d(v, f)

def plot3d_interactive_int(data_path, reverse=False):
    print(f"{data_path.split('/')[-3].upper()} - {data_path.split('/')[-2]}")
    g = glob(data_path + "/*.dcm")
    patient = load_scan(data_path, reverse)
    print_image_info(patient[0])
    imgs = get_pixels_hu(patient)
    imgs_after_resamp, spacing = resample(imgs, patient, [1, 1, 1])

    segmented_lungs = segment_lung_mask(imgs_after_resamp, fill_lung_structures=False)
    segmented_lungs_fill = segment_lung_mask(imgs_after_resamp, fill_lung_structures=True)
    internal_structures = segmented_lungs_fill - segmented_lungs

    p = internal_structures.transpose(2, 1, 0)
    verts, faces, _, _ = measure.marching_cubes_lewiner(p)
    plotly_3d(verts, faces, ext=False)

In [None]:
#reverse=True if the scanner data shows "upside-down" -> conerts it to right direction
plot3d_interactive_ext(data_path, reverse=True) 


In [None]:
#very low threshold such as -100 here can visualize softer tissue, including skin
#defaults used in this notebook (e.g., 300) just show the skull and other harder parts
plot3d_interactive_ext(data_path, threshold=-100, reverse=True)


# Threshold Plotting

To see how the images look at different thresholds. Each threshold shoudl visualize the CT scans from different intensity values. Lower intensity response on scan = softer tissue. E.g., skin/muscle vs bone.

In [None]:
def plot_3d(data_path, reverse=False):
    print(f"{data_path.split('/')[-3].upper()} - {data_path.split('/')[-2]}")
    g = glob(data_path + "/*.dcm")
    print(f"Total of {len(g)} DICOM images.")

    patient = load_scan(data_path, reverse)
    print(f"Slice Thickness: {patient[0].SliceThickness}")
    #print(f"Pixel Spacing (row, col): ({patient[0].PixelSpacing[0]}, {patient[0].PixelSpacing[1]})")

    imgs = get_pixels_hu(patient)
    print(f"Shape resampling: {imgs.shape}", end="")
    imgs_after_resamp, spacing = resample(imgs, patient, [1, 1, 1])
    print(f" -> {imgs_after_resamp.shape}")

    v1, f1 = make_mesh(imgs_after_resamp, 350, 2)
    print(type(v1[0]))
    print(type(f1[0]))

    ### PLOTS
    fig = plt.figure(figsize=(20, 10))
    bg = np.array((30, 39, 46))/255.0
    
    # Ext
    print(".", end="")
    x, y, z = zip(*v1)
    ax1 = fig.add_subplot(121, projection="3d")
#    mesh = Poly3DCollection(v1[f1], alpha=0.5)
    mesh = Poly3DCollection(v1[f1], alpha=1.0)
    face_color = (1, 1, 0.9)
    mesh.set_facecolor(face_color)
    ax1.add_collection3d(mesh)
    ax1.set_xlim(0, max(x))
    ax1.set_ylim(0, max(y))
    ax1.set_zlim(0, max(z))
    ax1.w_xaxis.set_pane_color((*bg, 1))
    ax1.w_yaxis.set_pane_color((*bg, 1))
    ax1.w_zaxis.set_pane_color((*bg, 1))
    #ax1.view_init(-140, 30)
    ax1.view_init(20, -40)
    
    # Int
    print(".", end="")
#
    print(".", end="")
    fig.tight_layout()
    plt.show()

In [None]:
#data_path = ... #use other data path if want to look at specific study images


In [None]:
### 3D interactive ploting helper
def plotly_3d(verts, faces, threshold, ext=True):
    x, y, z = zip(*verts)

    fig = create_trisurf(
        x=x,
        y=y,
        z=z,
        plot_edges=False,
        show_colorbar=False,
        showbackground=False,
        colormap=["rgb(236, 236, 212)", "rgb(236, 236, 212)"] if ext else ["rgb(255, 107, 107)", "rgb(255, 107, 107)"],
        simplices=faces,
        backgroundcolor="rgb(30, 39, 46)",
        gridcolor="rgb(30, 39, 46)",
        title=f"<b>Interactive Visualization: {threshold}</b>",
    )
    fig.layout.template = "plotly_dark"  # for dark theme 
    camera = dict(
#        eye=dict(x=30, y=2, z=0.01)
        #1.25 defaults
        eye=dict(x=1, y=-1, z=0)
    )

    fig.update_layout(scene_camera=camera)
    fig.show()

In [None]:
def create_3d_interactive_ext(data_path, threshold=350, reverse=False, quiet=False):
    #print(f"{data_path.split('/')[-3].upper()} - {data_path.split('/')[-2]}")
    g = glob(data_path + "/*.dcm")
    patient = load_scan(data_path, reverse)
    if not quiet:
        print_image_info(patient[0])
    imgs = get_pixels_hu(patient)
#    imgs[imgs>200] = 0
    imgs_after_resamp, spacing = resample(imgs, patient, [1, 1, 1])

#    v, f = make_mesh(imgs_after_resamp, 0, 2)
    v, f = make_mesh(imgs_after_resamp, threshold, 2)
    return v,f


In [None]:
def plot_3d_interactive_ext(v, f, threshold):
    plotly_3d(v, f, threshold)

In [None]:
v, f = create_3d_interactive_ext(data_path, threshold=100)


In [None]:
plot_3d_interactive_ext(v, f, 100)

In [None]:
from tqdm import tqdm
tqdm.pandas()

vfs = {}
#plot thresholds from -200 to 1600, at intervals of 200. so -200, 0, 200, 400, ..., 1600
for threshold in tqdm(range(-200, 1601, 200)):
    #print(threshold)
    v, f = create_3d_interactive_ext(data_path, threshold=threshold, quiet=True)
    vfs[threshold] = (v,f)

In [None]:
figs = {}

for threshold in vfs.keys():
    x, y, z = zip(*vfs[threshold][0])
    f = vfs[threshold][1]
    print(f"threshold: {threshold}")

    fig = create_trisurf(
        x=x,
        y=y,
        z=z,
        plot_edges=False,
        show_colorbar=False,
        showbackground=False,
        colormap=["rgb(236, 236, 212)", "rgb(236, 236, 212)"],
        simplices=f,
        backgroundcolor="rgb(30, 39, 46)",
        gridcolor="rgb(30, 39, 46)",
        title=f"<b>Interactive Visualization: {threshold}</b>",
    )
    fig.layout.template = "plotly_dark"  # for dark theme 
    camera = dict(
        eye=dict(x=1, y=-1, z=0)
    )

    fig.update_layout(scene_camera=camera)
    figs[threshold] = (fig)


In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

#this was an attempt to build an interactive selector of thresholds into the notebook
#it (almost) worked but lots of issues in exporting to HTML etc. so I disabled it
def f(x):
    print(f"threshold: {x}")
    figs[x].show()
    return x


In [None]:
min_threshold = min(figs.keys())
max_threshold = max(figs.keys())

In [None]:
interact(f, x=10);

In [None]:
#interact(f, x=widgets.IntSlider(min=min_threshold, max=max_threshold, step=50, value=350));

In [None]:
#interact(f, x=widgets.IntSlider(min=0, max=10, value=1))


In [None]:
thresholds = list(figs.keys())
thresholds.sort()
thresholds

In [None]:
#this finally builds all the plots at the selected thresholds and plots them in the notebook
#doing this too much seems to crash the notebook, and its size gets too big with the output images to write to disk
#but it is nice or a quick view in the notebook
for x in thresholds:
    figs[x].show()
    #time.sleep(2)