# Solution: Serving an ML Model with FastAPI

This notebook contains the solution for the Iris API TP.

In [None]:
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
import threading

: 

## 1. Prepare the Model

In [None]:
# Load data and train a simple model
iris = load_iris()
class_names = iris.target_names  # ['setosa', 'versicolor', 'virginica']

print(f"Training model on {len(iris.data)} samples...")
model = RandomForestClassifier(n_estimators=10, random_state=42)
model.fit(iris.data, iris.target)
print("âœ… Model ready.")

## 2. Define Input Data Structure (Pydantic)

In [None]:
class IrisInput(BaseModel):
    sepal_length: float
    sepal_width: float
    petal_length: float
    petal_width: float

## 3. Create FastAPI App and Predict Route

In [None]:
app = FastAPI(title="Iris Model API")

@app.post("/predict")
def predict_species(iris_data: IrisInput):
    # 1. Convert Pydantic object to list of features
    input_features = [
        iris_data.sepal_length, 
        iris_data.sepal_width, 
        iris_data.petal_length, 
        iris_data.petal_width
    ]
    
    # 2. Reshape for scikit-learn (it expects a list of lists)
    features_array = [input_features]
    
    # 3. Predict
    prediction_idx = model.predict(features_array)[0]
    
    # 4. Post-processing (Label lookup)
    species_name = class_names[prediction_idx]
    
    # 5. Return result
    return {
        "prediction_class": int(prediction_idx),
        "species": species_name
    }

## 4. Run the Server

In [None]:
def run_server():
    uvicorn.run(app, host="127.0.0.1", port=8001)

# Run in a separate thread to not block the notebook
threading.Thread(target=run_server, daemon=True).start()
print("ðŸš€ API server running at http://127.0.0.1:8001")

## 5. Test the API

In [None]:
import time
import requests

# Wait a bit for server to start
time.sleep(2)

# Test data (should be 'setosa')
payload = {
    "sepal_length": 5.1,
    "sepal_width": 3.5,
    "petal_length": 1.4,
    "petal_width": 0.2
}

try:
    response = requests.post("http://127.0.0.1:8001/predict", json=payload)
    print("Status Code:", response.status_code)
    print("Response:", response.json())
except Exception as e:
    print("Error connecting to API:", e)