In [1]:
# import resources
%matplotlib inline

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns

In [2]:
df = pd.read_csv(r"../data_for_rupin_conference.csv")
df = df.query("treatment_start_year < 2023")

In [3]:
grouping_mode = 'stack'#'stack' #''group

# remove any pre-existing indices for ease of use in the D-Tale code, but this is not required
df = df.reset_index().drop('index', axis=1, errors='ignore')
df.columns = [str(c) for c in df.columns]  # update columns to strings in case they are numbers


In [4]:
chart_data = pd.concat([
    pd.Series(df.index, index=df.index, name='__index__'),
    df['treatment_start_year'],
    df['treatment_type_renamed'],
], axis=1)

In [5]:
chart_data = chart_data.query("""(`treatment_type_renamed` == 'TAU') or (`treatment_type_renamed` == 'IPT-SCI') or (`treatment_type_renamed` == 'NaN')""")
chart_data = chart_data.sort_values(['treatment_type_renamed', 'treatment_start_year'])
chart_data = chart_data.rename(columns={'treatment_start_year': 'x'})
chart_data_count = chart_data.groupby(['treatment_type_renamed', 'x'])[['__index__']].count()


In [6]:
chart_data_count.columns = ['__index__|count']
chart_data = chart_data_count.reset_index()
chart_data = chart_data.dropna()
# WARNING: This is not taking into account grouping of any kind, please apply filter associated with
#          the group in question in order to replicate chart. For this we're using '"""`treatment_type_renamed` == 'normal'"""'
query_chart_data = chart_data.query("""`treatment_type_renamed` == 'TAU'""")

import plotly.graph_objs as go

charts = []
charts.append(go.Bar(
    x=query_chart_data['x'],
    y=query_chart_data['__index__|count'],
    name='TAU'
))


query_chart_data = chart_data.query("""`treatment_type_renamed` == 'IPT-SCI'""")

charts.append(go.Bar(
    x=query_chart_data['x'],
    y=query_chart_data['__index__|count'],
    name='IPT-SCI'
))

figure = go.Figure(data=charts, layout=go.Layout({
    'barmode': f'{grouping_mode}',
    'legend': {'orientation': 'h'},
    'title': {'text': 'Treatment Type - Count by treatment_start_year'},
    'xaxis': {'tickformat': '0:g', 'title': {'text': 'treatment_start_year'}},
    'yaxis': {'tickformat': '0:g', 'title': {'text': 'Count'}, 'type': 'linear'}
}))



In [7]:
from plotly.offline import iplot, init_notebook_mode
init_notebook_mode(connected=True)
for chart in charts:
    chart.pop('id', None) # for some reason iplot does not like 'id'
iplot(figure)

In [8]:
figure.write_html(f"Treatment Type - Count by treatment_start_year {grouping_mode}.html")

In [None]:
figure.write_image(f"Treatment Type - Count by treatment_start_year {grouping_mode}.png")