* 多输出模型示例
* 函数式API建立模型

In [30]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pathlib
import random
import os
%matplotlib inline

In [25]:
data_dir = './dataset/moc'

In [26]:
data_root = pathlib.Path(data_dir)
data_root

PosixPath('dataset/moc')

In [27]:
all_img_path = list(data_root.glob('*/*'))

In [28]:
all_img_path = [str(p) for p in all_img_path]

In [29]:
random.shuffle(all_img_path)

In [31]:
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())

In [49]:
all_img_label = [pathlib.Path(p).parent.name for p in all_img_path]

In [44]:
color_label_names = set(name.split('_')[0] for name in label_names)
color_label_index = dict((name, index) for index, name in enumerate(color_label_names))

In [47]:
item_label_names = set(name.split('_')[1] for name in label_names)
item_label_index = dict((name, index) for index, name in enumerate(item_label_names))

In [51]:
color_labels = [color_label_index[label.split('_')[0]]   for label in all_img_label]

In [55]:
item_labels = [item_label_index[label.split('_')[1]]   for label in all_img_label]

### 加载图片

In [56]:
def load_preprocess_img(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [224, 224])
    image = tf.cast(image, tf.float32)
    image = image/255.0 
    image = 2*image - 1
    return image

In [57]:
 path_ds = tf.data.Dataset.from_tensor_slices(all_img_path)

In [58]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
img_ds = path_ds.map(load_preprocess_img, num_parallel_calls=AUTOTUNE)

In [59]:
label_ds = tf.data.Dataset.from_tensor_slices((color_labels, item_labels))

In [61]:
img_label_ds = tf.data.Dataset.zip((img_ds, label_ds))

In [62]:
img_label_ds

<ZipDataset shapes: ((224, 224, 3), ((), ())), types: (tf.float32, (tf.int32, tf.int32))>

In [70]:
all_count = len(all_img_path)

In [71]:
# 划分数据集
test_count = int(all_count*0.2)
train_count = all_count - test_count
test_data = img_label_ds.take(test_count)
train_data = img_label_ds.skip(test_count)

In [72]:
BATCH_SIZE = 32

In [74]:
train_data = train_data.shuffle(train_count).batch(BATCH_SIZE).repeat()
train_data = train_data.prefetch(AUTOTUNE)

test_data = test_data.batch(BATCH_SIZE)

### 建立模型

In [78]:
mobile_net = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3), include_top=False)

In [79]:
inputs = tf.keras.Input(shape=(224, 224, 3))

In [81]:
x = mobile_net(inputs)
x = tf.keras.layers.GlobalAveragePooling2D()(x)

In [87]:
x1 = tf.keras.layers.Dense(1024, activation="relu")(x)
output_color = tf.keras.layers.Dense(len(color_label_names), activation="softmax", name='output_color')(x1)

In [88]:
x2 = tf.keras.layers.Dense(1024, activation="relu")(x)
output_item = tf.keras.layers.Dense(len(item_label_names), activation="softmax", name='output_item')(x2)

In [89]:
model = tf.keras.Model(inputs=inputs, outputs=[output_color, output_item])

In [90]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
mobilenetv2_1.00_224 (Model)    (None, 7, 7, 1280)   2257984     input_3[0][0]                    
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 1280)         0           mobilenetv2_1.00_224[2][0]       
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 1024)         1311744     global_average_pooling2d[0][0]   
____________________________________________________________________________________________

In [91]:
model.compile(
    optimizer='adam', 
    loss={"output_color": "sparse_categorical_crossentropy", "output_item": "sparse_categorical_crossentropy"},
    metrics=['acc']
)

In [92]:
train_steps = train_count//BATCH_SIZE
test_steps = test_count//BATCH_SIZE

In [None]:
model.fit(
    train_data,
    epochs=15,
    steps_per_epoch=train_steps,
    validation_data=test_data,
    validation_steps=test_steps  
)