다음을 리뷰 :
https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_array.py

In [None]:
import torch
from torch.utils.data import DataLoader


import monai
## decollate_batch : 배치 텐서를 리스트의 텐서로 변환
from monai.data import ImageDataset, create_test_image_3d, decollate_batch
from monai.transforms import Activations, AddChannel, AsDiscrete, Compose, RandRotate90
from monai.transforms import RandSpatialCrop, ScaleIntensity, EnsureType
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.visualize import plot_2d_or_3d_image

# tensorboard가 읽을 수 있도록 loss, metric, out image, out_seg 를 기록 
from torch.utils.tensorboard import SummaryWriter

import nibabel as nib

import logging
import os
import sys
import numpy as np
from glob import glob
import matplotlib.pyplot as plt

**ImageDataset**
* 이건 image, seg 파일 배열을 넣어주면 데이터셋을 만들어주는 API인듯하다
* segmentation 한정 사용가능

In [None]:
tempdir = './dataset'

In [None]:
monai.config.print_config()
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)

# 디렉토리에 40개 랜덤이미지, 마스크 생성
print(f"generating synthetic data to {tempdir} (this may take a while)")
for i in range(40):
    # np image 생성
    im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)  
#     print(type(im), type(seg))  # np.array
#     print(im.shape, seg.shape)  # (128, 128, 128) (128, 128, 128) 3d 라서 img, seg가 같은 dim인 듯?
#     print(im)

#     print(np.eye(4))  # 2차원 identity 행렬(4x4) 생성
    n = nib.Nifti1Image(im, np.eye(4))
#     print(type(n), n)   # nifti 이미지
    nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))

    n = nib.Nifti1Image(seg, np.eye(4))
    nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

In [None]:
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

### image, seg 파일 리스트 Loading

In [None]:
images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))    # 40개 nifti file 리스트
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))
len(images), len(segs)

In [None]:
# 실제 MRI도 이런진 확인 필요.
im.max(), im.min(), seg.max(), seg.min()

### Transform 정의

In [None]:
# define transforms for image and segmentation
train_imtrans = Compose(
    [
        ScaleIntensity(),
        AddChannel(),
        # aug
        RandSpatialCrop((96, 96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 2)),
        EnsureType(),
    ]
)
train_segtrans = Compose(
    [
        # 스케일링 필요없나 봄 (1또는 0이므로)
        AddChannel(),
        # aug (img와 같은 aug를 해주는가..? -> 맞음.. 왜인진.. 모름)
        RandSpatialCrop((96, 96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 2)),
        EnsureType(),
    ]
)
val_imtrans = Compose([ScaleIntensity(), AddChannel(), EnsureType()])
val_segtrans = Compose([AddChannel(), EnsureType()])

train_ds = ImageDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans)
# worker : cores threads
# pin_memory : If ``True``, the data loader will copy Tensors into CUDA pinned memory before returning them
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available())
# create a validation data loader
val_ds = ImageDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=3, pin_memory=torch.cuda.is_available())

ImageDataset은 파일 경로로부터 데이터를 로딩하는 API.

train_ds 는 데이터셋 객체 이지만 index를 이용하여 접근 가능. image, seg (tensor([[]]), tensor([[]])) 의 형태

### 데이터셋 확인
그려 본다.

In [None]:
# plt.subplots(1, 2, figsize=(8, 8))
# for each in train_loader:
#     img = each[0].numpy()
#     seg = each[1].numpy()
#     print(type(img), img.shape)    # <class 'numpy.ndarray'> (4, 1, 96, 96, 96) batch, C, W, H, D
#     print(type(seg), seg.shape)    # segmentation에선 shape이 GT와 img가 같다.
#     img = img[0, 0, 0, :, :]
#     seg = seg[0, 0, 0, :, :]
#     print(img.shape)
# #     plt.imshow(img, cmap="gray", vmin=0, vmax=255)
# #     plt.imshow(img, cmap="gray")
#     plt.subplot(1, 2, 1)
#     plt.xlabel('img')
#     plt.imshow(img)
#     plt.subplot(1, 2, 2)
#     plt.xlabel('seg')
#     plt.imshow(seg)
# #     raise AssertionError("!!!")
# plt.tight_layout()
# plt.show()

### post-process, metrics, model 정의

In [None]:
# 0~1로 맞추고 0.5로 threshold
# AsDiscrete는 onehot도 가능
# 하나씩 빼보면서 더 해볼필요
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
# dice의 parameters는 확인 필요
# 왜 필요한지 알아보자.(validation check을 위해서 인듯하다.)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

# os.environ["CUDA_VISIBLE_DEVICES"] = '1'
device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")

model = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

```python
## activation sigmoid, asdiscrete 테스트

test = [-1, -0.4, 0.2, 0.4, 0.8, 1.5]
post_trans_test1 = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
print(post_trans_test(test))
# [tensor(0.), tensor(0.), tensor(1.), tensor(1.), tensor(1.), tensor(1.)]

post_trans_test2 = Compose([EnsureType(), Activations(sigmoid=True)])
print(post_trans_test2(test))
# [tensor(0), tensor(0.4013), tensor(0.5498), tensor(0.5987), tensor(0.6900), tensor(0.8176)]
```

In [None]:
model

### model 학습

In [None]:
## train set 사이즈와 안맞춰주면 정확도 확 감소

model = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

epochs = 5
epoch_loss_values = list()
val_interval = 2

for epoch in range(epochs):
#     pass
    print('-'*20)
    print(f'epoch: {epoch}/{epochs} ')
    
    epoch_loss=0
    model.train()
    step = 0
    for batch in train_loader:
        step += 1
        inputs, labels = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss
        epoch_len = len(train_ds) // train_loader.batch_size
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    
    if (epoch + 1) % val_interval == 0:  # 2번 epoch마다 validation.
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
#                 print(f'original shape : {val_images.shape}')
                val_outputs = model(val_images)
#                 print(f'val_out shape: {val_outputs.shape}')
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                dice_metric(y_pred=val_outputs, y=val_labels)
                print(dice_metric.aggregate())
            print(dice_metric.aggregate())
            dice_metric.reset()
        

-----------------

In [None]:
model = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

epochs = 5
epoch_loss_values = list()
val_interval = 2
metric_values = list()
best_metric = -1
best_epoch = -1

writer = SummaryWriter()

for epoch in range(epochs):
#     pass
    print('-'*20)
    print(f'epoch: {epoch + 1}/{epochs} ')
    
    epoch_loss=0
    model.train()
    step = 0
    for batch in train_loader:
        step += 1
        inputs, labels = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss
        epoch_len = len(train_ds) // train_loader.batch_size
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    
    if (epoch + 1) % val_interval == 0:  # 2번 epoch마다 validation.
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                roi_size = (96, 96, 96)
                sw_batch_size = 4
                
                # validation image가 model을 training하는데 썼던 traing보다 크기가 큰경우 sliding 방식으로 추론
                # 나오는 output은 original validation img size만큼 나오게 되나 더 정확한 추론.
                val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
#                 print(f'val_out : {val_outputs.shape}')    ## [1, 1, 128, 128, 128]

                ## decollate_batch는 post_transform과 함께 쓰임.
#                 val_outputs_test = [post_trans(i) for i in val_outputs]
#                 print(f'val_out shape-test: {len(val_outputs_test)} {val_outputs_test}')
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
#                 print(f'val_out shape: {len(val_outputs)} {val_outputs}')
                dice_metric(y_pred=val_outputs, y=val_labels)
#                 dice_metric(y_pred=val_outputs_test, y=val_labels)  # 같다
#                 print(dice_metric.aggregate())
    
            # aggregate the final mean dice result
#             print(dice_metric.aggregate())
            metric = dice_metric.aggregate().item()
            dice_metric.reset()
        
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_epoch = epoch + 1
                torch.save(model.state_dict(), "./models/best_metric_model_segmentation3d_array.pth")
                print("saved new best metric model")
            
            print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric, best_metric, best_epoch )
                )
            writer.add_scalar("val_mean_dice", metric, epoch + 1)
            # plot the last model output as GIF image in TensorBoard with the corresponding image and label
            plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
            plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
            plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
            
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_epoch}")
writer.close()
                