In [10]:
import os
print(os.listdir('ACDC_dataset/training'))

['patient001', 'patient002', 'patient003', 'patient004', 'patient005', 'patient006', 'patient007', 'patient008', 'patient009', 'patient010', 'patient011', 'patient012', 'patient013', 'patient014', 'patient015', 'patient016', 'patient017', 'patient018', 'patient019', 'patient020', 'patient021', 'patient022', 'patient023', 'patient024', 'patient025', 'patient026', 'patient027', 'patient028', 'patient029', 'patient030', 'patient031', 'patient032', 'patient033', 'patient034', 'patient035', 'patient036', 'patient037', 'patient038', 'patient039', 'patient040', 'patient041', 'patient042', 'patient043', 'patient044', 'patient045', 'patient046', 'patient047', 'patient048', 'patient049', 'patient050', 'patient051', 'patient052', 'patient053', 'patient054', 'patient055', 'patient056', 'patient057', 'patient058', 'patient059', 'patient060', 'patient061', 'patient062', 'patient063', 'patient064', 'patient065', 'patient066', 'patient067', 'patient068', 'patient069', 'patient070', 'patient071', 'pati

In [11]:
import os
print(os.listdir('ACDC_dataset/testing'))

['patient101', 'patient102', 'patient103', 'patient104', 'patient105', 'patient106', 'patient107', 'patient108', 'patient109', 'patient110', 'patient111', 'patient112', 'patient113', 'patient114', 'patient115', 'patient116', 'patient117', 'patient118', 'patient119', 'patient120', 'patient121', 'patient122', 'patient123', 'patient124', 'patient125', 'patient126', 'patient127', 'patient128', 'patient129', 'patient130', 'patient131', 'patient132', 'patient133', 'patient134', 'patient135', 'patient136', 'patient137', 'patient138', 'patient139', 'patient140', 'patient141', 'patient142', 'patient143', 'patient144', 'patient145', 'patient146', 'patient147', 'patient148', 'patient149', 'patient150', 'test', 'train']


In [16]:
import os
import re
import time
import shutil
import errno
import pickle
import numpy as np
import nibabel as nib
import skimage.morphology as morph
import skimage.transform
import matplotlib.pyplot as plt
from matplotlib import animation
from tqdm import tqdm
from collections import OrderedDict
from scipy.fftpack import fftn, ifftn
from skimage.feature import peak_local_max, canny
from skimage.transform import hough_circle

np.random.seed(42)

# ================================================================
#                    HELPER FUNCTIONS
# ================================================================

def heart_metrics(seg_3Dmap, voxel_size, classes=[3, 1, 2]):
    """Compute the volumes of each class in mL."""
    volumes = []
    for c in classes:
        seg_copy = np.copy(seg_3Dmap)
        seg_copy[seg_copy != c] = 0
        seg_copy = np.clip(seg_copy, 0, 1)
        volume = seg_copy.sum() * np.prod(voxel_size) / 1000.
        volumes.append(volume)
    return volumes


def ejection_fraction(ed_vol, es_vol):
    """Calculate ejection fraction (in %)."""
    stroke_vol = ed_vol - es_vol
    return (float(stroke_vol) / float(ed_vol)) * 100 if ed_vol != 0 else 0


def myocardialmass(myocardvol):
    """Compute myocardial mass in grams (density = 1.05 g/mL)."""
    return myocardvol * 1.05


def imshow(*args, **kwargs):
    """Display multiple images in one row."""
    cmap = kwargs.get('cmap', 'gray')
    title = kwargs.get('title', '')
    if len(args) == 0:
        raise ValueError("No images given to imshow")

    n = len(args)
    if isinstance(cmap, str):
        cmap = [cmap] * n
    if isinstance(title, str):
        title = [title] * n

    plt.figure(figsize=(n * 5, 10))
    for i in range(n):
        plt.subplot(1, n, i + 1)
        plt.title(title[i])
        plt.imshow(args[i], cmap[i])
    plt.show()


# ================================================================
#                    ROI PLOTTING & ANIMATION
# ================================================================

def plot_roi(data4D, roi_center, roi_radii):
    """Animate slices with ROI overlay."""
    x_c, y_c = roi_center
    x_r, y_r = roi_radii
    zslices, tframes = data4D.shape[2], data4D.shape[3]

    for z in range(zslices):
        slice_data = np.swapaxes(np.swapaxes(data4D[:, :, z, :], 0, 2), 1, 2)
        roi_mask = np.zeros_like(slice_data[0])
        roi_mask[x_c - x_r:x_c + x_r, y_c - y_r:y_c + y_r] = 1

        slice_data[:, roi_mask > 0.5] *= 0.8
        fig = plt.figure(1)
        fig.canvas.manager.set_window_title(f'Slice {z}')

        def animate_out(i):
            im.set_data(slice_data[i])
            return im

        im = plt.imshow(slice_data[0], cmap='gray')
        anim = animation.FuncAnimation(fig, animate_out, frames=tframes, interval=50)
        anim.save(f'Cine_MRI_SAX_{z}.mp4', fps=50, extra_args=['-vcodec', 'libx264'])
        plt.show()


# ================================================================
#                    FILE I/O UTILS
# ================================================================

def save_data(data, filename, out_path):
    out_filename = os.path.join(out_path, filename)
    with open(out_filename, 'wb') as f:
        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f"✅ Saved to {out_filename}")


def load_pkl(path):
    with open(path, 'rb') as f:
        return pickle.load(f)


def copy(src, dest):
    """Safe copy of files/directories."""
    try:
        shutil.copytree(src, dest, ignore=shutil.ignore_patterns())
    except OSError as e:
        if e.errno == errno.ENOTDIR:
            shutil.copy(src, dest)
        else:
            print(f"⚠️ Directory not copied. Error: {e}")


def read_patient_cfg(path):
    """Reads Info.cfg and returns patient metadata."""
    info = {}
    cfg_file = os.path.join(path, 'Info.cfg')
    if not os.path.exists(cfg_file):
        raise FileNotFoundError(f"Missing Info.cfg in {path}")
    with open(cfg_file) as f:
        for line in f:
            key, val = line.rstrip().split(": ")
            info[key] = val
    return info


# ================================================================
#                    DATA GROUPING
# ================================================================

NORMAL, MINF, DCM, HCM, RV = 'NOR', 'MINF', 'DCM', 'HCM', 'RV'


def group_patient_cases(src_path, out_path, force=False):
    """Group the ACDC patient data according to cardiac pathology."""
    walker = os.walk(src_path)
    try:
        cases = sorted(next(walker)[1])
    except StopIteration:
        raise FileNotFoundError(f"No subdirectories found in {src_path}")

    dest_path = os.path.join(out_path, 'Patient_Groups')
    if force and os.path.exists(dest_path):
        shutil.rmtree(dest_path)
    os.makedirs(dest_path, exist_ok=True)

    for group in [NORMAL, MINF, DCM, HCM, RV]:
        os.makedirs(os.path.join(dest_path, group), exist_ok=True)

    for case in cases:
        full_path = os.path.join(src_path, case)
        try:
            patient_group = read_patient_cfg(full_path)['Group']
            copy(full_path, os.path.join(dest_path, patient_group, case))
        except Exception as e:
            print(f"⚠️ Skipping {case}: {e}")
    return dest_path


def generate_train_validate_test_set(src_path, dest_path):
    """Split grouped data into train/val/test (70/15/15)."""
    SPLIT_TRAIN, SPLIT_VALID = 0.7, 0.15
    dest_path = os.path.join(dest_path, 'dataset')
    if os.path.exists(dest_path):
        shutil.rmtree(dest_path)

    os.makedirs(os.path.join(dest_path, 'train_set'))
    os.makedirs(os.path.join(dest_path, 'validation_set'))
    os.makedirs(os.path.join(dest_path, 'test_set'))

    for group in next(os.walk(src_path))[1]:
        group_path = os.path.join(src_path, group)
        patients = next(os.walk(group_path))[1]
        np.random.shuffle(patients)

        n_train = int(SPLIT_TRAIN * len(patients))
        n_valid = int((SPLIT_TRAIN + SPLIT_VALID) * len(patients))
        splits = {
            'train_set': patients[:n_train],
            'validation_set': patients[n_train:n_valid],
            'test_set': patients[n_valid:]
        }

        for split_name, patient_list in splits.items():
            for p in patient_list:
                copy(os.path.join(group_path, p),
                     os.path.join(dest_path, split_name, p))
    return dest_path


# ================================================================
#                    ROI EXTRACTION
# ================================================================

def extract_roi_stddev(data4D, pixel_spacing, minradius_mm=15, maxradius_mm=45,
                       kernel_width=5, center_margin=8, num_peaks=10, num_circles=20, radstep=2):
    """ROI extraction using StdDev + Hough Transform."""
    px, py, _, _ = pixel_spacing
    minr, maxr = int(minradius_mm / px), int(maxradius_mm / py)

    xsize, ysize, zslices, tframes = data4D.shape
    xsurf = np.tile(np.arange(xsize), (ysize, 1)).T
    ysurf = np.tile(np.arange(ysize), (xsize, 1))
    lsurf = np.zeros((xsize, ysize))
    centers, accums, radii = [], [], []

    for z in range(zslices):
        fh = np.std(np.array([data4D[:, :, z, t] for t in range(tframes)]), axis=0)
        fh[fh < 0.1 * np.max(fh)] = 0
        image = fh / np.max(fh)
        edges = canny(image, sigma=3)
        hough_radii = np.arange(minr, maxr, radstep)
        hough_res = hough_circle(edges, hough_radii)

        if not hough_res.any():
            continue

        for r, h in zip(hough_radii, hough_res):
            peaks = peak_local_max(h, num_peaks=num_peaks)
            centers.extend(peaks)
            accums.extend(h[peaks[:, 0], peaks[:, 1]])
            radii.extend([r] * len(peaks))

        sorted_idx = np.argsort(accums)[::-1][:num_circles]
        for idx in sorted_idx:
            cx, cy = centers[idx]
            brightness = accums[idx]
            lsurf += brightness * np.exp(-((xsurf - cx) ** 2 + (ysurf - cy) ** 2) / kernel_width ** 2)

    if lsurf.max() == 0:
        return (0, 0), (0, 0)

    lsurf /= lsurf.max()
    roi_center = np.unravel_index(lsurf.argmax(), lsurf.shape)

    roi_xr, roi_yr = 0, 0
    for i in range(len(centers)):
        dx, dy = abs(centers[i][0] - roi_center[0]), abs(centers[i][1] - roi_center[1])
        if dx <= center_margin and dy <= center_margin:
            roi_xr = max(roi_xr, radii[i] + dx)
            roi_yr = max(roi_yr, radii[i] + dy)

    return roi_center, (roi_xr, roi_yr) if roi_xr and roi_yr else None


# ================================================================
#                    DATASET CLASS (TRAIN/VAL/TEST)
# ================================================================

class Dataset:
    def __init__(self, directory, subdir):
        self.directory = directory
        self.name = subdir
        self.patient_data = {}

    def _filename(self, file):
        return os.path.join(self.directory, self.name, file)

    def load_nii(self, img_path):
        nimg = nib.load(self._filename(img_path))
        return nimg.get_fdata(), nimg.affine, nimg.header

    def read_patient_info_data(self):
        path = self._filename('Info.cfg')
        with open(path) as f:
            for line in f:
                key, val = line.rstrip().split(": ")
                self.patient_data[key] = val

    def read_patient_data(self, roi_detect=True):
        pid = int(re.match(r"patient(\d{3})", self.name).group(1))
        self.read_patient_info_data()
        ED = int(self.patient_data['ED'])
        ES = int(self.patient_data['ES'])

        ed_img, es_img = f"patient{pid:03d}_frame{ED:02d}.nii.gz", f"patient{pid:03d}_frame{ES:02d}.nii.gz"
        ed, affine, hdr = self.load_nii(ed_img)
        es, _, _ = self.load_nii(es_img)
        self.patient_data['ED_VOL'], self.patient_data['ES_VOL'] = ed, es
        self.patient_data['header'] = {'affine': affine, 'hdr': hdr}

        ed_gt, _, _ = self.load_nii(f"patient{pid:03d}_frame{ED:02d}_gt.nii.gz")
        es_gt, _, _ = self.load_nii(f"patient{pid:03d}_frame{ES:02d}_gt.nii.gz")
        self.patient_data['ED_GT'], self.patient_data['ES_GT'] = ed_gt, es_gt

        ed_lv, ed_rv, ed_myo = heart_metrics(ed_gt, hdr.get_zooms())
        es_lv, es_rv, es_myo = heart_metrics(es_gt, hdr.get_zooms())
        ef_lv, ef_rv = ejection_fraction(ed_lv, es_lv), ejection_fraction(ed_rv, es_rv)
        self.patient_data['HP'] = {'EDV_LV': ed_lv, 'EDV_RV': ed_rv,
                                   'ESV_LV': es_lv, 'ESV_RV': es_rv,
                                   'EF_LV': ef_lv, 'EF_RV': ef_rv}

        if roi_detect:
            img4d, _, hdr = self.load_nii(f"patient{pid:03d}_4d.nii.gz")
            self.patient_data['4D'] = img4d
            c, r = extract_roi_stddev(img4d, hdr.get_zooms())
            self.patient_data['roi_center'], self.patient_data['roi_radii'] = c, r


# ================================================================
#                    MAIN EXECUTION
# ================================================================

def convert_nii_np(data_path, roi_detect):
    patient_fulldata = OrderedDict()
    print(f"📁 Processing {data_path}")
    for patient in tqdm(sorted(next(os.walk(data_path))[1])):
        dset = Dataset(data_path, patient)
        dset.read_patient_data(roi_detect=roi_detect)
        patient_fulldata[dset.name] = dset.patient_data
    return patient_fulldata


if __name__ == '__main__':
    start_time = time.time()

    complete_data_path = 'ACDC_dataset/training'
    dest_path = 'processed_acdc_dataset'
    group_path = os.path.join(dest_path, 'Patient_Groups')

    train_dataset_path = os.path.join(dest_path, 'dataset', 'train_set')
    val_dataset_path = os.path.join(dest_path, 'dataset', 'validation_set')
    test_dataset_path = os.path.join(dest_path, 'dataset', 'test_set')
    out_path_train = os.path.join(dest_path, 'pickled', 'full_data')

    # === STRATIFY DATA ===
    group_patient_cases(complete_data_path, dest_path)
    generate_train_validate_test_set(group_path, dest_path)
    print(f"✅ Stratification done in {time.time() - start_time:.2f}s")

    # === ROI + PICKLE SAVE ===
    os.makedirs(out_path_train, exist_ok=True)

    for name, path in zip(['train_set', 'validation_set', 'test_set'],
                          [train_dataset_path, val_dataset_path, test_dataset_path]):
        data = convert_nii_np(path, roi_detect=True)
        save_data(data, f"{name}.pkl", out_path_train)

✅ Stratification done in 7.37s
📁 Processing processed_acdc_dataset\dataset\train_set


100%|██████████| 70/70 [01:54<00:00,  1.64s/it]


✅ Saved to processed_acdc_dataset\pickled\full_data\train_set.pkl
📁 Processing processed_acdc_dataset\dataset\validation_set


100%|██████████| 15/15 [00:19<00:00,  1.27s/it]


✅ Saved to processed_acdc_dataset\pickled\full_data\validation_set.pkl
📁 Processing processed_acdc_dataset\dataset\test_set


100%|██████████| 15/15 [00:12<00:00,  1.16it/s]


✅ Saved to processed_acdc_dataset\pickled\full_data\test_set.pkl
