In [7]:
## General Imports from all libraries
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys, os
import pathlib
import glob
import time
import math, random
import pprint
import collections
import numbers, string

import yaml
from tqdm import tqdm

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import SimpleITK as sitk
from PIL import Image

import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

np.set_printoptions(precision=3)
curr_path = pathlib.Path(os.getcwd()).absolute()

cards = !echo $SGE_HGR_gpu_card
!export CUDA_VISIBLE_DEVICES="${SGE_HGR_gpu_card// /,}"
device = torch.device(f"cuda" if torch.cuda.is_available() else 'cpu')
print('Device:', device)

# Import custom files for this project
dest_path = str(curr_path.parent.parent)
if dest_path not in sys.path:
    print('Adding current path:', dest_path)
    sys.path.append(str(dest_path))

from run_experiment import batch_metrics, get_model

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Device: cuda


# (0909) Inference with Overlap

In [2]:
## Inference Constants
infer_batch_size = 2

In [14]:
## ------------- Get Config ------------
from configs import get_config
# cfg = get_config('./configs/ftbcv_train.yaml', merge_default=False)

## ------------- Get Checkpoint -------------
cp_file = './(0911bcv-3g)TUNE_nnunet3d_nesterov_cedc_s4_ftbcv_ep175_best-val-dice_mean-0.762.pth'
cp_path = cp_file
checkpoint = torch.load(cp_path, map_location='cpu')

cfg = checkpoint['config']
cfg.experiment.distributed = False
cfg.experiment.rank = 0
cfg.experiment.device = device
cfg.experiment.gpu_idxs = cards[0].split(',')

## -------------- Get Model --------------
state_dict = checkpoint['state_dict']
new_state_dict = collections.OrderedDict()
for k, v in state_dict.items():
    new_state_dict['.'.join(k.split('.')[1:])] = v
del state_dict

model = get_model(cfg)['model']
print(model.load_state_dict(new_state_dict))
model = model.to(device)

## ----------------- Get Data ----------------
from experiments.ftbcv.data_setup import get_data_components
data_d = get_data_components(cfg)
val_set = data_d['val_set']
test_set = data_d['test_set']



💠 nnUNet3D model initiated with n_classes=14, 
   n_input=1, 
   params=16,478,616, trainable_params=16,478,616.
   (Model) Successfully initialized weights via kaiming.
<All keys matched successfully>
Collecting data samples:
    Took 36.01s for sample creation.
    Took 39.08s for sample creation.
    Took 39.18s for sample creation.
    Took 40.91s for sample creation.
    Took 41.24s for sample creation.
    Took 41.87s for sample creation.
    Took 41.94s for sample creation.
    Took 42.05s for sample creation.
    Took 43.38s for sample creation.
    Took 43.42s for sample creation.
    Took 43.70s for sample creation.
    Took 43.89s for sample creation.
    Took 44.08s for sample creation.
    Took 44.48s for sample creation.
    Took 45.32s for sample creation.
    Took 46.50s for sample creation.
    Took 48.79s for sample creation.
    Took 50.18s for sample creation.
    Took 53.15s for sample creation.
    Took 55.64s for sample creation.
    Took 58.30s for sample creati

In [15]:
## Run Val Inference
print('batch_size:', cfg.test.batch_size, '->', infer_batch_size)
cfg.test.batch_size = infer_batch_size
metrics_queue = torch.multiprocessing.Queue()

from run_experiment import test_metrics
print(test_metrics(cfg, model, val_set, 0, metrics_queue, len(val_set),
                   name='val', overlap_perc=0.2))
print(test_metrics(cfg, model, test_set, 0, metrics_queue, len(test_set),
                   name='test', overlap_perc=0.2))

batch_size: 2 -> 2
 🖼️  Inference for example 1.
     Getting predictions for 48 batches.
Completed inference for vol 1 (26.32 sec).

 🖼️  Inference for example 2.
     Getting predictions for 48 batches.
Completed inference for vol 2 (27.18 sec).

 🖼️  Inference for example 3.
     Getting predictions for 48 batches.
Completed inference for vol 3 (27.35 sec).

 🖼️  Inference for example 4.
     Getting predictions for 48 batches.
Completed inference for vol 4 (27.45 sec).

 🖼️  Inference for example 5.
     Getting predictions for 48 batches.
Completed inference for vol 5 (26.76 sec).

 🖼️  Inference for example 6.
     Getting predictions for 56 batches.
Completed inference for vol 6 (31.55 sec).

(Val) Sample 2, idx=1 
       Dice: 0.83 
        [0.954 0.887 0.919 0.783 0.765 0.96  0.909 0.88  0.876 0.79  0.747 0.715
 0.639] 
       Jaccard: 0.72 
        [0.913 0.797 0.851 0.643 0.619 0.923 0.832 0.785 0.78  0.653 0.596 0.556
 0.47 ]
(Val) Sample 8, idx=7 
       Dice: 0.71 
      

# (0901) Pengfei Inference Demo

## Data Collection and Components

In [3]:
def get_imgs_masks_info(args):
    images_f, mask_f = args
    # 1. Read image and preprocess (clamp + normalize)
    sitk_image = sitk.ReadImage(image_f, sitk.sitkInt16)
    sitk_image = sitk.Clamp(sitk_image, sitk.sitkInt16, -1024, 325)
    sitk_image = sitk.NormalizeImageFilter().Execute(sitk_image)
    image_tensor = torch.from_numpy(sitk.GetArrayFromImage(sitk_image))
    image_tensor = image_tensor.float()

    # 2. Read mask and convert it to one-hot
    sitk_mask = sitk.ReadImage(mask_f, sitk.sitkInt64)
    mask_tensor = torch.from_numpy(sitk.GetArrayFromImage(sitk_mask))

    shape = image_tensor.shape
    oh_shape = [num_classes] + list(shape)
    mask_oh_tensor = torch.zeros(oh_shape, dtype=torch.int32)
    mask_oh_tensor.scatter_(0, mask_tensor.unsqueeze(0), 1)
    
    info = {
        'origin': sitk_image.GetOrigin(),
        'spacing': sitk_image.GetSpacing(),
        'direction': sitk_image.GetDirection()
    }
    return image_tensor, mask_oh_tensor, info
    
    
class MockDataset(torch.utils.data.Dataset):
    def __init__(self, image_files, mask_files, num_classes):
        self.num_classes = num_classes
        self.image_files = image_files  # list of image files
        self.mask_files = mask_files  # list of mask files
        
        # Get image tensors and store permanently
        self.images, self.masks, self.image_info = [], [], []
        args = []
        for image_f, mask_f in zip(image_files, mask_files):
            args.append((image_f, mask_f))
            
        
            
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]
        info_d = self.image_info[idx]
        return image, mask, info_d
        

In [4]:
bcv_dir = pathlib.Path('/afs/crc.nd.edu/user/y/yzhang46/datasets/BCV-2015')
train_image_dir = bcv_dir / 'train' / 'img_nii'
train_mask_dir = bcv_dir / 'train' / 'label_nii'

images = sorted(glob.glob(str(train_image_dir) + '/*.nii.gz'))
masks = sorted(glob.glob(str(train_mask_dir) + '/*.nii.gz'))

start = time.time()
test_size = 6
num_classes = 14
test_set = MockDataset(images[:test_size], masks[:test_size], num_classes)
print(f'Took {time.time() - start} sec to load {test_size} test images.')

Took 19.28532838821411 sec to load 6 test images.


## Model Setup

In [5]:
file = '(0831)bcv-scratch_adam_dice_s3_finetune_bcv_ep499_last.pth'
checkpoint = torch.load(file, map_location='cpu')

from lib.nets.volumetric.resunet3d import UNet3D
# model = UNet3D(1, 14, final_sigmoid=False, is_segmentation=False)

from experiments.finetune_bcv.ftbcv_unet3d import UNet3D as genesis_unet3d
model = genesis_unet3d(n_input=1, n_class=14, act='relu')
model.load_state_dict(checkpoint['state_dict'])

💠 UNet3D-PGL model initiated with n_classes=14, 
   n_input=1, activation=relu, 
   params=19,074,510, trainable_params=19,074,510.


<All keys matched successfully>

## Inference

In [8]:
from data.transforms.crops.inference import ChopBatchAggregate3d as CBA

num_classes = 14
device = 'cuda'
dataset = test_set

model = model.to(device)
model.eval()
with torch.no_grad():
    for i in range(len(dataset)):
        start_vol = time.time()
        print(f'⭐ Inference for example {i+1}..')
        image, mask, info = dataset[i]  # image: DxHxW, mask: CxDxHxW
                                        #  image: float32, mask: int32
        
        # Create Chop, batch, aggregator object
        image = image   # faster calc with GPUs
        mask = mask
        cba = CBA(image, (48, 160, 160), (0, 0, 0), 4, num_classes)
        
        # Run inference on batches of crops of image
        for bidx, batch in enumerate(cba):
            print(f'Batch {bidx + 1} / {len(cba)}')
            crops, locations = batch
            crops = crops.to(device)
            
            logits = model(crops)['out']
    
            cba.add_batch_predictions(logits.cpu(), locations, act='none')
                # NOTE: in this case, we are averaging logits, if you want
                #  to average probabilities instead, use act='softmax'
        
        # Get final predictions, calculate metrics
        agg_predictions = cba.aggregate(ret='one_hot', cpu=True, numpy=False)
        
        print(f'Getting image metrics..')
        start = time.time()
        mets = batch_metrics(agg_predictions.unsqueeze(0), mask.unsqueeze(0))
            # preds are CxDxHxW, but batch input takes 1xCxDxHxW
        print(mets['dice_mean'], mets['dice_class'])
        print(mets['jaccard_mean'], mets['jaccard_class'])
        print(f'Mets time: {time.time() - start:.2f}')
        
        # Convert from 1hot to id and save prediction volume
        print(f'Saving image..')
        id_preds = agg_predictions.argmax(0).numpy().astype(np.uint16)
        sitk_pred = sitk.GetImageFromArray(id_preds, isVector=False)
        sitk_pred.SetOrigin(info['origin'])
        sitk_pred.SetSpacing(info['spacing'])
        sitk_pred.SetDirection(info['direction'])
        sitk.WriteImage(sitk_pred, 'prediction.nii.gz')
        
        elapsed = time.time() - start_vol
        print(f'Completed inference for test volume {i+1} ({elapsed:.2f}).')
        
        break

⭐ Inference for example 1..
Batch 1 / 16
Batch 2 / 16
Batch 3 / 16
Batch 4 / 16
Batch 5 / 16
Batch 6 / 16
Batch 7 / 16
Batch 8 / 16
Batch 9 / 16
Batch 10 / 16
Batch 11 / 16
Batch 12 / 16
Batch 13 / 16
Batch 14 / 16
Batch 15 / 16
Batch 16 / 16
Aggregate (divide):  3.136298656463623
Aggregate:  14.577434301376343
Getting image metrics..
0.06910467892885208 [0.967 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
 0.    0.   ]
0.0669272169470787 [0.937 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
 0.    0.   ]
Mets time: 18.30
Saving image..
Completed inference for test volume 1 (82.26).


In [17]:
logits.argmax(1).unique()

tensor([0], device='cuda:0')