In [19]:
import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import pathlib
from time import time
from tensorflow_docs import modeling
import tensorflow_io as tfio

In [20]:
# gpus = tf.config.experimental.list_physical_devices('GPU')
# if gpus:
#   try:
#     # Currently, memory growth needs to be the same across GPUs
#     for gpu in gpus:
#       tf.config.experimental.set_memory_growth(gpu, True)
#     logical_gpus = tf.config.experimental.list_logical_devices('GPU')
#     print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
#   except RuntimeError as e:
#     # Memory growth must be set before GPUs have been initialized
#     print(e)

In [21]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
print("TensorFlow Version: ", tf.__version__)
print("Number of GPU available: ", len(tf.config.experimental.list_physical_devices("GPU")))

TensorFlow Version:  2.4.1
Number of GPU available:  1


In [22]:
IMG_HEIGHT = 36
IMG_WIDTH = 36
BATCH_SIZE = 64
val_fraction = 15
max_epochs = 100
augment_degree = 0.10
shuffle_buffer_size = 2500

In [23]:
# file_path = r'\\babyserverdw3\Pei-Hsun Wu\Collaboration\Amit Agarwal\210120 testing stitching algorith on mouse brain fluorescence image\for training\set4\Ast\trainim00001.tif'
# img = tf.io.read_file(file_path)
# img = tfio.experimental.image.decode_tiff(img)
# img.shape

In [24]:
# img2 = tfio.experimental.color.rgba_to_rgb(img)
# img2.shape

In [25]:
# channels = tf.unstack (img,num=4, axis=-1)

In [26]:
# image    = tf.stack   ([channels[0], channels[1], channels[2]], axis=-1)

In [27]:
# images = os.path.join(*[train_data_dir,'Ast','*.tif'])
# list_ds = tf.data.Dataset.list_files(images)

In [28]:
# for file_path in list_ds.take(2):
#     img = tf.io.read_file(file_path)
#     img = tfio.experimental.image.decode_tiff(img)
#     img = tfio.experimental.color.rgba_to_rgb(img)

In [29]:
# labeled_ds = list_ds.map(read_and_label, num_parallel_calls=AUTOTUNE)

In [30]:
# file_path = r'\\babyserverdw3\Pei-Hsun Wu\Collaboration\Amit Agarwal\210120 testing stitching algorith on mouse brain fluorescence image\for training\set4\Ast\trainim00001.tif'
# parts = tf.strings.split(file_path, os.path.sep)

In [31]:
# parts[-2]

In [32]:
# one_hot = parts[-2] == CLASS_NAMES
# tf.argmax(one_hot)

# ans = tf.reshape(tf.where(parts[-2] == CLASS_NAMES), [])
# ans

In [33]:
def read_and_label(file_path):
    label = get_label(file_path)
    img = tf.io.read_file(file_path)
    # img = tf.image.decode_jpeg(img, channels=3)
    img = tfio.experimental.image.decode_tiff(img)
    # img = tfio.experimental.color.rgba_to_rgb(img)
    channels = tf.unstack (img, num=4, axis=-1)
    img  = tf.stack   ([channels[0], channels[1], channels[2]], axis=-1)
    # img = tf.image.random_hue(img, max_delta=augment_degree, seed=5)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])
    return img, label


def get_label(file_path):
    parts = tf.strings.split(file_path, os.path.sep)
    one_hot = parts[-2] == CLASS_NAMES
    return tf.argmax(one_hot)

    # return tf.reshape(tf.where(parts[-2] == CLASS_NAMES), [])

# def augment(image, label):
#     degree = augment_degree
#     if degree == 0:
#         return image, label
#     image = tf.image.random_hue(image, max_delta=degree, seed=5)
#     image = tf.image.random_contrast(image, 1-degree, 1+degree, seed=5)  # tissue quality
#     image = tf.image.random_saturation(image, 1-degree, 1+degree, seed=5)  # stain quality
#     image = tf.image.random_brightness(image, max_delta=degree)  # tissue thickness, glass transparency (clean)
#     image = tf.image.random_flip_left_right(image, seed=5)  # cell orientation
#     image = tf.image.random_flip_up_down(image, seed=5)  # cell orientation
#     image = tf.image.random_crop(image, [96,96,3])
#     return image, label

def prepare(data_dir):
    tmp = [0]
    for idx,CLASS in enumerate(CLASS_NAMES):
        images = os.path.join(*[data_dir,CLASS,'*.tif'])
        list_ds = tf.data.Dataset.list_files(images)
        labeled_ds = (list_ds
                      .map(read_and_label, num_parallel_calls=AUTOTUNE)
                      )
        if tmp[0] == 0:
            tmp[0] = labeled_ds
        else:
            labeled_ds = tmp[0].concatenate(labeled_ds)
            tmp[0] = labeled_ds
    return tmp[0].shuffle(shuffle_buffer_size)

In [34]:
train_data_dir = r'\\babyserverdw3\Pei-Hsun Wu\Collaboration\Amit Agarwal\210120 testing stitching algorith on mouse brain fluorescence image\for training\set14_scale9to36_cortex_puff'

train_data_dir = pathlib.Path(train_data_dir)
CLASS_NAMES = np.array([item.name for item in train_data_dir.glob('*') if not item.name.endswith((".mat",".DS_store",".png"))])
CLASS_NAMES = sorted(CLASS_NAMES, key=str.lower) #sort alphabetically case-insensitive
CLASS_NAMES

['Ast', 'Neu', 'Oth']

In [35]:
train_labeled_ds = prepare(train_data_dir)

In [36]:
# train_image_count = len(list(train_labeled_ds))
train_image_count = tf.data.experimental.cardinality(train_labeled_ds).numpy()
print('training set size : ', train_image_count)
val_image_count = train_image_count // 100 * val_fraction
print('validation size: ', val_image_count)
train_image_count2 = train_image_count-val_image_count
print('training set size after split : ', train_image_count2)

training set size :  18000
validation size:  2700
training set size after split :  15300


In [37]:
STEPS_PER_EPOCH = train_image_count2 // BATCH_SIZE
VALIDATION_STEPS = val_image_count // BATCH_SIZE
print('train step #',STEPS_PER_EPOCH)
print('validation step #',VALIDATION_STEPS)

train step # 239
validation step # 42


In [None]:
plt.figure(figsize=(8,8))
for idx, elem in enumerate(train_labeled_ds.take(64)):
    img = elem[0]
    label = elem[1]
    ax = plt.subplot(8,8,idx+1)
    plt.imshow(img)
    plt.title(CLASS_NAMES[label].title())
    plt.axis('off')
target = 'logs'
if not os.path.exists(target): os.mkdir(target)
figname = 'augmented_dataset_{}.png'.format(str(round(augment_degree*100)))
plt.savefig(os.path.join(target,figname))

In [None]:
normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)
train_labeled_ds = train_labeled_ds.map(lambda x, y: (normalization_layer(x), y))
train_ds = (train_labeled_ds
            .skip(val_image_count)
            .shuffle(buffer_size=shuffle_buffer_size)
            # .repeat()
            .batch(BATCH_SIZE)
            .prefetch(buffer_size=AUTOTUNE)
            )
val_ds = (train_labeled_ds
          .take(val_image_count)
          # .repeat()
          .batch(BATCH_SIZE)
          .prefetch(buffer_size=AUTOTUNE))

In [None]:
log_dir = pathlib.Path('logs')

def get_callbacks(name):
    return [
        modeling.EpochDots(),
        tf.keras.callbacks.EarlyStopping(monitor='val_categorical_crossentropy',
                                         patience=50, restore_best_weights=True),
        tf.keras.callbacks.TensorBoard(log_dir/name, histogram_freq=1),
        tf.keras.callbacks.ModelCheckpoint(filepath=log_dir/name/"{}/cp.ckpt".format(name),
                                           verbose=0,
                                           monitor='val_sparse_categorical_crossentropy',
                                           save_weights_only=True,
                                           save_best_only=True),
        tf.keras.callbacks.ReduceLROnPlateau(monitor='val_categorical_crossentropy',
                                             factor=0.1, patience=10, verbose=0, mode='auto',
                                             min_delta=0.0001, cooldown=0, min_lr=0),
    ]

In [None]:
def compilefit(model, name, max_epochs, train_ds, val_ds):
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                  metrics=[tf.keras.losses.CategoricalCrossentropy(from_logits=True), 'accuracy'])
    model_history = model.fit(train_ds,
                              steps_per_epoch=STEPS_PER_EPOCH,
                              epochs=max_epochs,
                              verbose=1,
                              validation_data=val_ds,
                              callbacks=get_callbacks(name),
                              validation_steps=VALIDATION_STEPS,
                              use_multiprocessing=True
                              )
    # namename = os.path.dirname(name)
    # if not os.path.isdir(os.path.abspath(namename)):
    #     os.mkdir(os.path.abspath(namename))
    # if not os.path.isdir(os.path.abspath(name)):
    #     os.mkdir(os.path.abspath(name))
    # if not os.path.isfile(pathlib.Path(name) / 'full_model.h5'):
    #     try:
    #         model.save(pathlib.Path(name) / 'full_model.h5')
    #     except:
    #         print('model not saved?')
    return model_history

In [None]:
num_classes = 3
from tensorflow.keras import layers
model = tf.keras.Sequential([
  layers.experimental.preprocessing.Rescaling(1./255),
  layers.Conv2D(32, 3, activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(3)
])
model.compile(
  optimizer='adam',
  loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=['accuracy'])
model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=3
)

In [None]:
model.evaluate(val_ds)

In [None]:
model.evaluate(train_ds)

In [None]:
start = time()
# #min input size 76x76
MobileNetV2_base = tf.keras.applications.MobileNetV2(input_shape=(96, 96, 3),
                                            pooling=None,
                                            include_top=False,
                                            weights='imagenet'
                                            )

MobileNetV2 = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.Rescaling(1./255),
    MobileNetV2_base,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(3, activation='softmax')
])

histories={}
histories['mobilenetv2_01'] = compilefit(MobileNetV2, 'mobilenetv2_01', max_epochs, train_ds, val_ds)

end = time()