In [1]:
## General Imports from all libraries
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys, os
import pathlib
import glob
import time
import math, random
import pprint
import collections
import numbers, string

import yaml
from tqdm import tqdm

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import SimpleITK as sitk
from PIL import Image

import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

np.set_printoptions(precision=3)
curr_path = pathlib.Path(os.getcwd()).absolute()

cards = !echo $SGE_HGR_gpu_card
device = torch.device(f"cuda:{cards[0]}" if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Device: cuda:0


In [2]:
# Import custom files for this project
if curr_path not in sys.path:
    print('Adding current path:', curr_path)
    sys.path.append(str(curr_path))

from metrics import batch_metrics

Adding current path: /afs/crc.nd.edu/user/y/yzhang46/_3DPRE/src/pengfei_inference


# Data Collection and Components

In [3]:
class MockDataset(torch.utils.data.Dataset):
    def __init__(self, image_files, mask_files, num_classes):
        self.num_classes = num_classes
        self.image_files = image_files  # list of image files
        self.mask_files = mask_files  # list of mask files
        
        # Get image tensors and store permanently
        self.images, self.masks, self.image_info = [], [], []
        for image_f, mask_f in zip(image_files, mask_files):
            # 1. Read image and preprocess (clamp + normalize)
            sitk_image = sitk.ReadImage(image_f, sitk.sitkInt16)
            sitk_image = sitk.Clamp(sitk_image, sitk.sitkInt16, -1024, 325)
            sitk_image = sitk.NormalizeImageFilter().Execute(sitk_image)
            image_tensor = torch.from_numpy(sitk.GetArrayFromImage(sitk_image))
            image_tensor = image_tensor.float()
            
            # 2. Read mask and convert it to one-hot
            sitk_mask = sitk.ReadImage(mask_f, sitk.sitkInt64)
            mask_tensor = torch.from_numpy(sitk.GetArrayFromImage(sitk_mask))
            
            shape = image_tensor.shape
            oh_shape = [num_classes] + list(shape)
            mask_oh_tensor = torch.zeros(oh_shape, dtype=torch.int32)
            mask_oh_tensor.scatter_(0, mask_tensor.unsqueeze(0), 1)
            
            self.images.append(image_tensor)
            self.masks.append(mask_oh_tensor)
            self.image_info.append({
                'origin': sitk_image.GetOrigin(),
                'spacing': sitk_image.GetSpacing(),
                'direction': sitk_image.GetDirection()
            })
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]
        info_d = self.image_info[idx]
        return image, mask, info_d
        

In [4]:
bcv_dir = pathlib.Path('/afs/crc.nd.edu/user/y/yzhang46/datasets/BCV-2015')
train_image_dir = bcv_dir / 'train' / 'img_nii'
train_mask_dir = bcv_dir / 'train' / 'label_nii'

images = sorted(glob.glob(str(train_image_dir) + '/*.nii.gz'))
masks = sorted(glob.glob(str(train_mask_dir) + '/*.nii.gz'))

start = time.time()
test_size = 6
num_classes = 14
test_set = MockDataset(images[:test_size], masks[:test_size], num_classes)
print(f'Took {time.time() - start} sec to load {test_size} test images.')

Took 19.44656014442444 sec to load 6 test images.


# Model Setup

In [5]:
model = torch.nn.Identity()

# Inference

In [7]:
from inference import ChopBatchAggregate3d as CBA

num_classes = 14
device = device
dataset = test_set

model = model.to(device)
model.eval()
with torch.no_grad():
    for i in range(len(dataset)):
        start_vol = time.time()
        print(f'⭐ Inference for example {i+1}..')
        image, mask, info = dataset[i]  # image: DxHxW, mask: CxDxHxW
                                        #  image: float32, mask: int32
        
        # Create Chop, batch, aggregator object
        image = image   # faster calc with GPUs
        mask = mask
        cba = CBA(image, (48, 160, 160), (20, 20, 20), 20, num_classes)
        
        # Run inference on batches of crops of image
        for bidx, batch in enumerate(cba):
            print(f'Batch {bidx + 1} / {len(cba)}')
            crops, locations = batch
            crops = crops.to(device)
            
            # NOTE: In real code replace with logits = model(crops)
            # NOTE: I simulate logits here in 2 ways
            #  1. I take random floats
            use_mask_as_logits = True
            if use_mask_as_logits:
                logits_shape = [crops.shape[0], num_classes] + list(crops.shape[2:])
                logits = torch.zeros(logits_shape, device=image.device)
                for n in range(locations.shape[0]):
                    lower = locations[n, :3]
                    upper = locations[n, 3:]
                    logits[n] = mask[:, lower[0]:upper[0],
                                     lower[1]:upper[1],
                                     lower[2]:upper[2]]
            #  2. I feed in the gt mask itself as logits (should return 100% acc)
            else:
                logits = torch.randn(logits_shape, device=image.device) 
    
            cba.add_batch_predictions(logits.cpu(), locations, act='none')
                # NOTE: in this case, we are averaging logits, if you want
                #  to average probabilities instead, use act='softmax'
        
        # Get final predictions, calculate metrics
        agg_predictions = cba.aggregate(ret='one_hot', cpu=True, numpy=False)
        
        print(f'Getting image metrics..')
        start = time.time()
        mets = batch_metrics(agg_predictions.unsqueeze(0), mask.unsqueeze(0))
            # preds are CxDxHxW, but batch input takes 1xCxDxHxW
        print(mets['dice_mean'], mets['dice_class'])
        print(mets['jaccard_mean'], mets['jaccard_class'])
        print(f'Mets time: {time.time() - start:.2f}')
        
        # Convert from 1hot to id and save prediction volume
        print(f'Saving image..')
        id_preds = agg_predictions.argmax(0).numpy().astype(np.uint16)
        sitk_pred = sitk.GetImageFromArray(id_preds, isVector=False)
        sitk_pred.SetOrigin(info['origin'])
        sitk_pred.SetSpacing(info['spacing'])
        sitk_pred.SetDirection(info['direction'])
        sitk.WriteImage(sitk_pred, 'prediction.nii.gz')
        
        elapsed = time.time() - start_vol
        print(f'Completed inference for test volume {i+1} ({elapsed:.2f}).')

⭐ Inference for example 1..
Batch 1 / 4
Batch 2 / 4
Batch 3 / 4
Batch 4 / 4
Aggregate (divide):  0.5069081783294678
Aggregate:  2.1341826915740967
Getting image metrics..
1.0 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
1.0 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Mets time: 16.65
Saving image..
Completed inference for test volume 1 (40.24).
⭐ Inference for example 2..
Batch 1 / 4
Batch 2 / 4
Batch 3 / 4
Batch 4 / 4
Aggregate (divide):  0.4191896915435791
Aggregate:  1.9012672901153564
Getting image metrics..
1.0 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
1.0 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Mets time: 15.83
Saving image..
Completed inference for test volume 2 (39.73).
⭐ Inference for example 3..
Batch 1 / 6
Batch 2 / 6
Batch 3 / 6
Batch 4 / 6
Batch 5 / 6
Batch 6 / 6
Aggregate (divide):  0.5384798049926758
Aggregate:  3.03458571434021
Getting image metrics..
1.0 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
1.0 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Mets time: 22.13
Sa

KeyboardInterrupt: 