In [None]:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import pandas as pd
import numpy as np
from tensorflow.keras.models import load_model
from sklearn.preprocessing import MinMaxScaler
import pickle
import os
import nest_asyncio
nest_asyncio.apply()

app = FastAPI()

# Variables globales 
WORKSPACE = os.path.abspath(os.path.join(os.getcwd(), '../../'))
DATA_FOLDER = os.path.join(WORKSPACE, 'data')
model = None
station_dict = None
data = None
scaler = None

@app.on_event("startup")
async def load_model_and_data():
    """Carga el modelo y datos necesarios al iniciar la aplicación"""
    global model, station_dict, data, scaler
    
    # Cargar el modelo
    model = load_model(f'{DATA_FOLDER}/best_model_dnn.keras', compile=False)
    
    # Cargar datos
    data = pd.read_csv(f'{DATA_FOLDER}/data_modelo.csv')
    
    # Cargar diccionario de estaciones
    with open(f'{DATA_FOLDER}/station_dict.pkl', 'rb') as f:
        station_dict = pickle.load(f)
    
    # Inicializar scaler
    scaler = MinMaxScaler()

def obtener_predicciones_por_estacion(station):
    """Obtiene predicciones para una estación específica"""
    global model, station_dict, data, scaler
    
    # Filtrar datos por estación
    data_station = data[data['CodigoEstacion'] == station].copy()
    data_station.set_index('Fecha', inplace=True)
    
    # Obtener codificación one-hot para la estación
    onehot_encoded_station = station_dict[station]
    
    # Tomar los últimos 1000 registros
    data_station = data_station[-1000:]
    
    # Preparar datos
    y = data_station['PTPM_CON_shifted']
    X_features = data_station.drop(columns=['CodigoEstacion', 'PTPM_CON_shifted'])
    
    # Combinar características
    X = np.hstack([np.tile(onehot_encoded_station, (len(X_features), 1)), X_features.values])
    
    # Escalar features
    X = MinMaxScaler().fit_transform(X)
    
    # Hacer predicciones
    predictions = model.predict(X)
    
    # Crear DataFrame de resultados
    results = pd.DataFrame({
        'Fecha': data_station.index,
        'Prediccion': predictions.ravel()
    })
    
    # Escalar predicciones
    scaler.fit(y.to_frame())
    results['Prediccion'] = scaler.inverse_transform(results[['Prediccion']])
    
    # Agregar valores reales
    results['Real'] = y.values
    results.set_index('Fecha', inplace=True)
    
    return results

@app.post("/prediccion/{estacion_id}")
def prediccion(estacion_id: int):
    """Endpoint para obtener predicciones por estación"""
    try:
        results = obtener_predicciones_por_estacion(estacion_id)
        return results.to_dict(orient='index')
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Configuración de CORS
from fastapi.middleware.cors import CORSMiddleware

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Endpoint de prueba
@app.get("/")
def read_root():
    return {"message": "API de predicciones activa"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=1304) 

2024-12-19 19:55:50.005119: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1734638150.018322    1637 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1734638150.022252    1637 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-19 19:55:50.037136: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
        on_event is deprecated, use lifespan event handlers instead.

        Read more about it in the
        [FastAPI docs