In [None]:
# mount your drive if needed
from google.colab import drive
drive.mount('/content/drive')
DATA_ROOT = 'drive/MyDrive/medical_decathlon'

In [None]:
!pip install monai
!pip install datetime

In [None]:
import numpy as np
import json
from datetime import datetime

import torch
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torch import sigmoid


from monai.utils import set_determinism
from monai.transforms import (
    Compose,
    Lambdad,
    LoadImaged,
    AddChanneld,
    ScaleIntensityRanged,
    BorderPadd,
    RandCropByPosNegLabeld,
    RandRotated,
    RandZoomd,
    AsDiscreted,
    SelectItemsd,
    Resized,
    ToTensord
)
from monai.data import (
    Dataset, 
    CacheDataset, 
    DataLoader, 
    partition_dataset_classes
)
from monai.networks.nets import UNet, VNet
from monai.losses import DiceLoss
from monai.metrics import compute_meandice
from monai.inferers import SlidingWindowInferer

from tqdm import trange, tqdm
import matplotlib.pyplot as plt
def imshow(x):
    plt.imshow(x, cmap='gray')
    plt.show()

In [None]:
import os
datalist = []

image_list = os.listdir(f'{DATA_ROOT}/images/')
image_list.sort()
label_list = os.listdir(f'{DATA_ROOT}/mask/')
label_list.sort()

for i in range(len(image_list)):
    if image_list[i] in label_list:
        datalist.append(dict(image = f'images/{image_list[i]}', label = f'mask/{image_list[i]}'))

print(datalist)

In [None]:
datalist = datalist[:50]
np.random.seed(42)
np.random.shuffle(datalist)
# ---- train:valid:test = 7:2:1

test_ls = datalist[:10]

In [None]:
# PREPROCESSING

valid_trans = Compose([
    Lambdad(
        keys=['image', 'label'],
        func=lambda p: f'{DATA_ROOT}/{p}'
    ),
    LoadImaged(keys=['image', 'label']),
    AddChanneld(keys=['image', 'label']),
    # Spacing(..., pixdim=[3.0, 3.0, 3.0]),
    ScaleIntensityRanged(
        keys=['image'],
        a_min=-900, 
        a_max=300,
        b_min=0,
        b_max=1,
        clip=True
    ),
    ToTensord(
        keys=['image', 'label'],
    )
])

test_ds = CacheDataset(
    test_ls, 
    transform=valid_trans, 
    num_workers=2, 
    cache_rate=1.0
)

test_loader = DataLoader(test_ds, batch_size=1, num_workers=2)

In [None]:
ls = train_ds[0]
for dt in ls:
    image = dt['image']
    label = dt['label']
    D = image.shape[-1]
    image = image[0, :, :, D//2]
    label = label[0, :, :, D//2]*image.max()
    imshow(np.hstack([image, label]))

In [None]:
set_determinism(42) 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=[16, 32, 64],
    strides=[2, 2],
    num_res_units=0,
    dropout=0.0
).to(device)
'''
model = VNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    dropout_prob=0.5,
    dropout_dim=3,
).to(device)
'''
loss = DiceLoss(sigmoid=True)
optimizer = Adam(model.parameters(), 1e-2)
infer = SlidingWindowInferer(
    roi_size=(64, 64, 64),
    sw_batch_size=8,
    overlap=0.25,
    mode='constant',
)

In [None]:
model.load_state_dict(torch.load('/content/drive/MyDrive/small_dataset/model', map_location=torch.device('cpu')))
model.eval()

In [None]:
test_loss = 0
for batch in test_loader:
    labels = batch['label'].to(device)
    images = batch['image'].to(device)
    preds = infer(images, model)

    _loss = 1-compute_meandice(sigmoid(preds), labels)
    test_loss += _loss
test_loss /= len(test_loader)
print(f'test loss: {test_loss}')