In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

import os
import random
import time 
import PIL
import pathlib
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from PIL import Image

from tensorflow.keras.callbacks import CSVLogger

%matplotlib inline

In [None]:
def test_device():
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    cpus = tf.config.experimental.list_physical_devices(device_type='CPU')
    gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
    print('GPU Info')
    for gpu in gpus:
        print('   ', gpu)
    print('CPU Info')
    for cpu in cpus:
        print('   ', cpu)
    
    print()
    if gpus:
        try:
            tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
            tf.config.experimental.set_memory_growth(gpus[0], True)
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
        except RuntimeError as e:
            print(e)
        
test_device()
tf.keras.backend.clear_session() 
AUTOTUNE = tf.data.experimental.AUTOTUNE 

In [None]:
data_root_orig = './data/sep_data/train'
val_data_root_orig = './data/sep_data/validation'
img_size = 224
batch_size = 32
epoch = 50
channel_size = 1
ver = 5

In [None]:
data_root = pathlib.Path(data_root_orig)
print(data_root, list(data_root.iterdir()))
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
label_to_index = dict((name, index) for index, name in enumerate(label_names))
print(label_to_index)

val_data_root = pathlib.Path(val_data_root_orig)
print(val_data_root, list(val_data_root.iterdir()))
val_label_names = sorted(item.name for item in val_data_root.glob('*/') if item.is_dir())
val_label_to_index = dict((name, index) for index, name in enumerate(label_names))
print(val_label_to_index)

In [None]:
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
random.shuffle(all_image_paths)
image_count = len(all_image_paths)
all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]

val_all_image_paths = list(val_data_root.glob('*/*'))
val_all_image_paths = [str(path) for path in val_all_image_paths]
val_image_count = len(val_all_image_paths)
random.shuffle(val_all_image_paths)
val_all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in val_all_image_paths]

print("Train image count      : {}".format(image_count))
print("Validation image count : {}".format(val_image_count))

In [None]:
def preprocess_image(image):
    image = tf.image.decode_jpeg(image, channels=1)
    image = tf.image.resize(image, [img_size, img_size])
    image /= 255.0 
    return image

def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    return preprocess_image(image)

def load_and_preprocess_from_path_label(path, label):
    return load_and_preprocess_image(path), label

In [None]:
ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))
image_label_ds = ds.map(load_and_preprocess_from_path_label)
val_ds = tf.data.Dataset.from_tensor_slices((val_all_image_paths, val_all_image_labels))
val_image_label_ds = val_ds.map(load_and_preprocess_from_path_label)

steps_per_epoch = int(image_count / batch_size)
val_steps_per_epoch = int(val_image_count / batch_size)

In [None]:
 default_timeit_steps = 2*steps_per_epoch+1

def timeit(ds, steps=default_timeit_steps):
    overall_start = time.time()
    it = iter(ds.take(steps+1))
    next(it)

    start = time.time()
    for i,(images,labels) in enumerate(it):
        if i%10 == 0:
            print('.',end='')
            
    end = time.time()

    duration = end-start
    print("{} batches: {} s".format(steps, duration))
    print("{:0.5f} Images/s".format(batch_size*steps/duration))
    print("Total time: {}s".format(end-overall_start))

In [None]:
ds = image_label_ds.cache()
ds = ds.shuffle(buffer_size=image_count)
ds = ds.repeat()
ds = ds.batch(batch_size).prefetch(buffer_size=AUTOTUNE)
timeit(ds)

In [None]:
val_ds = val_image_label_ds.cache()
val_ds = val_ds.batch(batch_size).prefetch(buffer_size=val_image_count)
timeit(val_ds, 2*val_steps_per_epoch+1)

In [None]:
def change_range(image, label):
    return 2*image-1, tf.one_hot(label, depth=6)
keras_ds = ds.map(change_range)
keras_val_ds = val_ds.map(change_range)

In [None]:
tf.keras.backend.clear_session() 
mobile_net = tf.keras.applications.MobileNetV2(input_shape=(img_size, img_size, channel_size), 
                                               weights=None, include_top=False)
mobile_net.trainable = True

In [None]:
model_mobile = tf.keras.Sequential([
    mobile_net,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(6)
])

In [None]:
input1 = tf.keras.layers.Input(shape=(img_size, img_size, 1))
mergedOut1 = tf.keras.layers.average([mobile_net(input1), mobile_net(input1)])
model1 = tf.keras.models.Model(inputs=input1, outputs=mergedOut1)
model_2_mobile = tf.keras.Sequential([
    model1,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(6),
])

In [None]:
path = './model/checkpoints_{}_{}_{}_{}'.format(img_size, batch_size, 888, 0)
cp = tf.keras.callbacks.ModelCheckpoint(filepath=path,
                                        monitor='val_loss',
                                        verbose=0, save_best_only=True,
                                        save_weights_only=False, mode='min', save_freq='epoch')

In [None]:
model_ = model_mobile
model_.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=["accuracy"])

history = model_.fit(keras_ds, epochs=30, steps_per_epoch=steps_per_epoch, 
          validation_data=keras_val_ds, validation_steps=val_steps_per_epoch,
          callbacks=[cp])

In [None]:
log_ = 'log.log'
with open(log_, 'a+') as log:
    model_name = (path.split('/')[-1])
    valloss = min(history.history['val_loss'])
    idx = history.history['val_loss'].index(valloss)
    valacc = history.history['val_accuracy'][idx]
    log.write("{}, {}, {}\n".format(model_name, valloss, valacc))