# MONAI : Tutorial : 2d_segmentation
## UNet training / evaluation

参照URL:
- https://github.com/Project-MONAI/tutorials/tree/main/2d_segmentation/torch

## 0. 準備

In [1]:
# パッケージのインポート
import os
import sys
from glob import glob

import torch
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
%matplotlib inline

from monai.data import create_test_image_2d, list_data_collate, decollate_batch, DataLoader, Dataset
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    EnsureChannelFirstd,
    AsDiscrete,
    Compose,
    LoadImaged,
    RandCropByPosNegLabeld,
    RandRotate90d,
    ScaleIntensityd,
)
from monai.visualize import plot_2d_or_3d_image
from monai.utils import set_determinism
from monai.config import print_config

print_config()

MONAI version: 1.2.dev2302
Numpy version: 1.24.1
Pytorch version: 1.12.0+cu113
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 708e1a1cf4a1d5516eaf65b8a0bee8887cdee494
MONAI __file__: /home/aska/anaconda3/envs/monai/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: 9.4.0
Tensorboard version: 2.11.2
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: NOT INSTALLED or UNKNOWN VERSION.
tqdm version: 4.64.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.4
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For de

In [2]:
# 乱数シードの設定
set_determinism(seed=2023)

In [3]:
# データフォルダ
temp_dir = os.path.realpath('./data/temp_test_data')
os.makedirs(temp_dir, exist_ok=True)
print(root_dir)

NameError: name 'root_dir' is not defined

## 1. 学習
### 1.1 データ作成

In [None]:
# データ作成
for i in range(40):
    im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
    Image.fromarray((im * 255).astype('uint8')).save(os.path.join(temp_dir, f'img{i:d}.png'))
    Image.fromarray((seg * 255).astype('uint8')).save(os.path.join(temp_dir, f'seg{i:d}.png'))

In [None]:
# 学習 / 評価 = 20 / 20
images = sorted(glob(os.path.join(temp_dir, 'img*.png')))
segs = sorted(glob(os.path.join(temp_dir, 'seg*.png')))
train_files = [{'img': img, 'seg': seg} for img, seg in zip(images[:20], segs[:20])]
val_files = [{'img': img, 'seg': seg} for img, seg in zip(images[20:], segs[20:])]

In [None]:
# オリジナル画像
img = Image.open(train_files[0]['img'])
plt.imshow(img)
plt.axis('off')
plt.show()

In [None]:
# ラベル画像
seg = Image.open(train_files[0]['seg'])
plt.imshow(seg)
plt.axis('off')
plt.show()

## 1.2 データセット, データローダ

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=['img', 'seg']),
        EnsureChannelFirstd(keys=['img', 'seg']),
        ScaleIntensityd(keys=['img', 'seg']),
        RandCropByPosNegLabeld(
            keys=['img', 'seg'], label_key='seg', spatial_size=[96, 96], pos=1, neg=1, num_samples=4
        ),
        RandRotate90d(keys=['img', 'seg'], prob=0.5, spatial_axes=[0, 1]),
    ]
)

val_transforms = Compose(
    [
        LoadImaged(keys=['img', 'seg']),
        EnsureChannelFirstd(keys=['img', 'seg']),
        ScaleIntensityd(keys=['img', 'seg']),
    ]
)

In [None]:
train_ds = Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(
    train_ds, 
    batch_size=2,
    shuffle=True,
    num_workers=4,
    collate_fn=list_data_collate,
    pin_memory=torch.cuda.is_available(),
)

val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)

## 1.3 モデル構築

In [None]:
dice_metric = DiceMetric(include_background=True, reduction='mean', get_not_nans=False)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# モデル
model = monai.networks.nets.UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

print(model)

In [None]:
loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

## 1.4 モデル学習

In [None]:
max_epochs = 10
val_interval = 2

model_dir = './models'

In [None]:
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []

writer = SummaryWriter()
epoch_len = len(train_ds) // train_loader.batch_size

# epochループ
for epoch in range(max_epochs):
    print('-' * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    # mini batchループ
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data['img'].to(device), batch_data['seg'].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        writer.add_scalar('train_loss', loss.item(), epoch_len * epoch + step)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    
    # 評価
    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            val_images = None
            val_labels = None
            val_outputs = None
            for val_data in val_loader:
                val_images, val_labels = val_data['img'].to(device), val_data['seg'].to(device)
                roi_size = (96, 96)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                dice_metric(y_pred=val_outputs, y=val_labels)
            
            metric = dice_metric.aggregate().item()
            dice_metric.reset()
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(model_dir, 'best_metric_model_segmentation2d_dict.pth'))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f} "
                f"best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}"
            )
            writer.add_scalar('val_mean_dice', metric, epoch + 1)
            plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag='image')
            plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag='label')
            plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag='output')
            
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()