In [None]:
import os
import torch
import pandas as pd
from PIL import Image
from torchvision import transforms
from fastapi import FastAPI, UploadFile
from fastapi.responses import JSONResponse

In [54]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_FILE = os.path.join("..", "models", "emotion_cnn_best.pt")
DATA_PROCESSED_PATH = os.path.join("..", "data", "processed", "full_processed_dataset.csv")
IMAGE_SIZE = 48

In [58]:
df = pd.read_csv(DATA_PROCESSED_PATH)
label_mapping = dict(zip(df["label_encoded"], df["label"]))
num_classes = df["label_encoded"].nunique()

In [60]:
model = torch.load(MODEL_FILE, map_location=DEVICE)

In [61]:
preprocess = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor()
])

In [62]:
def predict_emotion(image_path):
    image = Image.open(image_path).convert("L")
    x = preprocess(image).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        outputs = model(x)
        _, pred = torch.max(outputs, 1)
    return label_mapping[pred.item()]

In [63]:
def predict_emotion_bytes(file_bytes):
    image = Image.open(file_bytes).convert("L")
    x = preprocess(image).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        outputs = model(x)
        _, pred = torch.max(outputs, 1)
    return label_mapping[pred.item()]

app = FastAPI()

In [64]:
@app.post("/predict_file")
async def predict_file(file: UploadFile):
    result = predict_emotion_bytes(file.file)
    return JSONResponse({"prediction": result})


In [65]:
@app.get("/predict_path")
def predict_path(path: str):
    if os.path.exists(path):
        result = predict_emotion(path)
        return JSONResponse({"prediction": result})
    return JSONResponse({"error": "File not found"}, status_code=404)

In [None]:
import nest_asyncio
import uvicorn

nest_asyncio.apply()
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)


INFO:     Will watch for changes in these directories: ['c:\\Users\\jobet\\emotion_detection\\notebooks']
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
INFO:     Started reloader process [15048] using StatReload
