In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import cv2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE

## Input Pipeline

This tutorial trains a model to translate from images of horses, to images of zebras. You can find this dataset and similar ones [here](https://www.tensorflow.org/datasets/datasets#cycle_gan). 

As mentioned in the [paper](https://arxiv.org/abs/1703.10593), apply random jittering and mirroring to the training dataset. These are some of the image augmentation techniques that avoids overfitting.

This is similar to what was done in [pix2pix](https://www.tensorflow.org/tutorials/generative/pix2pix#load_the_dataset)

* In random jittering, the image is resized to `286 x 286` and then randomly cropped to `256 x 256`.
* In random mirroring, the image is randomly flipped horizontally i.e left to right.

In [2]:
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMAGE_WIDTH = 1024
IMAGE_HEIGHT = 1024

In [3]:
def resize(input_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image

In [12]:
def crop_image_into_patches(image):
    # crop 1944x2592x3 image into six 1024x1024x3 patches
    #image=tf.expand_dims(image, axis=0)
    patches = tf.image.extract_patches(images=image,
                           sizes=[1,IMAGE_HEIGHT, IMAGE_WIDTH, 1],
                           strides=[1,920, 784, 1],
                           rates=[1, 1, 1, 1],
                           padding='VALID')
    reshaped_patches = tf.reshape(patches,[-1,IMAGE_HEIGHT,IMAGE_WIDTH,3])
    return reshaped_patches


## Input Pipeline

In [5]:
root_image_path = '../../TfResearch/research/object_detection/dataset_tools/assets'
train_root_image_path = '../../TfResearch/research/object_detection/dataset_tools/assets/images_train/'
valid_root_image_path = '../../TfResearch/research/object_detection/dataset_tools/assets/images_valid/'
test_root_image_path = '../../TfResearch/research/object_detection/dataset_tools/assets/images_test/'

root_mask_path = '../assets/all-patients-stained'

In [6]:
def load_image(image_file):
    image = tf.io.read_file(image_file)
    image = tf.image.decode_png(image, channels=3)
    image = tf.cast(image, tf.float32)

    return image

In [7]:
stained_train_dataset = tf.data.Dataset.list_files(root_mask_path+'/masks_train/*.png')
stained_valid_dataset = tf.data.Dataset.list_files(root_mask_path+'/masks_valid/*.png')

stained_dataset = stained_train_dataset.concatenate(stained_valid_dataset)

stained_dataset = stained_dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)

In [8]:
unstained_train_dataset_names = tf.data.Dataset.list_files(train_root_image_path + '*.png')
unstained_valid_dataset_names = tf.data.Dataset.list_files(valid_root_image_path + '*.png')
unstained_test_dataset_names = tf.data.Dataset.list_files(test_root_image_path + '*.png')

unstained_dataset = unstained_train_dataset_names.concatenate(unstained_valid_dataset_names)
unstained_dataset = unstained_dataset.concatenate(unstained_test_dataset_names)

unstained_dataset = unstained_dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)

### Whole Image into Patches

In [13]:
cropped_stained_dataset = stained_dataset.map(crop_image_into_patches, num_parallel_calls=tf.data.experimental.AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
cropped_unstained_dataset = unstained_dataset.map(crop_image_into_patches, num_parallel_calls=tf.data.experimental.AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)

In [16]:
def save_image(input_image, save_path):
    cv2.imwrite(save_path, cv2.cvtColor(input_image.numpy(), cv2.COLOR_RGB2BGR))
  

In [17]:
generate_path_unstained = 'generated/training_unstained/'
generate_path_stained = 'generated/training_stained/'

In [18]:
image_counter = 0
for image in cropped_unstained_dataset:
    image = image[0]
    for patch_num in range(image.shape[0]):
        save_image(image[patch_num], generate_path_unstained + f'{image_counter}_{patch_num}.png')
    image_counter = image_counter + 1

In [19]:
image_counter = 0
for image in cropped_stained_dataset:
    image = image[0]
    for patch_num in range(image.shape[0]):
        save_image(image[patch_num], generate_path_stained + f'{image_counter}_{patch_num}.png')
    image_counter = image_counter + 1