In [3]:
#pharmacovigilance signal detection, openfda api, pytorch, prophet, time series 
#jack baxter, 10/2025

In [2]:
#load packs
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
import pandas as pd
import requests
import plotly.graph_objects as go
from prophet import Prophet
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import QuantileLoss
import torch

  from .autonotebook import tqdm as notebook_tqdm
Matplotlib is building the font cache; this may take a moment.


In [4]:
#create app 
app = FastAPI(title="Pharmacovigilance Signal Detection API")

In [5]:
drugchoice = input('Please enter the name of the drug to investigate: ')

In [6]:
def get_fda_data(drugchoice):
    url = 'https://api.fda.gov/drug/event.json',
    search = f'patient.drug.medicinalproduct:"{drugchoice}"+AND+receivedate:[20140101+TO+20251231]',
    response = requests.get(url, params={'search': search, 'count': 'receivedate'})
    data = response.json().get('results', [])
    df = pd.DataFrame(data)
    df['ds'] = pd.to_datetime(df['term'])
    df['y'] = df['count']
    df = df[['ds', 'y']].sort_values('ds')
    df = df.resample('Q', on='ds').sum().reset_index()
    return df

In [7]:
@app.get("/", response_class=HTMLResponse)
async def home():
    return """
    <h1>AEGIS Time-Series ML</h1>
    <p>Live OpenFDA → Prophet + Temporal Fusion Transformer</p>
    <ul>
        <li><a href="/ml/semaglutide">Semaglutide (Ozempic)</a></li>
        <li><a href="/ml/trastuzumab">Trastuzumab (Herceptin)</a></li>
        <li><a href="/ml/rizatriptan">Rizatriptan (Maxalt)</a></li>
    </ul>
    """

In [8]:
@app.get("/ml/{drug}")
async def predict_drug(drug: str):
    df = get_fda_data(drug)
    if len(df) < 8:
        raise HTTPException(404, f"Not enough data for {drug}")

    # 1. Prophet forecast
    m = Prophet(yearly_seasonality=True, weekly_seasonality=False, daily_seasonality=False)
    m.add_country_holidays('US')
    m.fit(df)
    future = m.make_future_dataframe(periods=8, freq='Q')
    forecast = m.predict(future)

    # 2. Temporal Fusion Transformer (light version)
    df_tft = df.copy()
    df_tft["time_idx"] = range(len(df_tft))
    df_tft["group"] = drug

    training = TimeSeriesDataSet(
        df_tft,
        time_idx="time_idx",
        target="y",
        group_ids=["group"],
        max_encoder_length=8,
        max_prediction_length=4,
        static_categoricals=["group"],
        target_normalizer=GroupNormalizer(groups=["group"]),
    )

    train_dataloader = training.to_dataloader(train=True, batch_size=16, num_workers=0)
    tft = TemporalFusionTransformer.from_dataset(training, learning_rate=0.03, hidden_size=16, dropout=0.1)
    trainer = pl.Trainer(max_epochs=20, gpus=0, gradient_clip_val=0.1, logger=False, enable_progress_bar=False)
    trainer.fit(tft, train_dataloder=train_dataloader)

    pred = tft.predict(training, mode="prediction")
    tft_forecast = float(pred[0][-1])

    # Plot
    fig = go.Figure()
    fig.add_trace(go.Bar(x=df['ds'], y=df['y'], name="Actual Reports"))
    fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], name="Prophet Forecast", line=dict(dash='dash')))
    fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_upper'], name="Prophet 95% CI", fill=None, mode='lines', line_color='rgba(0,0,0,0)'))
    fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_lower'], name="95% CI", fill='tonexty', mode='lines', fillcolor='rgba(100,100,255,0.2)'))
    fig.add_trace(go.Scatter(x=[df['ds'].iloc[-1]], y=[tft_forecast], mode='markers', name="TFT Next Quarter", marker=dict(size=15, color='red')))

    fig.update_layout(title=f"{drug.title()} – Live FAERS Time-Series Forecast (Prophet + TFT)", height=600)

    return {
        "plot_html": fig.to_html(include_plotlyjs="cdn"),
        "next_quarter_tft": round(tft_forecast, 0),
        "prophet_next_quarter": round(forecast['yhat'].iloc[-4], 0)
    }