In [1]:
## Fast Import
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys, os
import pathlib
import logging
from pathlib import Path
import time
import math, random
import pprint
import collections
from collections import OrderedDict
import numbers, string
import nibabel as nib

import yaml
from tqdm import tqdm

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import SimpleITK as sitk
from PIL import Image

import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import albumentations as A

np.set_printoptions(precision=3)
curr_path = pathlib.Path(os.getcwd()).absolute()

cards = !echo $SGE_HGR_gpu_card
device = torch.device(f"cuda:{cards[0]}" if torch.cuda.is_available() else 'cpu')
print(device)

cuda:3


# DF

In [2]:
def natural_sort(l):
    import re
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(l, key=alphanum_key)

def collect_df(dataset_path, save=None):
    """ Collect df for MMWHS dataset (for now, only labeled training samples). 
    Note:
        - Only returns 'train' subset. You have to manually change. 
    """
    ds_path = Path(dataset_path)

    logging.info(f"Collecting MMWHS df.")
    start_time = time.time()

    train_path = ds_path / 'ct_train_images'
    train_images = natural_sort([str(f) for f in train_path.iterdir()
                                 if f.suffix == '.gz'])
    label_path = ds_path / 'ct_train_labels'
    label_images = natural_sort([str(f) for f in label_path.iterdir()
                                 if f.suffix == '.gz'])
    assert len(train_images) == len(label_images)

    # trains, vals, tests = split(range(len(train_iamges)))

    df_d = OrderedDict([
        ('id', []),
        ('image', []),
        ('mask', []),
        ('imgsize', []),
        ('subset', []),
    ])
    for i, img in enumerate(train_images):
        img_path = Path(img)
        mask_path = Path(label_images[i])
        assert img_path.name.split('_')[2] == mask_path.name.split('_')[2]

        vol = nib.load(img)

        df_d['id'].append(i + 1)
        df_d['image'].append(img)
        df_d['mask'].append(str(mask_path))
        df_d['subset'].append('train')
        df_d['imgsize'].append(vol.shape)

    df = pd.DataFrame(df_d)

    elapsed_time = time.time() - start_time
    logging.info(f"Done collecting MMWHS ({elapsed_time:.1f} sec).")
    if save:
        df.to_csv(save)
    return df

In [3]:
ds_dir = '/afs/crc.nd.edu/user/y/yzhang46/datasets/MMWHS-2017'
df = collect_df(ds_dir, save=None)
# df.to_csv('default_df.csv')

# 3, 6, 12, 15 are test
lab_df = df.copy()
for idx in (2, 5, 11, 14):
    lab_df.at[idx, 'subset'] = 'test'
# lab_df.to_csv('lab_df.csv')

## Image Information

In [37]:
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
np.set_printoptions(precision=3)

unique_props = collections.defaultdict(set)
for i, S in df.iterrows():
    image = S['image']
    mask = S['mask']
    
    sitk_image = sitk.ReadImage(image)
    spac = sitk_image.GetSpacing()
    direc = sitk_image.GetDirection()
    shape = sitk_image.GetSize()
    
    unique_props['shape'].add(tuple(shape))
    unique_props['spacing'].add(tuple(spac))
    unique_props['direction'].add(tuple(direc))
    
    print(f'Shape: {shape} | Spacing: {spac} | Direction: {direc}')
pprint.pprint(unique_props)

Shape: (512, 512, 363) | Spacing: (0.35546875, 0.35546875, 0.44999998807907104) | Direction: (1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0)
Shape: (512, 512, 239) | Spacing: (0.4882810115814209, 0.4882810115814209, 0.625) | Direction: (1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0)
Shape: (512, 512, 298) | Spacing: (0.302734375, 0.302734375, 0.44999998807907104) | Direction: (1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0)
Shape: (512, 512, 200) | Spacing: (0.3203119933605194, 0.3203119933605194, 0.625) | Direction: (1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0)
Shape: (512, 512, 177) | Spacing: (0.4882810115814209, 0.4882810115814209, 0.625) | Direction: (1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0)
Shape: (512, 512, 248) | Spacing: (0.4882810115814209, 0.4882810115814209, 0.625) | Direction: (1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0)
Shape: (512, 512, 243) | Spacing: (0.43554699420928955, 0.43554699420928955, 0.625) | Direction: (1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 1.0)
Shape:

### Masks

In [6]:
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
np.set_printoptions(precision=3)

unique_props = collections.defaultdict(set)
for i, S in df.iterrows():
    image = S['image']
    mask = S['mask']
    
    sitk_image = sitk.ReadImage(image)
    spac = sitk_image.GetSpacing()
    direc = sitk_image.GetDirection()
    shape = sitk_image.GetSize()
    
    unique_props['shape'].add(tuple(shape))
    unique_props['spacing'].add(tuple(spac))
    unique_props['direction'].add(tuple(direc))
    
    sitk_mask = sitk.ReadImage(mask)
    mask_arr = sitk.GetArrayFromImage(sitk_mask)
    for i, v in enumerate(np.unique(mask_arr)):
        print(v, '->', i)
        mask_arr[mask_arr == v] = i
    new_sitk_mask = sitk.GetImageFromArray(mask_arr)
    new_sitk_mask.CopyInformation(sitk_mask)
pprint.pprint(unique_props)

0 -> 0
205 -> 1
420 -> 2
500 -> 3
550 -> 4
600 -> 5
820 -> 6
850 -> 7
0 -> 0
205 -> 1
420 -> 2
500 -> 3
550 -> 4
600 -> 5
820 -> 6
850 -> 7
0 -> 0
205 -> 1
420 -> 2
500 -> 3
550 -> 4
600 -> 5
820 -> 6
850 -> 7
0 -> 0
205 -> 1
420 -> 2
500 -> 3
550 -> 4
600 -> 5
820 -> 6
850 -> 7
0 -> 0
205 -> 1
420 -> 2
500 -> 3
550 -> 4
600 -> 5
820 -> 6
850 -> 7


KeyboardInterrupt: 