## Requirements

In [None]:
!pip install -q albumentations==1.2.1 --no-index --find-links=/home/br/workspace/RSNA2023/input/rsna-bc-pip-requirements
!pip install -q pylibjpeg-libjpeg==1.3.1 --no-index --find-links=/home/br/workspace/RSNA2023/input/rsna-bc-pip-requirements
!pip install -q pydicom==2.0.0 --no-index --find-links=/home/br/workspace/RSNA2023/input/rsna-bc-pip-requirements
!pip install -q python-gdcm==3.0.20 --no-index --find-links=/home/br/workspace/RSNA2023/input/rsna-bc-pip-requirements
!pip install -q dicomsdl==0.109.1 --no-index --find-links=/home/br/workspace/RSNA2023/input/rsna-bc-pip-requirements

# dali
!pip install -q /home/br/workspace/RSNA2023/input/nvidia-dali-nightly-cuda110-1230dev/nvidia_dali_nightly_cuda110-1.23.0.dev20230203-7187866-py3-none-manylinux2014_x86_64.whl

 # Config

In [1]:
import os

IMG_SIZE = 1536
NBIT = 16

In [2]:
import sys
sys.path.append("/home/br/workspace/RSNA2023/input/pytorch-image-models-main/")
import timm
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os
from copy import copy
import gc
import shutil 

import glob
from scipy.special import expit

import albumentations as A
import cv2
cv2.setNumThreads(0)

import dicomsdl
import pydicom
from pydicom.filebase import DicomBytesIO

from os.path import join

from tqdm import tqdm

from joblib import Parallel, delayed
import multiprocessing as mp

from types import SimpleNamespace
from typing import Any, Dict

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast


import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali import pipeline_def
from nvidia.dali.types import DALIDataType


from nvidia.dali.backend import TensorGPU, TensorListGPU
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
from nvidia.dali import types
from nvidia.dali.plugin.base_iterator import _DaliBaseIterator
from nvidia.dali.plugin.base_iterator import LastBatchPolicy
import torch
import torch.utils.dlpack as torch_dlpack
import ctypes
import numpy as np
import torch.nn.functional as F
import pydicom

to_torch_type = {
    types.DALIDataType.FLOAT:   torch.float32,
    types.DALIDataType.FLOAT64: torch.float64,
    types.DALIDataType.FLOAT16: torch.float16,
    types.DALIDataType.UINT8:   torch.uint8,
    types.DALIDataType.INT8:    torch.int8,
    types.DALIDataType.UINT16:  torch.int16,
    types.DALIDataType.INT16:   torch.int16,
    types.DALIDataType.INT32:   torch.int32,
    types.DALIDataType.INT64:   torch.int64
}


def feed_ndarray(dali_tensor, arr, cuda_stream=None):
    """
    Copy contents of DALI tensor to PyTorch's Tensor.

    Parameters
    ----------
    `dali_tensor` : nvidia.dali.backend.TensorCPU or nvidia.dali.backend.TensorGPU
                    Tensor from which to copy
    `arr` : torch.Tensor
            Destination of the copy
    `cuda_stream` : torch.cuda.Stream, cudaStream_t or any value that can be cast to cudaStream_t.
                    CUDA stream to be used for the copy
                    (if not provided, an internal user stream will be selected)
                    In most cases, using pytorch's current stream is expected (for example,
                    if we are copying to a tensor allocated with torch.zeros(...))
    """
    dali_type = to_torch_type[dali_tensor.dtype]

    assert dali_type == arr.dtype, ("The element type of DALI Tensor/TensorList"
                                    " doesn't match the element type of the target PyTorch Tensor: "
                                    "{} vs {}".format(dali_type, arr.dtype))
    assert dali_tensor.shape() == list(arr.size()), \
        ("Shapes do not match: DALI tensor has size {0}, but PyTorch Tensor has size {1}".
            format(dali_tensor.shape(), list(arr.size())))
    cuda_stream = types._raw_cuda_stream(cuda_stream)

    # turn raw int to a c void pointer
    c_type_pointer = ctypes.c_void_p(arr.data_ptr())
    if isinstance(dali_tensor, (TensorGPU, TensorListGPU)):
        stream = None if cuda_stream is None else ctypes.c_void_p(cuda_stream)
        dali_tensor.copy_to_external(c_type_pointer, stream, non_blocking=True)
    else:
        dali_tensor.copy_to_external(c_type_pointer)
    return arr

@pipeline_def
def jpg_decode_pipeline(jpgfiles):
    jpegs, _ = fn.readers.file(files=jpgfiles)
    images = fn.experimental.decoders.image(jpegs, device='mixed', output_type=types.ANY_DATA, dtype=DALIDataType.UINT16)
    return images

def parse_window_element(elem):
    if type(elem)==list:
        return float(elem[0])
    if type(elem)==str:
        return float(elem)
    if type(elem)==float:
        return elem
    if type(elem)==pydicom.dataelem.DataElement:
        try:
            return float(elem[0])
        except:
            return float(elem.value)
    return None

def linear_window(data, center, width):
    lower, upper = center - width // 2, center + width // 2
    data = torch.clamp(data, min=lower, max=upper)
    return data 



# Comp data

In [None]:
DATA_FOLDER = '/home/br/workspace/RSNA2023/input/rsna-breast-cancer-detection/train_images/'

test_df = pd.read_csv(f'/home/br/workspace/RSNA2023/input/rsna-breast-cancer-detection/train.csv')
test_df["fns"] = test_df['patient_id'].astype(str) + '/' + test_df['image_id'].astype(str) + '.dcm'
test_df

In [None]:
SAVE_SIZE = int(IMG_SIZE * 1.125)
SAVE_FOLDER = f"/home/br/workspace/RSNA2023/input/images_gpugen/{IMG_SIZE}_{NBIT}bit/"
os.makedirs(SAVE_FOLDER, exist_ok=True)

N_CHUNKS = len(test_df["fns"]) // 2000 if len(test_df["fns"]) > 2000 else 1
CHUNKS = [(len(test_df["fns"]) / N_CHUNKS * k, len(test_df["fns"]) / N_CHUNKS * (k + 1)) for k in range(N_CHUNKS)]
CHUNKS = np.array(CHUNKS).astype(int)
JPG_FOLDER = f"/home/br/workspace/RSNA2023/input/images_gpugen/{IMG_SIZE}jpg/"
os.makedirs(JPG_FOLDER, exist_ok=True)

In [None]:
def convert_dicom_to_jpg(file, save_folder=""):
    patient = file.split('/')[-2]
    image = file.split('/')[-1][:-4]
    dcmfile = pydicom.dcmread(file)

    if dcmfile.file_meta.TransferSyntaxUID == '1.2.840.10008.1.2.4.90':
        with open(file, 'rb') as fp:
            raw = DicomBytesIO(fp.read())
            ds = pydicom.dcmread(raw)
        offset = ds.PixelData.find(b"\x00\x00\x00\x0C")  #<---- the jpeg2000 header info we're looking for
        hackedbitstream = bytearray()
        hackedbitstream.extend(ds.PixelData[offset:])
        with open(save_folder + f"{patient}_{image}.jpg", "wb") as binary_file:
            binary_file.write(hackedbitstream)
            
    if dcmfile.file_meta.TransferSyntaxUID == '1.2.840.10008.1.2.4.70':
        with open(file, 'rb') as fp:
            raw = DicomBytesIO(fp.read())
            ds = pydicom.dcmread(raw)
        offset = ds.PixelData.find(b"\xff\xd8\xff\xe0")  #<---- the jpeg lossless header info we're looking for
        hackedbitstream = bytearray()
        hackedbitstream.extend(ds.PixelData[offset:])
        with open(save_folder + f"{patient}_{image}.jpg", "wb") as binary_file:
            binary_file.write(hackedbitstream)


def process_dicom(img, dicom):
    try:
        invert = getattr(dicom, "PhotometricInterpretation", None) == "MONOCHROME1"
    except:
        invert = False
        
    center = parse_window_element(dicom["WindowCenter"]) 
    width = parse_window_element(dicom["WindowWidth"])
        
    if (center is not None) & (width is not None):
        img = linear_window(img, center, width)

    img = (img - img.min()) / (img.max() - img.min())
    if invert:
        img = 1 - img
    return img

## GPU Decoding

In [None]:
n_workers = 4
for ttt, chunk in enumerate(CHUNKS):
    print(f'chunk {ttt} of {len(CHUNKS)} chunks')
    os.makedirs(JPG_FOLDER, exist_ok=True)

    _ = Parallel(n_jobs=n_workers)(
        delayed(convert_dicom_to_jpg)(f'{DATA_FOLDER}/{img}', save_folder=JPG_FOLDER)
        for img in test_df["fns"].tolist()[chunk[0]: chunk[1]]
    )
    
    jpgfiles = glob.glob(JPG_FOLDER + "*.jpg")

    pipe = jpg_decode_pipeline(jpgfiles, batch_size=1, num_threads=n_workers, device_id=0)
    pipe.build()

    for i, f in enumerate(tqdm(jpgfiles)):
        patient, dicom_id = f.split('/')[-1][:-4].split('_')
        dicom = pydicom.dcmread(DATA_FOLDER + f"/{patient}/{dicom_id}.dcm")
        try:
            out = pipe.run()
            # Dali -> Torch
            img = out[0][0]
            img_torch = torch.empty(img.shape(), dtype=torch.int16, device="cuda")
            feed_ndarray(img, img_torch, cuda_stream=torch.cuda.current_stream(device=0))
            img = img_torch.float()

            #apply dicom preprocessing
            img = process_dicom(img, dicom)

            #resize the torch image
            img = F.interpolate(img.view(1, 1, img.size(0), img.size(1)), (SAVE_SIZE, SAVE_SIZE), mode="bilinear")[0, 0]

            if NBIT == 8:
                img = (img * 255).clip(0,255).to(torch.uint8).cpu().numpy() # 8 bit image
            elif NBIT == 16:
                img = (img * 65535).clip(0,65535).cpu().numpy().astype(np.uint16) # 16 bit image
            else:
                raise ValueError(f"Unsupported NBIT value: {NBIT}")
            
            out_file_name = SAVE_FOLDER + f"{patient}_{dicom_id}.png"
            cv2.imwrite(out_file_name, img)
    
        except Exception as e:
            print(i, e)
            pipe = jpg_decode_pipeline(jpgfiles[i+1:], batch_size=1, num_threads=n_workers, device_id=0)
            pipe.build()
            continue

    shutil.rmtree(JPG_FOLDER)
    
fns = glob.glob(f'{SAVE_FOLDER}/*.png')
print(f'GPU Processed: {len(fns)}')

## CPU Decoding

In [None]:
gpu_processed_files = [fn.split('/')[-1].replace('_','/').replace('png','dcm') for fn in fns]
to_process = [f for f in test_df["fns"].values if f not in gpu_processed_files]
print(f"GPU processed files number: {len(gpu_processed_files)}, remain number: {len(to_process)}")

In [None]:
def process(f, save_folder=""):
    patient = f.split('/')[-2]
    dicom_id = f.split('/')[-1][:-4]
    
    dicom = dicomsdl.open(f)
    img = dicom.pixelData()
    img = torch.from_numpy(img)
    img = process_dicom(img, dicom)
    
    img = F.interpolate(img.view(1, 1, img.size(0), img.size(1)), (SAVE_SIZE, SAVE_SIZE), mode="bilinear")[0, 0]

    ####
    if NBIT == 8:
        img = (img * 255).clip(0,255).to(torch.uint8).cpu().numpy() # 8 bit image
    elif NBIT == 16:
        img = (img * 65535).clip(0,65535).cpu().numpy().astype(np.uint16) # 16 bit image
    else:
        raise ValueError(f"Unsupported NBIT value: {NBIT}")

    out_file_name = SAVE_FOLDER + f"{patient}_{dicom_id}.png"
    cv2.imwrite(out_file_name, img)
    return out_file_name

In [None]:
cpu_processed_filenames = Parallel(n_jobs=2)(
    delayed(process)(f'{DATA_FOLDER}/{img}', save_folder=SAVE_FOLDER)
    for img in tqdm(to_process)
)
cpu_processed_filenames = [f for f in cpu_processed_filenames if f]
print(f'CPU Raw image load complete with {len(cpu_processed_filenames)} loaded')

In [None]:
n_saved = len(glob.glob(f'{SAVE_FOLDER}/*.png'))
print(f'Image on disk count : {n_saved}')
print(f"test_df length: {len(test_df)}")

gc.collect()
torch.cuda.empty_cache()
assert n_saved == len(test_df) == 54706

# VINDR


## vindr meta metadata

In [None]:
tmp_dcm = pydicom.dcmread('/home/br/workspace/RSNA2023/input/extra_data/vindr/images/57ef58281d90655693cd34628e4c8083/480a3198d3c64bcfb5bb193702500f3f.dicom')
print(f"file_meta.TransferSyntaxUID: {tmp_dcm.file_meta.TransferSyntaxUID}")
print(fr"jpeg2000 find:", tmp_dcm.PixelData.find(b'\x00\x00\x00\x0C'))
print(fr"jpeg lossless find:", tmp_dcm.PixelData.find(b'\xff\xd8\xff\xe0'))
print(f"PhotometricInterpretation: {getattr(tmp_dcm, 'PhotometricInterpretation', None)}")
print(f"WindowCenter: {parse_window_element(tmp_dcm['WindowCenter'])}")
print(f"WindowWidth: {parse_window_element(tmp_dcm['WindowWidth'])}")

In [None]:
transfersyntaxuids = []
jpeg2000s = []
jpeglosslesss = []
photometricinterpretations = []
windowcenters = []
windowwidths = []
for path in test_df["fns"]:
    dcm = pydicom.dcmread(f'{DATA_FOLDER}/{path}')
    transfersyntaxuids.append(dcm.file_meta.TransferSyntaxUID)
    jpeg2000s.append(dcm.PixelData.find(b'\x00\x00\x00\x0C'))
    jpeglosslesss.append(dcm.PixelData.find(b'\xff\xd8\xff\xe0'))
    photometricinterpretations.append(getattr(dcm, 'PhotometricInterpretation', None))
    windowcenters.append(parse_window_element(dcm['WindowCenter']))
    windowwidths.append(parse_window_element(dcm['WindowWidth']))

print(f"len(transfersyntaxuids): {len(transfersyntaxuids)}; len(jpeg2000s): {len(jpeg2000s)}; \
        len(jpeglosslesss): {len(jpeglosslesss)}; len(photometricinterpretations): {len(photometricinterpretations)}; \
        len(windowcenters): {len(windowcenters)}; len(windowwidths): {len(windowwidths)}")


test_df["transfersyntaxuids"] = transfersyntaxuids
test_df["jpeg2000s"] = jpeg2000s
test_df["jpeglosslesss"] = jpeglosslesss
test_df["photometricinterpretations"] = photometricinterpretations
test_df["windowcenters"] = windowcenters
test_df["windowwidths"] = windowwidths


from sklearn.model_selection import GroupKFold, StratifiedGroupKFold
#### breast_birads
split = StratifiedGroupKFold(5)
for k, (_, test_idx) in enumerate(split.split(test_df, test_df.breast_birads, groups=test_df.study_id)):
    test_df.loc[test_idx, 'split'] = k
test_df.split = test_df.split.astype(int)
test_df["sample_rand"] = np.random.rand(len(test_df)) 
# test_df.loc[test_df["breast_birads"]==1, "sample_rand"] = 0.0

test_df["breast_birads"] = test_df["breast_birads"].apply(lambda x: x.replace("BI-RADS ", ""))
test_df["breast_birads"] = test_df["breast_birads"].astype(int)
test_df.to_csv(f'/home/br/workspace/RSNA2023/input/extra_data/vindr/vindr.csv', index=False)

test_df

In [None]:
DATA_FOLDER = '../input/vindr/images'
test_df = pd.read_csv("../input/vindr/vindr.csv")
test_df

In [4]:
SAVE_SIZE = int(IMG_SIZE * 1.125)
SAVE_FOLDER = f"../input/images_gpugen/vindr_{IMG_SIZE}_{NBIT}bit_2/"
os.makedirs(SAVE_FOLDER, exist_ok=True)

N_CHUNKS = len(test_df["fns"]) // 2000 if len(test_df["fns"]) > 2000 else 1
CHUNKS = [(len(test_df["fns"]) / N_CHUNKS * k, len(test_df["fns"]) / N_CHUNKS * (k + 1)) for k in range(N_CHUNKS)]
CHUNKS = np.array(CHUNKS).astype(int)
JPG_FOLDER = f"../input/images_gpugen/vindr_{IMG_SIZE}jpg/"
os.makedirs(JPG_FOLDER, exist_ok=True)

## dicom process

In [None]:
def convert_dicom_to_jpg(file, save_folder=""):
    patient = file.split('/')[-2]
    image = file.split('/')[-1].replace('.dicom', '')

    with open(file, 'rb') as fp:
        raw = DicomBytesIO(fp.read())
        ds = pydicom.dcmread(raw)
    offset = ds.PixelData.find(b"\x00\x00\x00\x0C")  #<---- the jpeg lossless header info we're looking for
    if offset == -1:
        return None
    hackedbitstream = bytearray()
    hackedbitstream.extend(ds.PixelData[offset:])
    with open(save_folder + f"{patient}_{image}.jpg", "wb") as binary_file:
        binary_file.write(hackedbitstream)


def process_dicom(img, dicom):
    try:
        invert = getattr(dicom, "PhotometricInterpretation", None) == "MONOCHROME1"
    except:
        invert = False
        
    center = parse_window_element(dicom["WindowCenter"]) 
    width = parse_window_element(dicom["WindowWidth"])
        
    if (center is not None) & (width is not None):
        img = linear_window(img, center, width)

    img = (img - img.min()) / (img.max() - img.min())
    if invert:
        img = 1 - img
    return img

### Vindr GPU

In [None]:
# n_workers = 4
# for ttt, chunk in enumerate(CHUNKS):
#     print(f'chunk {ttt} of {len(CHUNKS)} chunks')
#     os.makedirs(JPG_FOLDER, exist_ok=True)

#     _ = Parallel(n_jobs=n_workers)(
#         delayed(convert_dicom_to_jpg)(f'{DATA_FOLDER}/{img}', save_folder=JPG_FOLDER)
#         for img in test_df["fns"].tolist()[chunk[0]: chunk[1]]
#     )
    
#     jpgfiles = glob.glob(JPG_FOLDER + "*.jpg")
#     print(f"len(jpgfiles): {len(jpgfiles)}, {len(jpgfiles)/(chunk[1]-chunk[0])}")

#     pipe = jpg_decode_pipeline(jpgfiles, batch_size=1, num_threads=n_workers, device_id=0)
#     pipe.build()

#     for i, f in enumerate(tqdm(jpgfiles)):
#         patient, dicom_id = f.split('/')[-1][:-4].split('_')
#         dicom = pydicom.dcmread(DATA_FOLDER + f"/{patient}/{dicom_id}.dicom")
#         try:
#             out = pipe.run()
#             # Dali -> Torch
#             img = out[0][0]
#             img_torch = torch.empty(img.shape(), dtype=torch.int16, device="cuda")
#             feed_ndarray(img, img_torch, cuda_stream=torch.cuda.current_stream(device=0))
#             img = img_torch.float()

#             #apply dicom preprocessing
#             img = process_dicom(img, dicom)

#             #resize the torch image
#             img = F.interpolate(img.view(1, 1, img.size(0), img.size(1)), (SAVE_SIZE, SAVE_SIZE), mode="bilinear")[0, 0]

#             ####
#             if NBIT == 8:
#                 img = (img * 255).clip(0,255).to(torch.uint8).cpu().numpy() # 8 bit image
#             elif NBIT == 16:
#                 img = (img * 65535).clip(0,65535).cpu().numpy().astype(np.uint16) # 16 bit image
#             else:
#                 raise ValueError(f"Unsupported NBIT value: {NBIT}")
            
#             out_file_name = SAVE_FOLDER + f"{patient}_{dicom_id}.png"
#             cv2.imwrite(out_file_name, img)
    
#         except Exception as e:
#             print(i, e)
#             pipe = jpg_decode_pipeline(jpgfiles[i+1:], batch_size=1, num_threads=n_workers, device_id=0)
#             pipe.build()
#             continue

#     shutil.rmtree(JPG_FOLDER)
    
# fns = glob.glob(f'{SAVE_FOLDER}/*.png')
# print(f'GPU Processed: {len(fns)}')

### Vindr CPU

In [None]:
# gpu_processed_files = [fn.split('/')[-1].replace('_','/').replace('png','dcm') for fn in fns]
# to_process = [f for f in test_df["fns"].values if f not in gpu_processed_files]
# print(f"GPU processed files number: {len(gpu_processed_files)}, remain number: {len(to_process)}")

In [None]:
# def process(f, save_folder=""):
#     patient = f.split('/')[-2]
#     dicom_id = f.split('/')[-1][:-4]
    
#     dicom = dicomsdl.open(f)
#     img = dicom.pixelData()
#     img = torch.from_numpy(img)
#     img = process_dicom(img, dicom)
    
#     img = F.interpolate(img.view(1, 1, img.size(0), img.size(1)), (SAVE_SIZE, SAVE_SIZE), mode="bilinear")[0, 0]

#     ####
#     if NBIT == 8:
#         img = (img * 255).clip(0,255).to(torch.uint8).cpu().numpy() # 8 bit image
#     elif NBIT == 16:
#         img = (img * 65535).clip(0,65535).cpu().numpy().astype(np.uint16) # 16 bit image
#     else:
#         raise ValueError(f"Unsupported NBIT value: {NBIT}")

#     out_file_name = SAVE_FOLDER + f"{patient}_{dicom_id}.png"
#     cv2.imwrite(out_file_name, img)
#     return out_file_name

In [None]:
# cpu_processed_filenames = Parallel(n_jobs=4)(
#     delayed(process)(f'{DATA_FOLDER}/{img}', save_folder=SAVE_FOLDER)
#     for img in tqdm(to_process)
# )
# cpu_processed_filenames = [f for f in cpu_processed_filenames if f]
# print(f'CPU Raw image load complete with {len(cpu_processed_filenames)} loaded')

In [None]:
# n_saved = len(glob.glob(f'{SAVE_FOLDER}/*.png'))
# print(f'Image on disk count : {n_saved}')
# print(f"test_df length: {len(test_df)}")

# gc.collect()
# torch.cuda.empty_cache()
# assert n_saved == len(test_df) == 20000

## cpu dicom

In [5]:
def process_dicom(img, dicom):
    try:
        invert = getattr(dicom, "PhotometricInterpretation", None) == "MONOCHROME1"
    except:
        invert = False
        
    center = parse_window_element(dicom["WindowCenter"]) 
    width = parse_window_element(dicom["WindowWidth"])
        
    if (center is not None) & (width is not None):
        img = linear_window(img, center, width)

    img = (img - img.min()) / (img.max() - img.min())
    if invert:
        img = 1 - img
    return img

def process2(f, size=SAVE_SIZE, data_folder=DATA_FOLDER, save_folder=SAVE_FOLDER):
    patient = f.split('/')[-2]
    dicom_id = f.split('/')[-1].replace('.dicom', '')


    dicom = dicomsdl.open(f"{data_folder}/{f}")
    img= dicom.pixelData()
    img = torch.from_numpy(img)
    img = process_dicom(img, dicom)
    img = F.interpolate(img.view(1, 1, img.size(0), img.size(1)), (size, size), mode="bilinear")[0, 0]

    if NBIT == 8:
        img = (img * 255).clip(0,255).to(torch.uint8).cpu().numpy()
    elif NBIT == 16:
        img = (img * 65535).clip(0,65535).cpu().numpy().astype(np.uint16)
    else:
        raise ValueError(f"Unsupported NBIT value: {NBIT}")

    out_file_name = save_folder + f"{patient}_{dicom_id}.png"
    cv2.imwrite(out_file_name, img)

In [None]:
_ = Parallel(n_jobs=5)(
    delayed(process2)(uid, size=SAVE_SIZE, data_folder=DATA_FOLDER, save_folder=SAVE_FOLDER)
    for uid in tqdm(test_df["fns"])
)