# Prepare Spleen Data as TFRecords

In [1]:
!pip install --quiet nibabel apache-beam[gcp,interactive]

import json
import os
from google.cloud import storage

gcs_client = storage.Client()

Parameters to define the location:

In [2]:
BUCKET = 'bayer-caip-poc-datasets'
DATASET_FOLDER = 'medical-segmentation-decathlon-spleen'
DATASET_DESCR = 'dataset.json'

## Test Reading the Dataset Descriptor

In [3]:
def convert_to_dataset_file(filepath):
    import os
    
    return os.path.normpath(os.path.join(DATASET_FOLDER, filepath))

In [13]:
gcs_bucket = gcs_client.get_bucket(BUCKET)

data = json.loads(gcs_bucket.blob(convert_to_dataset_file(DATASET_DESCR)).download_as_string())
images = [os.path.join('gs://', BUCKET, convert_to_dataset_file(s['image'])) for s in data['training']]
labels = [os.path.join('gs://', BUCKET, convert_to_dataset_file(s['label'])) for s in data['training']]

## Apache Beam to Convert Spleen Dataset at Scale

In [15]:
def parse_json_dataset(file, key, bucket, folder):
    import json
    import os
    with file.open() as f:
        return [os.path.join('gs://', bucket, folder, os.path.normpath(s[key])) 
                for s in json.loads(f.read().decode('utf-8'))['training']]

def convert_nib_to_tensor(readable):
    import os
    import shutil
    import tempfile
    
    from apache_beam.io import filesystem as beam_fs
    
    import nibabel as nib
    import tensorflow as tf
    
    idx = readable.metadata.path.split('_')[-1]
    
    _, dlfilename = tempfile.mkstemp(suffix='.nii.gz')
    with readable.open(compression_type=beam_fs.CompressionTypes.UNCOMPRESSED) as nzf:
        with open(dlfilename, 'wb') as dlf:
            shutil.copyfileobj(nzf, dlf)

    image_tensor = tf.cast(tf.convert_to_tensor(nib.load(dlfilename).get_fdata()), tf.float32)
    os.remove(dlfilename)
    
    return (idx, image_tensor)

def construct_TFRecord(tensor_group):
    import tensorflow as tf
    
    idx, tensors = tensor_group
    image_tensor_serialized = tf.io.serialize_tensor(tensors['image']).numpy()
    image_feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_tensor_serialized]))
    label_tensor_serialized = tf.io.serialize_tensor(tensors['label']).numpy()
    label_feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[label_tensor_serialized]))
    
    return tf.train.Example(features=tf.train.Features(feature={'image': image_feature, 'label': label_feature})).SerializeToString()

In [16]:
import apache_beam as beam
from apache_beam.io import fileio as beam_fileio
from apache_beam.io import tfrecordio as beam_tfrecordio
from datetime import datetime

PROJECT = !gcloud config get-value project 2> /dev/null
JSON_INPUT = os.path.join('gs://', BUCKET, convert_to_dataset_file(DATASET_DESCR))
TFRECORD_OUTPUT = os.path.join('gs://', BUCKET, DATASET_FOLDER, 'tfrecords', 'spleen.tfrecord')

p_options = beam.options.pipeline_options.PipelineOptions(
    runner='DataflowRunner',
    project=PROJECT[0],
    job_name=f'nii-to-tfrecords-{int(datetime.now().timestamp())}',
    temp_location='gs://bayer-caip-poc-datasets/dataflow/temp',
    staging_location='gs://bayer-caip-poc-datasets/dataflow/staging',
    region='europe-west1',
    machine_type='n1-standard-16',
    disk_size_gb=200,
    service_account_email='dataflow-runner@bayer-caip-poc.iam.gserviceaccount.com',
    requirements_file='./dataflow-requirements.txt');

with beam.Pipeline(options=p_options) as p:
    json_descr = (p | "Find JSON dataset descriptor" >> beam_fileio.MatchFiles(JSON_INPUT)
                    | "Read JSON dataset descriptor" >> beam_fileio.ReadMatches())
    images = (json_descr | "Extract images from JSON dataset descriptor" >> beam.FlatMap(parse_json_dataset, key='image', bucket=BUCKET, folder=DATASET_FOLDER)
                         | "Load dataset images" >> beam_fileio.ReadMatches()
                         | "Convert nibabel images to tensors" >> beam.Map(convert_nib_to_tensor))
    labels = (json_descr | "Extract labels from JSON dataset descriptor" >> beam.FlatMap(parse_json_dataset, key='label', bucket=BUCKET, folder=DATASET_FOLDER)
                         | "Load dataset labels" >> beam_fileio.ReadMatches()
                         | "Convert nibabel labels to tensors" >> beam.Map(convert_nib_to_tensor))
    ({'image': images, 'label': labels} 
     | "Merge" >> beam.CoGroupByKey() 
     | "Construct TFRecords" >> beam.Map(construct_TFRecord) 
     | "Store TFRecords" >> beam_tfrecordio.WriteToTFRecord(TFRECORD_OUTPUT))
    result = p.run()
    result.wait_until_finish()

