In [None]:
import glob
import matplotlib.pyplot as plt

def main(tempdir):
    
    # Define paths to image and mask directories
    image_dir = os.path.join(tempdir, "image")
    mask_dir = os.path.join(tempdir, "mask")

    # Read image and mask file paths
    images = sorted(glob.glob(os.path.join(image_dir, "*.png")))
    masks = sorted(glob.glob(os.path.join(mask_dir, "*.png")))

    if len(images) != len(masks):
        raise ValueError("The number of images and masks must be the same.")
    
    # Print the total number of samples in the dataset
    total_samples = len(images)
    print(f"Total number of samples in the dataset: {total_samples}")
    
    # define transforms for image and segmentation
    train_imtrans = Compose(
        [
            LoadImage(image_only=True, ensure_channel_first=True),
            ScaleIntensity(),
            RandSpatialCrop((128, 128), random_size=False),
            RandRotate90(prob=0.5, spatial_axes=(0, 1)),
        ]
    )
    train_masktrans = Compose(
        [
            LoadImage(image_only=True, ensure_channel_first=True),
            ScaleIntensity(),
            RandSpatialCrop((128, 128), random_size=False),
            RandRotate90(prob=0.5, spatial_axes=(0, 1)),
        ]
    )
    val_imtrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()])
    val_masktrans = Compose([LoadImage(image_only=True, ensure_channel_first=True), ScaleIntensity()])

    # define array dataset, data loader
    check_ds = ArrayDataset(images, train_imtrans, masks, train_masktrans)
    check_loader = DataLoader(check_ds, batch_size=10, num_workers=2, pin_memory=True)
    im, msk = monai.utils.misc.first(check_loader)
    print(im.shape, msk.shape)
    
    # Display some raw images and their corresponding segmentation masks
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    for i in range(2):
        for j in range(2):
            idx = i * 2 + j
            axes[i, j].imshow(im[idx].permute(1, 2, 0).byte().numpy())
            axes[i, j].imshow(msk[idx].squeeze().byte().numpy(), alpha=0.5)
            axes[i, j].axis('off')
    plt.suptitle("Raw Images and Segmentation Masks")
    plt.show()

    # Split the dataset into training and validation sets
    train_size = int(0.8 * total_samples)
    val_size = total_samples - train_size
    train_images, val_images = images[:train_size], images[train_size:]
    train_masks, val_masks = masks[:train_size], masks[train_size:]

    # create a training data loader
    train_ds = ArrayDataset(train_images, train_imtrans, train_masks, train_masktrans)
    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.backends.mps.is_available())
    
    # create a validation data loader
    val_ds = ArrayDataset(val_images, val_imtrans, val_masks, val_masktrans)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())
    
    # Display some augmented images and their corresponding segmentation masks
    aug_images, aug_masks = next(iter(train_loader))
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    for i in range(2):
        for j in range(2):
            idx = i * 2 + j
            axes[i, j].imshow(aug_images[idx].permute(1, 2, 0).byte().numpy())
            axes[i, j].imshow(aug_masks[idx].squeeze().byte().numpy(), alpha=0.5)
            axes[i, j].axis('off')
    plt.suptitle("Augmented Images and Segmentation Masks")
    plt.show()

    # Rest of the code remains the same
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        spatial_dims=2,
        in_channels=3,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    
    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(10):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{10}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                    roi_size = (128, 128)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                    val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                    # compute metric for current iteration
                    dice_metric(y_pred=val_outputs, y=val_labels)
                # aggregate the final mean dice result
                metric = dice_metric.aggregate().item()
                # reset the status for next validation round
                dice_metric.reset()
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best_metric_model_segmentation2d_array.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                        epoch + 1, metric, best_metric, best_metric_epoch
                    )
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")

    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()

In [None]:
if __name__ == "__main__":
    main("/Users/nittin_murthi/Documents/VS_Code/MONAI-UNET/segmentation_data/malignant/")