# Baseten <> Tensorflow example deployment

<a href="https://colab.research.google.com/github/basetenlabs/demos/blob/main/deployment/baseten_tensorflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install tensorflow baseten

In [None]:
# Model training

import tensorflow as tf

#Creates tensorflow model
def train_model():
    return tf.keras.applications.ResNet50V2(
        include_top=True,
        weights="imagenet",
        classifier_activation="softmax",
    )

In [None]:
# Model creation

my_model = train_model()

In [None]:
# Model deployment

import baseten

api_key = "PASTE API KEY HERE"
baseten.login(api_key)

baseten.deploy(
    my_model,
    model_name="My ResNet TensorFlow Model"
)

In [None]:
#Preprocess and Postprocess Functions

import requests
import tempfile
import numpy as np

from scipy.special import softmax

def preprocess(url):
    """Preprocess step for ResNet"""
    request = requests.get(url)
    with tempfile.NamedTemporaryFile() as f:
        f.write(request.content)
        f.seek(0)
        input_image = tf.image.decode_png(tf.io.read_file(f.name))
    preprocessed_image = tf.keras.applications.resnet_v2.preprocess_input(
        tf.image.resize([input_image], (224, 224))
    )
    return np.array(preprocessed_image)

def postprocess(predictions, k=5):
    """Post process step for ResNet"""
    class_predictions = predictions[0]
    LABELS = requests.get(
        'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
    ).text.split('\n')
    class_probabilities = softmax(class_predictions)
    top_probability_indices = class_probabilities.argsort()[::-1][:k].tolist()
    return {LABELS[index]: 100 * class_probabilities[index].round(3) for index in top_probability_indices}

In [None]:
# After the deployment is finished, call your new model!

deployed_model_id = "PASTE VERSION ID HERE" # See deployed model page to find version ID
model_input = preprocess("https://github.com/pytorch/hub/raw/master/images/dog.jpg")

deployed_model = baseten.deployed_model_version_id(deployed_model_id)
postprocess(deployed_model.predict(model_input))