In [None]:
pip install numpy==1.26.0

In [1]:
import numpy as np

print(np.__version__)

1.26.0


In [2]:
import logging
import os
import sys
import tempfile
from glob import glob

import torch
from PIL import Image
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import ArrayDataset, create_test_image_2d, decollate_batch, DataLoader
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
    LoadImage,
    RandRotate90,
    RandSpatialCrop,
    ScaleIntensity,
)
from monai.visualize import plot_2d_or_3d_image

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


In [32]:
import torch
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

device

device(type='mps')

In [33]:
import torch

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [34]:
import torch

if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print(x)
else:
    print("MPS device not found.")

tensor([1.], device='mps:0')


In [42]:
import glob

def main(tempdir):
    
    # Define paths to image and mask directories
    image_dir = os.path.join(tempdir, "image")
    mask_dir = os.path.join(tempdir, "mask")

    # Read image and mask file paths
    images = sorted(glob.glob(os.path.join(image_dir, "*.png")))
    masks = sorted(glob.glob(os.path.join(mask_dir, "*.png")))

    if len(images) != len(masks):
        raise ValueError("The number of images and masks must be the same.")
    
    # define transforms for image and segmentation
    train_imtrans = Compose(
        [
            LoadImage(image_only=True, ensure_channel_first=True),
            ScaleIntensity(),
            RandSpatialCrop((128, 128), random_size=False),
            RandRotate90(prob=0.5, spatial_axes=(0, 1)),
        ]
    )
    train_masktrans = Compose(
        [
            LoadImage(image_only=True, ensure_channel_first=True),
            ScaleIntensity(),
            RandSpatialCrop((128, 128), random_size=False),
            RandRotate90(prob=0.5, spatial_axes=(0, 1)),
        ]
    )
    val_imtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()])
    val_masktrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()])

    # define array dataset, data loader
    check_ds = ArrayDataset(images, train_imtrans, masks, train_masktrans)
    check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=True)
    im, msk = monai.utils.misc.first(check_loader)
    print(im.shape, msk.shape)
    
    # create a training data loader
    train_ds = ArrayDataset(images[:20], train_imtrans, masks[:20], train_masktrans)
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.backends.mps.is_available())
    # create a validation data loader
    val_ds = ArrayDataset(images[-20:], val_imtrans, masks[-20:], val_masktrans)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        spatial_dims=2,
        in_channels=3,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(10):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{10}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                    roi_size = (128, 128)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                    val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                    # compute metric for current iteration
                    dice_metric(y_pred=val_outputs, y=val_labels)
                # aggregate the final mean dice result
                metric = dice_metric.aggregate().item()
                # reset the status for next validation round
                dice_metric.reset()
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")

    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()

In [43]:
if __name__ == "__main__":
    main("/Users/nittin_murthi/Documents/VS_Code/MONAI-UNET/segmentation_data/malignant/")

torch.Size([10, 3, 128, 128]) torch.Size([10, 1, 128, 128])
----------
epoch 1/10
1/5, train_loss: 0.8579
2/5, train_loss: 0.6299
3/5, train_loss: 0.9263
4/5, train_loss: 0.8615
5/5, train_loss: 0.8319
epoch 1 average loss: 0.8215
----------
epoch 2/10
1/5, train_loss: 0.8687
2/5, train_loss: 0.7825
3/5, train_loss: 0.8348
4/5, train_loss: 0.8459
5/5, train_loss: 0.9885
epoch 2 average loss: 0.8641
saved new best metric model
current epoch: 2 current mean dice: 0.2361 best mean dice: 0.2361 at epoch 2
----------
epoch 3/10
1/5, train_loss: 0.8936
2/5, train_loss: 0.8390
3/5, train_loss: 0.6793
4/5, train_loss: 0.4085
5/5, train_loss: 0.8012
epoch 3 average loss: 0.7243
----------
epoch 4/10
1/5, train_loss: 0.7385
2/5, train_loss: 0.8314
3/5, train_loss: 0.6632
4/5, train_loss: 0.8089
5/5, train_loss: 0.8965
epoch 4 average loss: 0.7877
saved new best metric model
current epoch: 4 current mean dice: 0.2508 best mean dice: 0.2508 at epoch 4
----------
epoch 5/10
1/5, train_loss: 0.9457
