In [3]:
from fastapi import FastAPI, HTTPException

In [6]:
from pydantic import BaseModel, Field, field_validator

In [8]:
from typing import Optional
import mlflow.pyfunc
import pandas as pd

In [9]:
class DonationPredictionRequest(BaseModel):
    lag1: int = Field(..., ge=0, lt=7500, description="No. of donations 1 day ago.")
    lag2: int = Field(..., ge=0, lt=7500, description="No. of donations 2 day ago.")
    lag3: int = Field(..., ge=0, lt=7500, description="No. of donations 3 day ago.")
    lag4: int = Field(..., ge=0, lt=7500, description="No. of donations 4 day ago.")
    lag5: int = Field(..., ge=0, lt=7500, description="No. of donations 5 day ago.")
    lag6: int = Field(..., ge=0, lt=7500, description="No. of donations 6 day ago.")
    lag7: int = Field(..., ge=0, lt=7500, description="No. of donations 7 day ago.")
    high_donation_holiday: Optional[int] = Field(0, description="1 if next day is Hari Malaysia, Hari Pekerja, or Hari Wesak.")
    low_donation_holiday: Optional[int] = Field(0, description="1 if next day is Hari Raya Puasa or Hari Raya Qurban.")
    religion_or_culture_holiday: Optional[int] = Field(0, description="1 if next day is a religious or cultural holiday.")
    other_holiday: Optional[int] = Field(0, description="1 if next day is any other national public holiday.")

In [None]:
from abc import ABC
import numpy as np

In [None]:
def predict(request: DonationPredictionRequest):
    
    
    #TODO: Load scalers from pickle files
    
    pred_dict = request.dict()
    
    lag1, lag2, lag3, lag4, lag5, lag6, lag7, \
        high_donation_holiday, low_donation_holiday, religion_or_culture_holiday, other_holiday = pred_dict.values()
        
    scaler_y = ABC()
    scaler_x = ABC()
    
    x_seq = scaler_y.transform(np.ndarray(lag1, lag2, lag3, lag4, lag5, lag6, lag7))
    x_feat = scaler_x.transform(np.ndarray(high_donation_holiday, low_donation_holiday, religion_or_culture_holiday, other_holiday))
    
    #TODO: Load model 
    model = ABC()

    return scaler_y.inverse_transform(model.predict([x_seq, x_feat]).reshape(-1, 1)) 