In [26]:
from pathlib import Path
import pydicom
import matplotlib.pyplot as plt
import numpy as np
import ipywidgets as widgets
from IPython.display import display
import pandas as pd
import ast

In [53]:
DF_LOC = pd.read_csv("data/train_localizers.csv")
DF_TRAIN = pd.read_csv("data/train.csv")

DF_LOC[['x', 'y']] = DF_LOC['coordinates'].apply(
    lambda s: pd.Series(ast.literal_eval(s))[['x', 'y']]
)
DF_LOC = DF_LOC.merge(
    DF_TRAIN[['SeriesInstanceUID', 'Modality']],
    on='SeriesInstanceUID',
    how='left',           # keep every bbox even if modality is missing
    validate='many_to_one'  # each SeriesInstanceUID maps to exactly one Modality
)

In [54]:
# Src https://www.kaggle.com/code/redwankarimsony/visualizing-and-analyzing-dicoms-in-python
def load_series(slices):
    vol = np.stack([s.pixel_array for s in slices])
    vol = vol.astype(np.int16)
    
    slope  = float(getattr(slices[0], "RescaleSlope", 1.0))
    intercept  = float(getattr(slices[0], "RescaleIntercept", 0.0))
    
    if slope != 1:
        vol = slope * vol.astype(np.float64)
        vol = vol.astype(np.int16)
        
    vol += np.int16(intercept)
    
    return np.array(vol, dtype=np.int16)

def load_scan(paths):
    slices = [pydicom.dcmread(path) for path in paths]
    slices.sort(key = lambda x: int(x.InstanceNumber), reverse = True)
    try:
        slice_thickness = np.abs(slices[0].ImagePositionPatient[2] - slices[1].ImagePositionPatient[2])
    except:
        slice_thickness = np.abs(slices[0].SliceLocation - slices[1].SliceLocation)
        
    for s in slices:
        s.SliceThickness = slice_thickness
        
    return slices

In [None]:
class PatientSeries:
    
    def __init__(self, series_path: Path):
        self.series_instance_id = series_path.name
        
        self.dcm_paths = list(series_path.glob("*.dcm"))
        # self.sop_instance_uids = [path.stem for path in self.dcm_paths]

        self.dicoms = load_scan(self.dcm_paths)
        self.sop_instance_uids = [dcm.SOPInstanceUID for dcm in self.dicoms]
        self.modality = self.dicoms[0].Modality
        self.pixels = load_series(self.dicoms)
        
        if self.modality == "CTA":
            # Set outside-of-scan pixels to 0
            # The intercept is usually -1024, so air is approximately 0
            self.pixels[self.pixels == -2000] = 0
            
        self.df_loc_series = DF_LOC[DF_LOC.SeriesInstanceUID == self.series_instance_id]
        self.bboxes = (
            {
                row.SOPInstanceUID: {'x': row.x, 'y': row.y, 'class': row.location}
                for _, row in self.df_loc_series.iterrows()
            }
            if not self.df_loc_series.empty else None
        )
        


In [49]:
# We want viewer to take in also modality information, for vmin etc.
def series_viewer(patient: PatientSeries):
    scans = patient.pixels
    instance_uids = patient.sop_instance_uids
    n = len(scans)

    fig, ax = plt.subplots(figsize=(5,5))
    plt.close(fig)

    if patient.modality == "CT":
        WC, WW = 300, 600   # vascular window
        vmin = WC - WW/2    # ≈ 0 HU
        vmax = WC + WW/2    # ≈ 600 HU
    else:
        vmin, vmax = np.min(scans[0]), np.max(scans[0])
        
    im = ax.imshow(scans[0], cmap=plt.cm.bone, vmin=vmin, vmax=vmax)
    ax.axis('off')

    slider = widgets.IntSlider(0, 0, len(scans)-1, 1, description='Slice')
    prev_btn = widgets.Button(description='◀ Prev')
    next_btn = widgets.Button(description='Next ▶')
    idx_lbl  = widgets.HTML(f"<b>1</b> / {n}")
    controls = widgets.HBox([prev_btn, next_btn, slider, idx_lbl])

    out = widgets.Output()

    def update_view(i):
        im.set_data(scans[i])
        ax.set_title(f"SOP Instance UID: {instance_uids[i]}", fontsize="10")  # optional
        
        # rescale clim to this slice’s range
        if patient.modality != "CT":
            im.set_clim(vmin=float(np.min(scans[i])), vmax=float(np.max(scans[i])))
        
        # Always remove previous rectangles BEFORE maybe adding a new one
        for p in list(ax.patches):
            p.remove()
        

        if instance_uids[i] in patient.bboxes:
            bbox = patient.bboxes[instance_uids[i]]
            rect = plt.Rectangle((bbox['x']-5, bbox['y']-5), 10, 10,
                             linewidth=1, edgecolor='r', facecolor='none')
            ax.add_patch(rect)
            
        # im.set_clim(vmin=float(np.min(imgs[i])), vmax=float(np.max(imgs[i])))
        # or: im.norm = Normalize(vmin=np.min(imgs[i]), vmax=np.max(imgs[i]))
        with out:
            out.clear_output(wait=True)
            display(fig)
        idx_lbl.value = f"<b>{i+1}</b> / {n}"

    def on_slider(change):
        if change["name"] == "value":
            update_view(change["new"])
            
    def on_prev(_):
        if slider.value > slider.min:
            slider.value -= 1

    def on_next(_):
        if slider.value < slider.max:
            slider.value += 1
            


    slider.observe(on_slider)
    prev_btn.on_click(on_prev)
    next_btn.on_click(on_next)
    
    display(controls, out)
    update_view(0)  # initial draw

In [None]:
series_path = Path("data/series/1.2.826.0.1.3680043.8.498.13789305723712362238118274295587312089/")
patient = PatientSeries(series_path)
series_viewer(patient)

HBox(children=(Button(description='◀ Prev', style=ButtonStyle()), Button(description='Next ▶', style=ButtonSty…

Output()

In [52]:
series_path = Path("data/series/1.2.826.0.1.3680043.8.498.49640345168968922611291772802640560828")
patient = PatientSeries(series_path)
series_viewer(patient)

HBox(children=(Button(description='◀ Prev', style=ButtonStyle()), Button(description='Next ▶', style=ButtonSty…

Output()

In [57]:
series_uid = "1.2.826.0.1.3680043.8.498.13789305723712362238118274295587312089"
seg_path_1 = Path(f"data/segmentations/{series_uid}.nii")
seg_path_2 = Path(f"data/segmentations/{series_uid}_cowseg.nii")

In [61]:
import nibabel as nib

img = nib.load(seg_path_2)   # or "your_image.nii.gz"
data = img.get_fdata()  

In [None]:
import ipywidgets as widgets
from IPython.display import display

def explore_slices(volume):
    z_slider = widgets.IntSlider(min=0, max=volume.shape[2]-1, step=1, value=volume.shape[2]//2)
    out = widgets.Output()

    def update(z):
        with out:
            out.clear_output(wait=True)
            plt.imshow(volume[:, :, z].T, cmap="gray", origin="lower")
            plt.axis("off")
            plt.show()
    widgets.interactive(update, z=z_slider)
    display(z_slider, out)

explore_slices(data)

IntSlider(value=37, description='z', max=73)

Output()