In [1]:
import os
import json
import torch
import pydicom
import warnings
import scipy.ndimage
import numpy as np
import pickle
import random
from sklearn.model_selection import train_test_split

In [2]:
warnings.filterwarnings('ignore')
random.seed(51)

In [3]:
print(torch.cuda.is_available())

True


In [4]:
def find_files(dir_path, extension):
    """
    find DLD files(.dcm, .json) from the dir_path
    """
    file_list = []
    for root, dirs, files in os.walk(dir_path):
        for file in files:
            if file.endswith(extension):
                if 'C004' in file:
                    file_list.append(os.path.join(root,file))
    return file_list

In [5]:
def format_image(dcm_file):
    """
    crop the image by a certain ratio, adjust its size, and adjust the pixel value from 0 to 1.
    """
    dicom = pydicom.dcmread(dcm_file)
    image = dicom.pixel_array
    
    h, w = image.shape
    left, upper = int(0.1 * w), int (0.1 * h)
    right, lower = int (0.9 * w), int(0.9 * h)
    image = image[upper:lower, left:right]
    
    image = scipy.ndimage.zoom(image, (300 / image.shape[0], 300 / image.shape[1]), order=1)
    min_p = np.min(image)
    max_p = np.max(image)
    image = (image - min_p) / (max_p - min_p)
    return image

In [6]:
def make_dataset(dcm_files, json_files):
    """
    make patient and non-patient data from filename list.
    data format: [[pixel_array], age, sex, spinal cord width, spinal cord level, is_patient]
    """
    patient = []
    normal = []
    if (len(dcm_files) < 6000):
        size = len(dcm_files)
    else:
        size = 6000
    for idx in range(size):
        info = []
        is_patient = 0;
        info.append(format_image(dcm_files[idx]))
        with open(json_files[idx], 'r') as f:
            json_data = json.load(f)
        info.append(json_data['annotation']['clinic']['age'])
        info.append(json_data['annotation']['clinic']['sex']) 
        info.append(json_data['annotation']['ANNOTATION_DATA'][0]['m_area'])
        for i in range(10):
            value = list(json_data["annotation"]["DATA_CATEGORY"][i].values())[0]
            if not (value == 2 or value == 3):
                info.append(i % 5)
                info.append(1)
                is_patient = 1
                flag = -1
                break
            elif (value == 2):
                flag = i
                continue
            else:
                continue
        if not flag == -1:
            info.append(flag % 5)
            info.append(0)
        if is_patient:
            patient.append(tuple(info))
        else:
            normal.append(tuple(info))
    return patient, normal

In [7]:
def get_files(dcm_dir, json_dir):
    """
    remove unnecessary and invalid files
    """
    dcm_files = find_files(dcm_dir, '.dcm')
    json_files = find_files(json_dir, '.json')
    list_to_remove = []
    for file in dcm_files:
        dicom = pydicom.dcmread(file, force = True)
        if 'SeriesDescription' in dicom:
            if not('t2' in dicom.SeriesDescription.lower() and 'sag' not in dicom.SeriesDescription.lower()):
                list_to_remove.append(file.split('/')[-1].split('.')[0])
        else:
            list_to_remove.append(file.split('/')[-1].split('.')[0])
    dcm_files = [file for file in dcm_files if file.split('/')[-1].split('.')[0] not in list_to_remove]
    json_files = [file for file in json_files if file.split('/')[-1].split('.')[0] not in list_to_remove]
    list_to_remove = []

    for file in json_files:
        with open(file, 'r') as f:
            json_data = json.load(f)
        if all(int(value) == 3 for entry in json_data["annotation"]["DATA_CATEGORY"] for value in entry.values()):
            list_to_remove.append(file.split('/')[-1].split('.')[0])
        elif not json_data['annotation']['ANNOTATION_DATA']:
            list_to_remove.append(file.split('/')[-1].split('.')[0])
        else:
            continue
    json_files = [file for file in json_files if file.split('/')[-1].split('.')[0] not in list_to_remove]
    dcm_files = [file for file in dcm_files if file.split('/')[-1].split('.')[0] not in list_to_remove]
    return dcm_files, json_files

In [12]:
def get_traindata(dcm_dir, json_dir, test_dcm_dir, test_json_dir):
    """
    get train_data, val_data, test_data
    Set the parient and non-patient ratio of each data to be equal to the overall ratio
    """
    train_dcm, train_json = get_files(dcm_dir, json_dir)
    test_dcm, test_json = get_files(test_dcm_dir, test_json_dir)
    
    train_patient, train_normal = make_dataset(train_dcm, train_json)
    test_patient, test_normal = make_dataset(test_dcm, test_json)
    
    patient_dataset = train_patient + test_patient
    normal_dataset = train_normal + test_normal
    
    patients_train_val, patients_test = train_test_split(patient_dataset, test_size = 0.2, random_state = 51)
    patients_train, patients_val = train_test_split(patients_train_val, test_size = 0.25, random_state = 51)
    
    normal_train_val, normal_test = train_test_split(normal_dataset, test_size = 0.2, random_state = 51)
    normal_train, normal_val = train_test_split(normal_train_val, test_size = 0.25, random_state = 51)
    
    train_data = patients_train + normal_train
    val_data = patients_val + normal_val
    test_data = patients_test + normal_test
    
    random.shuffle(train_data)
    random.shuffle(val_data)
    random.shuffle(test_data)
    
    return train_data, val_data, test_data

In [13]:
# file path
json_dir = '../202101n070/070.퇴행성 척추질환 진단 및 치료를 위한 멀티모달리티 데이터/01.데이터/1.Training/2_라벨링데이터/DLD'
dcm_dir = '../202101n070/070.퇴행성 척추질환 진단 및 치료를 위한 멀티모달리티 데이터/01.데이터/1.Training/1_원천데이터/DLD'

test_json_dir = '../202101n070/070.퇴행성 척추질환 진단 및 치료를 위한 멀티모달리티 데이터/01.데이터/2.Validation/2_라벨링데이터/DLD'
test_dcm_dir = '../202101n070/070.퇴행성 척추질환 진단 및 치료를 위한 멀티모달리티 데이터/01.데이터/2.Validation/1_원천데이터/DLD'

In [14]:
train_data, val_data, test_data = get_traindata(dcm_dir, json_dir, test_dcm_dir, test_json_dir)

In [18]:
# save data to data.pkl
with open('data.pkl', 'wb') as f:
    pickle.dump([train_data, val_data, test_data], f)