<a href="https://colab.research.google.com/github/leejooan/tumor_segmentation/blob/master/Inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!python -c "import monai" || pip install -q "monai-weekly[nibabel]"
!python -c "import matplotlib" || pip install -q matplotlib

In [None]:
path1 = '/content/drive/MyDrive/AI/project/brats18_test' # replace your data path
dpiv = 80

In [None]:
from monai.transforms import (
    AddChanneld,
    Compose,
    LoadImaged,
    CenterSpatialCropd,
    NormalizeIntensityd,
    RandSpatialCropd,
    MapTransform,
    ToTensord,
)
from monai.config import print_config
from monai.data import DataLoader, Dataset
from monai.utils import first

from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.inferers import sliding_window_inference
from monai.networks.layers import Norm

import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt
import os
import glob
#add
from monai.transforms import(
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandRotate90d,
    ScaleIntensityd,
    CropForegroundd,
    AdjustContrastd, 
    Spacingd,
    ThresholdIntensityd,
     RandAdjustContrastd,
    Invertd,
   EnsureTyped,
   HistogramNormalized,
    EnsureChannelFirstd
)

In [None]:
path_test = glob.glob(os.path.join(path1,'Brats18*'))
len(path_test)

In [None]:
if torch.cuda.is_available():
  device = torch.device('cuda:0')
else:
  device = torch.device('cpu')

print(device)

In [None]:
test_ind = np.arange(0,20)
data_dicts = [
    {
        "image": os.path.join(path_test[idx],
                                    "t1ce.nii.gz"),
    }
    for idx in test_ind
]
test_files =  data_dicts

In [None]:
test_transforms = Compose(
    [
      LoadImaged(keys=("image")),
      AddChanneld(keys=("image")),
      NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
      ToTensord(keys=["image"]),
    
    ]
)

In [None]:
test_ds = Dataset(data = test_files, transform = test_transforms)
test_loader = DataLoader(test_ds,batch_size = 1)

In [None]:
import nibabel as nib
path_out = os.path.join(path1,'test_out') # the results will be saved in this folder
if os.path.isdir(path_out)==0:
  os.mkdir(path_out)

In [None]:
path2='/content/drive/MyDrive/AI/project/brats18_test'

In [None]:
device = torch.device("cuda:0")
model = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=2,
    channels=(8, 16, 32, 64),
    strides=(2, 2, 2, 2),
    num_res_units=3,
    norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)

root_dir = path2
model.load_state_dict(torch.load(
    os.path.join(root_dir, "best_metric_model_epoch_208.pth")))


In [None]:
model.eval()
flag_save = 1

with torch.no_grad():
    for val_data in test_loader:
        val_inputs = val_data["image"].to(device)
        roi_size = (160, 160, 64)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(
            val_inputs, roi_size, sw_batch_size, model, overlap=0.75)
        val_preds = val_outputs.softmax(1)
        if flag_save == 1:
            val_seg = torch.argmax(val_preds.cpu(),dim=1).numpy()
            pid = val_data['image_meta_dict']['filename_or_obj'][0].split('/')[-2]
            h = nib.load(val_data['image_meta_dict']['filename_or_obj'][0])
            h_new=nib.Nifti1Image(val_seg[0],h.affine,h.header)
            nib.save(h_new,os.path.join(path_out,pid+'.nii.gz'))
            