# Train a baseline U-Net on the fastMRI dataset

In [None]:
import functools
import glob

import tensorflow as tf
import tensorflow_io as tfio
import tensorflow_mri as tfmri

In [None]:
# If necessary, change the path names here.
data_path_train = "fastmri/brain_multicoil_train"
data_path_val = "fastmri/brain_multicoil_val"
data_path_test = "fastmri/brain_multicoil_test"

In [None]:
files_train = glob.glob("*.h5", root_dir=data_path_train)
files_val = glob.glob("*.h5", root_dir=data_path_val)
files_test = glob.glob("*.h5", root_dir=data_path_test)

In [None]:
def read_hdf5(filename, spec=None):
  """Reads an HDF file into a `dict` of `tf.Tensor`s.

  Args:
    filename: A string, the filename of an HDF5 file.
    spec: A dict of `dataset:tf.TensorSpec` or `dataset:dtype`
      pairs that specify the HDF5 dataset selected and the `tf.TensorSpec`
      or dtype of the dataset. In eager mode the spec is probed
      automatically. In graph mode `spec` has to be specified.
  """
  io_tensor = tfio.IOTensor.from_hdf5(filename, spec=spec)
  tensors = {k: io_tensor(k).to_tensor() for k in io_tensor.keys}
  return {k: tf.ensure_shape(v, spec[k].shape) for k, v in tensors.items()}

def create_fastmri_dataset(files,
                           element_spec=None,
                           batch_size=1,
                           shuffle=False):
  """Creates a `tf.data.Dataset` from a list of fastMRI HDF5 files.
  
  Args:
    files: A list of strings, the filenames of the HDF5 files.
    element_spec: The spec of an element of the dataset. See `read_hdf5` for
      more details.
    batch_size: An int, the batch size.
    shuffle: A boolean, whether to shuffle the dataset.
  """
  # Make a `tf.data.Dataset` from the list of files.
  ds = tf.data.Dataset.from_tensor_slices(files)
  # Read the k-space data from the file.
  ds = ds.map(functools.partial(read_hdf5, spec=element_spec))
  # The first dimension of the inputs is the slice dimension. Split each
  # multi-slice element into multiple single-slice elements, as the
  # reconstruction is performed on a slice-by-slice basis.
  split_slices = lambda x: tf.data.Dataset.from_tensor_slices(x)
  ds = ds.flat_map(split_slices)
  # TODO: create mask.

  # TODO: create labels.
  if shuffle:
    ds = ds.shuffle(buffer_size=100)
  # Batch the elements.
  ds = ds.batch(batch_size)
  ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
  return ds

In [None]:
element_spec = None
batch_size = 1

ds_train = create_fastmri_dataset(files_train,
                                  element_spec=element_spec,
                                  batch_size=batch_size,
                                  shuffle=True)

ds_val = create_fastmri_dataset(files_val,
                                element_spec=element_spec,
                                batch_size=batch_size,
                                shuffle=False)

ds_test = create_fastmri_dataset(files_test,
                                 element_spec=element_spec,
                                 batch_size=batch_size,
                                 shuffle=False)

In [None]:
model = tfmri.models.UNet2D(filters=[32, 64, 128], kernel_size=3)

model.compile(optimizer='rmsprop',
              loss='mse',
              metrics=[tfmri.metrics.PSNR(),
                       tfmri.metrics.SSIM()])

In [None]:
model.fit(ds_train, epochs=1, validation_data=ds_val)