# 数据集链接

https://www.kaggle.com/competitions/rsna-breast-cancer-detection

https://www.kaggle.com/datasets/remekkinas/rsnamodules

https://www.kaggle.com/code/vslaykovsky/rsna-2022-whl

本笔记本演示了如何将本次比赛的300GB以上的巨大数据集处理成TFRecords，以便在训练期间快速加载数据。

TFRecords的优点是可以加载包含许多样本的大量数据，而不是单独加载每个图像和标签。

所有图像都调整为1024x1024并保存在100个TFRecords中，使每个TFRecord包含大约550张图像。

[RSNA EfficientNetV2 Training Tensorflow TPU](https://www.kaggle.com/code/markwijkhuizen/rsna-efficientnetv2-training-tensorflow-tpu)

**V2**

* 640x512 -> 1024x1024 resolution
* Cropping images
* Single image approach, not both CC and MLO image

**V3**

* 1024x1024 -> 768x1344 based on cropped image ratio
* using PNG encoded images instead of raw tensors to reduce disk space needed

**V5**

此版本中的更新：

* 纠正图像的线性/S 形归一化，非常感谢(https://www.kaggle.com/bobdegraaf) (https://www.kaggle.com/code/bobdegraaf/dicomsdl-voi-lut)
* 从PNG切换到JPEG，压缩级别为95，以保持在20GB磁盘空间内
* 调整裁剪算法，从最大值到阈值搜索，而不是从边缘到阈值搜索
* 将裁剪偏移量填充到图像尺寸以保留图像信息，而不是零填充

## 安装并可用必要的DICOM处理库

In [None]:
%%capture
# Source: https://www.kaggle.com/code/remekkinas/fast-dicom-processing-1-6-2x-faster?scriptVersionId=113360473
!pip install /kaggle/input/rsnamodules/dicomsdl-0.109.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl 

try:
    import pylibjpeg
except:
   !pip install /kaggle/input/rsna-2022-whl/{pylibjpeg-1.4.0-py3-none-any.whl,python_gdcm-3.0.15-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl}

In [None]:
import numpy as np
import pandas as pd
import pylibjpeg
import pydicom
import matplotlib as mpl
import matplotlib.pyplot as plt
import tensorflow as tf
from joblib import Parallel, delayed
from tqdm.notebook import tqdm
from multiprocessing import cpu_count
import cv2
import glob
import importlib
import os
import joblib
import sys
import dicomsdl
print(f'Tensorflow Version: {tf.__version__}')
print(f'Python Version: {sys.version}')

# Tensorflow and CV2 set number of threads to 1 for speedup in parallell function mapping
tf.config.threading.set_inter_op_parallelism_threads(num_threads=1)
cv2.setNumThreads(1)

# Config

In [None]:
mpl.rcParams.update(mpl.rcParamsDefault)
mpl.rcParams['xtick.labelsize'] = 16
mpl.rcParams['ytick.labelsize'] = 16
mpl.rcParams['axes.labelsize'] = 18
mpl.rcParams['axes.titlesize'] = 24

# 用于调试目的的交互标志
IS_INTERACTIVE = os.environ['KAGGLE_KERNEL_RUN_TYPE'] == 'Interactive'

# 处理图像的尺寸
TARGET_HEIGHT = 1344  # 目标高度
TARGET_WIDTH = 768    # 目标宽度
N_CHANNELS = 1       # 通道数（单通道灰度图像）
TARGET_HEIGHT_WIDTH_RATIO = TARGET_HEIGHT / TARGET_WIDTH  # 高宽比

# 图像归一化工具，未改善LB得分
# 教程见：https://docs.opencv.org/4.x/d5/daf/tutorial_py_histogram_equalization.html
CLAHE = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(32, 32))  # 创建CLAHE对象
APPLY_CLAHE = False  # 是否应用CLAHE
APPLY_EQ_HIST = False  # 是否应用直方图均衡化

# 图像格式和配置
IMAGE_FORMAT = 'JPG'  # 图像格式为JPG
IMAGE_QUALITY = 95     # 图像质量设为95

# 随机生成器种子
SEED = 42

## Train

In [None]:
# 读取RSNA乳腺癌检测的数据集，并为每个图像生成文件路径
if IS_INTERACTIVE:
    train = pd.read_csv('/kaggle/input/rsna-breast-cancer-detection/train.csv').head(1024)
else:
    train = pd.read_csv('/kaggle/input/rsna-breast-cancer-detection/train.csv')
    
def get_file_path(args):
    patient_id, image_id = args
    return f'/kaggle/input/rsna-breast-cancer-detection/train_images/{patient_id}/{image_id}.dcm'
    
train['file_path'] = train[['patient_id', 'image_id']].apply(get_file_path, axis=1)
    
display(train.info())
display(train.head())

# VOI_LUT

In [None]:
# 来源：https://www.kaggle.com/code/bobdegraaf/dicomsdl-voi-lut
def voi_lut(image, dicom):
    # 只加载我们需要的变量
    center = dicom['WindowCenter']  # 窗口中心
    width = dicom['WindowWidth']    # 窗口宽度
    bits_stored = dicom['BitsStored']  # 存储位数
    voi_lut_function = dicom['VOILUTFunction']  # VOI LUT函数

    # 对于SIGMOID函数，center和width是列表，否则是单个值
    if isinstance(center, list):
        center = center[0]  # 取列表中的第一个值
    if isinstance(width, list):
        width = width[0]    # 取列表中的第一个值

    # 设置y_min, y_max和范围
    y_min = 0
    y_max = float(2**bits_stored - 1)  # 最大值
    y_range = y_max  # 范围

    # 默认使用线性函数（所以对于NaN会使用线性）
    if voi_lut_function == 'SIGMOID':
        # 使用SIGMOID函数进行处理
        image = y_range / (1 + np.exp(-4 * (image - center) / width)) + y_min
    else:
        # 检查宽度是否小于1（在我们的情况下不必要，始终>= 750）
        center -= 0.5
        width -= 1

        # 根据阈值分类图像
        below = image <= (center - width / 2)  # 小于下阈值
        above = image > (center + width / 2)   # 大于上阈值
        between = np.logical_and(~below, ~above)  # 在阈值之间

        # 根据分类设置图像值
        image[below] = y_min  # 设置小于下阈值的像素为y_min
        image[above] = y_max  # 设置大于上阈值的像素为y_max
        if between.any():
            image[between] = (
                ((image[between] - center) / width + 0.5) * y_range + y_min
            )  # 在阈值之间的像素进行线性变换

    # 归一化，使背景为0，某些图像中0是最大强度
    if dicom['PhotometricInterpretation'] == 'MONOCHROME1':
        image = np.max(image) - image  # 反转图像

    return image  # 返回处理后的图像

# Crop Image

In [None]:
# 平滑向量，用于平滑轴的和/标准差
def smooth(l):
    # 内核大小为向量的1%
    kernel_size = int(len(l) * 0.01)
    kernel = np.ones(kernel_size) / kernel_size  # 创建平滑内核
    return np.convolve(l, kernel, mode='same')  # 使用卷积平滑数据

# 根据第一列的和低于最大列和*标准差的5%来计算X偏移
def get_x_offset(image, max_col_sum_ratio_threshold=0.05, debug=None):
    # 图像维度
    H, W = image.shape
    # 添加到偏移的百分比边距
    margin = int(image.shape[1] * 0.00)
    # 根据平滑的和*标准差来捕捉变化的强度列
    vv = smooth(image.sum(axis=0).squeeze()) * smooth(image.std(axis=0).squeeze())
    # 在前75%的列中找到最大和
    vv_argmax = vv[:int(image.shape[1] * 0.75)].argmax()
    # 阈值
    vv_threshold = vv.max() * max_col_sum_ratio_threshold
    
    # 找到最大列之后第一个低于阈值的列
    for offset, v in enumerate(vv):
        # 从vv_argmax开始搜索
        if offset < vv_argmax:
            continue
        
        # 找到低于阈值的列
        if v < vv_threshold:
            offset = min(W, offset + margin)  # 添加边距
            break
            
    # 如果debug参数是ndarray，则进行可视化
    if isinstance(debug, np.ndarray):
        debug[1].imshow(image)
        debug[1].set_title('X 偏移')
        vv_scale = H / vv.max() * 0.90
        # 绘制值
        debug[1].plot(H - vv * vv_scale, c='red', label='vv')
        # 阈值线
        debug[1].hlines(H - vv_threshold * vv_scale, 0, W - 1, colors='orange', label='阈值')
        # 最大值
        debug[1].scatter(vv_argmax, H - vv[vv_argmax] * vv_scale, c='blue', s=100, label='最大值', zorder=np.PINF)
        # 第一个低于阈值的列
        debug[1].scatter(offset, H - vv[offset] * vv_scale, c='purple', s=100, label='偏移', zorder=np.PINF)
        debug[1].set_ylim(H, 0)
        debug[1].legend()
        debug[1].axis('off')
        
    return offset

# 根据第一行的底部和顶部行和低于最大行和*标准差的10%来计算Y偏移
def get_y_offsets(image, max_row_sum_ratio_threshold=0.05, debug=None):
    # 图像维度
    H, W = image.shape
    # 添加到偏移的边距
    margin = 0
    # 根据平滑的和*标准差来捕捉变化的强度行
    vv = smooth(image.sum(axis=1).squeeze()) * smooth(image.std(axis=1).squeeze())
    # 在四分位行中找到最大和*标准差行
    vv_argmax = int(image.shape[0] * 0.25) + vv[int(image.shape[0] * 0.25):int(image.shape[0] * 0.75)].argmax()
    # 阈值
    vv_threshold = vv.max() * max_row_sum_ratio_threshold
    # 默认裁剪偏移
    offset_bottom = 0
    offset_top = H

    # 底部偏移，从argmax到底部搜索
    for offset in reversed(range(0, vv_argmax)):
        v = vv[offset]
        if v < vv_threshold:
            offset_bottom = offset  # 找到底部偏移
            break
    
    if isinstance(debug, np.ndarray):
        debug[2].imshow(image)
        debug[2].set_title('Y 底部偏移')
        vv_scale = W / vv.max() * 0.90
        # 绘制值
        debug[2].plot(vv * vv_scale, np.arange(H), c='red', label='vv')
        # 阈值线
        debug[2].vlines(vv_threshold * vv_scale, 0, H - 1, colors='orange', label='阈值')
        # 最大值
        debug[2].scatter(vv[vv_argmax] * vv_scale, vv_argmax, c='blue', s=100, label='最大值', zorder=np.PINF)
        # 第一个低于阈值的行
        debug[2].scatter(vv[offset_bottom] * vv_scale, offset_bottom, c='purple', s=100, label='偏移', zorder=np.PINF)
        debug[2].set_ylim(H, 0)
        debug[2].legend()
        debug[2].axis('off')
            
    # 顶部偏移，从argmax到顶部搜索
    for offset in range(vv_argmax, H):
        v = vv[offset]
        if v < vv_threshold:
            offset_top = offset  # 找到顶部偏移
            break
            
    if isinstance(debug, np.ndarray):
        debug[3].imshow(image)
        debug[3].set_title('Y 顶部偏移')
        vv_scale = W / vv.max() * 0.90
        # 绘制值
        debug[3].plot(vv * vv_scale, np.arange(H), c='red', label='vv')
        # 阈值线
        debug[3].vlines(vv_threshold * vv_scale, 0, H - 1, colors='orange', label='阈值')
        # 最大值
        debug[3].scatter(vv[vv_argmax] * vv_scale, vv_argmax, c='blue', s=100, label='最大值', zorder=np.PINF)
        # 第一个低于阈值的行
        debug[3].scatter(vv[offset_top] * vv_scale, offset_top, c='purple', s=100, label='偏移', zorder=np.PINF)
        debug[2].set_ylim(H, 0)
        debug[3].legend()
        debug[3].axis('off')
            
    return max(0, offset_bottom - margin), min(image.shape[0], offset_top + margin)

# 裁剪图像并填充偏移，以目标图像的高宽比保存信息
def crop(image, size=None, debug=False):
    # 图像维度
    H, W = image.shape
    # 计算x/底部/顶部偏移
    x_offset = get_x_offset(image, debug=debug)
    offset_bottom, offset_top = get_y_offsets(image[:, :x_offset], debug=debug)
    # 裁剪高度和宽度
    h_crop = offset_top - offset_bottom
    w_crop = x_offset
    
    # 将裁剪偏移填充到目标纵横比
    if size is not None:
        # 高度过大，填充x偏移
        if (h_crop / w_crop) > TARGET_HEIGHT_WIDTH_RATIO:
            x_offset += int(h_crop / TARGET_HEIGHT_WIDTH_RATIO - w_crop)
        else:
            # 高度过小，填充底部/顶部偏移
            offset_bottom -= int(0.50 * (w_crop * TARGET_HEIGHT_WIDTH_RATIO - h_crop))
            offset_bottom_correction = max(0, -offset_bottom)
            offset_bottom += offset_bottom_correction

            offset_top += int(0.50 * (w_crop * TARGET_HEIGHT_WIDTH_RATIO - h_crop))
            offset_top += offset_bottom_correction
        
    # 裁剪图像
    image = image[offset_bottom:offset_top, :x_offset]
        
    return image  # 返回裁剪后的图像

# Utility

In [None]:
# 基于：https://www.kaggle.com/code/remekkinas/fast-dicom-processing-1-6-2x-faster?scriptVersionId=113360473
def process(file_path, size=None, dicom_process=True, ret_target=False, crop_image=False, apply_clahe=APPLY_CLAHE, apply_eq_hist=APPLY_EQ_HIST, debug=False):
    # 读取DICOM文件
    dicom = dicomsdl.open(file_path)
    image = dicom.pixelData()  # 获取像素数据
    
    # 如果开启debug模式，保存原始图像以供调试
    if debug:
        fig, axes = plt.subplots(1, 5, figsize=(20, 10))
        image0 = np.copy(image)  # 复制原始图像
        axes[0].imshow(image0)
        axes[0].set_title('原始图像')
        axes[0].axis('off')
    else:
        axes = False
        
    # 进行VOI LUT处理
    image = voi_lut(image, dicom)

    # 归一化到[0,1]范围
    image = (image - image.min()) / (image.max() - image.min())

    # 转换为uint8格式，范围[0, 255]
    image = (image * 255).astype(np.uint8)
    
    # 根据左右方向规范化图像，翻转右侧朝向的图像
    h0, w0 = image.shape
    if image[:, int(-w0 * 0.10):].sum() > image[:, :int(w0 * 0.10)].sum():
        image = np.flip(image, axis=1)
    
    # 如果需要裁剪图像
    if crop_image:
        image = crop(image, size=size, debug=axes)
    
    # 调整图像大小
    if size is not None:
        # 填充黑色像素以获得正确的图像比例
        h, w = image.shape
        if (h / w) > TARGET_HEIGHT_WIDTH_RATIO:
            pad = int(h / TARGET_HEIGHT_WIDTH_RATIO - w)
            image = np.pad(image, [[0, 0], [0, pad]])  # 填充右侧
            h, w = image.shape
        else:
            pad = int(0.50 * (w * TARGET_HEIGHT_WIDTH_RATIO - h))
            image = np.pad(image, [[pad, pad], [0, 0]])  # 填充上下
            h, w = image.shape
        # 调整大小
        image = cv2.resize(image, size, interpolation=cv2.INTER_AREA)
        
    # 应用CLAHE对比度增强
    # 参考文档：https://docs.opencv.org/4.x/d5/daf/tutorial_py_histogram_equalization.html
    if apply_clahe:
        image = CLAHE.apply(image)
        
    # 应用直方图均衡化
    # 参考文档：https://docs.opencv.org/4.x/d5/daf/tutorial_py_histogram_equalization.html
    if apply_eq_hist:
        image = cv2.equalizeHist(image)
        
    # 如果开启debug模式，显示处理后的图像
    if debug:
        axes[4].imshow(image)
        axes[4].set_title('处理后的图像')
        axes[4].axis('off')
        plt.show()

    # 返回癌症目标
    if ret_target:
        patient_id = int(file_path.split('/')[-2])  # 提取患者ID
        image_id = int(file_path.split('/')[-1].split('.')[0])  # 提取图像ID

        target = PATIENT_ID_IMAGE_ID2CANCER[(patient_id, image_id)]  # 获取癌症目标
        
        return image, target  # 返回图像和目标
    # 仅返回图像
    else:
        if debug:
            return image0, image  # 返回原始图像和处理后的图像
        else:
            return image  # 返回处理后的图像

# Example Processing

In [None]:
# 根据交互模式设置N的值
N = 4 if IS_INTERACTIVE else 10  
# 遍历训练集中前N个文件路径
for fp in tqdm(train['file_path'].head(N)):
    process(
            fp,  # 处理每个文件路径
            crop_image=True,  # 开启裁剪图像
            size=(TARGET_WIDTH, TARGET_HEIGHT),  # 设置目标大小
            debug=True,  # 开启调试模式
        )

# Example Processed Images

In [None]:
def plot_original_processed_examples(rows=48, cols=5):
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, 6 * rows))
    for r in tqdm(range(rows)):
        for c in range(cols):
            idx = (r * cols) + c
            image = process(
                    train.loc[idx, 'file_path'],
                    crop_image=True,
                    size=(TARGET_WIDTH, TARGET_HEIGHT),
                    apply_clahe=APPLY_CLAHE,
                    apply_eq_hist=APPLY_EQ_HIST,
                    debug=False,
                )
            axes[r, c].imshow(image)
            axes[r, c].set_title(f'{idx} | processed')

    plt.show()
    
plot_original_processed_examples(rows=8 if IS_INTERACTIVE else 32)

# Train

In [None]:
# 检查所有患者是否同时具有CC和MLO视图
if not IS_INTERACTIVE:
    # 遍历按患者ID分组的训练集
    for g_idx, g in tqdm(train.groupby('patient_id')):
        # 如果当前组中不包含CC或MLO视图
        if 'CC' not in g['view'].values or 'MLO' not in g['view'].values:
            display(g)  # 显示缺少视图的患者数据

In [None]:
# 将患者ID和图像ID映射到癌症目标
PATIENT_ID_IMAGE_ID2CANCER = train.set_index(['patient_id', 'image_id'])['cancer'].to_dict()  

# Test

In [None]:
test = pd.read_csv('/kaggle/input/rsna-breast-cancer-detection/test.csv')
display(test.info())
display(test.head())

# Sample Submission

In [None]:
sample_submission = pd.read_csv('/kaggle/input/rsna-breast-cancer-detection/sample_submission.csv')
display(sample_submission.info())
display(sample_submission.head())

# Train Meta Data

In [None]:
# Train data is imbalanced, we have ~50x more negative samples than positive samples
plt.figure(figsize=(8, 8))
train['cancer'].value_counts().plot(kind='pie', autopct='%1.1f%%', title='Cancer Distribution')
plt.show()

# Image Statistics

In [None]:
# Patient folder paths
FOLDER_PATHS = glob.glob('/kaggle/input/rsna-breast-cancer-detection/train_images/*')
print(f'Found {len(FOLDER_PATHS)} Train Folders')

In [None]:
# File paths
FILE_PATHS = glob.glob('/kaggle/input/rsna-breast-cancer-detection/train_images/*/*.dcm')
print(f'Found {len(FILE_PATHS)} Train Files')

In [None]:
plt.figure(figsize=(15,8))
plt.title('Number of Scans Per Patient per Laterality (Left/Right side)')
train.groupby(['patient_id', 'laterality']).apply(len).value_counts().sort_index().plot(kind='bar')
plt.show()

In [None]:
# View Counts: AT/LM/ML/LMO are rare, too few samples to train on
display(train['view'].value_counts().to_frame('count'))

In [None]:
# Scan View Counts: most patient only have a CC and MLO scan
display(train.sort_values('view').groupby(['patient_id', 'laterality'])['view'].apply(tuple).value_counts().to_frame('Count'))

# Image Dimensions

In [None]:
np.random.seed(42)

# Get height/width statistics
N = int(16 if IS_INTERACTIVE else 1024)
WIDTHS = []
HEIGHTS = []
for fp in tqdm(np.random.choice(FILE_PATHS, N)):
    h, w = process(fp).shape
    HEIGHTS.append(h)
    WIDTHS.append(w)

In [None]:
# Patches are insanely huge!
plt.figure(figsize=(15,8))
plt.title('Image Dimensions', size=24)
pd.Series(HEIGHTS).plot(kind='hist', alpha=0.50, label='heights')
pd.Series(WIDTHS).plot(kind='hist', alpha=0.50, label='widths')
plt.grid()
plt.legend()
plt.show()

In [None]:
# Height to Width Ratio's, height is roughly 1.25x width, that's why resize to (512*1.25)x512 = 640x512
display(pd.Series(np.array(HEIGHTS) / np.array(WIDTHS)).describe().to_frame('Height/Width Ratio\'s'))

plt.figure(figsize=(15,8))
pd.Series(np.array(HEIGHTS) / np.array(WIDTHS)).plot(kind='hist')
plt.grid()
plt.show()

# Cropped Image Dimensions

In [None]:
np.random.seed(SEED)

N = int(16 if IS_INTERACTIVE else 1024)
WIDTHS_CROPPED = []
HEIGHTS_CROPPED = []

for fp in tqdm(np.random.choice(FILE_PATHS, N)):
    h, w = process(fp, crop_image=True).shape
    WIDTHS_CROPPED.append(h)
    HEIGHTS_CROPPED.append(w)

In [None]:
plt.figure(figsize=(15,8))
plt.title('Cropped Image Dimensions', size=24)
pd.Series(WIDTHS_CROPPED).plot(kind='hist', alpha=0.50, label='cropped heights')
pd.Series(HEIGHTS_CROPPED).plot(kind='hist', alpha=0.50, label='cropped widths')
plt.grid()
plt.legend()
plt.show()

In [None]:
# Height to Width Ratio's
display(pd.Series(np.array(WIDTHS_CROPPED) / np.array(HEIGHTS_CROPPED)).describe().to_frame('Cropped Height/Width Ratio\'s'))
plt.figure(figsize=(15,8))
pd.Series(np.array(WIDTHS_CROPPED) / np.array(HEIGHTS_CROPPED)).plot(kind='hist')
plt.grid()
plt.show()

# Chunk Generation

In [None]:
# Make Pairs of Views as input to the model
FILE_PATHS_PAIRS = []
for row_idx, row in tqdm(train.iterrows(), total=len(train)):
        FILE_PATHS_PAIRS.append(row[['patient_id', 'image_id']].values)
        
FILE_PATHS_PAIRS = np.array(FILE_PATHS_PAIRS, dtype=object)
print(f'FILE_PATHS_PAIRS shape: {FILE_PATHS_PAIRS.shape}')

In [None]:
# Put every image in a seperate TFRecord file
N_CHUNKS = 100
CHUNKS = np.array_split(FILE_PATHS_PAIRS, N_CHUNKS)

print(f'N_CHUNKS: {N_CHUNKS}, CHUNK len: {len(CHUNKS[0])}, shape: {CHUNKS[0].shape}')

In [None]:
# Single sample processing
def process_chunk(args):
    patient_id, image_id = args
    # Define file path
    fp = f'/kaggle/input/rsna-breast-cancer-detection/train_images/{patient_id}/{image_id}.dcm'
    # Get processed image and target
    image, target = process(fp, size=(TARGET_WIDTH, TARGET_HEIGHT), ret_target=True, crop_image=True)

    # Make grayscale channel
    image = np.expand_dims(image, 2)
    
    # Encode PNG
    if IMAGE_FORMAT == 'PNG':
        image_serialized = tf.io.encode_png(image, compression=9).numpy()
    # Encode JPEG
    else:
        image_serialized = tf.io.encode_jpeg(image, quality=IMAGE_QUALITY, optimize_size=True).numpy()
    
    return image_serialized, target, patient_id, image_id

In [None]:
def to_tf_records(chunks):
    for chunk_idx, chunk in enumerate(tqdm(chunks)):
        print(f'===== GENERATING TFRECORDS {chunk_idx} =====')
        tfrecord_name = f'batch_{chunk_idx}.tfrecords'
        
        # Create the actual TFRecords
        options = tf.io.TFRecordOptions(compression_type='GZIP', compression_level=9)
        with tf.io.TFRecordWriter(tfrecord_name, options=options) as file_writer:
            # Process Samples in Chunk in Parallell
            jobs = [joblib.delayed(process_chunk)(args) for args in chunk]
            chunk_processed = joblib.Parallel(
                n_jobs=cpu_count(),
                verbose=0,
                backend='multiprocessing',
                prefer='threads',
            )(jobs)
            
            # Add Processed Samples to TFRecord
            for image, target, patient_id, image_id in chunk_processed:
                record_bytes = tf.train.Example(features=tf.train.Features(feature={
                    # Image
                    'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),

                    # target
                    'target': tf.train.Feature(int64_list=tf.train.Int64List(value=[target])),
                    
                    # patient_id
                    'patient_id': tf.train.Feature(int64_list=tf.train.Int64List(value=[patient_id])),
                    
                    # image_id
                    'image_id': tf.train.Feature(int64_list=tf.train.Int64List(value=[image_id])),
                })).SerializeToString()
                file_writer.write(record_bytes)
            
# Create TFRecords
if IS_INTERACTIVE:
    to_tf_records(CHUNKS[:10])
else:
    to_tf_records(CHUNKS)

# Check TFRecords

In [None]:
N = 16 if IS_INTERACTIVE else 32

In [None]:
# Function to decode the TFRecords
def decode_tfrecord(record_bytes):
    features = tf.io.parse_single_example(record_bytes, {
        'image': tf.io.FixedLenFeature([], tf.string),
        'target': tf.io.FixedLenFeature([], tf.int64),
        'patient_id': tf.io.FixedLenFeature([], tf.int64),
        'image_id': tf.io.FixedLenFeature([], tf.int64),
    })
        
    if IMAGE_FORMAT == 'PNG':
        image = tf.io.decode_png(features['image'], channels=N_CHANNELS)
    else:
        image = tf.io.decode_jpeg(features['image'], channels=N_CHANNELS)
        
    image = tf.reshape(image, [TARGET_HEIGHT, TARGET_WIDTH, N_CHANNELS])

    target = features['target']
    patient_id = features['patient_id']
    image_id = features['image_id']
    
    return image, target, patient_id, image_id

More on Tensorflow TFRecord Datasets: [TFRecordDataset](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset)

In [None]:
# Sample TFRecord Dataset
def get_train_dataset():
    # Read all TFRecord file paths
    FNAMES_TRAIN_TFRECORDS = tf.io.gfile.glob('./*.tfrecords')
    # initialize TFRecord dataset
    train_dataset = tf.data.TFRecordDataset(FNAMES_TRAIN_TFRECORDS, num_parallel_reads=1, compression_type='GZIP')
    # Decode samples by mapping with decode function
    train_dataset = train_dataset.map(decode_tfrecord)
    # Batch samples
    train_dataset = train_dataset.batch(N)
    
    return train_dataset

In [None]:
# Shows a batch of images
def show_batch(dataset, rows=N, cols=1):
    images, targets, patient_ids, image_ids = next(iter(dataset))
    images = np.moveaxis(images, 3, 1)
    fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(cols*6, rows*10))
    for r in range(rows):
        for c in range(cols):
            img = images[r,c]
            axes[r].imshow(img)
            if c == 0:
                target = targets[r]
                patient_id = patient_ids[r]
                image_id = image_ids[r]
                axes[r].set_title(f'target: {target}, patient_id: {patient_id}, image_id: {image_id}', fontsize=12, pad=16)
        
    plt.show()

In [None]:
# Show Example Batch
train_dataset = get_train_dataset()
show_batch(train_dataset)