In [None]:
# fashion_mnist_serving.ipynb
import os
import datetime
import tensorflow as tf    
import argparse
from tensorflow.python.keras.callbacks import Callback


class MyModel(object):
    def train(self):
        parser = argparse.ArgumentParser()
        # Tuner 에서 얻은 최적값을 default로 변경
        parser.add_argument('--node_amount', required=False, type=int, default=256)
        parser.add_argument('--epoch', required=False, type=int, default=27)
        parser.add_argument('--dropout_rate', required=False, type=float, default=0.283)
        parser.add_argument('--optimizer', required=False, type=str, default="adam")
        # argparse fairing 적용여부 분기 (조건절))
        if os.getenv('FAIRING_RUNTIME', None) is None:        
            args = parser.parse_args(args=[])
        else:            
            args = parser.parse_args()
        
        mnist = tf.keras.datasets.fashion_mnist
        (x_train, y_train), (x_test, y_test) = mnist.load_data()

        print("x_train shape:", x_train.shape, "y_train shape:", y_train.shape)
        print("x_test shape:", x_test.shape, "y_test shape:", y_test.shape)

        x_train, x_test = x_train / 255.0, x_test / 255.0

        model = tf.keras.models.Sequential([
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(args.node_amount, activation='relu'),
            tf.keras.layers.Dropout(args.dropout_rate),
            tf.keras.layers.Dense(10, activation='softmax')
        ])

        model.compile(optimizer=args.optimizer,
                        loss='sparse_categorical_crossentropy',
                        metrics=['acc'])


        date_folder = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 
        if os.getenv('FAIRING_RUNTIME', None) is None:
            log_dir = "log/fit/" + date_folder
            args = parser.parse_args(args=[])
        else:
            args = parser.parse_args()
            log_dir = "/notebook/log/fit/" + date_folder 

        print(f"tensorboard log dir : {log_dir}")

        tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=log_dir,
                                                        histogram_freq=1)
        model.fit(x_train, y_train,
                    verbose=0,
                    validation_data=(x_test, y_test),
                    epochs=args.epoch,
                    callbacks=[LoggingTrain(),
                              tensorboard_cb])
        return model
    
def p(msg):
    dt_now = datetime.datetime.now()
    strftime = dt_now.strftime('%Y-%m-%dT%H:%M:%SZ')
    print(f"{strftime} {msg}", flush=True)    
        
class LoggingTrain(Callback):
    """logging for train
    """
    def on_batch_end(self, batch, logs={}):
        if batch % 100 == 0:
            p(f"batch: {batch}")
            p(f"accuracy={logs.get('acc')} loss={logs.get('loss')}")
            
    def on_epoch_begin(self, epoch, logs={}):
        p(f"epoch: {epoch}")

    def on_epoch_end(self, epoch, logs={}):
        p(f"Validation-accuracy={logs.get('val_acc')}")
        p(f"Validation-loss={logs.get('val_loss')}")
        return   

In [None]:
my_model = MyModel()
model = my_model.train() 

In [None]:
# bentoml 설치 확인
!pip freeze | grep BentoML

In [None]:
# bentoml 설치
!pip install bentoml

In [None]:
%%writefile bento_fashion_mnist.py

# API 서버 코드 만들기
from typing import List
import numpy as np
from PIL import Image
from bentoml import api, artifacts, env, BentoService
from bentoml.frameworks.keras import KerasModelArtifact
from bentoml.adapters import ImageInput

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
                'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
@env(docker_base_image="dudaji/cap-jupyterlab:tf2.0-cpu")
@artifacts([KerasModelArtifact('classifier')])
class KerasFashionMnistService(BentoService):
    @api(input=ImageInput(pilmode='L'), batch=True)
    def predict(self, imgs: List[np.ndarray]) -> List[str]:
        inputs = []
        for img in imgs:
            img = Image.fromarray(img).resize((28,28))
            img = np.array(img.getdata()).reshape((28,28))
            inputs.append(img)
        inputs = np.stack(inputs)
        class_idxs = self.artifacts.classifier.predict_classes(inputs)
        return [class_names[class_idx] for class_idx in class_idxs]

In [None]:
# 모델 API 서버 패키징하기
from bento_fashion_mnist import KerasFashionMnistService

fashion_mnist_svc = KerasFashionMnistService()
fashion_mnist_svc.pack('fashion_mnist', model)

saved_path = fashion_mnist_svc.save(labels={"Validation-accucray":"89.12"}) #89.12 는 accuracy 값입니다.
print(saved_path)