# Clinical Trial Patient Demand Calculator

Interact with the Streamlit dashboard directly inside this notebook. Adjust trial and dosing parameters to see total patient demand update instantly.

In [None]:
# Install required packages
!pip install -r requirements.txt

In [None]:
# Enable Streamlit to render inline within Jupyter
from streamlit_jupyter import StreamlitPatcher

patcher = StreamlitPatcher()
patcher.jupyter()

In [None]:
def product_inputs(prefix: str, index: int) -> ProductParams:
    """Get inputs for a single product in a treatment group."""
    product_name = select_or_add_product(f"{prefix}_prod{index}")
    if not product_name:
        st.warning("Please enter a product name")
        return None
        
    product_amount = st.number_input(
        "Product per administration (mg)", 
        min_value=0.0, value=50.0, step=0.1, format="%.2f", 
        key=f"{prefix}_product{index}_amount"
    )
    admin_points = st.number_input(
        "Administration points / day", 
        min_value=1, max_value=24, value=1, step=1, 
        key=f"{prefix}_product{index}_points"
    )
    days = st.number_input(
        "Days of administration", 
        min_value=1, max_value=365, value=28, step=1, 
        key=f"{prefix}_product{index}_days"
    )
    
    return ProductParams(
        name=product_name,
        product_amount=float(product_amount),
        admin_points=int(admin_points),
        days=int(days)
    )


def calculate_product_demand(patients: int, params: ProductParams, buffer_pct: int = 0) -> float:
    """Calculate demand for a single product configuration."""
    base = (
        patients
        * params.product_amount
        * params.admin_points
        * params.days
    )
    return base * (1 + buffer_pct / 100.0)


def calculate_group_demand(params: DosingParams) -> tuple[float, Dict[str, float]]:
    """Calculate total demand and per-product breakdown for a group."""
    total = 0.0
    by_product: Dict[str, float] = {}
    
    for product in params.products:
        if product:  # Skip None values from incomplete product setup
            amount = calculate_product_demand(params.patients, product, params.buffer_pct)
            by_product[product.name] = by_product.get(product.name, 0) + amount
            total += amount
            
    return total, by_product


def group_inputs(prefix: str):
    """Get all inputs for a treatment group."""
    patients = st.number_input("Patients", min_value=1, max_value=10000, value=50, step=1, key=f"{prefix}_patients")
    num_products = st.number_input("Number of distinct products", min_value=1, max_value=20, value=1, step=1, key=f"{prefix}_num_products")
    buffer_pct = st.number_input("Contingency buffer (%)", min_value=0, max_value=100, value=0, step=1, key=f"{prefix}_buffer")
    
    products = []
    if num_products > 1:
        st.markdown("### Product-specific parameters")
        st.info("Configure dosing for each distinct product in this treatment group")
    
    for i in range(1, int(num_products) + 1):
        with st.expander(f"Product {i} Configuration", expanded=(i == 1)):
            product = product_inputs(prefix, i)
            products.append(product)
    
    # Filter out None values from incomplete product setups
    products = [p for p in products if p is not None]
    
    return DosingParams(
        patients=int(patients),
        products=products,
        buffer_pct=int(buffer_pct),
    )


def run_app(*, embedded: bool = False) -> None:
    configure_page(embedded=embedded)
    inject_custom_css()

    st.markdown(
        """
        <div style="text-align:center; margin-bottom: 1.6rem;">
            <h1 style="color:#0b2948;">Clinical Trial Patient Demand Calculator</h1>
            <p style="color:#5b6b7f; font-size:1.05rem;">
                Model patient-level dosing needs across trials with flexible treatment configurations.
            </p>
        </div>
        """
        , unsafe_allow_html=True
    )

    # Initialize global product totals
    if 'product_totals' not in st.session_state:
        st.session_state.product_totals = {}

    st.sidebar.markdown("## Planner Settings")
    num_trials = st.sidebar.number_input(
        "Number of Trials",
        min_value=1,
        max_value=25,
        value=1,
        step=1,
        help="How many clinical trials should be included in this demand model?",
    )

    st.sidebar.info(
        "Adjust baseline dosing for each trial and fine-tune group level assumptions for precision.",
    )

    total_demand = 0.0
    trial_totals: Dict[str, float] = {}
    product_totals: Dict[str, float] = {}

    for trial_index in range(1, num_trials + 1):
        with st.expander(
            f"Trial {trial_index} configuration",
            expanded=(trial_index == 1),
        ):
            st.markdown("<div class='trial-expander'>", unsafe_allow_html=True)

            num_groups = st.number_input(
                "Treatment Groups",
                min_value=1,
                max_value=12,
                value=2,
                step=1,
                key=f"trial_{trial_index}_groups",
                help="How many treatment arms or cohorts are in this trial?",
            )

            trial_total = 0.0
            for group_index in range(1, int(num_groups) + 1):
                with card_container():
                    st.markdown(
                        "<div class='group-card'>",
                        unsafe_allow_html=True,
                    )
                    st.markdown(
                        f"<h4 style='margin-top:0;color:#0b2948;'>Treatment Group {group_index}</h4>",
                        unsafe_allow_html=True,
                    )

                    group_params = group_inputs(f"trial{trial_index}_group{group_index}")
                    group_demand, group_by_product = calculate_group_demand(group_params)

                    # Update product totals
                    for prod_name, amount in group_by_product.items():
                        product_totals[prod_name] = product_totals.get(prod_name, 0) + amount

                    trial_total += group_demand
                    total_demand += group_demand

                    # Show group details
                    st.markdown(
                        f"<p style='color:#5b6b7f;margin:0.8rem 0 0;'>Group demand: <strong>{group_demand:,.0f} mg</strong></p>",
                        unsafe_allow_html=True,
                    )
                    
                    # Show per-product breakdown for this group
                    if group_by_product:
                        st.markdown("##### Product Breakdown:")
                        for name, amount in group_by_product.items():
                            st.markdown(f"- {name}: {amount:,.0f} mg")
                            
                    st.markdown("</div>", unsafe_allow_html=True)

            st.markdown("</div>", unsafe_allow_html=True)
            trial_totals[f"Trial {trial_index}"] = trial_total

    st.divider()

    # Show total demand and per-product totals
    rounded_demand = int(round(total_demand))
    with card_container():
        st.markdown(
            f"""
            <div class='result-card'>
                <h2>&#128138; Total Patient Demand</h2>
                <h1>{rounded_demand:,} mg</h1>
                <p>Sum of all product administrations across trials, groups, and buffers.</p>
            </div>
            """,
            unsafe_allow_html=True,
        )

    # Show per-product totals
    st.markdown("### Product-Specific Totals")
    for name, amount in product_totals.items():
        st.metric(
            label=f"Total {name}",
            value=f"{int(round(amount)):,} mg"
        )

    # Show trial totals
    if trial_totals:
        st.markdown("### Trial-Specific Totals")
        items = list(trial_totals.items())
        for idx in range(0, len(items), 3):
            row = items[idx : idx + 3]
            metric_cols = st.columns(len(row))
            for col, (label, value) in zip(metric_cols, row):
                col.metric(label=label, value=f"{int(round(value)):,} mg")

    st.caption("Adjust protocol settings to refresh the demand estimates instantly.")


print("Demand calculator ready. Call run_app(embedded=True) to display it.")

In [None]:
# Display the app
run_app(embedded=True)

---
If the widgets fail to render, ensure the `streamlit-jupyter` package is installed and rerun the Streamlit patch cell above.