## Train the Model

In [None]:
import numpy as np
import tensorflow.keras as keras

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

model = keras.models.Sequential()
model.add(keras.Input(shape=(28, 28, 1)))
model.add(keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"))
model.add(keras.layers.MaxPooling2D(pool_size=(2, 2)))
model.add(keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"))
model.add(keras.layers.MaxPooling2D(pool_size=(2, 2)))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dropout(0.5))
model.add(keras.layers.Dense(10, activation="softmax"))

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(x_train, y_train, batch_size=128, epochs=10, validation_split=0.1)
model.save('mnist', save_format='tf')

## Deploy to Vertex AI

In [None]:
pass

## Inference

In [None]:
import json

import numpy as np
from PIL import Image

YOUR_PIC = '2.jpg'

def preprocess_image(f, res=200, r=False):
    img = Image.open(f)
    img = img.convert("L")
    img = img.resize((res, res))
    data = np.array(img.getdata())
    img = data.reshape(1, *img.size, 1)
    img = img.astype('float32')
    if r:
        img = 255 - img
    img /= 255.
            
    return img

def dump_image(img):
    d = {'instances': img.tolist()}
    print(json.dumps(d))

img = preprocess_image(YOUR_PIC, res=28)
dump_image(img)

In [None]:
import os

from PIL import Image
from google.cloud import aiplatform

YOUR_PROJECT = 'YOUR_PROJECT'
YOUR_LOCATION = 'YOUR_LOCATION'
YOUR_ENDPOINT = 'YOUR_ENDPOINT'
YOUR_SERVICE = 'YOUR_SERVICE'

os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = YOUR_SERVICE
aiplatform.init(project=YOUR_PROJECT, location=YOUR_LOCATION)
endpoint = aiplatform.Endpoint(YOUR_ENDPOINT)
img = preprocess_image(YOUR_PIC, res=28)  # img shape (1, 28, 28, 1)
prediction = endpoint.predict(instances=img.tolist())  # instances < 1.5M
print(np.argmax(prediction.predictions))