In [2]:
from helpers import fetch_data, aggregate_data, base_query
import polars as pl
import pandas as pd
from polars import col as c
import polars as pl
from polars import col as c
import polars.selectors as cs
import plotly.express as px

In [None]:
base = base_query(28)
base_data = fetch_data(base)
state_data = aggregate_data(base_data, "state").collect()

def map_fig(data):
    fig = px.choropleth(
        state_data,
        locations="state",
        locationmode="USA-states",
        color="diff_per_rx",
        hover_name="state",
        title="Average Difference in Cost per Prescription by State",
        scope="usa",
    )

    return fig



In [None]:
def fetch_data(date_id: int, state_filter: str, group_by_col: list[str], **kwargs) -> pl.LazyFrame:
    # add context for group_by_col to be either 'product' or 'product_group'
    # if group_by_col not in ['product', 'product_group']:
    #     raise ValueError("group_by_col must be either 'product' or 'product_group'")

    unit_price_change = (c.unit_price - c.previous_unit_price).alias('unit_price_change')
    new_nadac = (c.unit_price * c.units).round(2).alias('new_nadac')
    old_nadac = (c.previous_unit_price * c.units).round(2).alias('old_nadac')
    total_diff = (new_nadac - old_nadac).alias('total_diff')
    diff_per_rx = (c.total_diff / c.rx_count).round(2).alias('diff_per_rx')


    # avg nadac change per unit
    def avg_unit_change() -> pl.Expr:
        return (c.avg_new_nadac - c.avg_old_nadac).round(2).alias('avg_unit_change')
    
    def classification()-> pl.Expr:
        # when total_diff is less than 0, return 'Decrease', otherwise return 'Increase'
        return pl.when(c.total_diff < 0).then(pl.lit('Decrease')).otherwise(pl.lit('Increase')).alias('classification')
    # add absolut columns for charting
    def abs_diff_col() -> pl.Expr:
        return cs.matches('(?i)diff').abs().name.suffix('_abs')
    
    # calculate percent change from old nadac to new nadac
    def percent_change() -> pl.Expr:
        return ((c.new_nadac - c.old_nadac) / c.old_nadac).round(4).alias('percent_change')    

    avg_new_nadac = (c.new_nadac / c.units).round(4).alias('avg_new_nadac')
    avg_old_nadac = (c.old_nadac / c.units).round(4).alias('avg_old_nadac')
    
    sdud = load_sdud().filter(c.state == state_filter)
    nadac = load_nadac().filter(c.date_id == date_id)

    data = sdud.join(nadac, on='ndc').with_columns(unit_price_change.round(4),new_nadac, old_nadac).with_columns((c.new_nadac - c.old_nadac).alias('total_diff'))
    data = data.join(load_medispan().select(c.ndc, c.generic_name.alias('product'), c.gpi_10_generic_name.alias('product_group')), on='ndc')

    
    if kwargs.get('product_group_filter'):
        data = data.filter(c.product_group == kwargs['product_group_filter'])
    
    data = (
        data
        .group_by(group_by_col)
        .agg(pl.col(['units','rx_count','total_diff', 'new_nadac', 'old_nadac']).sum())
        .with_columns(avg_new_nadac, avg_old_nadac)
        .with_columns(diff_per_rx)
        .with_columns(abs_diff_col(), classification(), avg_unit_change(), percent_change())
    )
    return data

In [2]:
date_id = 120
state_filter = 'XX'  # Example state filter, replace with actual state code if needed
df = fetch_data(date_id, 'XX', ['product_group']).collect()

def scatter_plot(df: pl.DataFrame):
    fig = px.scatter(
        df,
        x='rx_count',
        y='total_diff_abs',
        size='total_diff_abs',
        color='percent_change',
        title='Prescription Analysis: Total Difference vs Prescription Count',
        log_x=True,
        log_y=True,
        size_max=60,  # Slightly smaller max size to reduce overlap
        color_continuous_scale='Spectral_r',  # High contrast color scale for better visibility
    )

    # Custom hovertemplate
    hovertemplate = (
        "<b>Product:</b> %{customdata[0]}<br>"
        "<b>Classification:</b> %{customdata[1]}<br>"
        "<b>Avg Unit Change:</b> %{customdata[2]:$.2f}<br>"
        "<b>Avg New NADAC:</b> %{customdata[3]:$.2f}<br>"
        "<b>Avg Old NADAC:</b> %{customdata[4]:$.2f}<br>"
        "<b>Total Diff:</b> %{customdata[5]:$,.0f}<br>"
        "<b>Avg Diff Per Rx:</b> %{customdata[6]:$.2f}<br>"
        "<b>Rx Count:</b> %{customdata[7]:,.0f}<br>"
        "<b>Units:</b> %{customdata[8]:,.0f}<br>"
        "<b>Avg Percent Change:</b> %{customdata[9]:.1%}<br>"
        "<extra></extra>"
    )    # Prepare customdata for hovertemplate
    fig.update_traces(
        customdata=df[[
            'product_group',
            'classification',
            'avg_unit_change',
            'avg_new_nadac',
            'avg_old_nadac',
            'total_diff',
            'diff_per_rx',
            'rx_count',
            'units',
            'percent_change'
        ]].to_numpy(),        hovertemplate=hovertemplate,
        marker={
            'sizemin': 5,
              # Slightly smaller max size to reduce overlap
            'line': {'width': 2, 'color': 'rgba(44, 62, 80, 0.9)'},  # Darker, thicker borders
            'opacity': 0.7  # Reduced opacity to see overlapping bubbles
        },
    )

    fig.update_layout(
        title={
            'text': '<b>Prescription Analysis</b><br><sub>Total Difference vs Prescription Count</sub>',
            'x': 0.5,
            'xanchor': 'center',
            'font': {'size': 24, 'family': 'Inter, Segoe UI, Arial, sans-serif', 'color': '#2c3e50'}
        },
        xaxis={
            'title': {
                'text': '<b>Prescription Count</b> (log scale)', 
                'font': {'size': 16, 'family': 'Inter, Segoe UI, Arial, sans-serif', 'color': '#34495e'}
            },
            'tickformat': '~s',
            'gridcolor': 'rgba(189, 195, 199, 0.3)',
            'gridwidth': 1,
            'showline': True,
            'linecolor': 'rgba(149, 165, 166, 0.5)',
            'linewidth': 1,
            'tickfont': {'size': 12, 'color': '#7f8c8d'}
        },
        yaxis={
            'title': {
                'text': '<b>Total Difference</b> ($, log scale)', 
                'font': {'size': 16, 'family': 'Inter, Segoe UI, Arial, sans-serif', 'color': '#34495e'}
            },
            'tickformat': '$~s',
            'gridcolor': 'rgba(189, 195, 199, 0.3)',
            'gridwidth': 1,
            'showline': True,
            'linecolor': 'rgba(149, 165, 166, 0.5)',
            'linewidth': 1,
            'tickfont': {'size': 12, 'color': '#7f8c8d'}
        },        coloraxis={
            'cmin': -1,
            'cmax': 1,
            'colorscale': 'Spectral_r',  # High contrast color scale
            'colorbar': {
                'title': {
                    'text': '<b>Percent Change</b>', 
                    'font': {'size': 14, 'family': 'Inter, Segoe UI, Arial, sans-serif', 'color': '#2c3e50'}
                },
                'tickformat': '.0%',
                'orientation': 'h',
                'x': 0.5,
                'y': -0.25,
                'len': 0.8,
                'thickness': 20,
                'tickfont': {'size': 11, 'color': '#7f8c8d'},
                'bordercolor': 'rgba(149, 165, 166, 0.3)',
                'borderwidth': 1
            }
        },
        plot_bgcolor='rgba(255, 255, 255, 0.95)',  # Nearly white background for better contrast
        paper_bgcolor='white',
        font={'family': 'Inter, Segoe UI, Arial, sans-serif', 'size': 12, 'color': '#2c3e50'},
        margin={'l': 90, 'r': 90, 't': 120, 'b': 120},        width=1400,
        height=700,
        showlegend=False,
        # Configure hover mode for better interaction with overlapping points
        hovermode='closest',
        # Add subtle shadow effect
        annotations=[
            dict(
                text="",
                showarrow=False,
                xref="paper", yref="paper",
                x=0, y=0, xanchor='left', yanchor='bottom',
                xshift=-5, yshift=-5,
                bordercolor="rgba(0,0,0,0.1)",
                borderwidth=1,
                bgcolor="rgba(0,0,0,0.02)",
                width=1410, height=710
            )
        ]
    )

    return fig
scatter_plot(df)

In [None]:
date_id = 120
state_filter = 'XX'
product_group_filter = 'Levothyroxine'  # Example state filter, replace with actual state code if needed
df = fetch_data(date_id, 'XX', ['product', ], product_group_filter=product_group_filter).collect()

def bar_chart(df):
    fig = px.bar(
        df.to_pandas(),
        x="total_diff",
        y="product",
        color="percent_change",
        color_continuous_scale="Spectral_r",
        orientation="h",
        title="Product-level Total Difference (Diverging Bar Chart)",
        text="total_diff",
    )

    fig.update_traces(
        texttemplate="%{x:$,.0f}",
        textposition="outside",
        marker_line_width=2,
        marker_line_color="rgba(44, 62, 80, 0.9)",
        opacity=0.8,
        customdata=df.select([
            "product",
            "classification",
            "avg_unit_change",
            "avg_new_nadac",
            "avg_old_nadac",
            "total_diff",
            "diff_per_rx",
            "rx_count",
            "units",
            "percent_change"
        ]).to_numpy(),
        hovertemplate=(
            "<b>Product:</b> %{customdata[0]}<br>"
            "<b>Classification:</b> %{customdata[1]}<br>"
            "<b>Avg Unit Change:</b> %{customdata[2]:$.2f}<br>"
            "<b>Avg New NADAC:</b> %{customdata[3]:$.2f}<br>"
            "<b>Avg Old NADAC:</b> %{customdata[4]:$.2f}<br>"
            "<b>Total Diff:</b> %{customdata[5]:$,.0f}<br>"
            "<b>Avg Diff Per Rx:</b> %{customdata[6]:$.2f}<br>"
            "<b>Rx Count:</b> %{customdata[7]:,.0f}<br>"
            "<b>Units:</b> %{customdata[8]:,.0f}<br>"
            "<b>Avg Percent Change:</b> %{customdata[9]:.1%}<br>"
            "<extra></extra>"
        )
    )

    fig.update_layout(
        title={
            "text": "<b>Product-level Total Difference</b><br><sub>Diverging Bar Chart</sub>",
            "x": 0.5,
            "xanchor": "center",
            "font": {"size": 22, "family": "Inter, Segoe UI, Arial, sans-serif", "color": "#2c3e50"}
        },
        xaxis={
            "title": {
                "text": "<b>Total Difference</b> ($)",
                "font": {"size": 16, "family": "Inter, Segoe UI, Arial, sans-serif", "color": "#34495e"}
            },
            "tickformat": "$~s",
            "gridcolor": "rgba(189, 195, 199, 0.3)",
            "gridwidth": 1,
            "showline": True,
            "linecolor": "rgba(149, 165, 166, 0.5)",
            "linewidth": 1,
            "tickfont": {"size": 12, "color": "#7f8c8d"}
        },
        yaxis={
            "title": {
                "text": "<b>Product</b>",
                "font": {"size": 16, "family": "Inter, Segoe UI, Arial, sans-serif", "color": "#34495e"}
            },
            "tickfont": {"size": 12, "color": "#7f8c8d"},
            "categoryorder": "total ascending"
        },
        coloraxis={
            "cmin": -1,
            "cmax": 1,
            "colorscale": "Spectral_r",
            "colorbar": {
                "title": {
                    "text": "<b>Percent Change</b>",
                    "font": {"size": 14, "family": "Inter, Segoe UI, Arial, sans-serif", "color": "#2c3e50"}
                },
                "tickformat": ".0%",
                "orientation": "h",
                "x": 0.5,
                "y": -0.25,
                "len": 0.8,
                "thickness": 20,
                "tickfont": {"size": 11, "color": "#7f8c8d"},
                "bordercolor": "rgba(149, 165, 166, 0.3)",
                "borderwidth": 1
            }
        },
        plot_bgcolor="rgba(255, 255, 255, 0.95)",
        paper_bgcolor="white",
        font={"family": "Inter, Segoe UI, Arial, sans-serif", "size": 12, "color": "#2c3e50"},
        margin={"l": 120, "r": 60, "t": 100, "b": 80},
        width=1200,
        height=600,
        showlegend=False,
        hovermode="closest"
    )

    fig.show()

bar_chart()


TypeError: bar_chart() missing 1 required positional argument: 'df'

In [5]:
df.schema.to_python()

{'product': str,
 'units': float,
 'rx_count': float,
 'total_diff': int,
 'new_nadac': float,
 'old_nadac': float,
 'avg_new_nadac': float,
 'avg_old_nadac': float,
 'diff_per_rx': float,
 'total_diff_abs': int,
 'diff_per_rx_abs': float,
 'classification': str,
 'avg_unit_change': float,
 'percent_change': float}

In [9]:
date_id = 146
state_filter = 'XX'  # Example state filter, replace with actual state code if needed
df = fetch_data(date_id, 'XX', ['product_group'], product_group_filter='Disulfiram').sort('total_diff',descending=True).collect()
df.to_pandas()

Unnamed: 0,product_group,units,rx_count,total_diff,new_nadac,old_nadac,avg_new_nadac,avg_old_nadac,diff_per_rx,total_diff_abs,diff_per_rx_abs,classification,avg_unit_change,percent_change
0,Disulfiram,122653.0,4734.0,-4,1633110.02,219247.42,13.3149,1.7875,-0.0,4,0.0,Decrease,11.53,6.4487


In [14]:
from helpers import load_sdud, load_nadac, load_medispan
sdud = load_sdud().filter(c.state == 'XX')
nadac = load_nadac().filter(c.date_id == 146)
m = load_medispan().select(c.ndc, c.gpi_14_name.alias('product'), c.gpi_10_generic_name.alias('product_group'))
sdud.join(m,on='ndc').filter(c.product_group == 'Disulfiram').join(nadac, on='ndc').collect().to_pandas()

Unnamed: 0,state,ndc,units,rx_count,product,product_group,unit_price,effective_date,year,previous_unit_price,date_id,effective_date_right
0,XX,42794002908,16466.0,645.0,Disulfiram Tab 500 MG,Disulfiram,13.31488,2025-04-23,2025,13.31516,146,2025-04-01
1,XX,62135043230,84139.0,3255.0,Disulfiram Tab 500 MG,Disulfiram,13.31488,2025-04-23,2025,,146,2025-04-01
2,XX,62135043290,22048.0,834.0,Disulfiram Tab 500 MG,Disulfiram,13.31488,2025-04-23,2025,,146,2025-04-01


In [21]:
load_nadac().filter(c.ndc == '62135043230').sort(c.date_id).collect()

ndc,unit_price,effective_date,year,previous_unit_price,date_id,effective_date_right
str,f32,date,i32,f32,u32,date
"""62135043230""",12.97255,2024-04-17,2024,,134,2024-04-01
"""62135043230""",13.31433,2025-02-19,2025,12.97255,144,2025-02-01
"""62135043230""",13.31473,2025-03-19,2025,13.31433,145,2025-03-01
"""62135043230""",13.31488,2025-04-23,2025,13.31473,146,2025-04-01
"""62135043230""",13.31516,2025-05-21,2025,13.31488,147,2025-05-01
