In [None]:
import logging
logging.disable(logging.WARNING)

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import pprint

# Sample Overview

In [None]:
import json
import nibabel as nib

img_path = '/kaggle/input/abdomen/imagesTr/img0001.nii'
mask_path = '/kaggle/input/abdomen/labelsTr/label0001.nii'
test_image_nib = nib.load(img_path)
test_mask_nib = nib.load(mask_path)

test_image_nib.shape, test_mask_nib.shape

In [None]:
import matplotlib.pyplot as plt

test_image = np.transpose(test_image_nib.get_fdata(), (2, 1, 0))[:, -1::-1, -1::-1]
test_mask = np.transpose(test_mask_nib.get_fdata(), (2, 1, 0))[:, -1::-1, -1::-1]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
ax1.imshow(test_image[test_image.shape[0]//2], cmap='gray')
ax1.set_title(f'Image shape: {test_image.shape}')
ax2.imshow(test_mask[test_mask.shape[0]//2])
ax2.set_title(f'Label shape: {test_mask.shape}')
plt.show()

In [None]:
from skimage.util import montage

fig, ax1 = plt.subplots(1, 1, figsize = (20, 20))
ax1.imshow(montage(test_image, padding_width=10, fill=1), cmap='gray')
plt.axis('off')
plt.show()

In [None]:
fig, ax1 = plt.subplots(1, 1, figsize = (20, 20))
ax1.imshow(montage(test_mask, padding_width=10, fill=1))
plt.axis('off')
plt.show()

In [None]:
with open('/kaggle/input/abdomen/dataset_0.json') as f:
    dataset_json = json.load(f)

print(dataset_json.keys())
pprint.pprint(dataset_json)

# TFRecord Creation

In [None]:
import tensorflow as tf

def serialize_example(image_path, label_path):
    image_path = os.path.join('/kaggle/input/abdomen', image_path)
    label_path = os.path.join('/kaggle/input/abdomen', label_path)

    image_nii = nib.load(image_path)
    label_nii = nib.load(label_path)

    image = image_nii.get_fdata().astype(np.float32)
    label = label_nii.get_fdata().astype(np.float32)
    image_shape = np.array(image.shape, dtype=np.int64)
    label_shape = np.array(label.shape, dtype=np.int64)
    image_raw = image.tobytes()
    label_raw = label.tobytes()

    image_affine = image_nii.affine
    label_affine = label_nii.affine
    image_header = image_nii.header
    label_header = label_nii.header
    image_pixdim = np.array(image_header['pixdim'], dtype=np.float32)
    label_pixdim = np.array(label_header['pixdim'], dtype=np.float32)
    
    feature = {
        "image_raw": tf.train.Feature(
            bytes_list=tf.train.BytesList(value=[image_raw])
        ),
        "label_raw": tf.train.Feature(
            bytes_list=tf.train.BytesList(value=[label_raw])
        ),
        "image_shape": tf.train.Feature(
            int64_list=tf.train.Int64List(value=image_shape)
        ),
        "label_shape": tf.train.Feature(
            int64_list=tf.train.Int64List(value=label_shape)
        ),
        "image_affine": tf.train.Feature(
            float_list=tf.train.FloatList(value=image_affine.flatten())
        ),
        "label_affine": tf.train.Feature(
            float_list=tf.train.FloatList(value=label_affine.flatten())
        ),
        "image_pixdim": tf.train.Feature(
            float_list=tf.train.FloatList(value=image_pixdim.flatten())
        ),
        "label_pixdim": tf.train.Feature(
            float_list=tf.train.FloatList(value=label_pixdim.flatten())
        ),
    }
    example = tf.train.Example(features=tf.train.Features(feature=feature))
    return example.SerializeToString()

In [None]:
from tqdm import tqdm

output_dir = "tfrecords"
os.makedirs(output_dir, exist_ok=True)

def create_tfrec(dataset, shard_size=3, set='training'):
    num_shards = (len(dataset) + shard_size - 1) // shard_size
    
    for shard_idx in range(num_shards):
        shard_path = os.path.join(output_dir, f"{set}_shard_{shard_idx}.tfrec")
        start_idx = shard_idx * shard_size
        end_idx = min(start_idx + shard_size, len(dataset))
    
        with tf.io.TFRecordWriter(shard_path) as writer:
            for i in tqdm(
                range(start_idx, end_idx), 
                desc=f"Writing Shard {shard_idx}/{num_shards-1}"
            ):
                sample = dataset[i]
                tf_example = serialize_example(sample["image"], sample["label"])
                writer.write(tf_example)
    
        print(f"Shard {shard_idx} for {set} saved: {shard_path}")

In [None]:
create_tfrec(dataset_json["training"], shard_size=10, set='training')
create_tfrec(dataset_json["validation"], shard_size=10, set='validation')

In [None]:
def parse_tfrecord_fn(example_proto):
    feature_description = {
        "image_raw": tf.io.FixedLenFeature([], tf.string),
        "label_raw": tf.io.FixedLenFeature([], tf.string),
        "image_shape": tf.io.FixedLenFeature([3], tf.int64),
        "label_shape": tf.io.FixedLenFeature([3], tf.int64),
        "image_affine": tf.io.FixedLenFeature([16], tf.float32),
        "label_affine": tf.io.FixedLenFeature([16], tf.float32),
        "image_pixdim": tf.io.FixedLenFeature([8], tf.float32),
        "label_pixdim": tf.io.FixedLenFeature([8], tf.float32),
    }
    
    example = tf.io.parse_single_example(example_proto, feature_description)
    
    # Decode image and label data
    image = tf.io.decode_raw(example["image_raw"], tf.float32)
    label = tf.io.decode_raw(example["label_raw"], tf.float32)
    
    # Reshape to original dimensions
    image = tf.reshape(image, example["image_shape"])
    label = tf.reshape(label, example["label_shape"])
    
    # Decode affine matrices
    image_affine = tf.reshape(example["image_affine"], (4, 4))
    label_affine = tf.reshape(example["label_affine"], (4, 4))
    
    # Decode voxel spacing (pixdim)
    image_pixdim = example["image_pixdim"]
    label_pixdim = example["label_pixdim"]

    return image, label, image_affine, label_affine, image_pixdim, label_pixdim

In [None]:
tfrecord_path = "tfrecords/training_shard_0.tfrec"
dataset = tf.data.TFRecordDataset(tfrecord_path)
dataset = dataset.map(parse_tfrecord_fn)

for image, label, image_affine, label_affine, image_pixdim, label_pixdim in dataset:
    print("Image shape:", image.shape)
    print("Label shape:", label.shape)
    print("Image affine matrix:", image_affine.numpy())
    print("Label affine matrix:", label_affine.numpy())
    print("Image voxel spacing:", image_pixdim.numpy())
    print("Label voxel spacing:", label_pixdim.numpy())
    break

# Load

In [None]:
def parse_tfrecord_fn(example):
    feature_description = {
        "image_raw": tf.io.FixedLenFeature([], tf.string),
        "label_raw": tf.io.FixedLenFeature([], tf.string),
        "image_shape": tf.io.FixedLenFeature([3], tf.int64),
        "label_shape": tf.io.FixedLenFeature([3], tf.int64),
    }
    parsed_example = tf.io.parse_single_example(example, feature_description)
    image = tf.io.decode_raw(parsed_example["image_raw"], tf.float32)
    label = tf.io.decode_raw(parsed_example["label_raw"], tf.float32)
    image_shape = tf.cast(parsed_example["image_shape"], tf.int64)
    label_shape = tf.cast(parsed_example["label_shape"], tf.int64)
    image = tf.reshape(image, image_shape)
    label = tf.reshape(label, label_shape)
    return image, label

In [None]:
def load_tfrecord_dataset(tfrecord_pattern, batch_size=1):
    dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(tfrecord_pattern))
    dataset = dataset.map(parse_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

tfrecord_pattern = "tfrecords/{}_shard_*.tfrec"
train_ds = load_tfrecord_dataset(tfrecord_pattern.format("training"), batch_size=1)
val_ds = load_tfrecord_dataset(tfrecord_pattern.format("validation"), batch_size=1)

In [None]:
x, y = next(iter(train_ds))
x.shape, y.shape

In [None]:
x, y = next(iter(val_ds))
x.shape, y.shape

In [None]:
x_temp = x.numpy().squeeze()
y_temp = y.numpy().squeeze()
x_temp.shape, y_temp.shape

In [None]:
test_image = np.transpose(x_temp, (2, 1, 0))[:, -1::-1, -1::-1]
test_mask = np.transpose(y_temp, (2, 1, 0))[:, -1::-1, -1::-1]
print(np.unique(test_mask))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
ax1.imshow(test_image[test_image.shape[0]//2], cmap='gray')
ax1.set_title(f'Image shape: {test_image.shape}')
ax2.imshow(test_mask[test_mask.shape[0]//2])
ax2.set_title(f'Label shape: {test_mask.shape}')
plt.show()

In [None]:
fig, ax1 = plt.subplots(1, 1, figsize = (20, 20))
ax1.imshow(montage(test_mask, padding_width=10, fill=1))
plt.axis('off')
plt.show()