In [None]:
!nvidia-smi -L

GPU 0: Tesla K80 (UUID: GPU-bb40b437-d0ff-011a-d28b-3bcedc215c47)


In [None]:
# https://github.com/tensorflow/tensorflow/issues/46589
!sudo apt-get install --no-install-recommends --allow-change-held-packages cuda-11-0 libcudnn8=8.0.5.39-1+cuda11.0 libcudnn8-dev=8.0.5.39-1+cuda11.0

Reading package lists... Done
Building dependency tree       
Reading state information... Done
cuda-11-0 is already the newest version (11.0.3-1).
The following held packages will be changed:
  libcudnn8
The following packages will be upgraded:
  libcudnn8 libcudnn8-dev
2 upgraded, 0 newly installed, 0 to remove and 37 not upgraded.
Need to get 627 MB of archives.
After this operation, 258 MB of additional disk space will be used.
Get:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64  libcudnn8-dev 8.0.5.39-1+cuda11.0 [272 MB]
Get:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64  libcudnn8 8.0.5.39-1+cuda11.0 [356 MB]
Fetched 627 MB in 24s (26.4 MB/s)
debconf: unable to initialize frontend: Dialog
debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 76, <> line 2.)
debconf: falling back to frontend: Readline
debconf: unable to initial

In [None]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Wed_Jul_22_19:09:09_PDT_2020
Cuda compilation tools, release 11.0, V11.0.221
Build cuda_11.0_bu.TC445_37.28845127_0


In [None]:
!pip install SimpleITK
!pip install voxelmorph
#!pip install tensorflow==2.3.0

Collecting SimpleITK
[?25l  Downloading https://files.pythonhosted.org/packages/9c/6b/85df5eb3a8059b23a53a9f224476e75473f9bcc0a8583ed1a9c34619f372/SimpleITK-2.0.2-cp37-cp37m-manylinux2010_x86_64.whl (47.4MB)
[K     |████████████████████████████████| 47.4MB 54kB/s 
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.0.2
Collecting voxelmorph
[?25l  Downloading https://files.pythonhosted.org/packages/3c/77/fdcf9ff2c8450d447ba760122b50575cfc037921b5870dac61c04a4609cc/voxelmorph-0.1-py3-none-any.whl (75kB)
[K     |████████████████████████████████| 81kB 5.6MB/s 
[?25hCollecting neurite
[?25l  Downloading https://files.pythonhosted.org/packages/a0/4b/705ff365b11bef90b73f5f680c66e34eb3053a7e9ab2bb0705be7b854f08/neurite-0.1-py3-none-any.whl (86kB)
[K     |████████████████████████████████| 92kB 7.4MB/s 
Collecting pystrum
  Downloading https://files.pythonhosted.org/packages/e4/3a/99e310f01f9e3ef4ab78d9e194c3b22bc53574c70c61c9c9bfc136161439/pystrum-0.1-py3-n

In [None]:
import tensorflow
print(tensorflow.__version__)

2.5.0


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
# Change to the location of the preprocessed data (data_cache)
data_zip_path = '/content/gdrive/MyDrive/mnms2_challenge/data_cache.zip'    # Example 

Copy data from google drive to Colab session (slightly slow process)

In [None]:
!cp "{data_zip_path}" .
!unzip -q data_cache.zip
!rm data_cache.zip

In [None]:
# session path
path_to_data_cache = '/content/data_cache/'

src/data/preprocess.py

In [None]:
from typing import List, Union, Tuple

from multiprocessing import Pool

import numpy as np
from scipy import ndimage

import SimpleITK as sitk


class Preprocess():
    
    @staticmethod
    def resample_image(image: sitk.Image, out_spacing: Tuple[float]=(1.0, 1.0, 1.0),
                       out_size: Union[None, Tuple[int]]=None, is_label: bool=False,
                       pad_value: float=0) -> sitk.Image:
        original_spacing = np.array(image.GetSpacing())
        original_size = np.array(image.GetSize())
        
        if original_size[-1] == 1:
            out_spacing = list(out_spacing)
            out_spacing[-1] = original_spacing[-1]
            out_spacing = tuple(out_spacing)
    
        if out_size is None:
            out_size = np.round(np.array(original_size * original_spacing / np.array(out_spacing))).astype(int)
        else:
            out_size = np.array(out_size)
    
        original_direction = np.array(image.GetDirection()).reshape(len(original_spacing),-1)
        original_center = (np.array(original_size, dtype=float) - 1.0) / 2.0 * original_spacing
        out_center = (np.array(out_size, dtype=float) - 1.0) / 2.0 * np.array(out_spacing)
    
        original_center = np.matmul(original_direction, original_center)
        out_center = np.matmul(original_direction, out_center)
        out_origin = np.array(image.GetOrigin()) + (original_center - out_center)
    
        resample = sitk.ResampleImageFilter()
        resample.SetOutputSpacing(out_spacing)
        resample.SetSize(out_size.tolist())
        resample.SetOutputDirection(image.GetDirection())
        resample.SetOutputOrigin(out_origin.tolist())
        resample.SetTransform(sitk.Transform())
        resample.SetDefaultPixelValue(pad_value)
    
        if is_label:
            resample.SetInterpolator(sitk.sitkNearestNeighbor)
        else:
            resample.SetInterpolator(sitk.sitkBSpline)
    
        return resample.Execute(image)
    
    
    @staticmethod
    def normalise_intensities(image: sitk.Image) -> sitk.Image:
        # Normalise image fro hypothetical 0-500 to 0-1 range
        normalised_image = sitk.Cast(image, sitk.sitkFloat32) / 500.0
        
        return normalised_image
    
    

class Registration():
    
    def __init__(self):
        pass
    
    
    @staticmethod
    def _function_register(initial_transform, moving_image, fixed_image,
                           learning_rate, histogram_bins, sampling_rate,
                           seed) -> Tuple[sitk.Transform, float]:
    
        sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(1)
        registration_method = sitk.ImageRegistrationMethod()
            
        # Similarity metric settings.
        registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=histogram_bins)
        registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
        registration_method.SetMetricSamplingPercentage(sampling_rate, seed=seed)
        
        registration_method.SetInterpolator(sitk.sitkLinear)
        
        # Optimizer settings.
        if learning_rate == None:
            estimate_learning_rate = registration_method.EachIteration
            learning_rate = 0
        else:
            estimate_learning_rate = registration_method.Never
            
        registration_method.SetOptimizerAsGradientDescent(learningRate=learning_rate,
                                                          numberOfIterations=100,
                                                          convergenceMinimumValue=1e-12,
                                                          convergenceWindowSize=10,
                                                          estimateLearningRate=estimate_learning_rate)
    
        registration_method.SetOptimizerScalesFromPhysicalShift()
        
        registration_method.SetInitialTransform(initial_transform, inPlace=True)        
        
        transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                                sitk.Cast(moving_image, sitk.sitkFloat32))
        
        
        return transform, registration_method.GetMetricValue()
        
    
    @staticmethod
    def _parallel_register(initial_transform, moving_image, fixed_image,
                           learning_rate_list, histogram_bins, sampling_rate,
                           seed) -> Tuple[sitk.Transform, float]:
        
        function_input = [(sitk.AffineTransform(initial_transform),
                           moving_image,
                           fixed_image,
                           learning_rate_list[i],
                           histogram_bins,
                           sampling_rate,
                           seed) for i in range(len(learning_rate_list))]
        
        with Pool() as pool:
            output_results = pool.starmap(Registration._function_register, function_input)

        selected_transform = None
        min_metric_value = 1e5
        for i in range(len(output_results)):
            transform, metric_value = output_results[i]
            if metric_value < min_metric_value:
                min_metric_value = metric_value
                selected_transform = transform
                
        return selected_transform, min_metric_value
        
    
    @staticmethod
    def _major_alignment(moving_image: sitk.Image, fixed_image: sitk.Image,
                         debug_output: int=0) -> Tuple[sitk.Transform, float]:
        debug_image_outputs = []
        debug_image_moving = []
        debug_image_fixed = []
        
        sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(1)
        
        initial_transform = sitk.AffineTransform(3)
        
        if debug_output > 0:
            debug_image = sitk.Resample(moving_image,
                                        fixed_image,
                                        initial_transform,
                                        sitk.sitkLinear,
                                        0.0,
                                        moving_image.GetPixelID())
            debug_image_moving.append(debug_image)
            debug_image_fixed.append(fixed_image)
        
        
        transform = initial_transform
        
        
        gaussian_sigma = [8, 4, 2, 1, 0]
        
        histogram_bins = 200
        learning_rate_list = [[8.0, 4.0, 2.0, 1.0, None],
                              [4.0, 2.0, 1.0, 0.5, None],
                              [2.0, 1.0, 0.5, 0.25, None],
                              [1.0, 0.5, 0.25, 0.1, None],
                              [0.5, 0.25, 0.1, 0.05, None]]
        sampling_rate = 1.0
        
        seed = 12453
        
        for i in range(len(gaussian_sigma)):
                
            numpy_fixed_image = sitk.GetArrayFromImage(fixed_image)    
            numpy_fixed_image = ndimage.gaussian_filter(numpy_fixed_image,
                                                        sigma=(1,
                                                               gaussian_sigma[i],
                                                               gaussian_sigma[i]),
                                                        mode='constant')
            
            
            tmp_fixed_image = sitk.GetImageFromArray(numpy_fixed_image)
            tmp_fixed_image.CopyInformation(fixed_image)
            
            numpy_moving_image = sitk.GetArrayFromImage(moving_image)
            numpy_moving_image = ndimage.gaussian_filter(numpy_moving_image,
                                                         sigma=(gaussian_sigma[i] / 2,
                                                                gaussian_sigma[i],
                                                                gaussian_sigma[i]),
                                                         mode='constant')
            
            tmp_moving_image = sitk.GetImageFromArray(numpy_moving_image)
            tmp_moving_image.CopyInformation(moving_image)
            
    
            transform, metric = Registration._parallel_register(sitk.AffineTransform(transform),
                                                                tmp_moving_image, tmp_fixed_image,
                                                                learning_rate_list[i], histogram_bins,
                                                                sampling_rate, seed)
        
            if debug_output > 0:
                debug_image = sitk.Resample(tmp_moving_image,
                                            tmp_fixed_image,
                                            transform,
                                            sitk.sitkLinear,
                                            0.0,
                                            tmp_moving_image.GetPixelID())
                debug_image_moving.append(debug_image)
                debug_image_fixed.append(tmp_fixed_image)
    
    
        
        final_transform = transform
        
        debug_image_outputs = [debug_image_moving, debug_image_fixed]
        
        if debug_output == 1:
            return final_transform, metric, debug_image_outputs
        else:
            return final_transform, metric
    
    
    @staticmethod
    def _minor_alignment(moving_image: sitk.Image, fixed_image: sitk.Image,
                         debug_output: int=0) -> Tuple[sitk.Transform, float]:
        debug_image_outputs = []
        debug_image_moving = []
        debug_image_fixed = []
        
        sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(1)
        
        initial_transform = sitk.AffineTransform(3)
        
        if debug_output > 0:
            debug_image = sitk.Resample(moving_image,
                                        fixed_image,
                                        initial_transform,
                                        sitk.sitkLinear,
                                        0.0,
                                        moving_image.GetPixelID())
            debug_image_moving.append(debug_image)
            debug_image_fixed.append(fixed_image)
        
        
        transform = initial_transform
        
        
        gaussian_sigma = [2, 1, 0]
        
        histogram_bins = 200
        learning_rate_list = [[2.0, 2.0, 1.0, 0.5, None],
                              [2.0, 1.0, 0.5, 0.25, None],
                              [1.0, 0.5, 0.25, 0.1, None],]
        sampling_rate = 1.0
        
        seed =  12453
        
        for i in range(len(gaussian_sigma)):
                
            numpy_fixed_image = sitk.GetArrayFromImage(fixed_image)    
            numpy_fixed_image = ndimage.gaussian_filter(numpy_fixed_image,
                                                        sigma=(1,
                                                               gaussian_sigma[i],
                                                               gaussian_sigma[i]),
                                                        mode='constant')
            
            
            tmp_fixed_image = sitk.GetImageFromArray(numpy_fixed_image)
            tmp_fixed_image.CopyInformation(fixed_image)
            
            numpy_moving_image = sitk.GetArrayFromImage(moving_image)
            numpy_moving_image = ndimage.gaussian_filter(numpy_moving_image,
                                                         sigma=(gaussian_sigma[i] / 2,
                                                                gaussian_sigma[i],
                                                                gaussian_sigma[i]),
                                                         mode='constant')
            
            tmp_moving_image = sitk.GetImageFromArray(numpy_moving_image)
            tmp_moving_image.CopyInformation(moving_image)
    
            transform, metric = Registration._parallel_register(sitk.AffineTransform(transform),
                                                                tmp_moving_image, tmp_fixed_image,
                                                                learning_rate_list[i], histogram_bins,
                                                                sampling_rate, seed)
        
            if debug_output > 0:
                debug_image = sitk.Resample(tmp_moving_image,
                                            tmp_fixed_image,
                                            transform,
                                            sitk.sitkLinear,
                                            0.0,
                                            tmp_moving_image.GetPixelID())
                debug_image_moving.append(debug_image)
                debug_image_fixed.append(tmp_fixed_image)
    
        
        final_transform = transform
        
        debug_image_outputs = [debug_image_moving, debug_image_fixed]
        
        if debug_output == 1:
            return final_transform, metric, debug_image_outputs
        else:
            return final_transform, metric
        
    
    @staticmethod
    def register(moving_image: sitk.Image, fixed_image: sitk.Image,
                 debug_output: int=0) -> Tuple[sitk.Transform, float, Union[None, List[List[sitk.Image]]]]:        
        major_output = Registration._major_alignment(moving_image, fixed_image, debug_output)
        minor_output = Registration._minor_alignment(moving_image, fixed_image, debug_output)

        if major_output[1] < minor_output[1]:
            return major_output
        else:
            return minor_output
        
    
    @staticmethod
    def get_affine_matrix(image: sitk.Image) -> np.ndarray:
        # get affine transform in LPS
        c = [image.TransformContinuousIndexToPhysicalPoint(p)
             for p in ((1, 0, 0),
                       (0, 1, 0),
                       (0, 0, 1),
                       (0, 0, 0))]
        c = np.array(c)
        affine = np.concatenate([
            np.concatenate([c[0:3] - c[3:], c[3:]], axis=0),
            [[0.], [0.], [0.], [1.]]
        ], axis=1)
        affine = np.transpose(affine)
        # convert to RAS to match nibabel etc.
        affine = np.matmul(np.diag([-1., -1., 1., 1.]), affine)
        return affine
    
    
    @staticmethod
    def get_affine_registration_matrix(moving_image: sitk.Image,
                                       registration_affine: sitk.Transform) -> np.ndarray:
        # Get affine transform in LPS
        c = [registration_affine.TransformPoint(
                 moving_image.TransformContinuousIndexToPhysicalPoint(p))
             for p in ((1, 0, 0),
                       (0, 1, 0),
                       (0, 0, 1),
                       (0, 0, 0))]
        c = np.array(c)
        affine = np.concatenate([
            np.concatenate([c[0:3] - c[3:], c[3:]], axis=0),
            [[0.], [0.], [0.], [1.]]
        ], axis=1)
        affine = np.transpose(affine)
        # Convert to RAS to match nibabel etc.
        affine = np.matmul(np.diag([-1., -1., 1., 1.]), affine)
        return affine
    


src/data/loader.py

In [None]:
import os

from enum import Enum

from typing import Any, Dict, Tuple, List, Union
from pathlib import Path
from glob import glob

import numpy as np

import SimpleITK as sitk


class FileType(Enum):
    sa_ed = 'SA_ED'
    sa_ed_gt = 'SA_ED_gt'
    sa_es = 'SA_ES'
    sa_es_gt = 'SA_ES_gt'
    la_ed = 'LA_ED'
    la_ed_gt = 'LA_ED_gt'
    la_es = 'LA_ES'
    la_es_gt = 'LA_ES_gt'
    
    
class ExtraType(Enum):
    reg_affine = 'SA_to_LA_registration_affine'
    

class OutputAffine(Enum):
    sa_affine = 'SA_Affine'
    la_affine = 'LA_Affine'    
    

class DataGenerator():

    
    def __init__(self, floating_precision: str = '32') -> None:
        #file_path = Path(__file__).parent.absolute()
        #expected_data_directory = os.path.join('..', '..', 'data')
        
        #self.data_directory = Path(os.path.join(file_path, expected_data_directory))
        #self.cache_directory = os.path.join('..', '..', 'data_cache')
        #self.cache_directory = Path(os.path.join(file_path, self.cache_directory))
        self.data_directory = path_to_data_cache
        self.cache_directory = path_to_data_cache

        self.train_directory = Path(os.path.join(self.data_directory, 'training'))
        # For the purposes of model development, the 'validation' set is treated
        # as the test set
        # (It does not have ground truth - validated on submission only)
        self.testing_directory = Path(os.path.join(self.data_directory, 'validation'))
        
        self.train_list = self.get_patient_list(self.train_directory)
        self.train_list = self.randomise_list(self.train_list, seed=4516, inplace=True)
        self.train_list, self.validation_list = self.split_list(self.train_list, split_fraction=0.8)
        self.test_list = self.get_patient_list(self.testing_directory)
        
        self.target_spacing = (1.25, 1.25, 10)
        self.target_size = (160, 160, 17)
        
        self.n_classes = 4  # Including background

        self.floating_precision = floating_precision
        
        # Compute the shape for the inputs and outputs
        self.sa_target_shape = list(self.target_size)
        self.sa_shape = self.sa_target_shape.copy()
        self.sa_target_shape.append(self.n_classes)
        
        self.la_target_shape = list(self.target_size)
        self.la_shape = self.la_target_shape.copy()
        self.la_shape[-1] = 1
        self.la_target_shape[-1] = self.n_classes
        
        self.affine_shape = (4, 4)


    @staticmethod
    def get_patient_list(root_directory: Union[str, Path]) -> List[Path]:
        files = glob(os.path.join(root_directory, "**"))
        files = [Path(i) for i in files]
        
        return files
    
    
    @staticmethod
    def randomise_list(item_list: List[Any], seed: Union[None, int]=None,
                       inplace: bool=True) -> List[Any]:
        if not inplace:
            item_list = item_list.copy()
            
        random_generator = np.random.RandomState(seed)
        random_generator.shuffle(item_list)
        
        return item_list
    
    
    @staticmethod
    def split_list(item_list: List[Any], split_fraction: float) -> Tuple[List[Any]]:
        assert 0 < split_fraction < 1
        
        split_index = int(len(item_list) * split_fraction)
        
        split_1 = item_list[:split_index]
        split_2 = item_list[split_index:]
                
        return split_1, split_2

        
    @staticmethod
    def load_image(patient_directory: Union[str, Path], file_type: FileType) -> sitk.Image:
        file_suffix = '*' + file_type.value + '.nii.gz'
        
        file_path = os.path.join(patient_directory, file_suffix)
        file_path = glob(file_path)
        assert len(file_path) == 1
        file_path = file_path[0]
        
        sitk_image = sitk.ReadImage(file_path)
        
        return sitk_image
    
    @staticmethod
    def load_transformation(patient_directory: Union[str, Path], file_type: ExtraType) -> sitk.Transform:
        file_suffix = '*' + file_type.value + '.tfm'
        
        file_path = os.path.join(patient_directory, file_suffix)
        file_path = glob(file_path)
        assert len(file_path) == 1
        file_path = file_path[0]
        
        sitk_transform = sitk.ReadTransform(file_path)
        
        return sitk_transform
    
    
    @staticmethod
    def load_patient_data(patient_directory: Union[str, Path], has_gt: bool = True) -> Dict[str, sitk.Image]:
        patient_data = {}
        
        patient_data[FileType.sa_ed.value] = DataGenerator.load_image(patient_directory, FileType.sa_ed)        
        patient_data[FileType.sa_es.value] = DataGenerator.load_image(patient_directory, FileType.sa_es)        
        patient_data[FileType.la_ed.value] = DataGenerator.load_image(patient_directory, FileType.la_ed)
        patient_data[FileType.la_es.value] = DataGenerator.load_image(patient_directory, FileType.la_es)
        
        if has_gt:
            patient_data[FileType.sa_ed_gt.value] = DataGenerator.load_image(patient_directory, FileType.sa_ed_gt)
            patient_data[FileType.sa_es_gt.value] = DataGenerator.load_image(patient_directory, FileType.sa_es_gt)
            patient_data[FileType.la_ed_gt.value] = DataGenerator.load_image(patient_directory, FileType.la_ed_gt)
            patient_data[FileType.la_es_gt.value] = DataGenerator.load_image(patient_directory, FileType.la_es_gt)
            
        
        return patient_data
    
    
    @staticmethod
    def load_extra_patient_data(patient_directory: Union[str, Path],
                                patient_data: Dict[str, sitk.Image]) -> Dict[str, sitk.Image]:
        
        patient_data[ExtraType.reg_affine.value] = DataGenerator.load_transformation(patient_directory,
                                                                                     ExtraType.reg_affine)
        
        return patient_data

    
    @staticmethod
    def preprocess_patient_data(patient_data: Dict[str, sitk.Image], spacing: Tuple[float],
                                size: Tuple[int], has_gt: bool = True, register: bool = True) -> Dict[str, sitk.Image]:
        # Resample images to standardised spacing and size
        # Short-axis
        patient_data[FileType.sa_ed.value] = Preprocess.resample_image(patient_data[FileType.sa_ed.value],
                                                                       spacing, size, is_label=False)
        patient_data[FileType.sa_es.value] = Preprocess.resample_image(patient_data[FileType.sa_es.value],
                                                                       spacing, size, is_label=False)
        if has_gt:
            patient_data[FileType.sa_ed_gt.value] = Preprocess.resample_image(patient_data[FileType.sa_ed_gt.value],
                                                                              spacing, size, is_label=True)
            patient_data[FileType.sa_es_gt.value] = Preprocess.resample_image(patient_data[FileType.sa_es_gt.value],
                                                                              spacing, size, is_label=True)

        # Long-axis
        la_spacing = list(spacing)
        la_spacing[2] = patient_data[FileType.la_ed.value].GetSpacing()[2]
        la_size = list(size)
        la_size[2] = 1
        patient_data[FileType.la_ed.value] = Preprocess.resample_image(patient_data[FileType.la_ed.value],
                                                                       la_spacing, la_size, is_label=False)
        patient_data[FileType.la_es.value] = Preprocess.resample_image(patient_data[FileType.la_es.value],
                                                                       la_spacing, la_size, is_label=False)
        if has_gt:
            patient_data[FileType.la_ed_gt.value] = Preprocess.resample_image(patient_data[FileType.la_ed_gt.value],
                                                                              la_spacing, la_size, is_label=True)
            patient_data[FileType.la_es_gt.value] = Preprocess.resample_image(patient_data[FileType.la_es_gt.value],
                                                                              la_spacing, la_size, is_label=True)
        
        # Register short-axis to long axis (only for end diastolic for faster execution time)
        if register:
            affine_transform, _ = Registration.register(patient_data[FileType.sa_ed.value],
                                                        patient_data[FileType.la_ed.value])
            patient_data[ExtraType.reg_affine.value] = affine_transform
        
        # Normalise intensities so there are (roughly) [0-1]
        patient_data[FileType.sa_ed.value] = Preprocess.normalise_intensities(patient_data[FileType.sa_ed.value])
        patient_data[FileType.sa_es.value] = Preprocess.normalise_intensities(patient_data[FileType.sa_es.value])
        
        patient_data[FileType.la_ed.value] = Preprocess.normalise_intensities(patient_data[FileType.la_ed.value])
        patient_data[FileType.la_es.value] = Preprocess.normalise_intensities(patient_data[FileType.la_es.value])
        
        return patient_data
        

    def get_cache_directory(self, patient_directory: Union[str, Path]) -> Path:
        path = os.path.normpath(patient_directory)
        split_path = path.split(os.sep)
        # .. / data / training or vlaidation / patient ID
        # only last two are of interest
        cache_directory = Path(os.path.join(self.cache_directory,
                                            split_path[-2],
                                            split_path[-1]))
        
        return cache_directory

    
    def is_cached(self, patient_directory: Union[str, Path], has_gt: bool = True) -> bool:
        patient_cache_directory = self.get_cache_directory(patient_directory)
        
        # Check if folder exists
        if os.path.isdir(patient_cache_directory):
            # and every individual file exist
            for expected_file_name in FileType:
                if not has_gt and expected_file_name.value.endswith('_gt'):
                    continue
                expected_file_path = os.path.join(patient_cache_directory,
                                                  expected_file_name.value + '.nii.gz')
                if not os.path.exists(expected_file_path):
                    return False
                
            for expected_file_name in ExtraType:
                expected_file_path = os.path.join(patient_cache_directory,
                                                  expected_file_name.value + '.tfm')
                if not os.path.exists(expected_file_path):
                    return False
            return True
        
        return False

        
    def save_cache(self, patient_directory: Union[str, Path],
                    patient_data: Dict[str, sitk.Image]) -> None:
        patient_cache_directory = self.get_cache_directory(patient_directory)
        os.makedirs(patient_cache_directory, exist_ok=True)
        
        for key, data in patient_data.items():
            if key in (k.value for k in FileType):
                file_path = os.path.join(patient_cache_directory, key + '.nii.gz')
                sitk.WriteImage(data, file_path)
            elif key in (k.value for k in ExtraType):
                file_path = os.path.join(patient_cache_directory, key + '.tfm')
                sitk.WriteTransform(data, file_path)
        
    
    def load_cache(self, patient_directory: Union[str, Path], has_gt: bool = True) -> Dict[str, sitk.Image]:
        patient_cache_directory = self.get_cache_directory(patient_directory)
        patient_data = self.load_patient_data(patient_cache_directory, has_gt)
        patient_data = self.load_extra_patient_data(patient_cache_directory, patient_data)
        
        return patient_data
    
    
    def to_numpy(self, patient_data: Dict[str, sitk.Image], has_affine_matrix: bool) -> Dict[str, np.ndarray]:
        
        # Handle 'ExtraType' data first
        if has_affine_matrix:
            sa_affine = Registration.get_affine_registration_matrix(patient_data[FileType.sa_ed.value],
                                                                    patient_data[ExtraType.reg_affine.value])
            sa_affine = sa_affine.astype(np.float32)
            la_affine = Registration.get_affine_matrix(patient_data[FileType.la_ed.value])
            la_affine = la_affine.astype(np.float32)
        
        # Free from memory (and indexing)
        del patient_data[ExtraType.reg_affine.value]
        
        # Handle original file data (images and segmentations)
        for key, image in patient_data.items():
            numpy_image = sitk.GetArrayFromImage(image)
            # Swap axes so ordering is x, y, z rather than z, y, x as stored
            # in sitk
            numpy_image = np.swapaxes(numpy_image, 0, -1)
            
            # Generate one-hot encoding of the labels
            if 'gt' in key:
                numpy_image = numpy_image.astype(np.uint8)
                if 'LA' in key: # use the 'depth; axis as the channel for the label
                    numpy_image = np.squeeze(numpy_image, axis=-1)
                n_values = self.n_classes
                numpy_image = np.eye(n_values)[numpy_image]
            
            
            if self.floating_precision == '16':
                numpy_image = numpy_image.astype(np.float16)
            else:
                numpy_image = numpy_image.astype(np.float32)
                
            # Add 'channel' axis for 3D images
            #if 'sa' in key:
            #    numpy_image = np.expand_dims(numpy_image, axis=-1)
                
            patient_data[key] = numpy_image
        
        if has_affine_matrix:
            patient_data[OutputAffine.sa_affine.value] = sa_affine
            patient_data[OutputAffine.la_affine.value] = la_affine
        
        return patient_data
    
    @staticmethod
    def to_structure(patient_data: Dict[str, sitk.Image], has_affine_matrix: bool,
                     has_gt: bool = True):
        output_data = []
        if has_gt:
            output_data.append(({'input_sa': patient_data[FileType.sa_ed.value],
                                 'input_la': patient_data[FileType.la_ed.value]},
                                {'output_sa': patient_data[FileType.sa_ed_gt.value],
                                 'output_la': patient_data[FileType.la_ed_gt.value]}))
            
            output_data.append(({'input_sa': patient_data[FileType.sa_es.value],
                                 'input_la': patient_data[FileType.la_es.value]},
                                {'output_sa': patient_data[FileType.sa_es_gt.value],
                                 'output_la': patient_data[FileType.la_es_gt.value]}))
        else:
            output_data.append(({'input_sa': patient_data[FileType.sa_ed.value],
                                 'input_la': patient_data[FileType.la_ed.value]},))
            
            output_data.append(({'input_sa': patient_data[FileType.sa_es.value],
                                 'input_la': patient_data[FileType.la_es.value]},))
            
        if has_affine_matrix:
            for data in output_data:
                data[0]['input_sa_affine'] = patient_data[OutputAffine.sa_affine.value]
                data[0]['input_la_affine'] = patient_data[OutputAffine.la_affine.value]
                
        return output_data
        

    def generator(self, patient_directory: Union[str, Path], affine_matrix: bool,
                  has_gt: bool = True) -> Tuple[Dict[str, np.ndarray]]:
        if self.is_cached(patient_directory, has_gt):
            patient_data = self.load_cache(patient_directory, has_gt)
        else:
            patient_data = DataGenerator.load_patient_data(patient_directory, has_gt)
            patient_data = DataGenerator.preprocess_patient_data(patient_data,
                                                                 self.target_spacing,
                                                                 self.target_size,
                                                                 has_gt,
                                                                 affine_matrix)
            self.save_cache(patient_directory, patient_data)

        
        patient_data = self.to_numpy(patient_data, affine_matrix)
    
        output_data = self.to_structure(patient_data, affine_matrix, has_gt)
        return output_data

    
    def sitk_generator(self, patient_directory: Union[str, Path], has_gt: bool = True) -> Tuple[Dict[str, np.ndarray]]:
        """
        Returns pre- and post-processed data in sitk
        """
        if self.is_cached(patient_directory, has_gt):
            pre_patient_data = DataGenerator.load_patient_data(patient_directory, has_gt)
            post_patient_data = self.load_cache(patient_directory, has_gt)
        else:
            pre_patient_data = DataGenerator.load_patient_data(patient_directory, has_gt)
            post_patient_data = DataGenerator.load_patient_data(patient_directory, has_gt)
            post_patient_data = DataGenerator.preprocess_patient_data(post_patient_data,
                                                                      self.target_spacing,
                                                                      self.target_size,
                                                                      has_gt,
                                                                      False)
            self.save_cache(patient_directory, pre_patient_data)
            
        
        pre_output_data = self.to_structure(pre_patient_data, False, has_gt)
        post_output_data = self.to_structure(post_patient_data, False, has_gt)
        
        return pre_output_data, post_output_data
        
        
    def train_generator(self, verbose: int = 0) -> Tuple[Dict[str, np.ndarray]]:
        for patient_directory in self.train_list:
            if verbose > 0:
                print('Generating patient: ', patient_directory)
            patient_data = self.generator(patient_directory, affine_matrix=False)
            
            yield patient_data[0]   # End diastolic
            yield patient_data[1]   # End systolic
        
    
    def validation_generator(self, verbose: int = 0) -> Tuple[Dict[str, np.ndarray]]:
        for patient_directory in self.validation_list:
            if verbose > 0:
                print('Generating patient: ', patient_directory)
            patient_data = self.generator(patient_directory, affine_matrix=False)
            
            yield patient_data[0]
            yield patient_data[1]
            
    
    def test_generator(self, verbose: int = 0) -> Tuple[Dict[str, np.ndarray]]:
        for patient_directory in self.test_list:
            if verbose > 0:
                print('Generating patient: ', patient_directory)
            patient_data = self.generator(patient_directory, affine_matrix=False)
            
            yield patient_data[0]
            yield patient_data[1]
            
    
    def test_generator_inference(self, verbose: int = 0) -> Tuple[Dict[str, np.ndarray]]:
        for patient_directory in self.test_list:
            if verbose > 0:
                print('Generating patient: ', patient_directory)
            patient_data = self.generator(patient_directory, affine_matrix=False, has_gt=False)
            pre_patient_data, post_patient_data = self.sitk_generator(patient_directory, has_gt=False)
            
            yield patient_data[0], pre_patient_data[0], post_patient_data[0], patient_directory, 'ed'
            yield patient_data[1], pre_patient_data[1], post_patient_data[1], patient_directory, 'es'
        
        
    def train_affine_generator(self, verbose: int = 0) -> Tuple[Dict[str, np.ndarray]]:
        for patient_directory in self.train_list:
            if verbose > 0:
                print('Generating patient: ', patient_directory)
            patient_data = self.generator(patient_directory, affine_matrix=True)
            
            yield patient_data[0]   # End diastolic
            yield patient_data[1]   # End systolic
        
    
    def validation_affine_generator(self, verbose: int = 0) -> Tuple[Dict[str, np.ndarray]]:
        for patient_directory in self.validation_list:
            if verbose > 0:
                print('Generating patient: ', patient_directory)
            patient_data = self.generator(patient_directory, affine_matrix=True)
            
            yield patient_data[0]
            yield patient_data[1]
            
    
    def test_affine_generator(self, verbose: int = 0) -> Tuple[Dict[str, np.ndarray]]:
        for patient_directory in self.test_list:
            if verbose > 0:
                print('Generating patient: ', patient_directory)
            patient_data = self.generator(patient_directory, affine_matrix=True)
            
            yield patient_data[0]
            yield patient_data[1]


    def test_affine_generator_inference(self, verbose: int = 0) -> Tuple[Dict[str, np.ndarray]]:
        for patient_directory in self.test_list:
            if verbose > 0:
                print('Generating patient: ', patient_directory)
            patient_data = self.generator(patient_directory, affine_matrix=True, has_gt=False)
            pre_patient_data, post_patient_data = self.sitk_generator(patient_directory, has_gt=False)
            
            yield patient_data[0], pre_patient_data[0], post_patient_data[0], patient_directory, 'ed'
            yield patient_data[1], pre_patient_data[1], post_patient_data[1], patient_directory, 'es'

src/data/tf_generator.py

In [None]:
from typing import Dict, Tuple, Union

import tensorflow as tf


class TensorFlowDataGenerator():
    
    @staticmethod
    def _prepare_generators(dg: DataGenerator, use_affine: bool, batch_size: int,
                            output_shapes: Tuple[Dict[str, tf.TensorShape]],
                            output_types: Tuple[Dict[str, tf.dtypes.DType]],
                            max_buffer_size: Union[int, None]=None,
                            floating_precision: str='32') -> Tuple[tf.data.Dataset]:
        
        buffer_size = len(dg.train_list) * 2
        if max_buffer_size is not None:
            buffer_size = min(buffer_size, max_buffer_size)    

        generator_type = dg.train_affine_generator if use_affine else dg.train_generator
        train_generator = tf.data.Dataset.from_generator(generator_type,
                                                         output_types=output_types,
                                                         output_shapes=output_shapes)
        train_generator = train_generator.shuffle(buffer_size=buffer_size,
                                                  seed=4875,
                                                  reshuffle_each_iteration=True
                                                  ).batch(batch_size).prefetch(2)
        
        generator_type = dg.validation_affine_generator if use_affine else dg.validation_generator
        validation_generator = tf.data.Dataset.from_generator(generator_type,
                                                              output_types=output_types,
                                                              output_shapes=output_shapes)
        validation_generator = validation_generator.batch(batch_size)
        
        inference = False
        if inference:
            generator_type = dg.test_affine_generator_inference if use_affine else dg.test_generator_inference
        else:
            generator_type = dg.test_affine_generator if use_affine else dg.test_generator
        test_generator = tf.data.Dataset.from_generator(generator_type,
                                                        output_types=output_types)
        test_generator = test_generator.batch(batch_size)
        
        return train_generator, validation_generator, test_generator, dg
        
    
    @staticmethod
    def get_generators(batch_size: int, max_buffer_size: Union[int, None]=None,
                       floating_precision: str='32') -> Tuple[tf.data.Dataset]:
        dg = DataGenerator(floating_precision)
        
        output_shapes = ({'input_sa': tf.TensorShape(dg.sa_shape),
                          'input_la': tf.TensorShape(dg.la_shape)},
                         {'output_sa': tf.TensorShape(dg.sa_target_shape),
                          'output_la': tf.TensorShape(dg.la_target_shape)})
        
        if floating_precision == '16':
            float_type = tf.float16
        else:
            float_type = tf.float32
        # TODO: Change to dynamic input parameters
        output_types = ({'input_sa': float_type,
                         'input_la': float_type},
                        {'output_sa': float_type,
                         'output_la': float_type})

        use_affine = False
        return TensorFlowDataGenerator._prepare_generators(dg, use_affine, batch_size,
                                                           output_shapes,
                                                           output_types,
                                                           max_buffer_size,
                                                           floating_precision)


    @staticmethod
    def get_affine_generators(batch_size: int, max_buffer_size: Union[int, None]=None,
                              floating_precision: str='32') -> Tuple[tf.data.Dataset]:
        dg = DataGenerator(floating_precision)
        
        output_shapes = ({'input_sa': tf.TensorShape(dg.sa_shape),
                          'input_la': tf.TensorShape(dg.la_shape),
                          'input_sa_affine': tf.TensorShape(dg.affine_shape),
                          'input_la_affine': tf.TensorShape(dg.affine_shape)},
                         {'output_sa': tf.TensorShape(dg.sa_target_shape),
                          'output_la': tf.TensorShape(dg.la_target_shape)})
        
        if floating_precision == '16':
            float_type = tf.float16
        else:
            float_type = tf.float32
        # TODO: Change to dynamic input parameters
        output_types = ({'input_sa': float_type,
                         'input_la': float_type,
                         'input_sa_affine': tf.float32,
                         'input_la_affine': tf.float32},
                        {'output_sa': float_type,
                         'output_la': float_type})

        use_affine = True
        return TensorFlowDataGenerator._prepare_generators(dg, use_affine, batch_size,
                                                           output_shapes,
                                                           output_types,
                                                           max_buffer_size,
                                                           floating_precision)



src/tf/loasses.loss.py

In [None]:
import tensorflow as tf


# Loss taken from here:
#    https://github.com/tensorflow/models/blob/master/official/vision/keras_cv/losses/focal_loss.py
class FocalLoss(tf.keras.losses.Loss):
    """Implements a Focal loss for classification problems.
    Reference:
      [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002).
    """

    def __init__(self,
                 alpha,
                 gamma,
                 reduction=tf.keras.losses.Reduction.AUTO,
                 name=None):
        """Initializes `FocalLoss`.
        Args:
          alpha: The `alpha` weight factor for binary class imbalance.
          gamma: The `gamma` focusing parameter to re-weight loss.
          reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
            loss. Default value is `AUTO`. `AUTO` indicates that the reduction
            option will be determined by the usage context. For almost all cases
            this defaults to `SUM_OVER_BATCH_SIZE`. When used with
            `tf.distribute.Strategy`, outside of built-in training loops such as
            `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
            will raise an error. Please see this custom training [tutorial](
              https://www.tensorflow.org/tutorials/distribute/custom_training) for
                more details.
          name: Optional name for the op. Defaults to 'retinanet_class_loss'.
        """
        self._alpha = alpha
        self._gamma = gamma
        super(FocalLoss, self).__init__(reduction=reduction, name=name)
    
    
    def call(self, y_true, y_pred):
        """Invokes the `FocalLoss`.
        Args:
          y_true: A tensor of size [batch, num_anchors, num_classes]
          y_pred: A tensor of size [batch, num_anchors, num_classes]
        Returns:
          Summed loss float `Tensor`.
        """
        with tf.name_scope('focal_loss'):
            y_true = tf.cast(y_true, dtype=tf.float32)
            y_pred = tf.cast(y_pred, dtype=tf.float32)
            positive_label_mask = tf.equal(y_true, 1.0)
            cross_entropy = (
                tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true, logits=y_pred))
            probs = tf.sigmoid(y_pred)
            probs_gt = tf.where(positive_label_mask, probs, 1.0 - probs)
            # With small gamma, the implementation could produce NaN during back prop.
            modulator = tf.pow(1.0 - probs_gt, self._gamma)
            loss = modulator * cross_entropy
            weighted_loss = tf.where(positive_label_mask, self._alpha * loss,
                                     (1.0 - self._alpha) * loss)
        
        return weighted_loss
    
    
    def get_config(self):
        config = {
            'alpha': self._alpha,
            'gamma': self._gamma,
        }
        base_config = super(FocalLoss, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))



class TverskyLoss(tf.keras.losses.Loss):
    """Implements a Tversky loss for classification problems.
    Reference:
      [Tversky loss function for image segmentation using 3D fully convolutional
       deep networks](https://arxiv.org/abs/1706.05721).
      
      'In the case of α=β=0.5 the Tversky index simplifies to be the same as
       the Dice coefficient, which is also equal to the F1 score. With α=β=1,
       Equation 2 produces Tanimoto coefficient, and setting α+β=1 produces
       the set of Fβ scores. Larger βs weigh recall higher than precision (by
       placing more emphasis on false negatives)'
    """

    def __init__(self,
                 alpha,
                 beta,
                 reduction=tf.keras.losses.Reduction.AUTO,
                 name=None):
        """Initializes `TverskyLoss`.
        Args:
          alpha: The `alpha` weight factor for binary class imbalance.
          gamma: The `gamma` focusing parameter to re-weight loss.
          reduction: (Optional) Type of `tf.keras.losses.Reduction` to apply to
            loss. Default value is `AUTO`. `AUTO` indicates that the reduction
            option will be determined by the usage context. For almost all cases
            this defaults to `SUM_OVER_BATCH_SIZE`. When used with
            `tf.distribute.Strategy`, outside of built-in training loops such as
            `tf.keras` `compile` and `fit`, using `AUTO` or `SUM_OVER_BATCH_SIZE`
            will raise an error. Please see this custom training [tutorial](
              https://www.tensorflow.org/tutorials/distribute/custom_training) for
                more details.
          name: Optional name for the op.
        """
        self._alpha = alpha
        self._beta = beta
        super(TverskyLoss, self).__init__(reduction=reduction, name=name)
  
  
    def call(self, y_true, y_pred):
        """Invokes the `TverskyLoss`.
        Args:
          y_true: A tensor of size [batch, ..., num_classes]
          y_pred: A tensor of size [batch, ..., num_classes]
        Returns:
          Summed loss float `Tensor`.
        """
        with tf.name_scope('tversky_loss'):
            epsilon = 1e-6
            y_true = tf.cast(y_true, dtype=tf.float32)
            y_pred = tf.cast(y_pred, dtype=tf.float32)
            
            # TODO: softmax is unstable
            y_pred = tf.nn.softmax(y_pred, axis=-1)
            
            dim = tf.reduce_prod(tf.shape(y_true)[1:])
            y_true_flatten = tf.reshape(y_true, [-1, dim])
            y_pred_flatten = tf.reshape(y_pred, [-1, dim])
            
            tp = tf.math.reduce_sum(y_true_flatten * y_pred_flatten, axis=-1)
            fp = tf.math.reduce_sum((1.0 - y_true_flatten) * y_pred_flatten, axis=-1)
            fn = tf.math.reduce_sum(y_true_flatten * (1.0 - y_pred_flatten), axis=-1)
            
            tversky = (tp + epsilon) / (tp + self._alpha * fp + self._beta * fn + epsilon)
            
            loss = 1 - tf.reduce_mean(tversky)
    
        return loss
  
  
    def get_config(self):
        config = {
            'alpha': self._alpha,
            'beta': self._beta
        }
        base_config = super(TverskyLoss, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

src/tf/metrics/metrics.py

In [None]:
import tensorflow as tf


@tf.autograph.experimental.do_not_convert
def dice(y_true, y_pred):
    epsilon = 1e-6
    
    y_true = tf.cast(y_true, dtype=tf.float32)
    y_pred = tf.cast(y_pred, dtype=tf.float32)
    # Expected y_pred to be 'logits'
    y_pred = tf.nn.softmax(y_pred)
    
    dim = tf.reduce_prod(tf.shape(y_true)[1:])
    y_true_flatten = tf.reshape(y_true, [-1, dim])
    y_pred_flatten = tf.reshape(y_pred, [-1, dim])

    intersection = tf.math.reduce_sum(y_true_flatten * y_pred_flatten, axis=-1)
    
    union = tf.math.reduce_sum(y_true_flatten, axis=-1) + \
        tf.math.reduce_sum(y_pred_flatten, axis=-1)
    
    dice_coef = tf.math.reduce_mean((2. * intersection + epsilon) / (union + epsilon))

    return dice_coef

    

src/tf/layers/transformer.py

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Layer

from voxelmorph.tf.layers import SpatialTransformer


class TargetAffineLayer(Layer):
    
    def __init__(self, **kwargs):
        super(self.__class__, self).__init__(**kwargs)
    
    
    @tf.autograph.experimental.do_not_convert
    def _get_transformation(self, inputs):
        image_affine = inputs[0]
        target_affine = inputs[1]
        
        affine_transform = tf.cond(tf.reduce_all(tf.math.equal(target_affine, image_affine)),
                                   lambda: tf.eye(4, dtype=image_affine.dtype),
                                   lambda: tf.tensordot(tf.linalg.inv(image_affine),
                                                        target_affine, axes=1))
        
        return affine_transform
        
    
    def get_config(self):
        config = super().get_config().copy()
        return config
    
    
    def call(self, inputs):
        """
        Parameters
            inputs: list with four entries
        """
        # check shapes
        assert len(inputs) == 2, "inputs has to be len 2, found: %d" % len(inputs)
        image_affine = tf.cast(inputs[0], dtype=tf.float32)
        target_affine = tf.cast(inputs[1], dtype=tf.float32)
        
        affine_transform = tf.map_fn(self._get_transformation,
                                     [image_affine, target_affine],
                                     dtype=tf.float32)
        
        return affine_transform
    


class TargetShapePad(Layer):
    
    def __init__(self, image_shape, target_shape, **kwargs):
        super(self.__class__, self).__init__(**kwargs)
        
        # TODO
        #zero = tf.constant(0, dtype=tf.int32)
        #self.paddings = [(zero, tf.math.maximum(tf.cast(target_shape[0] - image_shape[0], tf.int32), zero)),
        #                 (zero, tf.math.maximum(tf.cast(target_shape[1] - image_shape[1], tf.int32), zero)),
        #                 (zero, zero)]
        self.paddings = [(0, 0),
                         (0, 0),
                         (0, 0)]
        
        self.init_config = {'image_shape': image_shape, 'target_shape': target_shape, **kwargs}
    
    
    def get_config(self):
        return self.init_config
    
    
    def call(self, inputs):
        padded_image = tf.keras.layers.ZeroPadding3D(self.paddings)(inputs)

        return padded_image



class TargetShapeCrop(Layer):
    
    def __init__(self, image_shape, target_shape, **kwargs):
        super(self.__class__, self).__init__(**kwargs)
        
        # TODO
        #zero = tf.constant(0, dtype=tf.int32)
        #self.cropping = [(zero, tf.math.maximum(tf.cast(image_shape[0] - target_shape[0], tf.int32), zero)),
        #                 (zero, tf.math.maximum(tf.cast(image_shape[1] - target_shape[1], tf.int32), zero)),
        #                 (zero, tf.math.maximum(tf.cast(image_shape[2] - target_shape[2], tf.int32), zero))]
        self.cropping = [(0, 0),
                         (0, 0),
                         (0, 16)]
        
        self.init_config = {'image_shape': image_shape, 'target_shape': target_shape, **kwargs}
        
    
    def get_config(self):
        return self.init_config


    def call(self, inputs):
        cropped_image = tf.keras.layers.Cropping3D(self.cropping)(inputs)
        
        return cropped_image
    
    
def spatial_target_transformer(x, affine_matrix, target_affine_matrix,
                               image_shape, target_image_shape):
    affine = TargetAffineLayer()([affine_matrix, target_affine_matrix])
    
    x = TargetShapePad(image_shape, target_image_shape)(x)
    
    original_dtype = x.dtype
    x = tf.cast(x, dtype=tf.float32)
    x = SpatialTransformer(interp_method='linear',
                           indexing='ij',
                           add_identity=False,
                           shift_center=False,
                           fill_value=0.0,
                           dtype=tf.float32)([x, affine])
    x = tf.cast(x, dtype=original_dtype)
    
    x = TargetShapeCrop(image_shape, target_image_shape)(x)
    
    return x

src/tf/models/multi_stage_model.py

In [None]:
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers

#from tf.layers.transformer import spatial_target_transformer


def _inception_block_a(x, num_filters, kernel_initializer, suffix, index):
    # Branch 1
    x1 = layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding='same',
                             name=suffix + '_inception_a_max_pooling_1_1_' + index)(x)
    x1 = layers.Conv2D(num_filters // 2, (1, 1), (1, 1), padding='same',
                       kernel_initializer=kernel_initializer,
                       name=suffix + '_inception_a_conv2d_1_2_' + index)(x1)
    x1 = layers.Activation('relu', name=suffix + '_inception_a_activation_1_3_' + index)(x1)
    
    # Branch 2
    x2 = layers.Conv2D(num_filters // 2, (1, 1), (1, 1), padding='same',
                       kernel_initializer=kernel_initializer,
                       name=suffix + '_inception_a_conv2d_2_1_' + index)(x)
    x2 = layers.Activation('relu', name=suffix + '_inception_a_activation_2_2_' + index)(x2)
    x2 = layers.Conv2D(num_filters // 2, (3, 3), (1, 1), padding='same',
                       kernel_initializer=kernel_initializer,
                       name=suffix + '_inception_a_conv2d_2_3_' + index)(x2)
    x2 = layers.Activation('relu', name=suffix + '_inception_a_activation_2_4_' + index)(x2)
    
    # Branch 3
    x3 = layers.Conv2D(num_filters // 2, (1, 1), (1, 1), padding='same',
                       kernel_initializer=kernel_initializer,
                       name=suffix + '_inception_a_conv2d_3_1_' + index)(x)
    x3 = layers.Activation('relu', name=suffix + '_inception_a_activation_3_2_' + index)(x3)
    x3 = layers.Conv2D(num_filters // 2, (3, 3), (1, 1), padding='same',
                       kernel_initializer=kernel_initializer,
                       name=suffix + '_inception_a_conv2d_3_3_' + index)(x3)
    x3 = layers.Activation('relu', name=suffix + '_inception_a_activation_3_4_' + index)(x3)
    x3 = layers.Conv2D(num_filters // 2, (3, 3), (1, 1), padding='same',
                       kernel_initializer=kernel_initializer,
                       name=suffix + '_inception_a_conv2d_3_5_' + index)(x3)
    x3 = layers.Activation('relu', name=suffix + '_inception_a_activation_3_6_' + index)(x3)
    
    # Branch 4
    x4 = layers.Conv2D(num_filters // 2, (1, 1), (1, 1), padding='same',
                       kernel_initializer=kernel_initializer,
                       name=suffix + '_inception_a_conv2d_4_1_' + index)(x)
    x4 = layers.Activation('relu', name=suffix + '_inception_a_activation_4_2_' + index)(x4)
    
    # Concatenate branches
    x = layers.Concatenate(axis=-1, name=suffix + '_inception_a_concatenate_' + index)([x, x1, x2, x3, x4])
    # Reduce filter size
    x = layers.Conv2D(num_filters // 2, (1, 1), (1, 1), padding='same',
                      kernel_initializer=kernel_initializer,
                      name=suffix + '_inception_a_conv2d_merge_' + index)(x)
    
    return x


def _inception_block_b(x, num_filters, kernel_initializer, suffix, index):
    # Branch 1
    x1 = layers.Conv2D(num_filters // 2, (1, 1), (1, 1), padding='same',
                       kernel_initializer=kernel_initializer,
                       name=suffix + '_inception_b_conv2d_1_1_' + index)(x)
    x1 = layers.Activation('relu', name=suffix + '_inception_b_activation_1_2_' + index)(x1)
    
    # Branch 2
    x2 = layers.Conv2D(num_filters // 2, (1, 1), (1, 1), padding='same',
                       kernel_initializer=kernel_initializer,
                       name=suffix + '_inception_b_conv2d_2_1_' + index)(x)
    x2 = layers.Activation('relu', name=suffix + '_inception_b_activation_2_2_' + index)(x2)
    x2 = layers.SeparableConv2D(num_filters // 2, (5, 5), (1, 1), padding='same',
                                depthwise_initializer=kernel_initializer,
                                pointwise_initializer=kernel_initializer,
                                name=suffix + '_inception_b_seperable_conv2d_2_3_' + index)(x2)
    x2 = layers.Activation('relu', name=suffix + '_inception_b_activation_2_4_' + index)(x2)
    
    # Branch 3
    x3 = layers.Conv2D(num_filters // 2, (1, 1), (1, 1), padding='same',
                       kernel_initializer=kernel_initializer,
                       name=suffix + '_inception_b_conv2d_3_1_' + index)(x)
    x3 = layers.Activation('relu', name=suffix + '_inception_b_activation_3_2_' + index)(x3)
    x3 = layers.Conv2D(num_filters // 2, (3, 3), (1, 1), padding='same',
                       kernel_initializer=kernel_initializer,
                       name=suffix + '_inception_b_conv2d_3_3_' + index)(x3)
    x3 = layers.Activation('relu', name=suffix + '_inception_b_activation_3_4_' + index)(x3)
    x3 = layers.SeparableConv2D(num_filters // 2, (7, 7), (1, 1), padding='same',
                                depthwise_initializer=kernel_initializer,
                                pointwise_initializer=kernel_initializer,
                                name=suffix + '_inception_b_seperable_conv2d_3_5_' + index)(x3)
    x3 = layers.Activation('relu', name=suffix + '_inception_b_activation_3_6_' + index)(x3)
    
    # Branch 3
    x4 = layers.Conv2D(num_filters // 2, (1, 1), (1, 1), padding='same',
                       kernel_initializer=kernel_initializer,
                       name=suffix + '_inception_b_conv2d_4_1_' + index)(x)
    x4 = layers.Activation('relu', name=suffix + '_inception_b_activation_4_2_' + index)(x4)
    x4 = layers.Conv2D(num_filters // 2, (3, 3), (1, 1), padding='same',
                       kernel_initializer=kernel_initializer,
                       name=suffix + '_inception_b_conv2d_4_3_' + index)(x4)
    x4 = layers.Activation('relu', name=suffix + '_inception_b_activation_4_4_' + index)(x4)
    x4 = layers.SeparableConv2D(num_filters // 2, (9, 9), (1, 1), padding='same',
                                depthwise_initializer=kernel_initializer,
                                pointwise_initializer=kernel_initializer,
                                name=suffix + '_inception_b_seperable_conv2d_4_5_' + index)(x4)
    x4 = layers.Activation('relu', name=suffix + '_inception_b_activation_4_6_' + index)(x4)
    
    # Concatenate branches
    x = layers.Concatenate(axis=-1, name=suffix + '_inception_b_concatenate_' + index)([x, x1, x2, x3, x4])
    # Reduce filter size
    x = layers.Conv2D(num_filters, (1, 1), (1, 1), padding='same',
                      kernel_initializer=kernel_initializer,
                      name=suffix + '_inception_b_conv2d_merge_' + index)(x)
    
    return x
    
    
def _shared_feature_pyramid_layers(num_pyramid_layers, input_shape, num_filters,
                                   kernel_initializer, suffix, index):
    shared_down_level = []
    for i in range(num_pyramid_layers):
        i_s = str(i + 1)
        shared_layers = []
        shared_layers.append(layers.Conv2D(num_filters, (3, 3), (1, 1), padding='same',
                                           kernel_initializer=kernel_initializer,
                                           name=suffix + '_pyramid_down_conv2d_' + i_s + '_1_' + index))
        shared_layers.append(layers.Activation('relu', name=suffix + '_pyramid_down_activation_' + i_s + '_2_' + index))
        shared_layers.append(layers.MaxPooling2D((2, 2), padding='same',
                             name=suffix + '_pyramid_down_max_pooling_' + i_s + '_3_' + index))
        x_pad_size = input_shape[0] // 4
        y_pad_size = input_shape[1] // 4
        shared_layers.append(layers.ZeroPadding2D((x_pad_size, y_pad_size),
                             name=suffix + '_pyramid_down_padding_' + i_s + '_4_' + index))
        
        shared_down_level.append(shared_layers)
    
    
    shared_up_level = []
    for i in range(num_pyramid_layers):
        i_s = str(i + 1)
        shared_layers = []
        shared_layers.append(layers.Conv2D(num_filters, (3, 3), (1, 1), padding='same',
                                           kernel_initializer=kernel_initializer,
                                           name=suffix + '_pyramid_up_conv2d_' + i_s + '_1_' + index))
        shared_layers.append(layers.Activation('relu', name=suffix + '_pyramid_up_activation_' + i_s + '_2_' + index))
        shared_layers.append(layers.UpSampling2D((2, 2), interpolation='bilinear',
                                                 name=suffix + '_pyramid_upsampling_' + i_s + '_3_' + index))
        x_crop_size = input_shape[0] // 2
        y_crop_size = input_shape[1] // 2
        shared_layers.append(layers.Cropping2D((x_crop_size, y_crop_size),
                             name=suffix + '_pyramid_up_cropping_' + i_s + '_4_' + index))
        
        shared_up_level.append(shared_layers)
    
    
    shared_skip = []
    for i in range(num_pyramid_layers - 1):
        i_s = str(i + 1)
        shared_layers = []
        shared_layers.append(layers.Conv2D(num_filters, (1, 1), (1, 1), padding='same',
                                           kernel_initializer=kernel_initializer,
                                           name=suffix + '_pyramid_skip_conv2d_' + i_s + '_1_' + index))
        shared_layers.append(layers.Activation('relu', name=suffix + '_pyramid_skip_activation_' + i_s + '_2_' + index))
        shared_layers.append(layers.Add(name=suffix + '_pyramid_skip_add_' + i_s + '_3_' + index))
        
        shared_skip.append(shared_layers)
    
    
    return shared_down_level, shared_up_level, shared_skip

    
def feature_pyramid_layer(x, pyramid_layers, input_shape, num_filters, kernel_initializer,
                          suffix, index):
    
    x_input = layers.Conv2D(num_filters, (1, 1), (1, 1), padding='same',
                            kernel_initializer=kernel_initializer,
                            name=suffix + '_pyramid_input_conv2d_1_' + index)(x)
    x_input = layers.Activation('relu', name=suffix + '_pyramid_input_activation_2_' + index)(x_input)
    

    # Initialise shared layers for the pyramid
    shared_down_level, shared_up_level, shared_skip = _shared_feature_pyramid_layers(pyramid_layers,
                                                                                     input_shape,
                                                                                     num_filters,
                                                                                     kernel_initializer,
                                                                                     suffix,
                                                                                     index)
    pyramid_output = []
    
    while True:
        x_skip = []
        x = x_input
        # Downsampling
        for i in range(pyramid_layers):
            shared_layers = shared_down_level[i]
            for j in range(len(shared_layers)):
                x = shared_layers[j](x)
            x_skip.append(x)        

        # Remove last element, as last layer does not have a skip connection
        del x_skip[-1]
        x_skip.reverse()
        
        # Upsampling
        
        for i in range(pyramid_layers):
            # Pass skip data and add with main data flow
            if i > 0:
                shared_skip_layers = shared_skip[i - 1]
                x_s = x_skip[i - 1]
                for s in range(len(shared_skip_layers) - 1):
                    x_s = shared_skip_layers[s](x_s)
                x = shared_skip_layers[-1]([x_s, x])
            
            shared_layers = shared_up_level[i]
            for j in range(len(shared_layers)):
                x = shared_layers[j](x)
        
                
        pyramid_output.append(x)
        
        pyramid_layers -= 1
        if pyramid_layers <= 0:
            break
    
        
    x = layers.Concatenate(axis=-1,
                           name=suffix + '_pyramid_output_concatenate_1_' + index)(pyramid_output)
    x = layers.Conv2D(num_filters, (1, 1), (1, 1), padding='same',
                      kernel_initializer=kernel_initializer,
                      name=suffix + '_pyramid_output_conv2d_2_' + index)(x)
    x = layers.Activation('relu', name=suffix + '_pyramid_output_activation_3_' + index)(x)
        
    return x
        
    
def _shared_2d_branch(input_shape, kernel_initializer, downsample=False) -> keras.Model:
    suffix = 'shared_branch'
    
    shared_input = keras.layers.Input(shape=input_shape, name='input_' + suffix)
    
    x = shared_input
    
    target_shape = input_shape
    if downsample:
        # Downsample image to reduce total memory requirement
        original_dtype = x.dtype
        target_shape = (input_shape[0] // 2, input_shape[1] //2)
        x = tf.image.resize(x, target_shape, method=tf.image.ResizeMethod.BILINEAR,
                            antialias=True, name=suffix + 'image_resize_down')
        # Cast image back to original as 'resize' returns a Tensor of float32
        x = tf.cast(x, original_dtype, name=suffix + 'image_casting_down')
    
    # Pass input through inception pipeline
    x_inc = _inception_block_a(x, num_filters=32, kernel_initializer=kernel_initializer,
                               suffix=suffix, index='1')
    x_inc = _inception_block_a(x_inc, num_filters=64, kernel_initializer=kernel_initializer,
                               suffix=suffix, index='2')
    x_inc = _inception_block_a(x_inc, num_filters=64, kernel_initializer=kernel_initializer,
                               suffix=suffix, index='3')
    x_inc = _inception_block_b(x_inc, num_filters=128, kernel_initializer=kernel_initializer,
                               suffix=suffix, index='4')
    
    # Pass input through multi-level feature pyramid pipeline
    x_pyr = feature_pyramid_layer(x, pyramid_layers=3, input_shape=target_shape,
                                  num_filters=128, kernel_initializer=kernel_initializer,
                                  suffix=suffix, index='1')
    
    x = layers.Add(name=suffix + '_add_1')([x_inc, x_pyr])
    
    if downsample:
        # Upsample image back to original resolution
        target_shape = (input_shape[0], input_shape[1])
        x = tf.image.resize(x, target_shape, method=tf.image.ResizeMethod.BILINEAR,
                            name=suffix + 'image_resize_up')
        x = tf.cast(x, original_dtype, name=suffix + 'image_casting_up')
        
    shared_model = keras.models.Model(shared_input, x)
    return shared_model


def get_model(sa_input_shape, la_input_shape, num_classes) -> keras.Model:
    kernel_initializer = 'glorot_uniform'
    
    # The short-axis image is expected to have its 3rd dimension as channels: (B, W, H, C)    
    input_sa = keras.Input(shape=sa_input_shape, name='input_sa')
    input_la = keras.Input(shape=la_input_shape, name='input_la')
    
    input_sa_affine = keras.Input(shape=(4, 4), name='input_sa_affine', dtype=tf.float32)
    input_la_affine = keras.Input(shape=(4, 4), name='input_la_affine', dtype=tf.float32)
    
    x_sa = input_sa
    x_la = input_la
    
    
    shared_layers = _shared_2d_branch(la_input_shape, kernel_initializer, downsample=False)
        
    # Create 'channel' axis that will be carried over when unstacking
    x_sa = tf.expand_dims(x_sa, axis=-1)
    # Break the 3D image into single 2D slice input
    x_sa_list = tf.unstack(x_sa, axis=-2)
    # Pass each slice to the shared layer    
    for i in range(len(x_sa_list)):
        x_sa_list[i] = shared_layers(x_sa_list[i])
    
    # Stack back into a 3D image (W, H, D, C)
    x_sa = tf.stack(x_sa_list, axis=-2)
    
    # Short-Axis branch
    x_sa = layers.Conv3D(32, (3, 3, 3), padding='same', kernel_initializer=kernel_initializer)(x_sa)
    x_sa = layers.Activation('relu')(x_sa)
    
    x_sa = layers.Conv3D(32, (3, 3, 3), padding='same', kernel_initializer=kernel_initializer)(x_sa)
    x_sa = layers.Activation('relu')(x_sa)
    
    x_sa = layers.Conv3D(64, (3, 3, 3), padding='same', kernel_initializer=kernel_initializer)(x_sa)
    x_sa = layers.Activation('relu')(x_sa)
    
    output_sa = layers.Conv3D(num_classes, (1, 1, 1), padding='same',
                              kernel_initializer=kernel_initializer, name='output_sa')(x_sa)
    
    # Pass the long-axis slice through the shared layers
    x_la = shared_layers(x_la)
    
    # Long-Axis branch
    x_la = layers.Conv2D(32, (3, 3), padding='same', kernel_initializer=kernel_initializer)(x_la)
    x_la = layers.Activation('relu')(x_la)
    
    x_la = layers.Conv2D(64, (3, 3), padding='same', kernel_initializer=kernel_initializer)(x_la)
    x_la = layers.Activation('relu')(x_la)
    
    x_la = layers.Conv2D(64, (3, 3), padding='same', kernel_initializer=kernel_initializer)(x_la)
    x_la = layers.Activation('relu')(x_la)
      
    x_la = layers.Conv2D(num_classes, (1, 1), padding='same', kernel_initializer=kernel_initializer)(x_la)
    
    # output_sa or x_sa as input to spatial transformer
    x_la_t = spatial_target_transformer(output_sa, input_sa_affine, input_la_affine,
                                        sa_input_shape, la_input_shape)
    
    # Reshape from 3d to 2d (depth size is expected to be 1 after the spatial transformer)
    x_la_t = layers.Reshape((la_input_shape[0], la_input_shape[1], -1))(x_la_t)
    
    x_la = layers.Concatenate()([x_la, x_la_t])
    
    x_la = layers.Conv2D(32, (3, 3), padding='same', kernel_initializer=kernel_initializer)(x_la)
    x_la = layers.Activation('relu')(x_la)
    
    output_la = layers.Conv2D(num_classes, (1, 1), padding='same',
                              kernel_initializer=kernel_initializer, name='output_la')(x_la)
    
    model = keras.Model([input_sa, input_la, input_sa_affine, input_la_affine],
                        [output_sa, output_la])
    
    return model



src/configuration.py\
Use to select hyperparmaeter values

In [None]:
from sklearn.model_selection import ParameterGrid

from tensorboard.plugins.hparams import api as hp

# Sortable version of HParam
class HParamS(hp.HParam):
    
    def __init__(self, name, domain=None, display_name=None, description=None):
        hp.HParam.__init__(self, name, domain, display_name, description)
        
    def __lt__(self, other):
        return self.name.lower() < other.name.lower()


class HyperParameters():
    
    def __init__(self, search_type: str):
        # TODO: Load from file rather than hard-coded in this file
        self.HP_FLOATING_POINT = HParamS('floating_point', hp.Discrete(['16']))
        self.HP_XLA = HParamS('xla_compiler', hp.Discrete([False]))
        self.HP_EPOCHS = HParamS('epochs', hp.Discrete([20]))
        self.HP_BATCH_SIZE = HParamS('batch_size', hp.Discrete([1]))
        self.HP_LEANRING_RATE = HParamS('learning_rate', hp.Discrete([0.0005]))
        self.HP_OPTIMISER = HParamS('optimiser', hp.Discrete(['adam']))
        self.HP_LOSS = HParamS('loss', hp.Discrete(['focal']))
        self.HP_DROPOUT = HParamS('drop_out', hp.Discrete([0.0]))
        
        self.parameter_dict = {}
        self.parameter_dict[self.HP_FLOATING_POINT] = self.HP_FLOATING_POINT.domain.values
        self.parameter_dict[self.HP_XLA] = self.HP_XLA.domain.values
        self.parameter_dict[self.HP_EPOCHS] = self.HP_EPOCHS.domain.values
        self.parameter_dict[self.HP_BATCH_SIZE] = self.HP_BATCH_SIZE.domain.values
        self.parameter_dict[self.HP_LEANRING_RATE] = self.HP_LEANRING_RATE.domain.values
        self.parameter_dict[self.HP_OPTIMISER] = self.HP_OPTIMISER.domain.values
        self.parameter_dict[self.HP_LOSS] = self.HP_LOSS.domain.values
        self.parameter_dict[self.HP_DROPOUT] = self.HP_DROPOUT.domain.values
        
        if search_type == 'grid':
            self.parameter_space = ParameterGrid(self.parameter_dict)
        else:
            raise ValueError('Invalid \'search_type\' input. Given: {}'.format(search_type))
        
        
    def __iter__(self):
        parameter_list = list(self.parameter_space)
        for parameter in parameter_list:
            yield parameter


"""
if __name__ == '__main__':
    config = HyperParameters(search_type='grid')
    for i in config:
        print(i)
"""

"\nif __name__ == '__main__':\n    config = HyperParameters(search_type='grid')\n    for i in config:\n        print(i)\n"

src/run_training.py

In [None]:
base_output_path = '/content/gdrive/MyDrive/mnms2_challenge/'   # Example

In [None]:
import os

import datetime
import random

import numpy as np

import tensorflow as tf
from tensorflow import keras

from tensorboard.plugins.hparams import api as hp

#from configuration import HyperParameters
#from data import TensorFlowDataGenerator
#from tf.models import test_model
#from tf.losses.loss import FocalLoss, TverskyLoss
#from tf.metrics.metrics import dice


__SEED = 1456
os.environ['PYTHONHASHSEED'] = str(__SEED)
random.seed(__SEED)
tf.random.set_seed(__SEED)
np.random.seed(__SEED)


def get_callbacks(prefix: str, checkpoint_directory: str, hparams):
    model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_directory,
        save_weights_only=False,
        monitor='val_loss',
        mode='min',
        save_best_only=True)
    
    log_dir = os.path.join(base_output_path, 'logs', 'fit', prefix + datetime.datetime.now().strftime('_%Y%m%d-%H%M%S')) + '/'
    #file_writer = tf.summary.create_file_writer(log_dir + '\\metrics')
    #file_writer.set_as_default()
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
    hparams_callback = hp.KerasCallback(log_dir, hparams)
    
    return [model_checkpoint_callback,
            tensorboard_callback,
            hparams_callback]


if __name__ == '__main__':
    hyper_parameters = HyperParameters('grid')
    
    for hparams in hyper_parameters:
        keras.backend.clear_session()
        
        fp = hparams[hyper_parameters.HP_FLOATING_POINT]
        if fp == '16':
            policy = keras.mixed_precision.experimental.Policy('mixed_float16')
            keras.mixed_precision.experimental.set_policy(policy)

        use_xla = hparams[hyper_parameters.HP_XLA]
        if use_xla:
            tf.config.optimizer.set_jit('autoclustering')
        
        batch_size = hparams[hyper_parameters.HP_BATCH_SIZE]
        (train_gen, validation_gen,
         test_gen, data_gen) = TensorFlowDataGenerator.get_affine_generators(batch_size,
                                                                             max_buffer_size=None,
                                                                             floating_precision=fp)
                                                                    
        model = get_model(data_gen.sa_shape, data_gen.la_shape, data_gen.n_classes)
        
        learning_rate = hparams[hyper_parameters.HP_LEANRING_RATE]
        if hparams[hyper_parameters.HP_OPTIMISER] == 'adam':
            optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
            
        if hparams[hyper_parameters.HP_LOSS] == 'focal':
            loss = FocalLoss(0.25, 2.0)
        elif hparams[hyper_parameters.HP_LOSS] == 'tversky':
            loss = TverskyLoss(0.5, 0.5)
        elif hparams[hyper_parameters.HP_LOSS] == 'crossentropy':
            loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
            
        model.compile(
            optimizer=optimizer,
            loss=loss,
            metrics=[dice])
        
        epochs = hparams[hyper_parameters.HP_EPOCHS]
        prefix = 'multi_stage_model'
        checkpoint_path = os.path.join(base_output_path, 'checkpoint', prefix + datetime.datetime.now().strftime('_%Y%m%d-%H%M%S')) + '/'
        model.fit(x=train_gen,
                  validation_data=validation_gen,
                  epochs=epochs,
                  callbacks=get_callbacks(prefix, checkpoint_path, hparams),
                  verbose=1)


Your GPU may run slowly with dtype policy mixed_float16 because it does not have compute capability of at least 7.0. Your GPU:
  Tesla K80, compute capability 3.7
See https://developer.nvidia.com/cuda-gpus for a list of GPUs and their compute capabilities.
Instructions for updating:
Use tf.keras.mixed_precision.LossScaleOptimizer instead. LossScaleOptimizer now has all the functionality of DynamicLossScale
Instructions for updating:
Use fn_output_signature instead
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
  opt = tf.keras.mixed_precision.LossScaleOptimizer(opt)





Epoch 1/20
    115/Unknown - 889s 7s/step - loss: 0.0490 - output_sa_loss: 0.0172 - output_la_loss: 0.0318 - output_sa_dice: 0.7999 - output_la_dice: 0.6380

KeyboardInterrupt: ignored