# Clinical Trial Patient Demand Calculator

Interact with the Streamlit dashboard directly inside this notebook. Adjust trial and dosing parameters to see total patient demand update instantly.

In [None]:
# Install required packages
!pip install -r requirements.txt

: 

In [None]:
# Enable Streamlit to render inline within Jupyter
from streamlit_jupyter import StreamlitPatcher

patcher = StreamlitPatcher()
patcher.jupyter()

In [None]:
# Define the demand calculator app
from dataclasses import dataclass
from typing import Dict

import streamlit as st

try:
    from streamlit.errors import StreamlitAPIException
except ImportError:  # pragma: no cover
    StreamlitAPIException = Exception  # type: ignore


@dataclass
class DosingParams:
    patients: int
    products: int
    product_amount: float
    admin_points: int
    days: int
    buffer_pct: int


def inject_custom_css() -> None:
    st.markdown(
        """
        <style>
        :root {
            --primary-color: #0b5ea8;
            --accent-color: #2f9be0;
            --bg-gradient-top: #f6fbff;
            --bg-gradient-bottom: #eef5fb;
            --text-color: #0b2948;
            --muted-color: #5b6b7f;
        }

        .stApp {
            background: linear-gradient(180deg, var(--bg-gradient-top), var(--bg-gradient-bottom));
        }

        .main .block-container {
            padding-top: 2.5rem;
            padding-bottom: 3rem;
            max-width: 1100px;
        }

        .trial-expander {
            border-radius: 18px !important;
            border: 1px solid rgba(11, 94, 168, 0.15) !important;
            background-color: rgba(255, 255, 255, 0.88) !important;
            box-shadow: 0 6px 24px rgba(11, 41, 72, 0.08);
            padding: 0.4rem 0.2rem 1.2rem 0.2rem;
        }

        .group-card {
            border-radius: 16px;
            padding: 1.1rem 1.3rem;
            margin-bottom: 1.1rem;
            background-color: rgba(255, 255, 255, 0.92);
            border: 1px solid rgba(47, 155, 224, 0.24);
            box-shadow: 0 4px 16px rgba(11, 94, 168, 0.08);
        }

        .result-card {
            border-radius: 24px;
            padding: 2.4rem 2.6rem;
            background: linear-gradient(150deg, rgba(47, 155, 224, 0.18), rgba(11, 94, 168, 0.32));
            box-shadow: 0 18px 45px rgba(11, 41, 72, 0.18);
            text-align: center;
        }

        .result-card h2 {
            color: var(--text-color);
            margin-bottom: 0.4rem;
            font-weight: 600;
        }

        .result-card h1 {
            color: #05529b;
            font-size: clamp(2.8rem, 6vw, 3.6rem);
            font-weight: 700;
            margin-bottom: 0.3rem;
        }

        .result-card p {
            color: var(--muted-color);
            margin: 0;
        }

        .stNumberInput > label {
            font-weight: 600;
            color: var(--text-color);
        }

        .help-tooltip {
            color: var(--muted-color) !important;
        }

        .stExpander .streamlit-expanderHeader {
            font-weight: 600;
            color: var(--text-color);
            font-size: 1.02rem;
        }

        .stMetric-label {
            color: var(--muted-color) !important;
        }

        .stMetric-value {
            color: var(--text-color) !important;
        }
        </style>
        """
    , unsafe_allow_html=True)


def configure_page(*, embedded: bool = False) -> None:
    try:
        st.set_page_config(
            page_title="Clinical Trial Patient Demand Calculator",
            page_icon="",
            layout="centered",
            initial_sidebar_state="expanded",
        )
    except StreamlitAPIException:
        pass
    except RuntimeError:
        if not embedded:
            raise


def card_container():
    try:
        return st.container(border=True)
    except TypeError:
        return st.container()


def calculate_group_demand(params: DosingParams) -> float:
    base = (
        params.patients
        * params.products
        * params.product_amount
        * params.admin_points
        * params.days
    )
    return base * (1 + params.buffer_pct / 100.0)


def run_app(*, embedded: bool = False) -> None:
    configure_page(embedded=embedded)
    inject_custom_css()

    st.markdown(
        """
        <div style="text-align:center; margin-bottom: 1.6rem;">
            <h1 style="color:#0b2948;">Clinical Trial Patient Demand Calculator</h1>
            <p style="color:#5b6b7f; font-size:1.05rem;">
                Model patient-level dosing needs across trials with flexible treatment configurations.
            </p>
        </div>
        """
        , unsafe_allow_html=True
    )

    st.sidebar.markdown("## Planner Settings")
    num_trials = st.sidebar.number_input(
        "Number of Trials",
        min_value=1,
        max_value=25,
        value=1,
        step=1,
        help="How many clinical trials should be included in this demand model?",
    )

    st.sidebar.info(
        "Adjust baseline dosing for each trial and fine-tune group level assumptions for precision.",
    )

    st.sidebar.markdown(
        """
        **Reminder**: Total demand reflects every administration event across trials and treatment groups.
        """
    )

    total_demand = 0.0
    trial_totals: Dict[str, float] = {}

    for trial_index in range(1, num_trials + 1):
        with st.expander(
            f"Trial {trial_index} configuration",
            expanded=(trial_index == 1),
        ):
            st.markdown("<div class='trial-expander'>", unsafe_allow_html=True)

            details_cols = st.columns(2)
            num_groups = int(
                details_cols[0].number_input(
                    "Treatment Groups",
                    min_value=1,
                    max_value=12,
                    value=2,
                    step=1,
                    key=f"trial_{trial_index}_groups",
                    help="How many treatment arms or cohorts are in this trial?",
                )
            )
            products_per_trial = details_cols[1].number_input(
                "Products per Trial",
                min_value=1,
                max_value=10,
                value=1,
                step=1,
                key=f"trial_{trial_index}_products",
                help="Count formulations, strengths, or presentations required for this trial.",
            )

            dosing_defaults = st.columns(3)
            default_product_amount = dosing_defaults[0].number_input(
                "Product per Administration (mg)",
                min_value=0.0,
                value=50.0,
                step=5.0,
                format="%.2f",
                key=f"trial_{trial_index}_product_amount",
                help="Baseline dose per administration event. Adjust per group when needed.",
            )
            default_admin_points = dosing_defaults[1].number_input(
                "Administration Points / Day",
                min_value=1,
                max_value=24,
                value=1,
                step=1,
                key=f"trial_{trial_index}_admin_points",
                help="Number of dosing events each day (e.g., BID = 2).",
            )
            default_days = dosing_defaults[2].number_input(
                "Days of Administration",
                min_value=1,
                max_value=365,
                value=28,
                step=1,
                key=f"trial_{trial_index}_days",
                help="Planned duration of treatment for the trial.",
            )

            st.caption(
                "Tune individual treatment group inputs below to capture protocol nuances."
            )

            trial_total = 0.0
            for group_index in range(1, num_groups + 1):
                with card_container():
                    st.markdown(
                        "<div class='group-card'>",
                        unsafe_allow_html=True,
                    )
                    st.markdown(
                        f"<h4 style='margin-top:0;color:#0b2948;'>Treatment Group {group_index}</h4>",
                        unsafe_allow_html=True,
                    )

                    top_cols = st.columns(3)
                    patients = top_cols[0].number_input(
                        "Patients",
                        min_value=1,
                        max_value=1000,
                        value=50,
                        step=1,
                        key=f"trial_{trial_index}_group_{group_index}_patients",
                        help="Participants randomised or assigned to this treatment group.",
                    )
                    product_amount = top_cols[1].number_input(
                        "Product per Administration (mg)",
                        min_value=0.0,
                        value=default_product_amount,
                        step=5.0,
                        format="%.2f",
                        key=f"trial_{trial_index}_group_{group_index}_amount",
                        help="Dose per administration for this group.",
                    )
                    admin_points = top_cols[2].number_input(
                        "Administration Points / Day",
                        min_value=1,
                        max_value=24,
                        value=int(default_admin_points),
                        step=1,
                        key=f"trial_{trial_index}_group_{group_index}_points",
                        help="Daily dosing frequency for this group.",
                    )

                    bottom_cols = st.columns(3)
                    days = bottom_cols[0].number_input(
                        "Days of Administration",
                        min_value=1,
                        max_value=365,
                        value=int(default_days),
                        step=1,
                        key=f"trial_{trial_index}_group_{group_index}_days",
                        help="Number of treatment days for this group.",
                    )
                    products_multiplier = bottom_cols[1].number_input(
                        "Products per Trial",
                        min_value=1,
                        max_value=10,
                        value=int(products_per_trial),
                        step=1,
                        key=f"trial_{trial_index}_group_{group_index}_products",
                        help="Use when a group requires multiple product presentations.",
                    )
                    buffer_pct = bottom_cols[2].number_input(
                        "Contingency Buffer (%)",
                        min_value=0,
                        max_value=100,
                        value=0,
                        step=5,
                        key=f"trial_{trial_index}_group_{group_index}_buffer",
                        help="Optional overage to cover wastage or resupply (applies to this group only).",
                    )

                    group_params = DosingParams(
                        patients=int(patients),
                        products=int(products_multiplier),
                        product_amount=float(product_amount),
                        admin_points=int(admin_points),
                        days=int(days),
                        buffer_pct=int(buffer_pct),
                    )
                    group_demand = calculate_group_demand(group_params)

                    trial_total += group_demand
                    total_demand += group_demand

                    st.markdown(
                        f"<p style='color:#5b6b7f;margin:0.8rem 0 0;'>Group demand contribution: <strong>{group_demand:,.0f}</strong></p>",
                        unsafe_allow_html=True,
                    )
                    st.markdown("</div>", unsafe_allow_html=True)

            st.markdown("</div>", unsafe_allow_html=True)
            trial_totals[f"Trial {trial_index}"] = trial_total

    st.divider()

    rounded_demand = int(round(total_demand))
    with card_container():
        st.markdown(
            """
            <div class='result-card'>
                <h2>&#128138; Total Patient Demand</h2>
                <h1>{:,}</h1>
                <p>Sum of all product administrations across trials, groups, and buffers.</p>
            </div>
            """.format(rounded_demand),
            unsafe_allow_html=True,
        )

    if trial_totals:
        items = list(trial_totals.items())
        for idx in range(0, len(items), 3):
            row = items[idx : idx + 3]
            metric_cols = st.columns(len(row))
            for col, (label, value) in zip(metric_cols, row):
                col.metric(label=label, value=f"{int(round(value)):,}")

    st.caption("Adjust protocol settings to refresh the demand estimate instantly.")


print("Demand calculator ready. Call run_app(embedded=True) to display it.")

In [None]:
# Display the app
run_app(embedded=True)

---
If the widgets fail to render, ensure the `streamlit-jupyter` package is installed and rerun the Streamlit patch cell above.