In [14]:
import numpy as np
np.object = object    

In [15]:
%%writefile roi_extract.py

# Separated in .py file instead of a notebook cell for easier multiprocessing (e.g spawn)
import os
os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
import sys
import cv2
import numpy as np
import torch
import torchvision
sys.path.append('/kaggle/tmp/libs/')
from torch2trt import TRTModule
from torch.nn import functional as F

_TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]]
_TORCH11X = (_TORCH_VER >= [1, 10])


def meshgrid(*tensors):
    if _TORCH11X:
        return torch.meshgrid(*tensors, indexing="ij")
    else:
        return torch.meshgrid(*tensors)


def extract_roi_otsu(img, gkernel=(5, 5)):
    """WARNING: this function modify input image inplace."""
    ori_h, ori_w = img.shape[:2]
    # clip percentile: implant, white lines
    upper = np.percentile(img, 95)
    img[img > upper] = np.min(img)
    # Gaussian filtering to reduce noise (optional)
    if gkernel is not None:
        img = cv2.GaussianBlur(img, gkernel, 0)
    _, img_bin = cv2.threshold(img, 0, 255,
                               cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    # dilation to improve contours connectivity
    element = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3), (-1, -1))
    img_bin = cv2.dilate(img_bin, element)
    cnts, _ = cv2.findContours(img_bin, cv2.RETR_EXTERNAL,
                               cv2.CHAIN_APPROX_SIMPLE)
    if len(cnts) == 0:
        return None, None, None
    areas = np.array([cv2.contourArea(cnt) for cnt in cnts])
    select_idx = np.argmax(areas)
    cnt = cnts[select_idx]
    area_pct = areas[select_idx] / (img.shape[0] * img.shape[1])
    x0, y0, w, h = cv2.boundingRect(cnt)
    # min-max for safety only
    # x0, y0, x1, y1
    x1 = min(max(int(x0 + w), 0), ori_w)
    y1 = min(max(int(y0 + h), 0), ori_h)
    x0 = min(max(int(x0), 0), ori_w)
    y0 = min(max(int(y0), 0), ori_h)
    return [x0, y0, x1, y1], area_pct, None


class RoiExtractor:

    def __init__(self,
                 engine_path,
                 input_size,
                 num_classes,
                 conf_thres=0.5,
                 nms_thres=0.9,
                 class_agnostic=False,
                 area_pct_thres=0.04,
                 hw=None,
                 strides=None,
                 exp=None):
        self.input_size = input_size
        self.input_h, self.input_w = input_size
        self.num_classes = num_classes
        self.conf_thres = conf_thres
        self.nms_thres = nms_thres
        self.class_agnostic = class_agnostic
        self.area_pct_thres = area_pct_thres

        model = TRTModule()
        model.load_state_dict(torch.load(engine_path))
        self.model = model
        if hw is None or strides is None:
            assert exp is not None
            self._set_meta(exp)
        else:
            self.hw = hw
            self.strides = strides

    def _set_meta(self, exp):
        assert exp is not None
        print("Start probing model metadata..")
        # dummy infer
        torch_model = exp.get_model().cuda().eval()
        _dummy = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
        torch_model(_dummy)
        # set attributes
        self.hw = torch_model.head.hw
        self.strides = torch_model.head.strides
        # cleanup
        del torch_model, _dummy
        import gc
        gc.collect()
        torch.cuda.empty_cache()
        print('Done probbing model metadata..')

    def decode_outputs(self, outputs):
        dtype = outputs.type()
        grids = []
        strides = []
        for (hsize, wsize), stride in zip(self.hw, self.strides):
            yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
            grid = torch.stack((xv, yv), 2).view(1, -1, 2)
            grids.append(grid)
            shape = grid.shape[:2]
            strides.append(torch.full((*shape, 1), stride))

        grids = torch.cat(grids, dim=1).type(dtype)
        strides = torch.cat(strides, dim=1).type(dtype)

        outputs = torch.cat(
            [(outputs[..., 0:2] + grids) * strides,
             torch.exp(outputs[..., 2:4]) * strides, outputs[..., 4:]],
            dim=-1)
        return outputs

    def post_process(self,
                     pred,
                     conf_thres=0.5,
                     nms_thres=0.9,
                     class_agnostic=False):
        box_corner = pred.new(pred.shape)
        box_corner[:, :, 0] = pred[:, :, 0] - pred[:, :, 2] / 2
        box_corner[:, :, 1] = pred[:, :, 1] - pred[:, :, 3] / 2
        box_corner[:, :, 2] = pred[:, :, 0] + pred[:, :, 2] / 2
        box_corner[:, :, 3] = pred[:, :, 1] + pred[:, :, 3] / 2
        pred[:, :, :4] = box_corner[:, :, :4]

        output = [None for _ in range(len(pred))]
        for i, image_pred in enumerate(pred):

            # If none are remaining => process next image
            if not image_pred.size(0):
                continue
            # Get score and class with highest confidence
            class_conf, class_pred = torch.max(image_pred[:, 5:5 +
                                                          self.num_classes],
                                               1,
                                               keepdim=True)

            conf_mask = (image_pred[:, 4] * class_conf.squeeze() >=
                         conf_thres).squeeze()
            # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
            detections = torch.cat(
                (image_pred[:, :5], class_conf, class_pred.float()), 1)
            detections = detections[conf_mask]
            if not detections.size(0):
                continue

            if class_agnostic:
                nms_out_index = torchvision.ops.nms(
                    detections[:, :4],
                    detections[:, 4] * detections[:, 5],
                    nms_thres,
                )
            else:
                nms_out_index = torchvision.ops.batched_nms(
                    detections[:, :4],
                    detections[:, 4] * detections[:, 5],
                    detections[:, 6],
                    nms_thres,
                )
            detections = detections[nms_out_index]
            if output[i] is None:
                output[i] = detections
            else:
                output[i] = torch.cat((output[i], detections))
        return output

    def preprocess_single(self, img: torch.Tensor):
        ori_h = img.size(0)
        ori_w = img.size(1)
        ratio = min(self.input_h / ori_h, self.input_w / ori_w)
        # resize
        resized_img = F.interpolate(img.view(1, 1, ori_h, ori_w),
                                    mode="bilinear",
                                    scale_factor=ratio,
                                    recompute_scale_factor=True)[0, 0]
        # padding
        padded_img = torch.full((self.input_h, self.input_w),
                                114,
                                dtype=resized_img.dtype,
                                device='cuda')
        padded_img[:resized_img.size(0), :resized_img.size(1)] = resized_img
        # 1 channel --> 3 channels
        padded_img = padded_img.unsqueeze(-1).expand(-1, -1, 3)
        # HWC --> CHW
        padded_img = padded_img.permute(2, 0, 1)
        padded_img = padded_img.float()
        return padded_img, resized_img, ratio, ori_h, ori_w

    def detect_single(self, img):
        padded_img, resized_img, ratio, ori_h, ori_w = self.preprocess_single(
            img)
        padded_img = padded_img.unsqueeze(0)
        output = self.model(padded_img)
        output = self.decode_outputs(output)
        # x0, y0, x1, y1, box_conf, cls_conf, cls_id
        output = self.post_process(output, self.conf_thres, self.nms_thres)[0]
        if output is not None:
            output[:, :4] = output[:, :4] / ratio
            # re-compute: conf = box_conf * cls_conf
            output[:, 4] = output[:, 4] * output[:, 5]
            # select box with highest confident
            output = output[output[:, 4].argmax()]
            x0 = min(max(int(output[0]), 0), ori_w)
            y0 = min(max(int(output[1]), 0), ori_h)
            x1 = min(max(int(output[2]), 0), ori_w)
            y1 = min(max(int(output[3]), 0), ori_h)
            area_pct = (x1 - x0) * (y1 - y0) / (ori_h * ori_w)
            if area_pct >= self.area_pct_thres:
                # xyxy, area_pct, conf
                return [x0, y0, x1, y1], area_pct, output[4]

        # if YOLOX fail, try Otsu thresholding + find contours
        xyxy, area_pct, _ = extract_roi_otsu(
            resized_img.to(torch.uint8).cpu().numpy())
        # if both fail, use full frame
        if xyxy is not None:
            if area_pct >= self.area_pct_thres:
                print('ROI detection: using Otsu.')
                x0, y0, x1, y1 = xyxy
                x0 = min(max(int(x0 / ratio), 0), ori_w)
                y0 = min(max(int(y0 / ratio), 0), ori_h)
                x1 = min(max(int(x1 / ratio), 0), ori_w)
                y1 = min(max(int(y1 / ratio), 0), ori_h)
                return [x0, y0, x1, y1], area_pct, None
        print('ROI detection: both fail.')
        return None, area_pct, None

Overwriting roi_extract.py


In [16]:

import torch
from timm.data import resolve_data_config
from timm.models import create_model
from torch import nn


class KFoldEnsembleModel(nn.Module):

    def __init__(self, model_info, ckpt_paths):
        super(KFoldEnsembleModel, self).__init__()
        fmodels = []
        for i, ckpt_path in enumerate(ckpt_paths):
            print(f'Loading model from {ckpt_path}')
            fmodel = create_model(
                model_info['model_name'],
                num_classes=model_info['num_classes'],
                in_chans=model_info['in_chans'],
                pretrained=False,
                checkpoint_path=ckpt_path,
                global_pool=model_info['global_pool'],
            ).eval()
            data_config = resolve_data_config({}, model=fmodel)
            print('Data config:', data_config)
            mean = np.array(data_config['mean']) * 255
            std = np.array(data_config['std']) * 255
            print(f'mean={mean}, std={std}')
            fmodels.append(fmodel)
        self.fmodels = nn.ModuleList(fmodels)

        self.register_buffer('mean',
                             torch.FloatTensor(mean).reshape(1, 3, 1, 1))
        self.register_buffer('std', torch.FloatTensor(std).reshape(1, 3, 1, 1))

    def forward(self, x):
        #         x = x.sub(self.mean).div(self.std)
        x = (x - self.mean) / self.std
        probs = []
        for fmodel in self.fmodels:
            logits = fmodel(x)
            #             prob = logits.softmax(dim=1)[:, 1]
            prob = logits.sigmoid()[:, 0]
            probs.append(prob)
        probs = torch.stack(probs, dim=1)
        return probs


In [17]:
import roi_extract

# global vars
J2K_SUID = '1.2.840.10008.1.2.4.90'
J2K_HEADER = b"\x00\x00\x00\x0C"
JLL_SUID = '1.2.840.10008.1.2.4.70'
JLL_HEADER = b"\xff\xd8\xff\xe0"
SUID2HEADER = {J2K_SUID: J2K_HEADER, JLL_SUID: JLL_HEADER}
VOILUT_FUNCS_MAP = {'LINEAR': 0, 'LINEAR_EXACT': 1, 'SIGMOID': 2}
VOILUT_FUNCS_INV_MAP = {v: k for k, v in VOILUT_FUNCS_MAP.items()}

In [18]:
BATCH_SIZE = 2
# binarization threshold for classification
THRES = 0.31
AUTO_THRES = False
AUTO_THRES_PERCENTILE = 0.97935

# classification model
USE_TRT = True


# roi detection
ROI_YOLOX_INPUT_SIZE = [416, 416]
ROI_YOLOX_CONF_THRES = 0.5
ROI_YOLOX_NMS_THRES = 0.9
ROI_YOLOX_HW = [(52, 52), (26, 26), (13, 13)]
ROI_YOLOX_STRIDES = [8, 16, 32]
ROI_AREA_PCT_THRES = 0.04

# model
MODEL_INPUT_SIZE = [2048, 1024]

MODE = 'KAGGLE-TEST'
assert MODE in ['LOCAL-VAL', 'KAGGLE-VAL', 'KAGGLE-TEST']

# settings corresponding to each mode
if MODE == 'KAGGLE-VAL':
    TRT_MODEL_PATH = '/kaggle/input/rsna-breast-cancer-detection-best-ckpts/best_convnext_ensemble_batch2_fp32_torch2trt.engine'
    TORCH_MODEL_CKPT_PATHS = [
        f'/kaggle/input/rsna-breast-cancer-detection-best-ckpts/best_convnext_fold_{i}.pth.tar'
        for i in range(4)
    ]
    ROI_YOLOX_ENGINE_PATH = '/kaggle/input/rsna-breast-cancer-detection-best-ckpts/yolox_nano_416_roi_trt_p100.pth'
    CSV_PATH = '/kaggle/input/rsna-breast-cancer-detection-best-ckpts/_val_fold_0.csv'
    DCM_ROOT_DIR = '/kaggle/input/rsna-breast-cancer-detection/train_images'
    SAVE_IMG_ROOT_DIR = '/kaggle/tmp/pngs'
    N_CHUNKS = 2
    N_CPUS = 2
    RM_DONE_CHUNK = False
elif MODE == 'KAGGLE-TEST':
    TRT_MODEL_PATH = '/kaggle/input/rsna-breast-cancer-detection-best-ckpts/best_convnext_ensemble_batch2_fp32_torch2trt.engine'
    TORCH_MODEL_CKPT_PATHS = [
        f'/kaggle/input/rsna-breast-cancer-detection-best-ckpts/best_convnext_fold_{i}.pth.tar'
        for i in range(4)
    ]
    ROI_YOLOX_ENGINE_PATH = '/kaggle/input/rsna-breast-cancer-detection-best-ckpts/yolox_nano_416_roi_trt_p100.pth'
    CSV_PATH = '/kaggle/input/rsna-breast-cancer-detection/test.csv'
    DCM_ROOT_DIR = '/kaggle/input/rsna-breast-cancer-detection/test_images'
    SAVE_IMG_ROOT_DIR = '/kaggle/tmp/pngs'
    N_CHUNKS = 2
    N_CPUS = 2
    RM_DONE_CHUNK = True
elif MODE == 'LOCAL-VAL':
    TRT_MODEL_PATH = './assets/best_convnext_ensemble_batch2_fp32_torch2trt.engine'
    TORCH_MODEL_CKPT_PATHS = [
        f'./assets/best_convnext_fold_{i}.pth.tar'
        for i in range(4)
    ]
    ROI_YOLOX_ENGINE_PATH = '../roi_det/YOLOX/YOLOX_outputs/yolox_nano_bre_416/model_trt.pth'
    CSV_PATH = '../../datasets/cv/v1/val_fold_0.csv'
    DCM_ROOT_DIR = '../../datasets/train_images/'
    SAVE_IMG_ROOT_DIR = './temp_save'
    N_CHUNKS = 2
    N_CPUS = 2
    RM_DONE_CHUNK = False

In [19]:
class PydicomMetadata:

    def __init__(self, ds):
        if "WindowWidth" not in ds or "WindowCenter" not in ds:
            self.window_widths = []
            self.window_centers = []
        else:
            ww = ds['WindowWidth']
            wc = ds['WindowCenter']
            self.window_widths = [float(e) for e in ww
                                  ] if ww.VM > 1 else [float(ww.value)]

            self.window_centers = [float(e) for e in wc
                                   ] if wc.VM > 1 else [float(wc.value)]

        # if nan --> LINEAR
        self.voilut_func = str(ds.get('VOILUTFunction', 'LINEAR')).upper()
        self.invert = (ds.PhotometricInterpretation == 'MONOCHROME1')
        assert len(self.window_widths) == len(self.window_centers)


class DicomsdlMetadata:

    def __init__(self, ds):
        self.window_widths = ds.WindowWidth
        self.window_centers = ds.WindowCenter
        if self.window_widths is None or self.window_centers is None:
            self.window_widths = []
            self.window_centers = []
        else:
            try:
                if not isinstance(self.window_widths, list):
                    self.window_widths = [self.window_widths]
                self.window_widths = [float(e) for e in self.window_widths]
                if not isinstance(self.window_centers, list):
                    self.window_centers = [self.window_centers]
                self.window_centers = [float(e) for e in self.window_centers]
            except:
                self.window_widths = []
                self.window_centers = []

        # if nan --> LINEAR
        self.voilut_func = ds.VOILUTFunction
        if self.voilut_func is None:
            self.voilut_func = 'LINEAR'
        else:
            self.voilut_func = str(self.voilut_func).upper()
        self.invert = (ds.PhotometricInterpretation == 'MONOCHROME1')
        assert len(self.window_widths) == len(self.window_centers)

In [20]:
# slow
# from pydicom's source
def _apply_windowing_np_v1(arr,
                           window_width=None,
                           window_center=None,
                           voi_func='LINEAR',
                           y_min=0,
                           y_max=255):
    assert window_width > 0
    y_range = y_max - y_min
    # float64 needed (default) or just float32 ?
    # arr = arr.astype(np.float64)
    arr = arr.astype(np.float32)

    if voi_func in ['LINEAR', 'LINEAR_EXACT']:
        # PS3.3 C.11.2.1.2.1 and C.11.2.1.3.2
        if voi_func == 'LINEAR':
            if window_width < 1:
                raise ValueError(
                    "The (0028,1051) Window Width must be greater than or "
                    "equal to 1 for a 'LINEAR' windowing operation")
            window_center -= 0.5
            window_width -= 1
        below = arr <= (window_center - window_width / 2)
        above = arr > (window_center + window_width / 2)
        between = np.logical_and(~below, ~above)

        arr[below] = y_min
        arr[above] = y_max
        if between.any():
            arr[between] = ((
                (arr[between] - window_center) / window_width + 0.5) * y_range
                            + y_min)
    elif voi_func == 'SIGMOID':
        arr = y_range / (1 +
                         np.exp(-4 *
                                (arr - window_center) / window_width)) + y_min
    else:
        raise ValueError(
            f"Unsupported (0028,1056) VOI LUT Function value '{voi_func}'")
    return arr


def _apply_windowing_np_v2(arr,
                           window_width=None,
                           window_center=None,
                           voi_func='LINEAR',
                           y_min=0,
                           y_max=255):
    assert window_width > 0
    y_range = y_max - y_min
    # float64 needed (default) or just float32 ?
    # arr = arr.astype(np.float64)
    arr = arr.astype(np.float32)

    if voi_func == 'LINEAR' or voi_func == 'LINEAR_EXACT':
        # PS3.3 C.11.2.1.2.1 and C.11.2.1.3.2
        if voi_func == 'LINEAR':
            if window_width < 1:
                raise ValueError(
                    "The (0028,1051) Window Width must be greater than or "
                    "equal to 1 for a 'LINEAR' windowing operation")
            window_center -= 0.5
            window_width -= 1

        # simple trick to improve speed
        s = y_range / window_width
        b = (-window_center / window_width + 0.5) * y_range + y_min
        arr = arr * s + b
        arr = np.clip(arr, y_min, y_max)

    elif voi_func == 'SIGMOID':
        # simple trick to improve speed
        s = -4 / window_width
        arr = y_range / (1 + np.exp((arr - window_center) * s)) + y_min
    else:
        raise ValueError(
            f"Unsupported (0028,1056) VOI LUT Function value '{voi_func}'")
    return arr


def _apply_windowing_torch(arr,
                           window_width=None,
                           window_center=None,
                           voi_func='LINEAR',
                           y_min=0,
                           y_max=255):
    assert window_width > 0
    y_range = y_max - y_min
    # float64 needed (default) or just float32 ?
    # arr = arr.double()
    arr = arr.float()

    if voi_func == 'LINEAR' or voi_func == 'LINEAR_EXACT':
        # PS3.3 C.11.2.1.2.1 and C.11.2.1.3.2
        if voi_func == 'LINEAR':
            if window_width < 1:
                raise ValueError(
                    "The (0028,1051) Window Width must be greater than or "
                    "equal to 1 for a 'LINEAR' windowing operation")
            window_center -= 0.5
            window_width -= 1

        # simple trick to improve speed
        s = y_range / window_width
        b = (-window_center / window_width + 0.5) * y_range + y_min
        arr = arr * s + b
        arr = torch.clamp(arr, y_min, y_max)

    elif voi_func == 'SIGMOID':
        # simple trick to improve speed
        s = -4 / window_width
        arr = y_range / (1 + torch.exp((arr - window_center) * s)) + y_min
    else:
        raise ValueError(
            f"Unsupported (0028,1056) VOI LUT Function value '{voi_func}'")
    return arr


def apply_windowing(arr,
                    window_width=None,
                    window_center=None,
                    voi_func='LINEAR',
                    y_min=0,
                    y_max=255,
                    backend='np_v2'):
    if backend == 'torch':
        if isinstance(arr, torch.Tensor):
            pass
        elif isinstance(arr, np.ndarray):
            if arr.dtype == np.uint16:
                arr = torch.from_numpy(arr, torch.int16)
            else:
                arr = torch.from_numpy(arr)

    if backend == 'np_v1':
        windowing_func = _apply_windowing_np_v1
    elif backend == 'np_v2':
        windowing_func = _apply_windowing_np_v2
    elif backend == 'torch':
        windowing_func = _apply_windowing_torch
    else:
        raise ValueError(
            f'Invalid backend {backend}, must be one of ["np", "np_v2", "torch"]'
        )

    arr = windowing_func(arr,
                         window_width=window_width,
                         window_center=window_center,
                         voi_func=voi_func,
                         y_min=y_min,
                         y_max=y_max)
    return arr

In [21]:
def min_max_scale(img):
    maxv = img.max()
    minv = img.min()
    if maxv > minv:
        return (img - minv) / (maxv - minv)
    else:
        return img - minv  # ==0


#@TODO: percentile on both min-max?
# this version is not correctly implemented, but used in the winning submission
def percentile_min_max_scale(img, pct=99):
    if isinstance(img, np.ndarray):
        maxv = np.percentile(img, pct) - 1
        minv = img.min()
        assert maxv >= minv
        if maxv > minv:
            ret = (img - minv) / (maxv - minv)
        else:
            ret = img - minv  # ==0
        ret = np.clip(ret, 0, 1)
    elif isinstance(img, torch.Tensor):
        maxv = torch.quantile(img, pct / 100) - 1
        minv = img.min()
        assert maxv >= minv
        if maxv > minv:
            ret = (img - minv) / (maxv - minv)
        else:
            ret = img - minv  # ==0
        ret = torch.clamp(ret, 0, 1)
    else:
        raise ValueError(
            'Invalid img type, should be numpy array or torch.Tensor')
    return ret


def resize_and_pad(img, input_size=MODEL_INPUT_SIZE):
    input_h, input_w = input_size
    ori_h, ori_w = img.shape[:2]
    ratio = min(input_h / ori_h, input_w / ori_w)
    # resize
    img = F.interpolate(img.view(1, 1, ori_h, ori_w),
                        mode="bilinear",
                        scale_factor=ratio,
                        recompute_scale_factor=True)[0, 0]
    # padding
    padded_img = torch.zeros((input_h, input_w),
                             dtype=img.dtype,
                             device='cuda')
    cur_h, cur_w = img.shape
    y_start = (input_h - cur_h) // 2
    x_start = (input_w - cur_w) // 2
    padded_img[y_start:y_start + cur_h, x_start:x_start + cur_w] = img
    padded_img = padded_img.unsqueeze(-1).expand(-1, -1, 3)
    return padded_img


def save_img_to_file(save_path, img, backend='cv2'):
    file_ext = os.path.basename(save_path).split('.')[-1]
    if backend == 'cv2':
        if img.dtype == np.uint16:
            # https://docs.opencv.org/3.4/d4/da8/group__imgcodecs.html#gabbc7ef1aa2edfaa87772f1202d67e0ce
            assert file_ext in ['png', 'jp2', 'tiff', 'tif']
            cv2.imwrite(save_path, img)
        elif img.dtype == np.uint8:
            cv2.imwrite(save_path, img)
        else:
            raise ValueError(
                '`cv2` backend only support uint8 or uint16 images.')
    elif backend == 'np':
        assert file_ext == 'npy'
        np.save(save_path, img)
    else:
        raise ValueError(f'Unsupported backend `{backend}`.')


def load_img_from_file(img_path, backend='cv2'):
    if backend == 'cv2':
        return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH)
    elif backend == 'np':
        return np.load(img_path)
    else:
        raise ValueError()
        

def make_uid_transfer_dict(df, dcm_root_dir):
    machine_id_to_transfer = {}
    machine_id = df.machine_id.unique()
    for i in machine_id:
        row = df[df.machine_id == i].iloc[0]
        sample_dcm_path = os.path.join(dcm_root_dir, str(row.patient_id),
                                       f'{row.image_id}.dcm')
        dicom = pydicom.dcmread(sample_dcm_path)
        machine_id_to_transfer[i] = dicom.file_meta.TransferSyntaxUID
    return machine_id_to_transfer

In [23]:
# DALI patch for INT16 support
################################################################################
import types
import dali
DALI2TORCH_TYPES = {
    types.DALIDataType.FLOAT: torch.float32,
    types.DALIDataType.FLOAT64: torch.float64,
    types.DALIDataType.FLOAT16: torch.float16,
    types.DALIDataType.UINT8: torch.uint8,
    types.DALIDataType.INT8: torch.int8,
    types.DALIDataType.UINT16: torch.int16,
    types.DALIDataType.INT16: torch.int16,
    types.DALIDataType.INT32: torch.int32,
    types.DALIDataType.INT64: torch.int64
}

TORCH_DTYPES = {
    'uint8': torch.uint8,
    'float16': torch.float16,
    'float32': torch.float32,
    'float64': torch.float64,
}


# @TODO: dangerous to copy from UINT16 to INT16 (memory layout?)
# little/big endian ?
# @TODO: faster reuse memory without copying: https://github.com/NVIDIA/DALI/issues/4126
def feed_ndarray(dali_tensor, arr, cuda_stream=None):
    """
    Copy contents of DALI tensor to PyTorch's Tensor.

    Parameters
    ----------
    `dali_tensor` : nvidia.dali.backend.TensorCPU or nvidia.dali.backend.TensorGPU
                    Tensor from which to copy
    `arr` : torch.Tensor
            Destination of the copy
    `cuda_stream` : torch.cuda.Stream, cudaStream_t or any value that can be cast to cudaStream_t.
                    CUDA stream to be used for the copy
                    (if not provided, an internal user stream will be selected)
                    In most cases, using pytorch's current stream is expected (for example,
                    if we are copying to a tensor allocated with torch.zeros(...))
    """
    dali_type = DALI2TORCH_TYPES[dali_tensor.dtype]

    assert dali_type == arr.dtype, (
        "The element type of DALI Tensor/TensorList"
        " doesn't match the element type of the target PyTorch Tensor: "
        "{} vs {}".format(dali_type, arr.dtype))
    assert dali_tensor.shape() == list(arr.size()), \
        ("Shapes do not match: DALI tensor has size {0}, but PyTorch Tensor has size {1}".
            format(dali_tensor.shape(), list(arr.size())))
    cuda_stream = types._raw_cuda_stream(cuda_stream)

    # turn raw int to a c void pointer
    c_type_pointer = ctypes.c_void_p(arr.data_ptr())
    if isinstance(dali_tensor, (TensorGPU, TensorListGPU)):
        stream = None if cuda_stream is None else ctypes.c_void_p(cuda_stream)
        dali_tensor.copy_to_external(c_type_pointer, stream, non_blocking=True)
    else:
        dali_tensor.copy_to_external(c_type_pointer)
    return arr


class _JStreamExternalSource:
    """DALI External Source for in-memory dicom decoding"""

    def __init__(self, dcm_paths, batch_size=1):
        self.dcm_paths = dcm_paths
        self.len = len(dcm_paths)
        self.batch_size = batch_size

    def __call__(self, batch_info):
        idx = batch_info.iteration
        # print('IDX:', batch_info.iteration, batch_info.epoch_idx)
        start = idx * self.batch_size
        end = min(self.len, start + self.batch_size)
        if end <= start:
            raise StopIteration()

        batch_dcm_paths = self.dcm_paths[start:end]
        j_streams = []
        inverts = []
        windowing_params = []
        voilut_funcs = []

        for dcm_path in batch_dcm_paths:
            ds = pydicom.dcmread(dcm_path)
            pixel_data = ds.PixelData
            offset = pixel_data.find(
                SUID2HEADER[ds.file_meta.TransferSyntaxUID])
            j_stream = np.array(bytearray(pixel_data[offset:]), np.uint8)
            invert = (ds.PhotometricInterpretation == 'MONOCHROME1')
            meta = PydicomMetadata(ds)
            windowing_param = np.array(
                [meta.window_centers, meta.window_widths], np.float16)
            voilut_func = VOILUT_FUNCS_MAP[meta.voilut_func]
            j_streams.append(j_stream)
            inverts.append(invert)
            windowing_params.append(windowing_param)
            voilut_funcs.append(voilut_func)
        return j_streams, np.array(inverts, dtype=np.bool_), \
            windowing_params, np.array(voilut_funcs, dtype=np.uint8)


@dali.pipeline_def
def _dali_pipeline(eii):
    jpeg, invert, windowing_param, voilut_func = dali.fn.external_source(
        source=eii,
        num_outputs=4,
        dtype=[
            dali.types.UINT8, dali.types.BOOL, dali.types.FLOAT16,
            dali.types.UINT8
        ],
        batch=True,
        batch_info=True,
        parallel=True)
    ori_img = dali.fn.experimental.decoders.image(
        jpeg,
        device='mixed',
        output_type=dali.types.ANY_DATA,
        dtype=dali.types.UINT16)
    return ori_img, invert, windowing_param, voilut_func


def decode_crop_save_dali(roi_yolox_engine_path,
                          dcm_paths,
                          save_paths,
                          save_backend='cv2',
                          batch_size=1,
                          num_threads=1,
                          py_num_workers=1,
                          py_start_method='fork',
                          device_id=0):
    """DALI dicom decoding --> ROI cropping --> norm --> save as 8-bits PNG"""
    
    assert len(dcm_paths) == len(save_paths)
    assert save_backend in ['cv2', 'np']
    num_dcms = len(dcm_paths)

    # dali to process with chunk in-memory
    external_source = _JStreamExternalSource(dcm_paths, batch_size=batch_size)
    pipe = _dali_pipeline(
        external_source,
        py_num_workers=py_num_workers,
        py_start_method=py_start_method,
        batch_size=batch_size,
        num_threads=num_threads,
        device_id=device_id,
        debug=False,
    )
    pipe.build()

    roi_extractor = roi_extract.RoiExtractor(engine_path=roi_yolox_engine_path,
                                             input_size=ROI_YOLOX_INPUT_SIZE,
                                             num_classes=1,
                                             conf_thres=ROI_YOLOX_CONF_THRES,
                                             nms_thres=ROI_YOLOX_NMS_THRES,
                                             class_agnostic=False,
                                             area_pct_thres=ROI_AREA_PCT_THRES,
                                             hw=ROI_YOLOX_HW,
                                             strides=ROI_YOLOX_STRIDES,
                                             exp=None)
    print('ROI extractor (YOLOX) loaded!')

    num_batchs = num_dcms // batch_size
    last_batch_size = batch_size
    if num_dcms % batch_size > 0:
        num_batchs += 1
        last_batch_size = num_dcms % batch_size

    cur_idx = -1
    for _batch_idx in tqdm(range(num_batchs)):
        try:
            outs = pipe.run()
        except Exception as e:
            #             print('DALI exception occur:', e)
            print(
                f'Exception: One of {dcm_paths[_batch_idx * batch_size: (_batch_idx + 1) * batch_size]} can not be decoded.'
            )
            # ignore this batch and re-build pipeline
            if _batch_idx < num_batchs - 1:
                cur_idx += batch_size
                del external_source, pipe
                gc.collect()
                torch.cuda.empty_cache()
                external_source = _JStreamExternalSource(
                    dcm_paths[(_batch_idx + 1) * batch_size:],
                    batch_size=batch_size)
                pipe = _dali_pipeline(
                    external_source,
                    py_num_workers=py_num_workers,
                    py_start_method=py_start_method,
                    batch_size=batch_size,
                    num_threads=num_threads,
                    device_id=device_id,
                    debug=False,
                )
                pipe.build()
            else:
                cur_idx += last_batch_size
            continue

        imgs = outs[0]
        inverts = outs[1]
        windowing_params = outs[2]
        voilut_funcs = outs[3]
        for j in range(len(inverts)):
            cur_idx += 1
            save_path = save_paths[cur_idx]
            img_dali = imgs[j]
            img_torch = torch.empty(img_dali.shape(),
                                    dtype=torch.int16,
                                    device='cuda')
            feed_ndarray(img_dali,
                         img_torch,
                         cuda_stream=torch.cuda.current_stream(device=0))
            # @TODO: test whether copy uint16 to int16 pointer is safe in this case
            if 0:
                img_np = img_dali.as_cpu().squeeze(-1)  # uint16
                print(type(img_np), img_np.shape)
                img_np = torch.from_numpy(img_np, dtype=torch.int16)
                diff = torch.max(torch.abs(img_np - img_torch))
                assert diff == 0, f'{img_torch.shape}, {img_np.shape}, {diff}'

            invert = inverts.at(j).item()
            windowing_param = windowing_params.at(j)
            voilut_func = voilut_funcs.at(j).item()
            voilut_func = VOILUT_FUNCS_INV_MAP[voilut_func]

            # YOLOX for ROI extraction
            img_yolox = min_max_scale(img_torch)
            img_yolox = (img_yolox * 255)  # float32
            if invert:
                img_yolox = 255 - img_yolox
            # YOLOX infer
            # who know if exception happen in hidden test ?
            try:
                xyxy, _area_pct, _conf = roi_extractor.detect_single(img_yolox)
                if xyxy is not None:
                    x0, y0, x1, y1 = xyxy
                    crop = img_torch[y0:y1, x0:x1]
                else:
                    crop = img_torch
            except:
                print('ROI extract exception!')
                crop = img_torch

            # apply windowing
            if windowing_param.shape[1] != 0:
                default_window_center = windowing_param[0, 0]
                default_window_width = windowing_param[1, 0]
                crop = apply_windowing(crop,
                                       window_width=default_window_width,
                                       window_center=default_window_center,
                                       voi_func=voilut_func,
                                       y_min=0,
                                       y_max=255,
                                       backend='torch')
            # if no window center/width in dcm file
            # do simple min-max scaling
            else:
                print('No windowing param!')
                crop = min_max_scale(crop)
                crop = crop * 255
            if invert:
                crop = 255 - crop
            crop = resize_and_pad(crop, MODEL_INPUT_SIZE)
            crop = crop.to(torch.uint8)
            crop = crop.cpu().numpy()
            save_img_to_file(save_path, crop, backend=save_backend)


#     assert cur_idx == len(
#         save_paths) - 1, f'{cur_idx} != {len(save_paths) - 1}'
    try:
        del external_source, pipe, roi_extractor
    except:
        pass
    gc.collect()
    torch.cuda.empty_cache()
    return


def decode_and_save_dali_parallel(
        roi_yolox_engine_path,
        dcm_paths,
        save_paths,
        save_backend='cv2',
        batch_size=1,
        num_threads=1,
        py_num_workers=1,
        py_start_method='fork',
        device_id=0,
        parallel_n_jobs=2,
        parallel_n_chunks=4,
        parallel_backend='joblib',  # joblib or multiprocessing
        joblib_backend='loky'):
    assert parallel_backend in ['joblib', 'multiprocessing']
    assert joblib_backend in ['threading', 'multiprocessing', 'loky']
    # py_num_workers > 0 means using multiprocessing worker
    # 'fork' multiprocessing after CUDA init is not work (we must use 'spawn' instead)
    # since our pipeline can be re-build (when a dicom can't be decoded on GPU),
    # 2 options:
    #       (py_num_workers = 0, py_start_method=?)
    #       (py_num_workers > 0, py_start_method = 'spawn')
    assert not (py_num_workers > 0 and py_start_method == 'fork')

    if parallel_n_jobs == 1:
        print('No parralel. Starting the tasks within current process.')
        return decode_crop_save_dali(roi_yolox_engine_path,
                                     dcm_paths,
                                     save_paths,
                                     save_backend=save_backend,
                                     batch_size=batch_size,
                                     num_threads=num_threads,
                                     py_num_workers=py_num_workers,
                                     py_start_method=py_start_method,
                                     device_id=device_id)
    else:
        num_samples = len(dcm_paths)
        num_samples_per_chunk = num_samples // parallel_n_chunks
        if num_samples % parallel_n_chunks > 0:
            num_samples_per_chunk += 1
        starts = [num_samples_per_chunk * i for i in range(parallel_n_chunks)]
        ends = [
            min(start + num_samples_per_chunk, num_samples) for start in starts
        ]
        if isinstance(device_id, list):
            assert len(device_id) == parallel_n_chunks
        elif isinstance(device_id, int):
            device_id = [device_id] * parallel_n_chunks

        print(
            f'Starting {parallel_n_jobs} jobs with backend `{parallel_backend}`, {parallel_n_chunks} chunks ...'
        )
        if parallel_backend == 'joblib':
            _ = Parallel(n_jobs=parallel_n_jobs, backend=joblib_backend)(
                delayed(decode_crop_save_dali)(
                    roi_yolox_engine_path,
                    dcm_paths[start:end],
                    save_paths[start:end],
                    save_backend=save_backend,
                    batch_size=batch_size,
                    num_threads=num_threads,
                    py_num_workers=py_num_workers,  # ram_v3
                    py_start_method=py_start_method,
                    device_id=worker_device_id,
                ) for start, end, worker_device_id in zip(
                    starts, ends, device_id))
        else:  # manually start multiprocessing's processes
            workers = []
            daemon = False if py_num_workers > 0 else True
            for i in range(parallel_n_jobs):
                start = starts[i]
                end = ends[i]
                worker_device_id = device_id[i]
                worker = mp.Process(group=None,
                                    target=decode_crop_save_dali,
                                    args=(
                                        roi_yolox_engine_path,
                                        dcm_paths[start:end],
                                        save_paths[start:end],
                                    ),
                                    kwargs={
                                        'save_backend': save_backend,
                                        'batch_size': batch_size,
                                        'num_threads': num_threads,
                                        'py_num_workers': py_num_workers,
                                        'py_start_method': py_start_method,
                                        'device_id': worker_device_id,
                                    },
                                    daemon=daemon)
                workers.append(worker)
            for worker in workers:
                worker.start()
            for worker in workers:
                worker.join()
    return


def _single_decode_crop_save_sdl(roi_extractor,
                                 dcm_path,
                                 save_path,
                                 save_backend='cv2',
                                 index=0):
    dcm = dicomsdl.open(dcm_path)
    meta = DicomsdlMetadata(dcm)
    info = dcm.getPixelDataInfo()
    if info['SamplesPerPixel'] != 1:
        raise RuntimeError('SamplesPerPixel != 1')
    else:
        shape = [info['Rows'], info['Cols']]

    ori_dtype = info['dtype']
    img = np.empty(shape, dtype=ori_dtype)
    dcm.copyFrameData(index, img)
    img_torch = torch.from_numpy(img.astype(np.int16)).cuda()

    # YOLOX for ROI extraction
    img_yolox = min_max_scale(img_torch)
    img_yolox = (img_yolox * 255)  # float32
    # @TODO: subtract on large array --> should move after F.interpolate()
    if meta.invert:
        img_yolox = 255 - img_yolox
    # YOLOX infer
    try:
        xyxy, _area_pct, _conf = roi_extractor.detect_single(img_yolox)
        if xyxy is not None:
            x0, y0, x1, y1 = xyxy
            crop = img_torch[y0:y1, x0:x1]
        else:
            crop = img_torch
    except:
        print('ROI extract exception!')
        crop = img_torch

    # apply voi lut
    if meta.window_widths:
        crop = apply_windowing(crop,
                               window_width=meta.window_widths[0],
                               window_center=meta.window_centers[0],
                               voi_func=meta.voilut_func,
                               y_min=0,
                               y_max=255,
                               backend='torch')
    else:
        print('No windowing param!')
        crop = min_max_scale(crop)
        crop = crop * 255

    if meta.invert:
        crop = 255 - crop
    crop = resize_and_pad(crop, MODEL_INPUT_SIZE)
    crop = crop.to(torch.uint8)
    crop = crop.cpu().numpy()
    save_img_to_file(save_path, crop, backend=save_backend)


def decode_crop_save_sdl(roi_yolox_engine_path,
                         dcm_paths,
                         save_paths,
                         save_backend='cv2'):
    """DicomSDL decoding --> ROI cropping --> norm --> save as 8-bits PNG"""
    
    assert len(dcm_paths) == len(save_paths)
    roi_detector = roi_extract.RoiExtractor(engine_path=roi_yolox_engine_path,
                                            input_size=ROI_YOLOX_INPUT_SIZE,
                                            num_classes=1,
                                            conf_thres=ROI_YOLOX_CONF_THRES,
                                            nms_thres=ROI_YOLOX_NMS_THRES,
                                            class_agnostic=False,
                                            area_pct_thres=ROI_AREA_PCT_THRES,
                                            hw=ROI_YOLOX_HW,
                                            strides=ROI_YOLOX_STRIDES,
                                            exp=None)
    print('ROI extractor (YOLOX) loaded!')
    for i in tqdm(range(len(dcm_paths))):
        _single_decode_crop_save_sdl(roi_detector, dcm_paths[i], save_paths[i],
                                     save_backend)

    del roi_detector
    gc.collect()
    torch.cuda.empty_cache()
    return


def decode_crop_save_sdl_parallel(roi_yolox_engine_path,
                                  dcm_paths,
                                  save_paths,
                                  save_backend='cv2',
                                  parallel_n_jobs=2,
                                  parallel_n_chunks=4,
                                  joblib_backend='loky'):
    assert len(dcm_paths) == len(save_paths)
    if parallel_n_jobs == 1:
        print('No parralel. Starting the tasks within current process.')
        return decode_crop_save_sdl(roi_yolox_engine_path, dcm_paths,
                                    save_paths, save_backend)
    else:
        num_samples = len(dcm_paths)
        num_samples_per_chunk = num_samples // parallel_n_chunks
        if num_samples % parallel_n_chunks > 0:
            num_samples_per_chunk += 1
        starts = [num_samples_per_chunk * i for i in range(parallel_n_chunks)]
        ends = [
            min(start + num_samples_per_chunk, num_samples) for start in starts
        ]

        print(
            f'Starting {parallel_n_jobs} jobs with backend `{joblib_backend}`, {parallel_n_chunks} chunks...'
        )
        _ = Parallel(n_jobs=parallel_n_jobs, backend=joblib_backend)(
            delayed(decode_crop_save_sdl)(roi_yolox_engine_path,
                                          dcm_paths[start:end],
                                          save_paths[start:end], save_backend)
            for start, end in zip(starts, ends))

ModuleNotFoundError: No module named 'dali'