In [None]:
# model_test_pipeline.ipynb
import os
import datetime
import tensorflow as tf    
import argparse
from tensorflow.python.keras.callbacks import Callback


class MyModel(object):
    def __init__(self):
        self.model_path = None
        self.model = None
        
    def get_model_path(self):
        return self.model_path
    
    def get_model(self):
        return self.model
    
    def train(self):
        parser = argparse.ArgumentParser()
        parser.add_argument('--node_amount', required=False, type=int, default=128)
        parser.add_argument('--epoch', required=False, type=int, default=10)
        parser.add_argument('--dropout_rate', required=False, type=float, default=0.2)
        parser.add_argument('--optimizer', required=False, type=str, default="sgd")
        parser.add_argument('--dataset_path', required=False, type=str, default=None)
        parser.add_argument('--model_path', required=False, type=str, default=None)
        parser.add_argument('--train_version', required=False, type=str, default=None)                 
        parser.add_argument('--save_version', required=False, type=str, default=None) 
        if os.getenv('FAIRING_RUNTIME', None) is None:        
            args = parser.parse_args(args=[])
        else:            
            args = parser.parse_args()
        
        if args.dataset_path is not None:
            new_dataset = np.load(args.dataset_path)
            new_x = new_dataset['x_train']
            new_y = new_dataset['y_train']

            add_x_train, add_x_test, \
            add_y_train, add_y_test = train_test_split(new_x, new_y, 
                                                   test_size=0.1,
                                                   random_state=42)
            train_size = len(add_y_train)
            test_size = len(add_y_test)
            x_train = np.append(x_train[:train_size], add_x_train, axis=0)
            x_test = np.append(x_test[:test_size], add_x_test, axis=0)
            y_train = np.append(y_train[:train_size], add_y_train, axis=0)   
            y_test = np.append(y_test[:test_size], add_y_test, axis=0)   
        
        if args.model_path is None:
            self.model = tf.keras.models.Sequential([
                tf.keras.layers.Flatten(input_shape=(28, 28)),
                tf.keras.layers.Dense(args.node_amount, activation='relu'),
                tf.keras.layers.Dropout(args.dropout_rate),
                tf.keras.layers.Dense(10, activation='softmax')
            ])

            self.model.compile(optimizer=args.optimizer,
                               loss='sparse_categorical_crossentropy',
                               metrics=['acc'])
        else:
            self.model = tf.keras.models.load_model(args.model_path)  
        
        mnist = tf.keras.datasets.fashion_mnist
        (x_train, y_train), (x_test, y_test) = mnist.load_data()

        print("x_train shape:", x_train.shape, "y_train shape:", y_train.shape)
        print("x_test shape:", x_test.shape, "y_test shape:", y_test.shape)

        x_train, x_test = x_train / 255.0, x_test / 255.0

        self.model = tf.keras.models.Sequential([
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(args.node_amount, activation='relu'),
            tf.keras.layers.Dropout(args.dropout_rate),
            tf.keras.layers.Dense(10, activation='softmax')
        ])

        self.model.compile(optimizer=args.optimizer,
                        loss='sparse_categorical_crossentropy',
                        metrics=['acc'])


        date_folder = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 
        if os.getenv('FAIRING_RUNTIME', None) is None:
            log_dir = "log/fit/" + date_folder
        else:
            log_dir = "/notebook/log/fit/" + date_folder 

        print(f"tensorboard log dir : {log_dir}")

        tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=log_dir,
                                                        histogram_freq=1)
        print(f"Total epochs {args.epoch}")
        hist = self.model.fit(x_train, y_train,
                              verbose=0,
                              validation_data=(x_test, y_test),
                              epochs=args.epoch,
                              callbacks=[LoggingTrain(),
                                         tensorboard_cb])
        
        model_ver = get_strftime('%Y%m%d%H%M%S') # timestamp 형식 변경 (문자 제거)
        if args.save_version is None:
            model_ver = get_strftime('%Y%m%d%H%M%S')
        else: 
            model_ver = args.save_version 
        model_val_acc = int(float(hist.history['val_acc'][-1]) * 100)
        self.model_version = f"{model_ver}.{model_val_acc}"
        save_model_path = f"{args.model_path}/{self.model_version}"
        self.model.save(save_model_path, save_format='tf')                      
        return self.model
    
def get_strftime(time_format):
    dt_now = datetime.datetime.now()
    return dt_now.strftime(time_format)        

def p(msg):
    dt_now = datetime.datetime.now()
    strftime = dt_now.strftime('%Y-%m-%dT%H:%M:%SZ')
    print(f"{strftime} {msg}", flush=True)    
    
class LoggingTrain(Callback):
    """logging for train
    """
    def on_batch_end(self, batch, logs={}):
        if batch % 100 == 0:
            p(f"batch: {batch}")
            p(f"accuracy={logs.get('acc')} loss={logs.get('loss')}")
            
    def on_epoch_begin(self, epoch, logs={}):
        p(f"epoch: {epoch}")

    def on_epoch_end(self, epoch, logs={}):
        p(f"Validation-accuracy={logs.get('val_acc')}")
        p(f"Validation-loss={logs.get('val_loss')}")
        return   

In [None]:
import os
from kubeflow.fairing.builders.append.append import AppendBuilder
from kubeflow.fairing.preprocessors.converted_notebook import ConvertNotebookPreprocessor

if __name__ == '__main__':
    if os.getenv('FAIRING_RUNTIME', None) is None:
        
        preprocessor = ConvertNotebookPreprocessor(notebook_file="my_model_retrain.ipynb")

        DOCKER_REGISTRY = "ydh0924"
        base_image = "dudaji/cap-jupyterlab:tf2.0-cpu"
        image_name = "fashion-mnist-retrain"
        image_tag = "handson"

        builder = AppendBuilder(registry=DOCKER_REGISTRY,
                                image_name=image_name,
                                base_image=base_image,
                                tag=image_tag,
                                preprocessor=preprocessor,
                                push=True)
        image_name = builder.build()
        print(image_name)

    else:
        remote_model = MyModel()
        remote_model.train()