# Image Slice inferencing
Testing image slice inferenceing for microservice setup

# 1) Imports and mount

In [40]:
%load_ext autoreload
%autoreload 2

# This sets up the appropriate logging and path configs
from notebook_setup import * 

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [41]:
import torch
from torchsummary import summary
import matplotlib.pyplot as plt

from monai.losses import DiceLoss
from src.features.build_features import train_transform, val_transform, test_transform

from src.visualization.visualize import view_slice

from monai.data import Dataset, DataLoader
from src.enums import DataDict

from src.settings.config import get_app_settings
from src.utils import load_yaml

## 2) Import Settings

In [42]:
settings = get_app_settings()
model_config = load_yaml("model_configs.yaml")

In [43]:
from src.enums import INTERIM_TESTING_DATA_PATHS, INTERIM_TRAINING_DATA_PATHS
from src.pytorch_utils import get_interim_data_path


train_paths = get_interim_data_path(INTERIM_TRAINING_DATA_PATHS)
val_paths = get_interim_data_path(INTERIM_TESTING_DATA_PATHS)

train_dataset  = Dataset(train_paths, train_transform)
validation_dataset  = Dataset(val_paths, val_transform)
test_dataset = Dataset(val_paths, test_transform)

# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True, num_workers=2)
# validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE_VAL, shuffle=True)

# from monai.transforms import (LoadImaged)
# image_loader = LoadImaged(keys=[DataDict.ImageT1, DataDict.ImageFlair])
# image_loader(val_paths)
# [item for item in dir(nib.load(train_paths[0][DataDict.ImageFlair])) if not item.startswith('_')]
test_dataset.__getitem__(0)[DataDict.ImageFlair].shape

## 3) Inference



### 3.1) Infernce on volume nii

In [54]:
# Get sample volume
from src.data.make_dataset import get_test_raw_paths
from src.enums import INTERIM_TESTING_DATA_PATHS
from src.pytorch_utils import get_interim_data_path
from copy import copy
import nibabel as nib
from src.pytorch_utils import slice_tensor_volume

raw_test_paths = get_test_raw_paths()
sample_inference_images = copy(raw_test_paths[0])

# Load up the nibabel file, not as a path to mimick serving
t1_sample = nib.load(sample_inference_images[DataDict.ImageT1]).get_fdata()
flair_sample = nib.load(sample_inference_images[DataDict.ImageFlair]).get_fdata()

sample_inference_images[DataDict.ImageT1] = t1_sample
sample_inference_images[DataDict.ImageFlair] = flair_sample
sample_inference_images.pop(DataDict.Label)

'data/raw/test/Singapore/70/wmh.nii.gz'

In [45]:


# sample = copy(raw_test_paths[0])
# import nibabel as nib
# import numpy as np

# t1_sample = nib.load(sample[DataDict.ImageT1]).get_fdata()
# flair_sample = nib.load(sample[DataDict.ImageFlair]).get_fdata()


# t1_sample = torch.Tensor([nib.Nifti1Image(t1_sample, affine=np.eye(4)).get_fdata()])
# flair_sample = torch.Tensor(
#     [nib.Nifti1Image(flair_sample, affine=np.eye(4)).get_fdata()]
# )


# # Set it in the sample
# sample[DataDict.ImageT1] = t1_sample
# sample[DataDict.ImageFlair] = flair_sample
# sample.pop(DataDict.Label)

'data/raw/test/Singapore/70/wmh.nii.gz'

In [None]:
# Actual Depency Code!!!!

def predict():

In [83]:
from src.pytorch_utils import normalize_img_intensity_range


sample_data_paths = [sample_inference_images]

def slice_inference_data(vol):
    slices = slice_tensor_volume(vol)
    return slices

print(t1_sample.shape)

t1_sample = normalize_img_intensity_range(t1_sample)
flair_sample = normalize_img_intensity_range(flair_sample)

t1_slices = slice_inference_data(t1_sample)
flair_slices = slice_inference_data(flair_sample)

inference_dict = [{
    DataDict.Id: -1,
    DataDict.ImageT1: v,
    DataDict.ImageFlair: flair_slices.get(idx),
    DataDict.DepthZ: idx,
} for idx, v in t1_slices.items()]

(232, 256, 48)


In [87]:
from monai.transforms import (
    Compose,
    ToTensord,
    LoadImaged,
    EnsureChannelFirstd,
    ToMetaTensord,
    Spacingd,
    Resized,
    Orientationd,
)
from src.features.transforms import ImagesToMultiChannel

custom_transform = Compose(
    [
        # ToMetaTensord(keys=[DataDict.ImageT1, DataDict.ImageFlair]),
        EnsureChannelFirstd(
            keys=[DataDict.ImageT1, DataDict.ImageFlair], 
            channel_dim="no_channel"
        ),
        # Spacingd(
        #     keys=[DataDict.ImageT1, DataDict.ImageFlair],
        #     # pixdim=(1.5, 1.5),
        #     # mode=("bilinear", "bilinear"),
        #     pixdim=(1.5, 1.5, 2.0),
        #     mode=("bilinear", "bilinear", "nearest"),
        # ),
        # Resized(
        #     keys=[DataDict.ImageT1, DataDict.ImageFlair],
        #     spatial_size=[256, 256],
        # ),
        # Orientationd(keys=[DataDict.ImageT1, DataDict.ImageFlair], axcodes="RAS"),
        # ToTensord(keys=[DataDict.ImageT1, DataDict.ImageFlair]),
        # ImagesToMultiChannel(keys=[DataDict.ImageT1, DataDict.ImageFlair]),
    ]
)

test_dataset = Dataset(sample_data_paths, custom_transform)
test_dataset.__getitem__(0)[DataDict.ImageFlair].shape
# view_slice(test_dataset.__getitem__(0)[DataDict.ImageT1][..., 30])

torch.Size([1, 232, 256, 48])

In [22]:
list(test_dataset.__getitem__(0).keys())

[<DataDict.Id: 'subj_id'>,
 <DataDict.Image: 'image'>,
 <DataDict.ImageFlair: 'img_flair'>,
 <DataDict.ImageT1: 'img_t1'>,
 <DataDict.DepthZ: 'depth_z'>]

In [29]:
from src.models.predict_model import ImagePredictor
from src.models.train_model import model

pred_network = model
checkpoint = torch.load(
    "models/single_slice_t1_flair_v1.pt", map_location=torch.device("cpu")
)
pred_network.load_state_dict(checkpoint["model_state_dict"])
img_predictor = ImagePredictor(pred_network, test_dataset)
test_predictions = img_predictor.predict_handler()
# print('DONE')
test_predictions[0][DataDict.Prediction].shape

Predicting 1/1 slices
1 subjects to predict
Reconstructing 1/1


Orientation: spatial shape = (256, 256), channels = 1,please make sure the input is in the channel-first format.
  ret = func(*args, **kwargs)


torch.Size([256, 256, 1])

In [None]:
from monai.inferers import inferer
from monai.transforms import Compose

model.eval()
device = torch.device("cpu")
transform = Compose([ToTensor(), LoadImage(image_only=True)])
data = transform(sample_data_paths).to(device)
with torch.no_grad():
    pred = inferer(inputs=data, network=model)