-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
58 lines (45 loc) · 1.65 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from CNNClassifier.utils.common import decodeImage
from CNNClassifier.pipeline.prediction import PredictionPipeline
import os
from pathlib import Path
os.putenv('LANG', 'en_US.UTF-8')
os.putenv('LC_ALL', 'en_US.UTF-8')
app = FastAPI()
templates = Jinja2Templates(directory="templates")
# CORS middleware to handle cross-origin requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # you might want to limit this to specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ClientApp:
def __init__(self):
self.filename = "inputImage.jpg"
self.classifier = PredictionPipeline(self.filename)
clApp = ClientApp()
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/train")
async def train_route():
# os.system("python main.py")
os.system("dvc repro")
return {"message": "Training done successfully!"}
@app.post("/predict")
async def predict_route(request: Request):
data = await request.json()
image = data.get('image', None)
if image is None:
raise HTTPException(status_code=400, detail="Image not provided")
decodeImage(image, clApp.filename)
result = clApp.classifier.predict()
return JSONResponse(content=result)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app="app:app", host="localhost", port=8080, reload=True)