In [1]:
import torch
print("Cuda_available? ", torch.cuda.is_available())
print("Name: ", torch.cuda.get_device_name(0))
print("Device count: ", torch.cuda.device_count())
print("current device: ", torch.cuda.current_device())
print("Select device: ", torch.cuda.device(0))
print("get device name: ",torch.cuda.get_device_name(0))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #finally got the solution

Cuda_available?  True
Name:  NVIDIA RTX A5000
Device count:  1
current device:  0
Select device:  <torch.cuda.device object at 0x7f9aab468cd0>
get device name:  NVIDIA RTX A5000


In [2]:
data_dir = 'data/ZIP'

In [3]:
import glob
import os
import nibabel as nib
from monai.data import CacheDataset,    Dataset,    DataLoader,    decollate_batch
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.transforms import AsDiscrete
from monai.inferers import sliding_window_inference
from monai.transforms import Resize
import numpy as np

In [4]:
images_pt = sorted(glob.glob(os.path.join(data_dir, 'PETCT*', '*', "SUV*")))
images_ct = sorted(glob.glob(os.path.join(data_dir, 'PETCT*', '*', "CTres*")))
images_sg = sorted(glob.glob(os.path.join(data_dir, 'PETCT*', '*', "SEG*")))
data_dicts = [
            {"image_pt": image_name_pt, "image_ct": image_name_ct, "image_sg": image_name_sg}
            for image_name_pt, image_name_ct, image_name_sg in zip(images_pt, images_ct, images_sg)
        ]
keys = ["image_pt", "image_ct", "image_sg"]
train_files, val_files = data_dicts[:70], data_dicts[70:]

In [5]:
from monai.transforms import (
Compose,
LoadImaged,
EnsureChannelFirstd,
Spacingd,
Orientationd,
ScaleIntensityRanged,
CropForegroundd,
ConcatItemsd,
DivisiblePadd,
AddChanneld,
RandCropByPosNegLabeld,
RandFlipd,EnsureTyped,
RandRotate90d,
RandShiftIntensityd,
Resized
)

In [6]:
crop_val = 96
train_transforms = Compose(
    [
        LoadImaged(keys=keys),
        AddChanneld(keys=keys),
        ScaleIntensityRanged(
            keys=["image_ct"], a_min=-100, a_max=250,
             b_min=0.0, b_max=1.0, clip=False,
        ),
        ScaleIntensityRanged(
            keys=["image_pt"], a_min=0, a_max=15,
            b_min=0.0, b_max=1.0, clip=False,
        ),
        #CropForegroundd(keys=["image_ct", "image_pt"], source_key="image_ct"),
        Orientationd(keys=keys, axcodes="LAS"),
        # Resized(keys=keys, spatial_size=(256, 256, 256)), # resizing all keys
        ConcatItemsd(keys=["image_pt", "image_ct"], name="image_petct", dim=0),
        # Add some augmentations
        RandCropByPosNegLabeld(keys=["image_petct", "image_sg"], label_key="image_sg", spatial_size=(crop_val, crop_val, crop_val), pos=1, neg=1,
            num_samples=6,
            image_key="image_petct",
            image_threshold=0,
        ),
        RandRotate90d(
            keys=["image_petct", "image_sg"],
            prob=0.10,
            max_k=3,
        ),
        RandFlipd(
            keys=["image_petct", "image_sg"],
            spatial_axis=[0, 1, 2],
            prob=0.10,
        ),
        RandShiftIntensityd(
            keys=["image_petct"],
            offsets=0.10,
            prob=0.50,
        ),
        EnsureTyped(keys=["image_petct", "image_sg"]),
    ]
)



In [7]:
val_transforms = Compose(
    [
        LoadImaged(keys=keys),
        AddChanneld(keys=keys),
        ScaleIntensityRanged(
            keys=["image_ct"], a_min=-100, a_max=250,
            b_min=0.0, b_max=1.0, clip=False,
        ),
        ScaleIntensityRanged(
            keys=["image_pt"], a_min=0, a_max=15,
            b_min=0.0, b_max=1.0, clip=False,
        ),
        #CropForegroundd(keys=["image_ct", "image_pt"], source_key="image_ct"),
        Orientationd(keys=keys, axcodes="LAS"),
        # Resized(keys=keys, spatial_size=(256, 256, 256)), # resizing all keys
        ConcatItemsd(keys=["image_pt", "image_ct"], name="image_petct", dim=0),
        EnsureTyped(keys=["image_petct", "image_sg"]),
    ]
)


In [12]:
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.03)
#train_ds = Dataset(data=train_files, transform=train_transforms)
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.05)
#val_ds = Dataset(data=val_files, transform=val_transforms)

Loading dataset: 100%|██████████| 2/2 [00:05<00:00,  2.89s/it]
Loading dataset: 100%|██████████| 1/1 [00:03<00:00,  3.07s/it]


In [13]:
# Create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
model = UNet(
    spatial_dims=3,
    in_channels=2,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH
).to(device)

loss_function = DiceLoss(to_onehot_y=2, softmax=True)
lr = 1e-2
optimizer = torch.optim.Adam(model.parameters(), lr)
dice_metric = DiceMetric(include_background=True, reduction="mean")

In [14]:
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=0)
val_loader  = DataLoader(val_ds, batch_size=1, shuffle=True, num_workers=0)

In [15]:
#tranformation check
for batch_data in train_loader:
  label = batch_data["image_sg"]
  label = label != 0
        #print('The label values: ', np.unique(label))
  print(label.shape)

  inputs, labels = (
     batch_data["image_petct"].to(device),
     label.to(device),
  )
  print(inputs.shape)
  print(labels.shape)
  break



torch.Size([6, 1, 96, 96, 96])
torch.Size([6, 2, 96, 96, 96])
torch.Size([6, 1, 96, 96, 96])


In [17]:
max_epochs = 50
val_interval = 1
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])
checkpoint_path = os.path.join(data_dir, "checkpoint.pth")

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    best_metric = checkpoint['best_metric']
    best_metric_epoch = checkpoint['best_metric_epoch']
    epoch_loss_values = checkpoint['epoch_loss_values']
    metric_values = checkpoint['metric_values']
else:
    start_epoch = 0

for epoch in range(start_epoch,max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image_petct"].to(device),
            batch_data["image_sg"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    if epoch % 6 == 0
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_metric': best_metric,
            'best_metric_epoch': best_metric_epoch,
            'epoch_loss_values': epoch_loss_values,
            'metric_values': metric_values
        }, checkpoint_path)

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image_petct"].to(device),
                    val_data["image_sg"].to(device),
                )
                roi_size = (160, 160, 160)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()

            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), os.path.join(data_dir, "best_metric_model.pth"))
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}"
            )

----------
epoch 1/50




1/70, train_loss: 1.0000
2/70, train_loss: 0.9794
3/70, train_loss: 0.9922
4/70, train_loss: 0.9889
5/70, train_loss: 0.9563
6/70, train_loss: 1.0000
7/70, train_loss: 0.9981




8/70, train_loss: 1.0000
9/70, train_loss: 0.9998
10/70, train_loss: 0.2812




11/70, train_loss: 1.0000
12/70, train_loss: 0.9995
13/70, train_loss: 1.3508




14/70, train_loss: 1.0000
15/70, train_loss: 1.0073
16/70, train_loss: 0.9893
17/70, train_loss: 0.7405
18/70, train_loss: 0.7329
19/70, train_loss: 1.0026
20/70, train_loss: 1.0000
21/70, train_loss: 1.0008
22/70, train_loss: 0.9412
23/70, train_loss: 0.9995
24/70, train_loss: 0.9922
25/70, train_loss: 0.9785




26/70, train_loss: 1.0000
27/70, train_loss: 0.9991
28/70, train_loss: -1.9065
29/70, train_loss: 1.0070
30/70, train_loss: 1.0248
31/70, train_loss: 1.0078
32/70, train_loss: 0.9916




33/70, train_loss: 1.0000
34/70, train_loss: 0.9393
35/70, train_loss: 1.0575
36/70, train_loss: 0.9872
37/70, train_loss: 1.0129




38/70, train_loss: 1.0000
39/70, train_loss: 0.9276
40/70, train_loss: 1.0000
41/70, train_loss: 0.9718
42/70, train_loss: 1.0366
43/70, train_loss: 0.9945
44/70, train_loss: 0.9955
45/70, train_loss: 0.9980
46/70, train_loss: 0.8592
47/70, train_loss: 0.9997
48/70, train_loss: 1.0080
49/70, train_loss: 0.9997
50/70, train_loss: 0.9351
51/70, train_loss: 0.9978




52/70, train_loss: 1.0000
53/70, train_loss: 0.9795
54/70, train_loss: 1.0000
55/70, train_loss: 0.6547
56/70, train_loss: 0.9792
57/70, train_loss: 0.9640
58/70, train_loss: 0.9376




59/70, train_loss: 1.0000
60/70, train_loss: 0.9865
61/70, train_loss: 0.7852
62/70, train_loss: 0.9996
63/70, train_loss: 0.8490
64/70, train_loss: 0.9853
65/70, train_loss: 0.9683




66/70, train_loss: 1.0000
67/70, train_loss: 0.9993
68/70, train_loss: 0.9634
69/70, train_loss: 0.3831
70/70, train_loss: 0.9878
epoch 1 average loss: 0.9171
saved new best metric model
current epoch: 1 current mean dice: 0.5256
best mean dice: 0.5256 at epoch: 1
----------
epoch 2/50
1/70, train_loss: 1.0000
2/70, train_loss: 0.8804
3/70, train_loss: 0.9812
4/70, train_loss: 1.3338
5/70, train_loss: 0.9672
6/70, train_loss: 0.9978
7/70, train_loss: 1.0000
8/70, train_loss: 0.9999
9/70, train_loss: 0.9861
10/70, train_loss: 1.0000
11/70, train_loss: 1.0113
12/70, train_loss: 1.0000
13/70, train_loss: 0.9918
14/70, train_loss: 1.0554
15/70, train_loss: 0.8502
16/70, train_loss: 0.9962
17/70, train_loss: 0.9963
18/70, train_loss: 1.0083
19/70, train_loss: 0.9986
20/70, train_loss: 1.0128
21/70, train_loss: 0.9992
22/70, train_loss: 0.9812
23/70, train_loss: 0.9979
24/70, train_loss: 1.0000
25/70, train_loss: 0.9537
26/70, train_loss: 0.9602
27/70, train_loss: -1.3434
28/70, train_loss: 

KeyboardInterrupt: 

In [None]:
print(
    f"train completed, best_metric: {best_metric:.4f} "
    f"at epoch: {best_metric_epoch}")


In [None]:
import matplotlib.pyplot as plt

plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.show()

In [None]:
model.load_state_dict(torch.load(os.path.join(data_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
    for i, val_data in enumerate(val_loader):
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(val_data["image_petct"].to(device), roi_size, sw_batch_size, model)
        # plot the slice [:, :, 80]
        plt.figure("check", (18, 6))
        plt.subplot(1, 3, 1)
        plt.title(f"image {i}")
        plt.imshow(val_data["image_petct"][0, 0, :, :, 80], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.title(f"label {i}")
        plt.imshow(val_data["image_sg"][0, 0, :, :, 80])
        plt.subplot(1, 3, 3)
        plt.title(f"output {i}")
        plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, 80])
        plt.show()
        if i == 2:
            break