<a href="https://colab.research.google.com/github/musicangora/MakeDataset_NAMIC/blob/main/MakeDatasets_NAMIC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## データセット入手先
[NAMIC: Brain Mutlimodality](https://www.insight-journal.org/midas/collection/view/190/1)

※t1w.nrrdとt2w.nrrdをダウンロード

## ディレクトリ構成
### pix2pix
```
T1T2_pix2pix/
    ├ train/
    │    └ T1-T2 image
    └ test/
          └ T1-T2 image
```

### CycleGAN
```
T1T2_cyclegan/
    ├ trainA/
    │    └ T1 image
    ├ testA/
    │    └ T1 image
    ├ trainB/
    │    └ T2 image
    └ testB/
          └ T2 image
```

In [1]:
# .zipをルートに展開
!unzip '/content/drive/My Drive/NAMIC.zip'

Archive:  /content/drive/My Drive/NAMIC.zip
   creating: NAMIC/
  inflating: NAMIC/PNL_3Tcaselist_NAMIC.xls  
   creating: NAMIC/t1/
  inflating: NAMIC/t1/desktop.ini    
   creating: NAMIC/t1/NC/
  inflating: NAMIC/t1/NC/01019-t1w.nrrd  
  inflating: NAMIC/t1/NC/01020-t1w.nrrd  
  inflating: NAMIC/t1/NC/01025-t1w.nrrd  
  inflating: NAMIC/t1/NC/01026-t1w.nrrd  
  inflating: NAMIC/t1/NC/01029-t1w.nrrd  
  inflating: NAMIC/t1/NC/01033-t1w.nrrd  
  inflating: NAMIC/t1/NC/01034-t1w.nrrd  
  inflating: NAMIC/t1/NC/01035-t1w.nrrd  
  inflating: NAMIC/t1/NC/01041-t1w.nrrd  
  inflating: NAMIC/t1/NC/01104-t1w.nrrd  
  inflating: NAMIC/t1/NC/desktop.ini  
   creating: NAMIC/t1/SZ/
  inflating: NAMIC/t1/SZ/01011-t1w.nrrd  
  inflating: NAMIC/t1/SZ/01015-t1w.nrrd  
  inflating: NAMIC/t1/SZ/01017-t1w.nrrd  
  inflating: NAMIC/t1/SZ/01018-t1w.nrrd  
  inflating: NAMIC/t1/SZ/01028-t1w.nrrd  
  inflating: NAMIC/t1/SZ/01039-t1w.nrrd  
  inflating: NAMIC/t1/SZ/01042-t1w.nrrd  
  inflating: NAMIC/t1/SZ

In [2]:
!pip install simpleitk

Collecting simpleitk
[?25l  Downloading https://files.pythonhosted.org/packages/f3/cb/a15f4612af8e37f3627fc7fb2f91d07bb584968b0a47e3d5103d7014f93e/SimpleITK-2.0.1-cp36-cp36m-manylinux1_x86_64.whl (44.9MB)
[K     |████████████████████████████████| 44.9MB 98kB/s 
[?25hInstalling collected packages: simpleitk
Successfully installed simpleitk-2.0.1


In [8]:
import os
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
import shutil
import SimpleITK as sitk
from PIL import Image, ImageOps

class MakeDataset():
    """pix2pixやCycleGAN用のデータセットを作成するクラス"""
    def __init__(self, A_path, B_path, dataset_type, img_res=(256, 256), dataset_slice=(0, -1), is_16bit=True):
        self.A_path = A_path
        self.B_path = B_path
        self.dataset_type = dataset_type
        self.dataset_slice = dataset_slice
        self.height = img_res[0]
        self.width = img_res[1]
        if self.dataset_type != "pix2pix" and self.dataset_type != "cyclegan":
            raise Exception('dataset_type error! Please input pix2pix or cyclegan.')
        self.format = np.uint16 if is_16bit else np.uint8
        self.format_range = 65535 if is_16bit else 255
    
    def __make_dir(self):
        """保存用のディレクトリを作成"""
        if self.dataset_type == "pix2pix":
            os.makedirs('T1T2_pix2pix/test/', exist_ok=True)
            os.makedirs('T1T2_pix2pix/train/', exist_ok=True)
        if self.dataset_type == "cyclegan":
            os.makedirs('T1T2_cyclegan/testA/', exist_ok=True)
            os.makedirs('T1T2_cyclegan/trainA/', exist_ok=True)
            os.makedirs('T1T2_cyclegan/testB/', exist_ok=True)
            os.makedirs('T1T2_cyclegan/trainB/', exist_ok=True)

    def __load(self, path):
        """ファイルを読み込み、ndarrayに変換"""
        # 入力：単一のパス、出力：画像のリスト
        imgs = sitk.GetArrayFromImage(sitk.ReadImage(path))
        start = self.dataset_slice[0]
        end = imgs.shape[0] if self.dataset_slice[1] == -1 or self.dataset_slice[1] > imgs.shape[0] else self.dataset_slice[1]
        return imgs[start:end]

    def __joint(self, A, B):
        """pix2pix用に2つの画像を結合する"""
        # 入力：AとBの画像のリスト(3次元)、出力：AとBを左右に結合した画像のリスト(3次元)
        joint_imgs = np.empty((0, self.height, self.width*2))
        for i in range(A.shape[0]):
            joint_img = np.concatenate( (self.__normalize(A[i]), self.__normalize(B[i])), axis=1)
            joint_img = joint_img[np.newaxis]
            joint_imgs = np.append(joint_imgs, joint_img, axis=0)
        return joint_imgs

    def __save(self, data_list, dir_type):
        """画像をuint8で保存する"""
        # 入力：リスト、出力：画像
        for i, data in enumerate(data_list):
            tmp = data if self.dataset_type == 'pix2pix' else self.__normalize(data)
            Image.fromarray(tmp.astype(self.format)).save('/content/T1T2_%s/%s/%s.png'%(self.dataset_type, dir_type, i))
    
    def __normalize(self, img):
        """画像の正規化"""
        # 入力：画像、出力：画素値の最大値で割ることで正規化した画像
        max = img.flatten().max()
        return (img/max)*self.format_range



    def validate(self):
        pass

    def hist(self, data):
        """入力された画像のヒストグラムを表示"""
        # 入力：画像(1枚でも複数でも可)、出力：ヒストグラムの画像
        plt.hist(data.flatten(), bins=100)
        plt.title('histgram')
        plt.show()
        return 0

    def pack(self):
        """データセットの作成"""
        A_path_list = sorted(glob('%s/*.nrrd'%(self.A_path)))
        B_path_list = sorted(glob('%s/*.nrrd'%(self.B_path)))

        if self.dataset_type == 'pix2pix':
            imgs = np.empty((0, self.height, self.width*2))
        else:
            imgsA = np.empty((0, self.height, self.width))
            imgsB = np.empty((0, self.height, self.width))

        for A_path, B_path in zip(A_path_list, B_path_list):
            dataA = self.__load(A_path)
            dataB = self.__load(B_path)
            if self.dataset_type == 'pix2pix':
                imgs = np.append(imgs, self.__joint(dataA, dataB), axis=0)
            else:
                imgsA = np.append(imgsA, dataA, axis=0)
                imgsB = np.append(imgsB, dataB, axis=0)

        self.__make_dir()
        if self.dataset_type == 'pix2pix':
            data_num = imgs.shape[0]
            num = int(data_num * 0.8)
            self.__save(imgs[:num], 'train')
            self.__save(imgs[num:], 'test')
        else:
            data_num = imgsA.shape[0]
            num = int(data_num * 0.8)
            self.__save(imgsA[:num], 'trainA')
            self.__save(imgsB[:num], 'trainB')
            self.__save(imgsA[num:], 'testA')
            self.__save(imgsB[num:], 'testB')
        
        # zip圧縮
        path = '/content/T1T2_%s' % ('pix2pix' if self.dataset_type=='pix2pix' else 'cyclegan')
        shutil.make_archive(path, 'zip', root_dir=path)
        # ファイルコピー
        shutil.copy2(path+'.zip', '/content/drive/My Drive/')

        print("End!")


In [9]:
dataset = MakeDataset('/content/NAMIC/t1/NC', '/content/NAMIC/t2/NC', dataset_type='cyclegan', dataset_slice=(100, 120), is_16bit=False)
dataset.pack()

End!


In [None]:
!rm -r T1T2_cyclegan
!rm -r T1T2_pix2pix

rm: cannot remove 'T1T2_pix2pix': No such file or directory
