In [1]:
#%%
#imports
import nibabel as nib
import numpy as np
from pathlib import Path
import tensorflow as tf
tf.compat.v1.enable_eager_execution()

from IPython.display import clear_output
import matplotlib.pyplot as plt
import tensorlayer as tl

In [20]:
def vizualise(img):
    plt.imshow(tf.keras.preprocessing.image.array_to_img(img), cmap='gray')
    plt.axis('off')
    plt.show()

In [21]:
def read_scan(path):
    'read the images, core label, penumbra label, merged label'
    img = nib.load(str(path))
    data = img.get_fdata()
    return np.array(data)

In [32]:
path = 'data/1/1OT.nii'

In [33]:
read_scan(path).shape

(230, 230, 154)

In [38]:
img = read_scan(path)[:,:,101:102]

In [39]:
np.unique(img)

array([0., 1.])

In [None]:
#%%
#imports
import nibabel as nib
import numpy as np
from pathlib import Path
import tensorflow as tf
tf.compat.v1.enable_eager_execution()

from IPython.display import clear_output
import matplotlib.pyplot as plt
import tensorlayer as tl
import pixtopix
#%%
#set all the vars and the hyper params
data_dir = Path.cwd().parent / 'data'
formatted_data_dir = Path.cwd().parent / 'formatted_data'
IMG_WIDTH = 96
IMG_HEIGHT = 96
NUM_FRAMES = 7
MRI_SCAN_TYPES = 7
BATCH_SIZE = 64
BUFFER_SIZE = 1000
TOTAL_LENGTH = 71*30
TRAIN_LENGTH = int(0.90* TOTAL_LENGTH)
VAL_LENGTH = int(0.01 * TOTAL_LENGTH)
TEST_LENGTH = int(0.09* TOTAL_LENGTH)
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
LABEL = 'c'
VAL_BATCH_SIZE = 64

#Data loading, processing

def read_ct_scan(path):
    'read the images, core label, penumbra label, merged label'
    img = nib.load(str(path))
    data = img.get_fdata()
    return tf.convert_to_tensor(data)

def load_preprocess(paths):
    '''
    read the ct scans, masks from the paths list and stack them along the z.
    transpose the last dimension to the first as a batch and split the entire dataset into discrete images
    :param paths:
    :return: preprocessed images
    '''
    imgs_list = list(map(read_ct_scan, paths))
    imgs_stack = tf.stack(imgs_list, axis=-1, name='img_stack')
    #transpose along the mri/label types. 1st dim is the frames
    blob = tf.transpose(imgs_stack, perm=[2, 0, 1, 3])
    resized_blob = resize(blob, IMG_HEIGHT, IMG_WIDTH)
    #unstack along the first dim to get component single training sample- no use as the yield dataset generator again stacks back to give stacked axis=0
    images = tf.unstack(resized_blob, axis=0)
    return images

#use it to get all the paths for all the images, masks
def get_paths(scan):
    scans = sorted(list(data_dir.glob(str(scan)+'/V*/*.nii')))
    masks = sorted(list(data_dir.glob(str(scan)+'/['+LABEL+']*/V*/*.nii')))
    #convert list to string so that the dataset can be a tensor
    return scans+ masks


def data_generator():
    '''
        use the data generator function to generate all the samples from all the .nii
    :return: sample
    '''
    paths_db = list(map(get_paths, np.arange(1,31)))
    for paths in paths_db:
        for sample in load_preprocess(paths):
            yield sample


def one_hot_blob_merged(mask_image):
    return tf.split(tf.one_hot(tf.cast(tf.squeeze(mask_image, axis=-1), dtype=tf.int64), depth=3, axis=-1), [1, 2], axis=-1)

# image preprocessing functions for augmentation
def normalize(input_image):
    # we need not normalize the input mask, as the values are already in the range of -1 to 1
    # norm_mask = None
    # if LABEL == 'c' or LABEL=='p':
    #     # no need to transform, as the image is full of zeros and ones
    #     norm_mask = mask_image
    # if LABEL=='m':
    #     _, norm_mask  = one_hot_blob_merged(mask_image)

    return tf.cast(input_image, tf.float32) / 128.0 - 1  #, norm_mask

def resize(input_image_stack, height, width):
    return tf.image.resize(input_image_stack, [height, width],
                           method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)


def random_crop(stacked_image):
    return tf.image.random_crop(stacked_image, size=[IMG_HEIGHT, IMG_WIDTH, tf.shape(stacked_image)[-1]])

@tf.function
def load_image_train(real_image, mask_image):

    # do all this on a stack size of scan_types + mask_count
    datapoint = tf.concat([real_image, mask_image], axis=-1)

    # resize to 117% of the actual image size
    datapoint = resize(datapoint, 108, 108)

    # randomly crop it back to desired size of 96x96
    datapoint = random_crop(datapoint)

    # split the stack into input and mask
    real_image, mask_image = tf.split(datapoint, [NUM_FRAMES, tf.shape(datapoint)[-1] - NUM_FRAMES], axis=-1)

    # random mirroring
    if np.random.uniform(()) > 0.5:
        real_image = tf.image.flip_left_right(real_image)
        mask_image = tf.image.flip_left_right(mask_image)

    real_image = normalize(real_image)

    return real_image, mask_image


def load_image_test(real_image, mask_image):

    datapoint = tf.concat([real_image, mask_image], axis=-1)

    # Resize the image stack
    datapoint = resize(datapoint, IMG_HEIGHT, IMG_WIDTH)

    # split the stack into input and mask
    real_image, mask_image = tf.split(datapoint, [NUM_FRAMES, tf.shape(datapoint)[-1] - NUM_FRAMES], axis=-1)

    # norm images
    real_image = normalize(real_image)

    return real_image, mask_image

def is_all_zero(tensor):
    real_image, mask_image= tf.split(tensor, num_or_size_splits=[7, 1], axis=-1)
    real, mask = normalize(real_image, mask_image)
    core, penumbra = tf.split(mask, num_or_size_splits=[1, 1], axis=-1)
    return tf.equal(tf.reduce_sum(core+ penumbra), 0)

def allow_mask_area(image, mask):
    return tf.math.greater(tf.reduce_sum(tf.cast(tf.math.equal(mask, 1.0), dtype= tf.float32))/ (96.0 * 96.0), 0.0)

def barricade_mask_area(image, mask):
    return tf.math.equal(tf.reduce_sum(tf.cast(tf.math.equal(mask, 1.0), dtype= tf.float32))/ (96.0 * 96.0), 0.0)


def just_load_images(datapoint):
    # split the stack into input and mask
    real_image, mask_image = tf.split(datapoint, [NUM_FRAMES, 1], axis=-1)

    norm_mask = None
    if LABEL == 'c' or LABEL == 'p':
        # no need to transform, as the image is full of zeros and ones
        norm_mask = mask_image
    if LABEL == 'm':
        _, norm_mask = one_hot_blob_merged(mask_image)

    return real_image, norm_mask


def get_dataset_handle(label):
    global TRAIN_LENGTH, VAL_LENGTH, TEST_LENGTH, LABEL
    LABEL = label
    ds = tf.data.Dataset.from_generator(data_generator, output_shapes=(96, 96, MRI_SCAN_TYPES + 1),
                                        output_types=tf.float32)

    shuffled_ds = ds.shuffle(buffer_size=BUFFER_SIZE)
    #ds.map(lambda x: is_all_zero(x)).reduce(np.float32(0), lambda x, y: x + y)


    # -------START Deviation ---------
    shuffled_ds = shuffled_ds.map(just_load_images)
    shuffled_ds_filtered_with_non_zero_mask = shuffled_ds.filter(allow_mask_area)
    shuffled_ds_filtered_with_zero_mask = shuffled_ds.filter(barricade_mask_area)

    # Get filter count
    count_filtered_set_with_non_zero_mask = shuffled_ds_filtered_with_non_zero_mask.map(lambda image, mask: 1.0).reduce(
        np.float32(0), lambda x, y: x + y)
    count_filtered_set_with_zero_mask = shuffled_ds_filtered_with_zero_mask.map(lambda image, mask: 1.0).reduce(np.float32(0), lambda x,y: x + y)

    print('filtered set with non-zero mask', count_filtered_set_with_non_zero_mask.numpy())
    print('filtered set with zero mask', count_filtered_set_with_zero_mask.numpy())

    TRAIN_LENGTH = int(0.9 * count_filtered_set_with_non_zero_mask)
    VAL_LENGTH = int(0.01 * count_filtered_set_with_non_zero_mask)
    TEST_LENGTH = int(0.09 * count_filtered_set_with_non_zero_mask)

    train = shuffled_ds_filtered_with_non_zero_mask.take(TRAIN_LENGTH)\
            .concatenate(shuffled_ds_filtered_with_zero_mask.take(int(0.5 * TRAIN_LENGTH)))

    shuffled_ds_filtered_with_non_zero_mask.skip(TRAIN_LENGTH)
    shuffled_ds_filtered_with_zero_mask.skip(int(0.5* TRAIN_LENGTH))

    TRAIN_LENGTH+= int(0.5*TRAIN_LENGTH)

    val = shuffled_ds_filtered_with_non_zero_mask.take(VAL_LENGTH)\
            .concatenate(shuffled_ds_filtered_with_zero_mask.take(int(0.5 * VAL_LENGTH)))

    shuffled_ds_filtered_with_non_zero_mask.skip(VAL_LENGTH)
    shuffled_ds_filtered_with_zero_mask.skip(int(0.5 * VAL_LENGTH))

    VAL_LENGTH+= int(0.5 * VAL_LENGTH)

    test = shuffled_ds_filtered_with_non_zero_mask.take(TEST_LENGTH) \
        .concatenate(shuffled_ds_filtered_with_zero_mask.take(int(0.5 * TEST_LENGTH)))

    TEST_LENGTH+= int(0.5 * TEST_LENGTH)

    #----------- END Deviation -----------

    # #Split the data into train, val and test
    # train = shuffled_ds.take(TRAIN_LENGTH)  # take out train length num of elements
    # shuffled_ds.skip(TRAIN_LENGTH)  # skip those elements on the buffer
    #
    # val = shuffled_ds.take(VAL_LENGTH)  # pull out a new sample
    # shuffled_ds.skip(VAL_LENGTH)  # skip those
    #
    # test = shuffled_ds.take(TEST_LENGTH)
    #
    # #Apply the preprocessing function
    # train = train.map(just_load_images, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    #
    # # Filter the dataset
    # filtered_set_with_non_zero_mask = train.filter(allow_mask_area)
    # # Get filter count
    # count_filtered_set_with_non_zero_mask = filtered_set_with_non_zero_mask.map(lambda image, mask: 1.0).reduce(np.float32(0), lambda x, y: x+ y)
    #
    # # Filter dataset with zero core penumbra masks
    # filtered_set_with_zero_mask = train.filter(barricade_mask_area)
    # # Get filter count
    # count_filtered_set_with_zero_mask = filtered_set_with_zero_mask.map(lambda image, mask: 1.0).reduce(np.float32(0), lambda x, y: x + y)
    #
    # count_with_non_zero_mask = count_filtered_set_with_non_zero_mask.numpy()
    # count_with_zero_mask = count_filtered_set_with_zero_mask.numpy()
    #
    # print('non zero masks %d' % int(count_with_non_zero_mask))
    # print('zero masks %d' % int(count_with_zero_mask))
    #
    # # Get the filtered training set
    # train = filtered_set_with_non_zero_mask.take(count_with_non_zero_mask) \
    #              .concatenate(filtered_set_with_zero_mask.take(int(0.5 * count_with_non_zero_mask)))
    #
    # val, test = val.map(just_load_images), test.map(just_load_images)

    return train, val, test


def parse(sample):
  image = tf.io.parse_tensor(sample['image'], out_type=tf.float32)
  image = tf.reshape(image, [96, 96, 7])

  mask = tf.io.parse_tensor(sample['mask'], out_type=tf.float32)
  if LABEL=='c' or LABEL=='p':
    mask = tf.reshape(mask, [96, 96, 1])
  elif LABEL =='m':
    mask = tf.reshape(mask, [96, 96, 2])

  return (image, mask)

def read_from_tensor_dataset(label):
    global LABEL
    LABEL = label

    if (formatted_data_dir / LABEL / 'train.tfrecords').is_file():
        print('Reading off the TFRecord for label \'{}\''.format(label))
        train = tf.data.TFRecordDataset(str(formatted_data_dir / LABEL / 'train.tfrecords'))
        val = tf.data.TFRecordDataset(str(formatted_data_dir / LABEL / 'val.tfrecords'))
        test = tf.data.TFRecordDataset(str(formatted_data_dir / LABEL / 'test.tfrecords'))

        # Create a description of the features.
        feature_description = {
            'image': tf.io.FixedLenFeature([], tf.string, default_value=''),
            'mask': tf.io.FixedLenFeature([], tf.string, default_value=''),
        }

        _parse_record = lambda example_proto: tf.io.parse_single_example(example_proto, feature_description)

        train, val, test = train.map(_parse_record), val.map(_parse_record), test.map(_parse_record)
        train_dataset, val_dataset, test_dataset = train.map(parse), val.map(parse), test.map(parse)

        train_dataset = train_dataset.map(load_image_train).cache().shuffle(BUFFER_SIZE).repeat().batch(BATCH_SIZE)
        train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        test_dataset = test_dataset.map(load_image_test).batch(BATCH_SIZE)
        val_dataset = val_dataset.map(load_image_test).batch(BATCH_SIZE)

        return train_dataset, val_dataset, test_dataset
    else:
        print('No existing records for label \'{}\''.format(label))
        train, val, test = get_dataset_handle()
        write_to_tensor_dataset(train, val, test)
        return read_from_tensor_dataset(label)

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def serialize_example(img, mask):
    feature = {
        'image': _bytes_feature(tf.io.serialize_tensor(img)),
        'mask': _bytes_feature(tf.io.serialize_tensor(mask))
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def write_to_tensor_dataset(train, val, test):
    # Map it to a stacked image.
    #concat_image_mask = lambda images, masks: tf.concat([images, masks], axis=-1)
    #train, val, test = train.map(concat_image_mask), val.map(concat_image_mask), test.map(concat_image_mask)

    # val_tfrec = tf.data.experimental.TFRecordWriter(str(formatted_data_dir / LABEL / 'val.tfrec'))
    # test_tfrec = tf.data.experimental.TFRecordWriter(str(formatted_data_dir / LABEL / 'test.tfrec'))

    #train = train.map(tf_serialize_example)
    #filename = 'test.tfrecord'
    #writer = tf.data.experimental.TFRecordWriter(filename)
    #writer.write(serialized_features_dataset)

    with tf.io.TFRecordWriter(str(formatted_data_dir / LABEL / 'train.tfrecords')) as train_writer:
        for img, mask in train:
            tf_string = serialize_example(img, mask)
            train_writer.write(tf_string)

    with tf.io.TFRecordWriter(str(formatted_data_dir / LABEL / 'val.tfrecords')) as val_writer:
        for img, mask in val:
            tf_string = serialize_example(img, mask)
            val_writer.write(tf_string)

    with tf.io.TFRecordWriter(str(formatted_data_dir / LABEL / 'test.tfrecords')) as test_writer:
        for img, mask in test:
            tf_string = serialize_example(img, mask)
            test_writer.write(tf_string)



    # serialize the tensors
    #train, val, test = train.take(2).map(tf.io.serialize_tensor), val.take(5).map(tf.io.serialize_tensor), test.take(5).map(tf.io.serialize_tensor)


    #
    # for img in train:
    #     serialize_example(img)
    #
    # print('Writing to disk...')
    # print(train)
    # train_tfrec.write(train)
    # train_tfrec.close()
    #
    # val_tfrec.write(val)
    # val_tfrec.close()
    #
    # test_tfrec.write(test)
    # test_tfrec.close()
    # print('Written to disk')


if __name__=='__main__':
    _, _, _ = read_from_tensor_dataset('m')

