In [1]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
    LabelToMaskd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm, Reshape
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
from torchvision.utils import save_image
import numpy as np

In [2]:
 def get_data(root_dir, resource, name: str):

    compressed_file = os.path.join(root_dir, name)
    
    if not os.path.exists(compressed_file):
        download_and_extract(resource, compressed_file, root_dir)

In [3]:
directory = os.environ.get('MONAI_DATA_DIRECTORY', "MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
data_dir = os.path.join(root_dir, "data")

get_data(root_dir, "https://huggingface.co/datasets/pytc/EM30/resolve/main/EM30-R-im.zip", "EM30-R-im.zip")

get_data(root_dir, "https://huggingface.co/datasets/pytc/MitoEM/resolve/main/EM30-R-mito-train-val-v2.zip?download=true", "EM30-R-mito-train-val-v2.zip")


In [4]:
images = sorted(glob.glob(os.path.join(root_dir, "im", "im*.png")))
labels = sorted(glob.glob(os.path.join(root_dir, "EM30-R-mito-train-val-v2/mito-train-v2", "seg*.tif")))
labels += sorted(glob.glob(os.path.join(root_dir, "EM30-R-mito-train-val-v2/mito-val-v2", "seg*.tif")))


In [5]:
data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(images, labels)
]

In [6]:
train_files, val_files = data_dicts[:166] + data_dicts[168:400], data_dicts[400:500]


print(len(train_files), len(val_files))

398 100


In [7]:
set_determinism(seed=0)

train_transforms = Compose(
    [LoadImaged(
        keys=["image", "label"]
    ),
    EnsureChannelFirstd(
        keys=["image", "label"]
    ),
    ScaleIntensityRanged(
        keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True
    ),
    CropForegroundd(
        keys=["image", "label"], source_key="image",allow_smaller=True
    ),
    Orientationd(
        keys=["image", "label"], axcodes="RAS"
    ),
    Spacingd(
        keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")
    ),
    RandCropByPosNegLabeld(
        keys=["image", "label"], label_key="label", spatial_size=(80, 80, 80), pos=1, neg=1, num_samples=4, image_key="image", image_threshold=0
    ),
    LabelToMaskd(keys=["label"], select_labels=[0, 1])  # Ensure labels are within the expected range
    ]   
)


val_transforms = Compose(
    [LoadImaged(
        keys=["image", "label"]
    ),
    EnsureChannelFirstd(
        keys=["image", "label"]
    ),
    ScaleIntensityRanged(
        keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True
    ),
    CropForegroundd(
        keys=["image", "label"], source_key="image", allow_smaller=True
    ),
    Orientationd(
        keys=["image", "label"], axcodes="RAS"
    ),
    Spacingd(
        keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")
    ),
    LabelToMaskd(keys=["label"], select_labels=[0, 1])  # Ensure labels are within the expected range
    ]
)


In [8]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        #Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(80, 80),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
        LabelToMaskd(keys=["label"], select_labels=[0, 1]),  # Ensure labels are within the expected range

    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        #Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        LabelToMaskd(keys=["label"], select_labels=[0, 1]),  # Ensure labels are within the expected range
    ]
)

monai.transforms.croppad.dictionary CropForegroundd.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.


check_ds = Dataset(data=val_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
image, label = (check_data["image"][0][0], check_data["label"][0][0])


print(f"image shape: {image.shape}, label shape: {label.shape}")

save_image(image, "check/image1.png")
save_image(label, "check/label1.tif")

image = image.unsqueeze(-1).numpy()
label = label.unsqueeze(-1).numpy()
print(f"image shape after reshape: {image.shape}, label shape: {label.shape}")


# plot the slice [:, :, 80]
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:,:,0], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :,0])
plt.show()

In [9]:
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=2)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2)

val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=2)

Loading dataset: 100%|██████████| 398/398 [07:02<00:00,  1.06s/it]
Loading dataset: 100%|██████████| 100/100 [01:48<00:00,  1.08s/it]


In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(
    spatial_dims=2,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(sigmoid=True, to_onehot_y=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")

In [11]:
from datetime import datetime

timestamp = datetime.now().strftime("%d%m%Y-%H%M%S")
save_dir = 'best_models'
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f"best_model_{timestamp}.pth")

In [None]:
max_epochs = 600
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose(AsDiscrete(argmax=True, to_onehot=2))
post_label = Compose(AsDiscrete(to_onehot=2))

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        #print(f"{step}/{len(train_ds) // train_loader.batch_size}, train_loss: {loss.item():.4f}")
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    #print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = val_data["image"].to(device), val_data["label"].to(device)
                roi_size = (160, 160)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                dice_metric(y_pred=val_outputs, y=val_labels)
            metric = dice_metric.aggregate().item()
            dice_metric.reset()

            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(save_dir, f"best_model_{timestamp}.pth"))
                torch.save(val_data["image"], os.path.join("best_val_outputs", f"best_val_data_{timestamp}.pth"))
                print("saved new best metric model")
                print(
                    f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
                )
            #print(
             #   f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
              #  f"\nbest mean dice: {best_metric:.4f} at epoch: {best_metric_epoch}"
            #)


----------
epoch 1/600
----------
epoch 2/600
saved new best metric model

best mean dice: 0.8973 at epoch: 2
----------
epoch 3/600
----------
epoch 4/600
saved new best metric model

best mean dice: 0.8996 at epoch: 4
----------
epoch 5/600
----------
epoch 6/600
saved new best metric model

best mean dice: 0.9640 at epoch: 6
----------
epoch 7/600
----------
epoch 8/600


In [None]:
print("end")