In [2]:
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import cv2
import resnet_V2
import data_util

In [3]:
tf.random.set_seed(42)

In [4]:
mse = keras.losses.MeanSquaredError()
ce = keras.losses.CategoricalCrossentropy()

In [5]:
def custom_loss(y_true, y_pred):
    y_class_true = y_true[:, 0:133]
    y_class_pred = y_pred[:, 0:133]
    
    y_point_true = y_true[:, 133:]
    y_point_pred = y_pred[:, 133:]
    
    class_loss = ce(y_class_true, y_class_pred)
    point_loss = mse(y_point_true, y_point_pred)
    
    return 0.5 * point_loss + 0.5 * class_loss

In [6]:
encoder = tf.keras.applications.mobilenet_v2.MobileNetV2(
    input_shape=(256,256,3), alpha=1.0, include_top=None, weights='imagenet',
    pooling=None, classes=1    
)



In [7]:
input_size= 256
batch_size=16

input = keras.layers.Input(shape=(256,256,3))

encoder_output = encoder(input)
flatten = keras.layers.Flatten()(encoder_output)

branch1_class = keras.layers.Dense(133,activation='softmax',name='branch1_class')(flatten)

branch2_landmark = keras.layers.Dense(16,name='branch2_landmark')(flatten)

out= keras.layers.concatenate([branch1_class,branch2_landmark])       # concatnate?

model = keras.models.Model(input, out)
model.compile(loss=custom_loss,
            optimizer='adam'
            )
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 mobilenetv2_1.00_224 (Function  (None, 8, 8, 1280)  2257984     ['input_2[0][0]']                
 al)                                                                                              
                                                                                                  
 flatten (Flatten)              (None, 81920)        0           ['mobilenetv2_1.00_224[0][0]']   
                                                                                              

In [8]:
ds = data_util.get_cu_dataset(train_type='class', batch_size=batch_size)

In [9]:
for i in ds.take(1):
    print(i[0].shape)
    print(i[1].shape)

(16, 256, 256, 3)
(16, 149)


In [10]:
ds_size = len(list(ds))
ds_size

478

In [11]:
train_size = int(0.7 * ds_size)
val_size = int(0.15 * ds_size)
test_size = int(0.15 * ds_size)

train_ds = ds.take(train_size)
test_ds = ds.skip(train_size)
val_ds = test_ds.skip(val_size)
test_ds = test_ds.take(test_size)

In [12]:
len(list(train_ds)), len(list(val_ds)), len(list(test_ds))

(334, 73, 71)

In [13]:
ckpt_path = 'mobilenetv2_multi'
callbacks_list = [
    # accuracy 기준 가장 높은 모델의 weight 저장
    tf.keras.callbacks.ModelCheckpoint(
        filepath = ckpt_path,
        monitor='val_loss',
        mode='min',
        save_weights_only=True,
        save_best_only=True
    ),
    # EarlyStopping
    # tf.keras.callbacks.EarlyStopping(
    #    monitor='val_loss',
    #    mode='min',
    #    verbose=1, 
    #    patience=20
    # )
]

In [14]:
hist = model.fit(train_ds, validation_data=val_ds, epochs=100, callbacks=callbacks_list)

Epoch 1/100
     11/Unknown - 16s 1s/step - loss: 1809.8558