Overview

* Train a TensorFlow Model
* Save the Model
* Create a Flask Application
* Integrate the Model with Flask
* Run the Flask Application



**1. Train a TensorFlow Model**

Let's create a simple TensorFlow model for image classification using the MNIST dataset.

Python Code: **train_model.py**

In [None]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.utils import to_categorical

# Load dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Preprocess data
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# Build the model
model = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(128, activation='relu'),
    Dense(10, activation='softmax')
])

# Compile the model
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Train the model
model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))

# Save the model
model.save('mnist_model.h5')

print("Model training complete and saved as 'mnist_model.h5'")

**2. Create a Flask Application**

Python Code: **app.py**

In [None]:
from flask import Flask, request, jsonify
import tensorflow as tf
import numpy as np

app = Flask(__name__)

# Load the model
model = tf.keras.models.load_model('mnist_model.h5')

@app.route('/predict', methods=['POST'])
def predict():
    try:
        # Get the data from the POST request
        data = request.get_json()
        # Convert the data to a numpy array
        image = np.array(data['image'])
        # Normalize the image
        image = image / 255.0
        # Expand dimensions to match model input shape
        image = np.expand_dims(image, axis=0)
        # Make a prediction
        predictions = model.predict(image)
        # Get the class with the highest probability
        predicted_class = np.argmax(predictions[0])
        return jsonify({'predicted_class': int(predicted_class)})
    except Exception as e:
        return jsonify({'error': str(e)})

if __name__ == '__main__':
    app.run(debug=True)

**4. Integrate the Model with Flask**

Here’s a step-by-step breakdown of what the Flask application does:

- Load the Model: The model is loaded using TensorFlow’s tf.keras.models.load_model() function.
- Predict Endpoint: The /predict endpoint expects a POST request containing image data in JSON format. The image data is normalized and reshaped to match the input shape of the model.
- Make Prediction: The model predicts the class of the image, and the result is returned as a JSON response.

In [None]:
# train the model
! python app.py
# Testing the Flask API
! curl -X POST http://127.0.0.1:5000/predict -H "Content-Type: application/json" -d '{"image": [[0, 0, 0, ..., 0], [0, 0, 0, ..., 0], ..., [0, 0, 0, ..., 0]]}'

**5. Use requests to test the flask API**

You can create a script to send a POST request to your Flask API and handle the response.

Python Code: **test_flask_api.py**

In [None]:
import requests
import numpy as np

# Example image data: a 28x28 array of zeros (you can replace this with actual image data)
example_image = np.zeros((28, 28)).tolist()

# Define the API endpoint
url = 'http://127.0.0.1:5000/predict'

# Create the payload
payload = {
    'image': example_image
}

# Send the POST request
response = requests.post(url, json=payload)

# Check the response
if response.status_code == 200:
    result = response.json()
    print('Response:', result)
else:
    print('Failed to get a response. Status code:', response.status_code)
    print('Response:', response.text)

In [None]:
! python test_flask_api.py

**6. Use Streamlit as front-end**

In [None]:
import streamlit as st
import requests
import numpy as np
from PIL import Image
import io

# Define the API endpoint
API_URL = 'http://127.0.0.1:5000/predict'

# Streamlit app
st.title("MNIST Digit Classifier")

# Upload image
uploaded_file = st.file_uploader("Choose an image...", type="png")

if uploaded_file is not None:
    # Load and preprocess image
    image = Image.open(uploaded_file).convert('L')  # Convert to grayscale
    image = image.resize((28, 28))  # Resize to 28x28
    image = np.array(image)  # Convert to numpy array

    # Display the image
    st.image(image, caption='Uploaded Image.', use_column_width=True)

    # Prepare payload
    payload = {
        'image': image.tolist()  # Convert numpy array to list
    }

    # Send POST request to Flask API
    response = requests.post(API_URL, json=payload)

    if response.status_code == 200:
        result = response.json()
        st.write(f'Predicted Class: {result["predicted_class"]}')
    else:
        st.error(f"Failed to get prediction. Status code: {response.status_code}")
        st.error(response.text)


In [None]:
! streamlit run streamlit_app.py