# Plot decile charts of each measure for each group

In [26]:
from  ebmdatalab import charts
import pandas as pd
from os import listdir,path,environ
from measures import measures_kwargs
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import itertools
import numpy as np

In [84]:
#Load input files
first = True
for f,d in [(path.join('..','output','measures',f),f.replace('input_','').replace('.csv.gz','')) for f in listdir(path.join('..','output','measures')) if f.startswith('input')]:
    if first:
        df = pd.read_csv(f).assign(date=d)
        first=False
    else:
        df = pd.concat([df,pd.read_csv(f).assign(date=d)])
df.date = df.date= pd.to_datetime(df.date)

In [89]:
#get infection and antibiotic prescription date columns
infection_antibiotic_measures = ['infection','antibiotic_prescription']
infection_antibiotic_cols = {}
for a,b in itertools.product(infection_antibiotic_measures,['date','code']):
    infection_antibiotic_cols[(a,b)] = sorted([c for c in df.columns if c.startswith(f'{a}_{b}_')])

date_cols = sorted([v for k,v in infection_antibiotic_cols.items() if k[1]=='date'])

In [99]:
#count intersections of infection and antibiotic prescription dates
A= df[date_cols[0]].to_numpy(dtype=str)
I= df[date_cols[1]].to_numpy(dtype=str)
intersection_count = lambda x: np.count_nonzero(np.intersect1d(x[0],x[1])!='nan')
df['infection_antibiotic_intersection'] = [intersection_count(r) for r in np.stack((A,I),axis=1)]

In [None]:
#Check pivot n is adequate
pivot_n = max([int(a.split('_')[-1]) for a in list(itertools.chain(*[v for _,v in infection_antibiotic_cols.items()]))])
print(f'n for pivot operations: {pivot_n}')
print('Max record counts for pivoted columns:')
print(df[[m + 's' for m in infection_antibiotic_measures]].max())

In [None]:
def plot_decile_group(df, group, measure):
    group = None if group=="practice" else group
    plot_groups = ["practice", group, "date"] if group else ["practice", "date"]
    df_to_plot = (
        df.groupby(plot_groups)[[measure["numerator"], measure["denominator"]]]
        .sum()
        .reset_index()
    )
    df_to_plot[measure["id"]] = df_to_plot.apply(
        lambda x: x[measure["numerator"]] / x[measure["denominator"]]
        if x[measure["denominator"]] > 0
        else 0,
        axis=1,
    )

    if "OPENSAFELY_BACKEND" not in environ or environ["OPENSAFELY_BACKEND"] == "expectations":
        df_to_plot[measure["id"]] = df_to_plot[measure["id"]].fillna(0)
        df_to_plot[measure["id"]] = df_to_plot.apply(
            lambda x: x[measure["id"]]
            if x[measure["id"]] <= 1
            else 1 / x[measure["id"]],
            axis=1,
        )
    if group:
        group_values = df_to_plot[group].drop_duplicates()
        n_groups = len(group_values)
    else:
        n_groups = 1
    fig = plt.figure(figsize=(12, 8 * n_groups))
    fig.autofmt_xdate()
    layout = gridspec.GridSpec(n_groups, 1, figure=fig)
    if group:
        for groupval, lax in zip(group_values, layout):
            ax = plt.subplot(lax)
            title = (
                f'{measure["id"].replace("_"," ").title()}'
                + f" - {group.title()}:{groupval}"
            )
            charts.deciles_chart(
                df=df_to_plot[df_to_plot[group] == groupval],
                period_column="date",
                column=measure["id"],
                title=title,
                ax=ax,
            )
    else:
        ax = plt.subplot(layout[0])
        title = f'{measure["id"].replace("_"," ").title()}'
        charts.deciles_chart(
            df=df_to_plot,
            period_column="date",
            column=measure["id"],
            title=title,
            ax=ax,
        )
    return fig

In [None]:
plt.ioff()
plt.rcParams.update({'figure.max_open_warning': 0})
for measure in measures_kwargs:
    df[measure["id"]] = df[measure["numerator"]] / df[measure["denominator"]]
    fig = plot_decile_group(df=df,group=None,measure=measure)
    plt.show()
    plt.close(fig)
    for group in measure["group_by"]:
        plot_decile_group(df=df,group=group,measure=measure)
        plt.show()
        plt.close(fig)
        