In [None]:
pip install python-multipart


In [None]:


from fastapi import FastAPI, HTTPException, Query, File, UploadFile
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import pandas as pd
import numpy as np
import io
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import StandardScaler
import joblib
import os
import logging
from typing import List, Optional, Dict, Any

# --- Config ---
DATA_PATH = '/mnt/data/cleaned_solar_data.csv'
MODEL_DIR = '/tmp/solar_models'
os.makedirs(MODEL_DIR, exist_ok=True)
MODEL_PATH = os.path.join(MODEL_DIR, 'solar_model.joblib')

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('solar_backend')

app = FastAPI(title='Solar Data Backend', version='1.1')
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# in-memory dataset
_df: Optional[pd.DataFrame] = None
_date_cols: List[str] = []

# --- Utility functions ---

def detect_and_parse_dates(df: pd.DataFrame) -> pd.DataFrame:
    """Try to detect date-like columns and parse them to datetime."""
    global _date_cols
    candidates = []
    for col in df.columns:
        if np.issubdtype(df[col].dtype, np.datetime64):
            candidates.append(col)
            continue
        # Heuristic: column name contains date/time keywords or dtype is object
        if col.lower().count('date') or col.lower().count('time') or df[col].dtype == object:
            try:
                parsed = pd.to_datetime(df[col], errors='coerce', infer_datetime_format=True)
                # if many non-nulls after parsing, accept
                non_null_ratio = parsed.notna().mean()
                if non_null_ratio > 0.6:
                    df[col] = parsed
                    candidates.append(col)
            except Exception:
                pass
    _date_cols = candidates
    return df


def ensure_df():
    if _df is None:
        raise HTTPException(status_code=500, detail=f'Data not loaded. Expected CSV at {DATA_PATH}')

from contextlib import asynccontextmanager

@asynccontextmanager
async def lifespan(app: FastAPI):
    load_data()   # call your existing loader
    yield

app = FastAPI(title='Solar Data Backend', version='1.1', lifespan=lifespan)

def load_data():
    global _df
    if os.path.exists(DATA_PATH):
        try:
            # read without forcing parse_dates; we'll detect
            df = pd.read_csv(DATA_PATH)
            df = detect_and_parse_dates(df)
            _df = df
            logger.info(f'Loaded dataframe with {len(_df)} rows and {len(_df.columns)} columns')
        except Exception as e:
            _df = None
            logger.exception('Failed to load CSV')
    else:
        _df = None
        logger.warning('Data file not found at %s', DATA_PATH)

class TrainRequest(BaseModel):
    feature_columns: List[str]
    target_column: str
    test_size: float = 0.2
    random_state: int = 42
    model_type: Optional[str] = 'linear'
    scale: Optional[bool] = False

    model_config = {
        "protected_namespaces": ()
    }


class PredictRequest(BaseModel):
    features: Dict[str, float]

# --- Endpoints ---

@app.get('/health')
def health():
    return {'status': 'ok', 'data_loaded': _df is not None}

@app.get('/features')
def features():
    """Return column names and dtypes."""
    ensure_df()
    cols = [{ 'name': c, 'dtype': str(_df[c].dtype) } for c in _df.columns]
    return {'n_rows': len(_df), 'columns': cols, 'date_columns': _date_cols}

@app.get('/info')
def info(sample: int = 3):
    ensure_df()
    sample_df = _df.head(sample).to_dict(orient='records')
    return {'n_rows': len(_df), 'n_cols': len(_df.columns), 'sample': sample_df}

@app.get('/data')
def get_data(columns: Optional[str] = Query(None, description='Comma separated column names'),
             offset: int = 0,
             limit: int = 100):
    ensure_df()
    df = _df
    if columns:
        cols = [c.strip() for c in columns.split(',') if c.strip()]
        missing = [c for c in cols if c not in df.columns]
        if missing:
            raise HTTPException(status_code=400, detail=f'Missing columns: {missing}')
        df = df[cols]
    total = len(df)
    rows = df.iloc[offset:offset+limit].to_dict(orient='records')
    return {'total': total, 'offset': offset, 'limit': limit, 'rows': rows}

@app.get('/summary')
def summary(columns: Optional[str] = Query(None, description='Comma separated column names')):
    ensure_df()
    df = _df
    if columns:
        cols = [c.strip() for c in columns.split(',') if c.strip()]
        missing = [c for c in cols if c not in df.columns]
        if missing:
            raise HTTPException(status_code=400, detail=f'Missing columns: {missing}')
        df = df[cols]
    numeric = df.select_dtypes(include=[np.number])
    if numeric.empty:
        raise HTTPException(status_code=400, detail='No numeric columns to summarize')
    desc = numeric.describe().to_dict()
    return desc

@app.get('/timeseries')
def timeseries(date_col: str, value_col: str, freq: str = 'D', agg: str = 'mean',
               start: Optional[str] = None, end: Optional[str] = None):
    ensure_df()
    df = _df.copy()
    if date_col not in df.columns:
        raise HTTPException(status_code=400, detail=f'date_col {date_col} not in columns')
    if value_col not in df.columns:
        raise HTTPException(status_code=400, detail=f'value_col {value_col} not in columns')

    if not np.issubdtype(df[date_col].dtype, np.datetime64):
        df[date_col] = pd.to_datetime(df[date_col], errors='coerce', infer_datetime_format=True)
    df = df.dropna(subset=[date_col])
    if start:
        df = df[df[date_col] >= pd.to_datetime(start)]
    if end:
        df = df[df[date_col] <= pd.to_datetime(end)]
    df = df.set_index(date_col)

    allowed_aggs = {'mean': 'mean', 'sum': 'sum', 'median': 'median'}
    if agg not in allowed_aggs:
        raise HTTPException(status_code=400, detail=f'Unsupported agg {agg}. Use one of {list(allowed_aggs.keys())}')
    try:
        res = getattr(df[value_col].resample(freq), allowed_aggs[agg])()
    except Exception as e:
        raise HTTPException(status_code=400, detail=f'Resampling failed: {e}')
    res = res.dropna()
    out = [{'timestamp': idx.isoformat(), 'value': float(val)} for idx, val in res.items()]
    return {'count': len(out), 'freq': freq, 'agg': agg, 'series': out}

@app.get('/plot')
def plot_col(col: str, kind: str = 'line', max_points: int = 1000):
    ensure_df()
    df = _df
    if col not in df.columns:
        raise HTTPException(status_code=400, detail=f'Column {col} not found')

    # find x axis: prefer first detected date column
    x = None
    if _date_cols:
        x = df[_date_cols[0]]
    elif 'date' in df.columns:
        x = pd.to_datetime(df['date'], errors='coerce')
    elif np.issubdtype(df.index.dtype, np.datetime64):
        x = df.index

    y = df[col]
    plot_df = pd.DataFrame({'y': y})
    if x is not None:
        plot_df['x'] = pd.to_datetime(x, errors='coerce')
        plot_df = plot_df.dropna(subset=['y'])
        # keep rows with x if x exists
        plot_df = plot_df.sort_values('x')
    else:
        plot_df = plot_df.dropna(subset=['y']).reset_index(drop=True)

    if len(plot_df) > max_points:
        # sample evenly to preserve trend
        idx = np.linspace(0, len(plot_df) - 1, max_points).astype(int)
        plot_df = plot_df.iloc[idx]

    plt.figure(figsize=(10,4))
    if kind == 'line':
        if 'x' in plot_df.columns:
            plt.plot(plot_df['x'], plot_df['y'])
        else:
            plt.plot(plot_df['y'].values)
    elif kind == 'hist':
        plt.hist(plot_df['y'].dropna().values, bins=50)
    elif kind == 'box':
        plt.boxplot(plot_df['y'].dropna().values)
    else:
        raise HTTPException(status_code=400, detail='Unsupported kind')
    plt.title(f'{kind} of {col}')
    plt.tight_layout()

    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    buf.seek(0)
    return StreamingResponse(buf, media_type='image/png')

@app.post('/train')
def train_model(req: TrainRequest):
    ensure_df()
    df = _df.copy()
    # validate columns
    missing = [c for c in req.feature_columns + [req.target_column] if c not in df.columns]
    if missing:
        raise HTTPException(status_code=400, detail=f'Missing columns: {missing}')

    X = df[req.feature_columns].select_dtypes(include=[np.number]).copy()
    # drop rows with NA in features/target
    data = X.join(df[[req.target_column]])
    data = data.dropna()
    if data.empty:
        raise HTTPException(status_code=400, detail='No numeric training data after dropping NA')
    X = data[req.feature_columns]
    y = data[req.target_column]

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=req.test_size, random_state=req.random_state)

    scaler = None
    if req.scale:
        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)

    if req.model_type == 'linear':
        model = LinearRegression()
    elif req.model_type == 'random_forest':
        model = RandomForestRegressor(n_estimators=100, random_state=req.random_state)
    else:
        raise HTTPException(status_code=400, detail='Unsupported model_type; use "linear" or "random_forest"')

    model.fit(X_train, y_train)
    preds = model.predict(X_test)
    mse = float(mean_squared_error(y_test, preds))

    saved = {'model': model, 'features': req.feature_columns, 'model_type': req.model_type, 'scale': req.scale}
    if scaler is not None:
        saved['scaler'] = scaler
    joblib.dump(saved, MODEL_PATH)
    logger.info('Saved model to %s', MODEL_PATH)

    return {'status': 'trained', 'mse': mse, 'n_train': len(X_train), 'n_test': len(X_test), 'model_path': MODEL_PATH}

@app.post('/predict')
def predict(req: PredictRequest):
    if not os.path.exists(MODEL_PATH):
        raise HTTPException(status_code=400, detail='Model not trained. Call /train first.')
    saved = joblib.load(MODEL_PATH)
    model = saved.get('model')
    features = saved.get('features')
    scaler = saved.get('scaler') if saved.get('scale') else None

    # Ensure all features are present
    missing = [f for f in features if f not in req.features]
    if missing:
        raise HTTPException(status_code=400, detail=f'Missing features in request: {missing}')

    # preserve order
    x_arr = np.array([float(req.features[f]) for f in features]).reshape(1, -1)
    if scaler is not None:
        x_arr = scaler.transform(x_arr)
    pred = model.predict(x_arr)
    return {'prediction': float(pred[0]), 'features': features}

@app.post('/upload_csv')
async def upload_csv(file: UploadFile = File(...)):
    if not file.filename.lower().endswith('.csv'):
        raise HTTPException(status_code=400, detail='Only CSV supported')
    content = await file.read()
    with open(DATA_PATH, 'wb') as f:
        f.write(content)
    try:
        load_data()
    except Exception as e:
        raise HTTPException(status_code=500, detail=f'Failed to reload data: {e}')
    return {'status': 'uploaded', 'rows': len(_df) if _df is not None else 0}

if __name__ == '__main__':
    import uvicorn
    uvicorn.run('fastapi_solar_backend:app', host='0.0.0.0', port=8000, reload=True)
