# Serving an ML Model with FastAPI

This practical session focuses on creating a REST API to expose the Iris classification model.

**Objectives**:
- Define a strict data schema using Pydantic.
- Implement a prediction endpoint.
- Understand synchronous vs asynchronous execution for ML models.

In [None]:
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
import numpy as np
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import threading

## 1. Prepare the Model
For this exercise, we will train a simple model right here. In production, you would load a saved model (e.g. from a .pkl file).

In [None]:
# Load data and train a simple model
iris = load_iris()
class_names = iris.target_names  # ['setosa', 'versicolor', 'virginica']

model = RandomForestClassifier(n_estimators=10, random_state=42)
model.fit(iris.data, iris.target)

## 2. Input Data Structure (Pydantic)
We define the input data schema to validate user requests (4 numerical values required).

In [None]:
# TODO: Create a Pydantic model named 'IrisInput' with the 4 required fields:
# sepal_length, sepal_width, petal_length, petal_width (all floats)

class IrisInput(BaseModel):
    # TODO: Define fields here
    pass

## 3. Create FastAPI App and Predict Route

**Synchronous execution (`def`) vs Asynchronous (`async def`)**
- Scikit-learn prediction is CPU-bound (blocking).
- Using `async def` would block the event loop during prediction.
- Using standard `def` allows FastAPI to run the prediction in a separate thread.

In [None]:
app = FastAPI(title="Iris Model API")

# TODO: Create a POST endpoint at "/predict"
# TODO: The function should take 'iris_data: IrisInput' as argument
# TODO: Inside the function:
#   1. Extract data from iris_data into a list or numpy array
#   2. Use model.predict() to get the class index (0, 1, or 2)
#   3. Convert the index to the string label (using class_names)
#   4. Return a dictionary with the prediction and the species name

@app.post("/predict")
def predict_species(iris_data: IrisInput):
    # TODO: Implement prediction logic
    return {"message": "TODO implemented"}

## 4. Run the Server

In [None]:
def run_server():
    uvicorn.run(app, host="127.0.0.1", port=8001)

# Run in a separate thread
threading.Thread(target=run_server, daemon=True).start()
print("Server started at http://127.0.0.1:8001")

## 5. Test the API
We can send a request directly from Python to test our running server.

In [None]:
import time
import requests

# Wait a bit for server to start
time.sleep(2)

# Test data (should be 'setosa')
payload = {
    "sepal_length": 5.1,
    "sepal_width": 3.5,
    "petal_length": 1.4,
    "petal_width": 0.2
}

try:
    response = requests.post("http://127.0.0.1:8001/predict", json=payload)
    print("Status Code:", response.status_code)
    print("Response:", response.json())
except Exception as e:
    print("Error connecting to API:", e)