In [None]:
import os
import sys
from collections import defaultdict

import pydicom
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
from scipy.ndimage.morphology import binary_closing, binary_erosion  # Morphological operator

MRI_FRAMES = 50
MRI_MIN_RADIUS = 2
MRI_MAX_MYOCARDIUM = 20
MRI_BIG_RADIUS_FACTOR = 0.9
MRI_SMALL_RADIUS_FACTOR = 0.19
MRI_SEGMENTED_CHANNEL_MAP = {'background': 0, 'ventricle': 1, 'myocardium': 2}

In [None]:
!mkdir ./dcm_scratch
!rm ./dcm_scratch/*

!cp /mnt/ml4cvd/projects/bulk/cardiac_mri/2467677_20209_2_0.zip ./dcm_scratch/
!unzip ./dcm_scratch/2467677_20209_2_0.zip -d ./dcm_scratch/

In [None]:
dcm_dir = './dcm_scratch/'
series = defaultdict(list)
for dcm_file in os.listdir(dcm_dir):
    if not dcm_file.endswith('.dcm'):
        continue
    dcm = pydicom.read_file(dcm_dir + dcm_file)
    if 'cine_segmented_sax_inlinevf' == dcm.SeriesDescription.lower():
        cur_angle = (dcm.InstanceNumber - 1) // MRI_FRAMES 
        series[cur_angle].append(dcm)
print('len is:' , len(series))
for k in series:
    print(f'b series {k} has {len(series[k])} instances')

In [None]:
MRI_MIN_RADIUS = 2
MRI_MAX_MYOCARDIUM = 20
MRI_BIG_RADIUS_FACTOR = 0.9
MRI_SMALL_RADIUS_FACTOR = 0.19
MRI_SEGMENTED_CHANNEL_MAP = {'background': 0, 'ventricle': 1, 'myocardium': 2}
def _is_mitral_valve_segmentation(d) -> bool:
    return d.SliceThickness == 6

def _get_overlay_from_dicom(d, debug=False):
    """Get an overlay from a DICOM file

    Morphological operators are used to transform the pixel outline of the myocardium
    to the labeled pixel masks for myocardium and left ventricle

    Arguments
        d: the dicom file
        stats: Counter to keep track of summary statistics

    Returns
        Tuple of two numpy arrays.
        The first is the raw overlay array with myocardium outline,
        The second is a pixel mask with 0 for background 1 for myocardium and 2 for ventricle
    """
    i_overlay = 0
    dicom_tag = 0x6000 + 2 * i_overlay
    overlay_raw = d[dicom_tag, 0x3000].value
    rows = d[dicom_tag, 0x0010].value  # rows = 512
    cols = d[dicom_tag, 0x0011].value  # cols = 512
    overlay_frames = d[dicom_tag, 0x0015].value
    bits_allocated = d[dicom_tag, 0x0100].value

    np_dtype = np.dtype('uint8')
    length_of_pixel_array = len(overlay_raw)
    expected_length = rows * cols
    if bits_allocated == 1:
        expected_bit_length = expected_length
        bit = 0
        overlay = np.ndarray(shape=(length_of_pixel_array * 8), dtype=np_dtype)
        for byte in overlay_raw:
            for bit in range(bit, bit + 8):
                overlay[bit] = byte & 0b1
                byte >>= 1
            bit += 1
        overlay = overlay[:expected_bit_length]
    if overlay_frames == 1:
        overlay = overlay.reshape(rows, cols)
        idx = np.where(overlay == 1)
        min_pos = (np.min(idx[0]), np.min(idx[1]))
        max_pos = (np.max(idx[0]), np.max(idx[1]))
        short_side = min((max_pos[0] - min_pos[0]), (max_pos[1] - min_pos[1]))
        small_radius = max(MRI_MIN_RADIUS, short_side * MRI_SMALL_RADIUS_FACTOR)
        big_radius = max(MRI_MIN_RADIUS+1, short_side * MRI_BIG_RADIUS_FACTOR)
        small_structure = _unit_disk(small_radius)
        m1 = binary_closing(overlay, small_structure).astype(np.int)
        big_structure = _unit_disk(big_radius)
        m2 = binary_closing(overlay, big_structure).astype(np.int)
        anatomical_mask = m1 + m2
        ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle'])
        myocardium_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['myocardium'])
        if ventricle_pixels == 0 and myocardium_pixels > MRI_MAX_MYOCARDIUM:
            erode_structure = _unit_disk(small_radius*1.5)
            anatomical_mask = anatomical_mask - binary_erosion(m1, erode_structure).astype(np.int)
            ventricle_pixels = np.count_nonzero(anatomical_mask == MRI_SEGMENTED_CHANNEL_MAP['ventricle'])
            print(f"rescue ventricle_pixels {ventricle_pixels} myo pixels: {myocardium_pixels} ")
        return overlay, anatomical_mask, ventricle_pixels
    
def _unit_disk(r) -> np.ndarray:
    y, x = np.ogrid[-r: r + 1, -r: r + 1]
    return (x ** 2 + y ** 2 <= r ** 2).astype(np.int)


def _outline_to_mask(labeled_outline, idx) -> np.ndarray:
    idx = np.where(labeled_outline == idx)
    poly = list(zip(idx[1].tolist(), idx[0].tolist()))
    img = Image.new("L", [labeled_outline.shape[1], labeled_outline.shape[0]], 0)
    ImageDraw.Draw(img).polygon(poly, outline=1, fill=1)
    return np.array(img)

In [None]:
def plot_b_series(b_series, sides=7):
    _, axes = plt.subplots(sides, sides, figsize=(18, 24))
    for dcm in b_series:
        idx = (dcm.InstanceNumber-1)%50
        if idx >= sides*sides:
            continue
        if _is_mitral_valve_segmentation(dcm):
            axes[idx%sides, idx//sides].imshow(dcm.pixel_array, cmap='gray')
        else:
            try:
                overlay, anatomical_mask, ventricle_pixels = _get_overlay_from_dicom(dcm)
                #axes[idx%sides, idx//sides].imshow(np.ma.masked_where(anatomical_mask == 2, dcm.pixel_array), cmap='gray')
                axes[idx%sides, idx//sides].imshow(dcm.pixel_array, cmap='gray')
            except KeyError:
                print(f'Could not get overlay at {dcm.InstanceNumber}, angle {s}')
                axes[idx, idx//sides].imshow(dcm.pixel_array)
        axes[idx%sides, idx//sides].set_yticklabels([])
        axes[idx%sides, idx//sides].set_xticklabels([])

In [None]:
plot_b_series(series[4], sides=2)

In [None]:
plot_b_series(series[8], sides=7)

In [None]:
plot_b_series(series[5], sides=7)

In [None]:
systoles = {}
diastoles = {}
systoles_pix = {}
diastoles_pix = {}
_, axes = plt.subplots(50, 12, figsize=(12, 36))
for s in series:
    for dcm in series[s]:
        if _is_mitral_valve_segmentation(dcm):
            axes[(dcm.InstanceNumber-1)%50, s].imshow(dcm.pixel_array)
            continue
        try:
            overlay, anatomical_mask, ventricle_pixels = _get_overlay_from_dicom(dcm)
            axes[(dcm.InstanceNumber-1)%50, s].imshow(np.ma.masked_where(anatomical_mask == 2, dcm.pixel_array))
            axes[(dcm.InstanceNumber-1)%50, s].set_yticklabels([])
            axes[(dcm.InstanceNumber-1)%50, s].set_xticklabels([])
        except KeyError:
            print(f'could get overlay at {dcm.InstanceNumber}, angle {s}')
            axes[(dcm.InstanceNumber-1)%50, s].imshow(dcm.pixel_array)
        if s not in diastoles:
            diastoles[s] = dcm
            diastoles_pix[s] = ventricle_pixels
            systoles[s] = dcm
            systoles_pix[s] = ventricle_pixels
        else:
            if ventricle_pixels > diastoles_pix[s]:
                diastoles[s] = dcm
                diastoles_pix[s] = ventricle_pixels
            if ventricle_pixels < systoles_pix[s]:
                systoles[s] = dcm
                systoles_pix[s] = ventricle_pixels

for angle in diastoles:
    print(f'Found systole at instance {systoles[angle].InstanceNumber}  pix: {systoles_pix[angle]}')
    print(f'Found diastole at instance {diastoles[angle].InstanceNumber}   pix: {diastoles_pix[angle]}\n')

In [None]:
print (series.keys())