다음을 리뷰 :
https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_training_dict.py

In [1]:
import os
import sys
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
import nibabel as nib    # nifti 포맷 파일 생성때만 이용

import torch
from torch.utils.data import DataLoader
import monai
## decollate_batch : 배치 텐서를 리스트의 텐서로 변환
from monai.data import create_test_image_3d, decollate_batch
# from monai.transforms import RandSpatialCrop, ScaleIntensity, EnsureType
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.visualize import plot_2d_or_3d_image
# tensorboard가 읽을 수 있도록 loss, metric, out image, out_seg 를 기록 
from torch.utils.tensorboard import SummaryWriter

## 삭제
from monai.data import ImageDataset
from monai.transforms import AddChannel

# 이건 image, seg 파일 배열을 넣어주면 데이터셋을 만들어주는 API인듯하다. segmentation 한정 사용가능
# train_ds = ImageDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans)

# AddChannel : 맨 앞단 1 차원 삽입 (ex) torch.Size([6]) -> torch.Size([1, 6]) 


## 새롭게 추가
from monai.data import Dataset   # dict에선 ImageDataset대신 이용
from monai.data import list_data_collate
from monai.transforms import (
    Activations,
    AsChannelFirstd,
    AsDiscrete,
    Compose,
    LoadImaged,
    RandCropByPosNegLabeld,  # randomly crop patch samples from big image based on pos / neg ratio.
    RandRotate90d,
    ScaleIntensityd,
    EnsureTyped,
    EnsureType,
)


import logging

```python
# AddChannel Test
test = np.array([-1, -0.4, 0.2, 0.4, 0.8, 1.5])
post_trans_test1 = Compose([EnsureType(), AddChannel()])
print(post_trans_test1(test).shape, post_trans_test1(test))

post_trans_test2 = Compose([EnsureType()])
print(post_trans_test2(test).shape, post_trans_test2(test))
```

In [2]:
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

### making random 3D segmentation dataset

In [3]:
tempdir = './dataset'
monai.config.print_config()


# 디렉토리에 40개 랜덤이미지, 마스크 생성
print(f"generating synthetic data to {tempdir} (this may take a while)")
for i in range(40):
    # np image 생성
    im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1)  
#     print(type(im), type(seg))  # np.array
#     print(im.shape, seg.shape)  # (128, 128, 128) (128, 128, 128) 3d 라서 img, seg가 같은 dim인 듯?
#     print(im)

#     print(np.eye(4))  # 2차원 identity 행렬(4x4) 생성
    n = nib.Nifti1Image(im, np.eye(4))
#     print(type(n), n)   # nifti 이미지
    nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))

    n = nib.Nifti1Image(seg, np.eye(4))
    nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

MONAI version: 0.9.dev2152
Numpy version: 1.21.2
Pytorch version: 1.10.0a0+0aef44c
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: c5bd8aff8ba461d7b349eb92427d452481a7eb72

Optional dependencies:
Pytorch Ignite version: 0.4.6
Nibabel version: 3.2.1
scikit-image version: 0.18.3
Pillow version: 8.4.0
Tensorboard version: 2.6.0
gdown version: 4.2.0
TorchVision version: 0.11.0a0
tqdm version: 4.62.3
lmdb version: 1.2.1
psutil version: 5.8.0
pandas version: 1.3.4
einops version: 0.3.2
transformers version: 4.12.5
mlflow version: 1.21.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies

generating synthetic data to ./dataset (this may take a while)


### image, seg 파일 dict Loading

* 참고: list form에선, 이미지리스트를 그대로 사용하여 ds를 만듦
```python 
# list form
train_ds = ImageDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans)
val_ds = ImageDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans)
````


In [None]:
images = sorted(glob(os.path.join(tempdir, "img*.nii.gz")))    # 40개 nifti file 리스트
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

train_files = [{"img": img, "seg": seg}for img, seg in zip(images[:20], segs[:20])]
val_files = [{"img": img, "seg": seg}for img, seg in zip(images[-20:], segs[-20:])]

### Transform 정의
* 참고: list form에선, img, seg에 해당하는 각각의 transform을 따로 정의하여 사용함
```python
train_imtrans = Compose(
    [
        ScaleIntensity(),
        AddChannel(),
        # aug
        RandSpatialCrop((96, 96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 2)),
        EnsureType(),
    ]
)
train_segtrans = Compose(
    [
        # 스케일링 필요없나 봄 (1또는 0이므로)
        AddChannel(),
        # aug (img와 같은 aug를 해주는가..? -> 맞음.. 왜인진.. 모름)
        RandSpatialCrop((96, 96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 2)),
        EnsureType(),
    ]
)
```

In [None]:

train_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),   # list에선 ImageDataset을 대신썻기 때문에 LoadImage가 포함되어있었음
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),   # scaling은 img에만.
        # aug
        RandCropByPosNegLabeld(
            keys=["img", "seg"],
            label_key="seg",
            spatial_size=[96, 96, 96],
            pos=1,   # foreground voxel as a center rather than a background voxel. ``pos / (pos + neg)``
            neg=1,
            num_samples=4    # 1개 이미지당 4개 결과생성. 즉 4배로 뻥튀기
        ),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
        EnsureTyped(keys=["img", "seg"])
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),  
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"), 
        EnsureTyped(keys=["img", "seg"])
    ]
)

# 잘 되는지 프로세스 검증
check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
# RandCrop에서 samples 4개 만들면, 이게 리스트로 만들어짐. 이걸 풀어서 tensor 하나에 8개(2x4)를 넣어주는게 list_data_collate임
check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate)
# check_loader = DataLoader(check_ds, batch_size=2, num_workers=4)
check_data = monai.utils.misc.first(check_loader)
print(check_data["img"].shape, check_data["seg"].shape)

# torch.Size([8, 1, 96, 96, 96]) torch.Size([8, 1, 96, 96, 96])

In [None]:
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(
    train_ds,
    batch_size=2,
    shuffle=True,
    num_workers=4,
    collate_fn=list_data_collate,
    pin_memory=torch.cuda.is_available(),
)
# create a validation data loader
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)

### post-proc, metric, model, loss,  정의
* 아예 차이 없음

### training 정의
* data 분리할 때 외엔 거의 차이 없음

In [None]:
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
model = monai.networks.nets.UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

epochs = 5
epoch_loss_values = list()   # for training loss
val_interval = 1
metric_values = list()   # for validation metric
best_metric = -1
best_epoch = -1
writer = SummaryWriter()

for epoch in range(epochs):
    print('-'*20)
    print(f'epoch: {epoch + 1}/{epochs} ')
    
    model.train()
    epoch_loss = 0
    step = 0
    for batch in train_loader:
        step += 1
        inputs, labels = batch["img"].to(device), batch["seg"].to(device)  # dic form
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    
    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            val_images = None
            val_labels = None
            val_outputs = None
            for val_data in val_loader:
                val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)  # dic form
                roi_size = (96, 96, 96)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                dice_metric(y_pred=val_outputs, y=val_labels)
            metric = dice_metric.aggregate().item()
            dice_metric.reset()
            metric_values.append(metric)

            if metric > best_metric:
                best_metric = metric
                best_epoch = epoch + 1
                torch.save(model.state_dict(), "./models/best_metric_model_segmentation3d_dict.pth")
                print("saved new best metric model")

            print(
                "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                    epoch + 1, metric, best_metric, best_epoch )
            )
            writer.add_scalar("val_mean_dice", metric, epoch + 1)
            # plot the last model output as GIF image in TensorBoard with the corresponding image and label
            plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
            plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
            plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
                
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_epoch}")
writer.close()