In [5]:
import nibabel as nib
import numpy as np
import os
from patchify import patchify, unpatchify
from scipy import ndimage
import re
import cv2


class Resize3D:
    def __init__(self, target_size, model='constant', order=1):
        self.model = model
        self.order = order

        if isinstance(target_size, list) or isinstance(target_size, tuple):
            if len(target_size) != 3:
                raise ValueError(
                    '`target_size` should include 3 elements, but it is {}'.
                    format(target_size))

        else:
            raise TypeError(
                "Type of `target_size` is invalid. It should be list or tuple, but it is {}"
                .format(type(target_size)))

        self.target_size = target_size

    def __call__(self, im, label=None):
        if not isinstance(im, np.ndarray):
            raise TypeError("Resize: image type is not numpy.")
        if len(im.shape) != 3:
            raise ValueError('Resize: image is not 3-dimensional.')
        if im.ndim == 3:
            desired_depth = self.target_size[2]  # 深度
            desired_width = self.target_size[1]
            desired_height = self.target_size[0]

            current_depth = im.shape[2]  # 深度
            current_width = im.shape[1]
            current_height = im.shape[0]

            depth = current_depth / desired_depth
            width = current_width / desired_width
            height = current_height / desired_height
            depth_factor = 1 / depth
            width_factor = 1 / width
            height_factor = 1 / height

            im = ndimage.zoom(im, (height_factor, width_factor, depth_factor), order=self.order, mode=self.model)
            if label is not None:
                label = ndimage.zoom(label, (height_factor, width_factor, depth_factor), order=0, mode='nearest',
                                     cval=0.0)

        else:
            raise ValueError('Resize: target_size is not 3-dimensional.')

        if label is None:
            return im
        else:
            return (im, label)


def crop(image, patch_size):
    cropped_image_list = []
    patches = patchify(image, patch_size, step=patch_size)
    for d in range(patches.shape[0]):
        for h in range(patches.shape[1]):
            for w in range(patches.shape[2]):
                # 从 patches 中获取一个小块
                cropped_image_list.append(patches[d, h, w])

    return cropped_image_list


def equalization(image):
    a, b, c = image.shape
    flattened_image = image.flatten()
    equalized_image = cv2.equalizeHist(flattened_image)

    return equalized_image.reshape(a, b, c)


def nii_save(imlist, filedir, index, affine):
    os.makedirs(filedir, exist_ok=True)  # 디렉토리 생성
    for i in range(len(imlist)):
        data = imlist[i].squeeze()
        nifti_image = nib.Nifti1Image(data, affine)
        filename = f'CT{index}_{i+1:04d}.nii.gz'  # 파일 이름 형식 변경
        filepath = os.path.join(filedir, filename)
        nib.save(nifti_image, filepath)
        print(f'{filename} saved successfully!')



pattern = re.compile(r".*\.nii\.gz$")
image_dir = r'/home/chae/segmentation/dataset/ct_train'
label_dir = r'/home/chae/segmentation/dataset/ct_label'
image_list = [file for file in sorted(os.listdir(image_dir)) if pattern.match(file)]
label_list = [file for file in sorted(os.listdir(label_dir)) if pattern.match(file)]
for idx in range(len(image_list)):
    image_path = os.path.join(image_dir, image_list[idx])
    label_path = os.path.join(label_dir, label_list[idx])

    img = nib.load(image_path)
    img_data = img.get_fdata()
    affine1 = img.affine if hasattr(img, 'affine') else np.eye(4)
    img_data = (img_data / np.max(img_data) * 255).astype(np.uint8)
    img_data = equalization(img_data)

    label = nib.load(label_path)
    label_data = label.get_fdata()
    affine2 = img.affine if hasattr(label, 'affine') else np.eye(4)

    target_size = (512, 512, 384)
    resize3d = Resize3D(target_size=target_size)
    img_data, label_data = resize3d(img_data, label=label_data)

    output_size = (64, 64, 64)
    img_list = crop(img_data, output_size)
    lab_list = crop(label_data, output_size)

    nii_save(img_list, r'/home/chae/segmentation/dataset/ct_train_patch', f'train_{idx+1}', affine1)
    nii_save(lab_list, r'/home/chae/segmentation/dataset/ct_label_patch', f'label_{idx+1}', affine2)


CTtrain_1_0001.nii.gz saved successfully!
CTtrain_1_0002.nii.gz saved successfully!
CTtrain_1_0003.nii.gz saved successfully!
CTtrain_1_0004.nii.gz saved successfully!
CTtrain_1_0005.nii.gz saved successfully!
CTtrain_1_0006.nii.gz saved successfully!
CTtrain_1_0007.nii.gz saved successfully!
CTtrain_1_0008.nii.gz saved successfully!
CTtrain_1_0009.nii.gz saved successfully!
CTtrain_1_0010.nii.gz saved successfully!
CTtrain_1_0011.nii.gz saved successfully!
CTtrain_1_0012.nii.gz saved successfully!
CTtrain_1_0013.nii.gz saved successfully!
CTtrain_1_0014.nii.gz saved successfully!
CTtrain_1_0015.nii.gz saved successfully!
CTtrain_1_0016.nii.gz saved successfully!
CTtrain_1_0017.nii.gz saved successfully!
CTtrain_1_0018.nii.gz saved successfully!
CTtrain_1_0019.nii.gz saved successfully!
CTtrain_1_0020.nii.gz saved successfully!
CTtrain_1_0021.nii.gz saved successfully!
CTtrain_1_0022.nii.gz saved successfully!
CTtrain_1_0023.nii.gz saved successfully!
CTtrain_1_0024.nii.gz saved succes