In [17]:
import os
from glob import glob

import pandas as pd
import numpy as np

import cv2
import pydicom                    #   pydicom data를 불러오기 위함이다.
import nibabel as nib             #   segmenataion dataset 불러오기
import matplotlib.pyplot as plt
import itertools
plt.style.use('dark_background')

In [18]:
data_folder = '../input/rsna-2022-cervical-spine-fracture-detection/'
nii_dir = data_folder + 'segmentations/'
nii_paths = glob(nii_dir + '*.nii')

In [19]:
class DrawMaskSample():
    def __init__(self, nii_paths, i, ax):
        self.i = i
        self.ax = ax
        self.nii_sample = nib.load(nii_paths[self.i]).get_fdata()
        print(f'sample({i+1}) shape: {self.nii_sample.shape}')
        self.get_xyz()
        self.draw_sample_3d()
        
    def get_xyz(self):
        self.xyz_li = []
        cnt = 0
        max_cnt = self.nii_sample.shape[-1]
        for iter_z, (iter_img) in enumerate(self.nii_sample.transpose(2,0,1)):
            for iter_x, iter_arr_y in enumerate(iter_img):
                iter_arr_y = np.where(iter_arr_y)[0]
                
                if len(iter_arr_y) >= 1:
                    iter_arr_y = list(set([iter_arr_y.max(), iter_arr_y.min()]))
                    xyz = [(iter_x, iter_y, iter_z) for iter_y in iter_arr_y if np.any(iter_y)]
                    self.xyz_li.append(xyz)
            cnt += 1
            if (cnt % 100 == 0) | (cnt == max_cnt):
                print(f'iteration: ({cnt} / {max_cnt})')

    def draw_sample_3d(self):
        xyz_matrix = np.array(list(itertools.chain.from_iterable(self.xyz_li)))
        X = xyz_matrix[:,0]
        Y = xyz_matrix[:,1]
        Z = xyz_matrix[:,2]

        self.ax.scatter(X, Y, Z, s=1, alpha=0.04, color='beige')
        xlim, ylim, zlim = self.nii_sample.shape
        self.ax.set_xlim(0, xlim)
        self.ax.set_ylim(0, ylim)
        self.ax.set_zlim(0, zlim)
        self.ax.set_title(f'sample - ({self.i+1})')

In [20]:
fig = plt.figure(figsize=(16,16))
for i in range(9):
    ax = fig.add_subplot(int(f'33{i+1}'), projection='3d')
    DrawMaskSample(nii_paths, i, ax)

plt.tight_layout()
plt.show()