In [1]:
from IPython.display import display

import pandas as pd
import warnings
from model_drift import settings
from model_drift.data.utils import nested2series
import matplotlib.pylab as plt
import numpy as np
import seaborn as sns
from model_drift.drift.numeric import KSDriftCalculator, BasicDriftCalculator
from model_drift.drift.categorical import ChiSqDriftCalculator
from model_drift.drift.collection import DriftCollectionCalculator

from model_drift.data.padchest import PadChest
import plotly.graph_objects as go

warnings.filterwarnings("ignore")

# Real Valued Number Drift Detection

In [2]:
# Load padchest CSV
pc = PadChest(settings.PADCHEST_FILENAME)
train, val, test = pc.split(settings.PADCHEST_SPLIT_DATES, studydate_index=True)

In [3]:
pd.concat(
    {
        "all": pc.df["StudyDate"].describe(datetime_is_numeric=True),
        "train": train.df["StudyDate"].describe(datetime_is_numeric=True),
        "val": val.df["StudyDate"].describe(datetime_is_numeric=True),
        "test": test.df["StudyDate"].describe(datetime_is_numeric=True),
    },
    axis=1,
)

Unnamed: 0,all,train,val,test
count,160819,91726,22176,46917
mean,2012-09-14 20:54:45.910246912,2011-01-06 03:16:23.616423168,2013-06-16 00:14:59.999999744,2015-08-29 00:15:02.359485696
min,2007-05-03 00:00:00,2007-05-03 00:00:00,2013-01-01 00:00:00,2014-01-01 00:00:00
25%,2010-10-27 00:00:00,2010-01-19 00:00:00,2013-03-10 00:00:00,2014-08-08 00:00:00
50%,2012-06-18 00:00:00,2011-01-18 00:00:00,2013-06-04 00:00:00,2015-06-09 00:00:00
75%,2014-05-28 00:00:00,2012-01-11 00:00:00,2013-09-25 00:00:00,2016-09-13 00:00:00
max,2017-11-17 00:00:00,2012-12-28 00:00:00,2013-12-31 00:00:00,2017-11-17 00:00:00


In [4]:
day = "2014-05-05"
window = "30D"

day_dt = pd.to_datetime(day)
delta = pd.tseries.frequencies.to_offset(window)
sample = test.df.loc[str(day_dt-delta):str(day_dt)]
sample['StudyDate'].describe()

count                    2255
unique                     26
top       2014-05-05 00:00:00
freq                      219
first     2014-04-07 00:00:00
last      2014-05-05 00:00:00
Name: StudyDate, dtype: object

In [6]:
ks_test = KSDriftCalculator(val.df['age'].values)
stats_test = BasicDriftCalculator(val.df['age'].values)


rv_test = DriftCollectionCalculator([ks_test, stats_test])
stats = rv_test(sample['age'].values)
stats

{'ks': {'distance': 0.07797949002217297,
  'pval': 2.8395980627922525e-11,
  'critical_value': 0.02705099092457315,
  'critical_diff': 0.05092849909759982},
 'stats': {'mean': 62.90305773977247,
  'std': 17.959973185546414,
  'median': 66.28472863919177}}

In [66]:

display(nested2series(stats, name=day).to_frame())
fig = go.Figure()

i = 0
x = day
y = "age"
ref = val.df.assign(src="ref")
fig.add_trace(
    go.Violin(
        # x=['ref'] * len(ref),
        x=ref[y],
        # legendgroup=day,
        # scalegroup=day,
        name="Ref",
        # side="positive",
        # line_color="blue",
    )
)

sample = sample.assign(src="sample")
fig.add_trace(
    go.Violin(
        # x=[x] * len(sample),
        x=sample[y],
        # legendgroup=x,
        # scalegroup=x,
        name=day+'~',
        hovertemplate = 'Price: %{y:$.2f}<extra></extra>',
        # side="positive",
        # line_color='blue'
    )
)
fig.update_layout(hovermode="x unified", 
    hoverlabel=dict(
        bgcolor="white",
        font_size=10,
    ))
# fig.update_layout(violinmode='group')
fig.update_traces(meanline_visible=True)
fig.update_traces(orientation='h', side='positive', width=3, points=False)
fig.update_layout(xaxis_showgrid=False, xaxis_zeroline=False)
fig.show()

Unnamed: 0,2014-05-05
distance,6.751263
pval,0.239806
dof,5.0
critical_value,9.236357
critical_diff,-2.485093


In [8]:
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10*2, 8))

# sns.histplot(data, x='age', stat='density', kde=True, hue='src', common_norm=False, ax=ax1)
# ax1.set_title(f"Distance: {stats['ks']['distance']:.2e}, p-Value: {stats['ks']['pval']:.2e}")
# sns.violinplot(data=data, y='age', hue='src', x='dist', ax=ax2, inner="quartile", split=True)


In [9]:
col = "Projection"
chi2_test = ChiSqDriftCalculator(val.df[col].values)
stats = chi2_test(sample[col].values)
stats

{'distance': 6.751263401336422,
 'pval': 0.2398060810665436,
 'dof': 5,
 'critical_value': 9.236356899781123,
 'critical_diff': -2.4850934984447015}

In [10]:
fig = go.Figure()


fig.add_trace(go.Histogram(
    x=ref[col],
    histnorm='probability',
    name='ref', # name used in legend and hover labels
    # xbins=dict( # bins used for histogram
    #     start=-4.0,
    #     end=3.0,
    #     size=0.5
    # ),
    opacity=0.75
))
fig.add_trace(go.Histogram(
    x=sample[col],
    histnorm='probability',
    name=day,
    # xbins=dict(
    #     start=-3.0,
    #     end=4,
    #     size=0.5
    # ),
    opacity=0.75
))

fig.update_layout(hovermode="x unified", 
    hoverlabel=dict(
        bgcolor="white",
        font_size=10,
    ))

fig.show()

In [11]:
# display(nested2series(stats, name=day).to_frame())

# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10*2, 8))
# ax1.set_title(f"Distance: {stats['distance']:.2e}, p-Value: {stats['pval']:.2e}")
# sns.histplot(data=data, x=col, hue="src", multiple="dodge", shrink=.8, ax=ax1, stat='density', common_norm=False)
# data2 = data.groupby('src')[col].value_counts(normalize=True).unstack().T.sort_values('ref', ascending=False).T
# data2.plot(kind='bar', stacked=True, ax=ax2)

In [12]:
from model_drift.drift.tabular import TabularDriftCalculator

dwc = TabularDriftCalculator(val.df)

dwc.add_drift_stat('age', KSDriftCalculator)
dwc.add_drift_stat('RelativeXRayExposure_DICOM', KSDriftCalculator)
dwc.add_drift_stat('WindowCenter_DICOM', KSDriftCalculator)
dwc.add_drift_stat('WindowWidth_DICOM', KSDriftCalculator)
# dwc.add_drift_stat('WindowWidth_DICOM', BasicDriftCalculator)
# dwc.add_drift_stat('WindowCenter_DICOM', BasicDriftCalculator)

dwc.add_drift_stat('Projection', ChiSqDriftCalculator, )
dwc.add_drift_stat('PatientSex_DICOM', ChiSqDriftCalculator)
dwc.add_drift_stat('Modality_DICOM', ChiSqDriftCalculator)

dwc.prepare()

dwc._metric_collections
results = dwc.predict(sample)

In [13]:
settings.PADCHEST_SPLIT_DATES

['2012-01-01', ' 2013-01-01']

In [14]:
pd.concat([val.df, test.df]).sort_index()

Unnamed: 0_level_0,ImageID,ImageDir,StudyDate_DICOM,StudyID,PatientID,PatientBirth,PatientSex_DICOM,ViewPosition_DICOM,Projection,MethodProjection,...,Edema,Lesion,No Finding,Opacity,Pleural Abnormalities,Pleural Effusion,Pneumonia,StudyDate,Frontal,age
StudyDate,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2013-01-01,216840111366964013534861372972013001133538714_...,18,20130101,216840111366964013534861372972013001133538714,131217184544060144428102249424455424427,1993-01-01,M,,PA,Manual review of DICOM fields,...,0,0,1,0,0,0,0,2013-01-01,True,20.000411
2013-01-02,216840111366964013534861372972013002095150191_...,12,20130102,216840111366964013534861372972013002095150191,20530356097946148636945669523456588004,1978-01-01,M,LATERAL,L,Manual review of DICOM fields,...,0,0,1,0,0,0,0,2013-01-02,False,35.004141
2013-01-02,216840111366964013534861372972013002084858300_...,12,20130102,216840111366964013534861372972013002084858300,330092038443399311375446356483168560535,1963-01-01,F,,AP_horizontal,Manual review of DICOM fields,...,0,0,0,0,0,0,0,2013-01-02,True,50.005134
2013-01-02,216840111366964013515091760022012318104233667_...,12,20130102,216840111366964013515091760022012318104233667,332813282591181625959886132259604742938,1952-01-01,F,LATERAL,L,Manual review of DICOM fields,...,0,0,0,0,0,0,0,2013-01-02,False,61.006044
2013-01-02,12752243479320242082624_02-011-127.png,12,20130102,12752243479320242082624,125588790726016741103762189093973766431,1938-01-01,F,POSTEROANTERIOR,PA,Manual review of DICOM fields,...,0,0,0,0,0,0,0,2013-01-02,True,75.004962
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2017-11-17,16195291732581043929483414230565603291-2_pm3a7...,7,20171117,16195291732581043929483414230565603291-2,254654689860114856846890774049753268092,1997-01-01,M,LL,L,Manual review of DICOM fields,...,0,0,1,0,0,0,0,2017-11-17,False,20.876541
2017-11-17,16195291732581043929483414230565603291-2_plkqu...,7,20171117,16195291732581043929483414230565603291-2,254654689860114856846890774049753268092,1997-01-01,M,PA,PA,Manual review of DICOM fields,...,0,0,1,0,0,0,0,2017-11-17,True,20.876541
2017-11-17,7599320594485360758641047593811181914_xksbmi.png,5,20171117,7599320594485360758641047593811181914,302138234998745245365531693724342029459,1959-01-01,F,PA,PA,Manual review of DICOM fields,...,0,0,1,0,0,0,0,2017-11-17,True,58.878690
2017-11-17,46789753258781872305172598914730827377_6l0g1x.png,4,20171117,46789753258781872305172598914730827377,78859491215125139641170526670617564733,1954-01-01,M,PA,PA,Manual review of DICOM fields,...,0,0,1,0,0,0,0,2017-11-17,True,63.878108


In [15]:
output = dwc.rolling_window_predict(pc.df.set_index('StudyDate'), stride='D', window='30D')
output

2007-05-03 - 2017-11-17: 100%|██████████| 3852/3852 [05:46<00:00, 11.12it/s, 2017-11-17]


Unnamed: 0_level_0,Modality_DICOM,Modality_DICOM,Modality_DICOM,Modality_DICOM,Modality_DICOM,PatientSex_DICOM,PatientSex_DICOM,PatientSex_DICOM,PatientSex_DICOM,PatientSex_DICOM,...,WindowCenter_DICOM,WindowWidth_DICOM,WindowWidth_DICOM,WindowWidth_DICOM,WindowWidth_DICOM,age,age,age,age,count
Unnamed: 0_level_1,chi2,chi2,chi2,chi2,chi2,chi2,chi2,chi2,chi2,chi2,...,ks,ks,ks,ks,ks,ks,ks,ks,ks,Unnamed: 21_level_1
Unnamed: 0_level_2,critical_diff,critical_value,distance,dof,pval,critical_diff,critical_value,distance,dof,pval,...,pval,critical_diff,critical_value,distance,pval,critical_diff,critical_value,distance,pval,Unnamed: 21_level_2
2007-05-03,,,0.000000,0.0,1.0,-2.672029,4.60517,1.933141,2.0,0.380385,...,0.366883,-0.009234,0.865430,0.856196,0.041359,-0.285163,0.865430,0.580267,0.352352,2.0
2007-05-04,,,0.000000,0.0,1.0,-2.672029,4.60517,1.933141,2.0,0.380385,...,0.366883,-0.009234,0.865430,0.856196,0.041359,-0.285163,0.865430,0.580267,0.352352,2.0
2007-05-05,,,0.000000,0.0,1.0,-2.672029,4.60517,1.933141,2.0,0.380385,...,0.366883,-0.009234,0.865430,0.856196,0.041359,-0.285163,0.865430,0.580267,0.352352,2.0
2007-05-06,,,0.000000,0.0,1.0,-2.672029,4.60517,1.933141,2.0,0.380385,...,0.366883,-0.009234,0.865430,0.856196,0.041359,-0.285163,0.865430,0.580267,0.352352,2.0
2007-05-07,,,0.000000,0.0,1.0,-2.672029,4.60517,1.933141,2.0,0.380385,...,0.366883,-0.009234,0.865430,0.856196,0.041359,-0.285163,0.865430,0.580267,0.352352,2.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2017-11-13,16887.583783,2.705543,16890.289326,1.0,0.0,5.084967,4.60517,9.690137,2.0,0.007867,...,0.000000,0.648238,0.039082,0.687320,0.000000,0.032866,0.039082,0.071948,0.000073,1026.0
2017-11-14,17026.223308,2.705543,17028.928851,1.0,0.0,4.884923,4.60517,9.490093,2.0,0.008695,...,0.000000,0.648807,0.038513,0.687320,0.000000,0.032426,0.038513,0.070939,0.000073,1058.0
2017-11-15,17073.011322,2.705543,17075.716866,1.0,0.0,3.432346,4.60517,8.037517,2.0,0.017975,...,0.000000,0.649531,0.037788,0.687320,0.000000,0.033562,0.037788,0.071350,0.000043,1101.0
2017-11-16,17446.791812,2.705543,17449.497355,1.0,0.0,4.182499,4.60517,8.787669,2.0,0.012353,...,0.000000,0.649300,0.038019,0.687320,0.000000,0.035870,0.038019,0.073889,0.000023,1087.0


In [112]:
output.swaplevel(0, 2, axis=1).swaplevel(1, 2, axis=1).sort_index(axis=1)


Unnamed: 0_level_0,Unnamed: 1_level_0,critical_diff,critical_diff,critical_diff,critical_diff,critical_diff,critical_diff,critical_diff,critical_value,critical_value,...,dof,dof,dof,pval,pval,pval,pval,pval,pval,pval
Unnamed: 0_level_1,count,Modality_DICOM,PatientSex_DICOM,Projection,RelativeXRayExposure_DICOM,WindowCenter_DICOM,WindowWidth_DICOM,age,Modality_DICOM,PatientSex_DICOM,...,Modality_DICOM,PatientSex_DICOM,Projection,Modality_DICOM,PatientSex_DICOM,Projection,RelativeXRayExposure_DICOM,WindowCenter_DICOM,WindowWidth_DICOM,age
Unnamed: 0_level_2,Unnamed: 1_level_2,chi2,chi2,chi2,ks,ks,ks,ks,chi2,chi2,...,chi2,chi2,chi2,chi2,chi2,chi2,ks,ks,ks,ks
2007-05-03,2.0,0.000000,-2.672029,-8.732325,-0.421933,-0.293731,-0.009234,-0.285163,0.000000,4.60517,...,0.0,2.0,5.0,1.0,0.380385,9.919749e-01,7.004698e-01,0.366883,0.041359,0.352352
2007-05-04,2.0,0.000000,-2.672029,-8.732325,-0.421933,-0.293731,-0.009234,-0.285163,0.000000,4.60517,...,0.0,2.0,5.0,1.0,0.380385,9.919749e-01,7.004698e-01,0.366883,0.041359,0.352352
2007-05-05,2.0,0.000000,-2.672029,-8.732325,-0.421933,-0.293731,-0.009234,-0.285163,0.000000,4.60517,...,0.0,2.0,5.0,1.0,0.380385,9.919749e-01,7.004698e-01,0.366883,0.041359,0.352352
2007-05-06,2.0,0.000000,-2.672029,-8.732325,-0.421933,-0.293731,-0.009234,-0.285163,0.000000,4.60517,...,0.0,2.0,5.0,1.0,0.380385,9.919749e-01,7.004698e-01,0.366883,0.041359,0.352352
2007-05-07,2.0,0.000000,-2.672029,-8.732325,-0.421933,-0.293731,-0.009234,-0.285163,0.000000,4.60517,...,0.0,2.0,5.0,1.0,0.380385,9.919749e-01,7.004698e-01,0.366883,0.041359,0.352352
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2017-11-13,1026.0,16887.583783,5.084967,105.489975,0.542871,0.589932,0.648238,0.032866,2.705543,4.60517,...,1.0,2.0,5.0,0.0,0.007867,4.103110e-23,4.859811e-316,0.000000,0.000000,0.000073
2017-11-14,1058.0,17026.223308,4.884923,110.036481,0.543470,0.590501,0.648807,0.032426,2.705543,4.60517,...,1.0,2.0,5.0,0.0,0.008695,4.474460e-24,0.000000e+00,0.000000,0.000000,0.000073
2017-11-15,1101.0,17073.011322,3.432346,114.876367,0.544231,0.591225,0.649531,0.033562,2.705543,4.60517,...,1.0,2.0,5.0,0.0,0.017975,4.219512e-25,0.000000e+00,0.000000,0.000000,0.000043
2017-11-16,1087.0,17446.791812,4.182499,117.806477,0.543988,0.590994,0.649300,0.035870,2.705543,4.60517,...,1.0,2.0,5.0,0.0,0.012353,1.009154e-25,0.000000e+00,0.000000,0.000000,0.000023


Index(['Modality_DICOM', 'PatientSex_DICOM', 'Projection',
       'RelativeXRayExposure_DICOM', 'WindowCenter_DICOM', 'WindowWidth_DICOM',
       'age'],
      dtype='object')

In [177]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=1, cols=1, shared_xaxes=True, vertical_spacing=0.05,
                    horizontal_spacing=.025, shared_yaxes=True)

output["Modality_DICOM"] = output["Modality_DICOM"].fillna(0)

graph_view = output.swaplevel(0, 2, axis=1).swaplevel(1, 2, axis=1).sort_index(axis=1)['critical_diff']

weights = {
    'Modality_DICOM': 1,
    'PatientSex_DICOM': .1,
    'Projection': .1,
    'RelativeXRayExposure_DICOM': 1,
    'WindowCenter_DICOM': 1,
    'WindowWidth_DICOM': 1,
    'age': 1
}
graph_view = graph_view[weights]


def w_avg(df, weights):
    cols = df.columns.get_level_values(0)
    cols = [c for c in weights if c in cols]
    weights = np.array([weights[c] for c in cols])
    weights = weights/weights.sum()
    tmp = df[cols].copy()
    for c, w in zip(cols, weights):
        tmp[c] = tmp[c]*w
    return tmp.sum(axis=1)


def sigmoid(x):
    return np.where(x < 0, np.exp(x)/(1 + np.exp(x)), 1/(1 + np.exp(-x)))


def tanh_clip(x):
    return np.tanh(x).clip(0)


act = tanh_clip


def y_func(y):
    # return y.apply(lambda x: np.clip(x, -1, 1))
    return y.apply(act)

def smooth(y: pd.DataFrame):
    return y.ewm(span=30).mean()

y = smooth(
    w_avg(y_func(graph_view), weights)
)

fig.add_trace(go.Line(x=g.index, y=y, showlegend=True, name="Combined", hovertemplate="%{y: .5f}"), row=1, col=1)

single_disp = dict(line=dict(dash="dot", width=.8))
g = graph_view
for c in list(g.columns.to_flat_index()):
    name = "{} ({})".format(*c)
    y = smooth(y_func(g[c]))
    fig.add_trace(go.Line(x=g.index, y=y,
                          showlegend=True,
                          hovertemplate=f"%{{y: .5f}}, weight={weights[c[0]]:.2f}",
                          legendgroup=str(name),
                          name=str(name),
                          **single_disp
                          ),
                  row=1, col=1)


fig.update_layout(title=f"Input Data Drift, Statisical Distance ({act.__name__})")
fig.update_layout(hovermode="x unified")

fig.add_shape(type='line',
              x0=settings.PADCHEST_SPLIT_DATES[0],
              y0=0,
              x1=settings.PADCHEST_SPLIT_DATES[0],
              y1=1,
              line=dict(color='black', dash='dot'),
              xref='x',
              yref='paper'
              )
fig.add_annotation(textangle=0,
                   xref="x",
                   yref="paper", x=settings.PADCHEST_SPLIT_DATES[0], y=1.08,
                   text=f"Val Start<br />({settings.PADCHEST_SPLIT_DATES[0]})", showarrow=False,)
fig.add_shape(type='line', x0=settings.PADCHEST_SPLIT_DATES[1], y0=0, x1=settings.PADCHEST_SPLIT_DATES[1], y1=1,
              line=dict(color='black', dash='dot'), xref='x', yref='paper'
              )
fig.add_annotation(textangle=0,
                   xref="x",
                   yref="paper", x=settings.PADCHEST_SPLIT_DATES[1], y=1.08,
                   text=f"Test Start<br />({settings.PADCHEST_SPLIT_DATES[1].strip()})", showarrow=False,)
fig.update_layout(height=600)
fig.show()


In [67]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots (rows=2, cols=1, shared_xaxes=True, vertical_spacing = 0.05)
graph_view = output.swaplevel(0,-1, axis=1)
g = graph_view['pval']['ks']
for c in list(g.columns.to_flat_index()):
    y = g[c].rolling(3).mean()
    # fig.add_trace(go.Line(x=a.index, y=a[c], 
    # customdata=a.values,
    # hovertemplate="<br />".join([f"{c}: %{{customdata[{i}]:d}}" for i, c in enumerate(a)]),
    # name=c), row=1, col=1)

    fig.add_trace(go.Line(x=g.index, y=y, 
    # customdata=output_swap[c][hcols].values,
    # hovertemplate='distance: %{customdata[0]:.3f}, p_val: %{customdata[3]:.3f}, is_drift:%{customdata[1]:d}, nobs: %{customdata[2]:d}',
    name=str(c)), 
    row=1, col=1)


g = graph_view['pval']['chi2']
for c in list(g.columns.to_flat_index()):
    y = g[c].rolling(3).mean()
    # fig.add_trace(go.Line(x=a.index, y=a[c], 
    # customdata=a.values,
    # hovertemplate="<br />".join([f"{c}: %{{customdata[{i}]:d}}" for i, c in enumerate(a)]),
    # name=c), row=1, col=1)

    fig.add_trace(go.Line(x=g.index, y=y, 
    # customdata=output_swap[c][hcols].values,
    # hovertemplate='distance: %{customdata[0]:.3f}, p_val: %{customdata[3]:.3f}, is_drift:%{customdata[1]:d}, nobs: %{customdata[2]:d}',
    name=str(c)), 
    row=2, col=1)

    

fig.update_layout(title = "Input Data Drift, P Values")
fig.update_layout(hovermode="x unified")
fig.update_layout(height=600)
fig.show()


plotly.graph_objs.Line is deprecated.
Please replace it with one of the following more specific types
  - plotly.graph_objs.scatter.Line
  - plotly.graph_objs.layout.shape.Line
  - etc.




In [None]:
compare = ["2014-11-30", "2015-01-08"]

compare = pd.to_datetime(compare)
compare = pd.date_range(
    compare[0], compare[1], freq='3D'
)

cols=["WindowWidth_DICOM", "WindowCenter_DICOM"]
stats, data = dwc.drilldown(test.df, compare, cols=cols)
for c in cols:
    data[c] = pd.to_numeric(data[c], errors='coerce')
display(stats.T)


In [None]:

# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10*2, 8))
fig = make_subplots (rows=2, cols=1, shared_xaxes=True, vertical_spacing = 0.05)

ref = data[data['src'] == "_ref"]

ref
rem = data[data['src'] != "_ref"]


cycle = ['red', 'green', 'purple', 'orange']

colors = n_colors('rgb(5, 200, 200)', 'rgb(200, 10, 10)', len(compare), colortype='rgb')

for r, y in enumerate(cols):
    fig.add_trace(
        go.Violin(
            x=ref[y],
            name="ref",
            showlegend=False,
            ),
            row=r+1, col=1
        )
    for i, (name, sample) in enumerate(rem.groupby('src')):
        fig.add_trace(
            go.Violin(
                x=sample[y],
                name=str(name.date())+'.',
                line_color=colors[i],
                scalemode="width",
                showlegend=False
            ), row=r+1, col=1
        )

fig.update_layout(
    width=1600,
    height=400*2
)
fig.update_layout(hovermode=False, 
#     hoverlabel=dict(
#         bgcolor="white",
#         font_size=10,
    # )
    )
fig.update_traces(orientation='h', side='positive', width=3, points=False)
# fig.update_layout(xaxis_showgrid=False, xaxis_zeroline=False)
fig.show()
# sns.histplot(data, x=cols[0], stat='density', kde=True, hue='src', common_norm=False, ax=ax1)
# # ax1.set_title(f"Distance: {stats['ks']['distance']:.2e}, p-Value: {stats['ks']['pval']:.2e}")
# sns.violinplot(data=data, y=cols[0], x='src', ax=ax2, inner="quartile", split=True)

In [None]:
compare = ["2015-12-23", "2016-02-15"]

cols=["Modality_DICOM"]
stats, data = dwc.compare_dates(test.df, compare, cols=cols)
display(stats)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10*2, 8), sharey=True)
col="Modality_DICOM"
sns.histplot(data=data, x=col, hue="src", multiple="dodge", shrink=.8, ax=ax1, stat='density', common_norm=False)
data2 = data.groupby('src')[col].value_counts(normalize=True).unstack().T.sort_values('_ref', ascending=False).T
data2.plot(kind='bar', stacked=True, ax=ax2)

In [None]:
compare = ["2015-12-23", "2016-02-15"]

cols=["Modality_DICOM"]
stats, data = dwc.drilldown(test.df, compare, cols=cols)
display(stats)

# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10*2, 8))
fig = make_subplots (rows=len(["Modality_DICOM"]), cols=1, shared_xaxes=True, vertical_spacing = 0.05)

ref = data[data['src'] == "_ref"]

ref
rem = data[data['src'] != "_ref"]




cycle = ['red', 'green', 'purple', 'orange']
y = cols[0]
for r, y in enumerate(cols):
    for i, (name, sample) in enumerate(rem.groupby('src')):
        fig.add_trace(
go.Histogram(
    x=ref[y],
    histnorm='percent',
    name='ref', # name used in legend and hover labels
    # xbins=dict( # bins used for histogram
    #     start=-4.0,
    #     end=3.0,
    #     size=0.5
    # ),
    opacity=0.75
),
            row=r+1, col=1
        )

        fig.add_trace(
            go.Histogram(
    x=sample[y],
    histnorm='percent',
    name=day,
    # xbins=dict(
    #     start=-3.0,
    #     end=4,
    #     size=0.5
    # ),
    opacity=0.75
), row=r+1, col=1
        )

fig.update_layout(
    width=600*2,
    height=400
)
fig.update_layout(hovermode="x unified", 
#     hoverlabel=dict(
#         bgcolor="white",
#         font_size=10,
    # )
    )
# fig.update_layout(violinmode='overlay')
# fig.update_traces(box_visible=False, meanline_visible=True)
fig.show()

# sns.histplot(data, x=cols[0], stat='density', kde=True, hue='src', common_norm=False, ax=ax1)
# # ax1.set_title(f"Distance: {stats['ks']['distance']:.2e}, p-Value: {stats['ks']['pval']:.2e}")
# sns.violinplot(data=data, y=cols[0], x='src', ax=ax2, inner="quartile", split=True)


In [None]:
FLOAT = "f"
CAT = 'c'

cols = {
'age': FLOAT,
'image_size': FLOAT,
'Projection': CAT,
"PatientSex_DICOM": CAT,
# "ViewPosition_DICOM": CAT,
"Modality_DICOM": CAT,
"Manufacturer_DICOM": CAT,
# "PhotometricInterpretation_DICOM": CAT,
# "PixelRepresentation_DICOM": CAT,
# "PixelAspectRatio_DICOM": CAT,
# "SpatialResolution_DICOM": CAT,
# "BitsStored_DICOM": CAT,
"WindowCenter_DICOM": FLOAT,
"WindowWidth_DICOM": FLOAT,
"Rows_DICOM": FLOAT,
"Columns_DICOM": FLOAT,
"XRayTubeCurrent_DICOM": CAT,
# "Exposure_DICOM": CAT,
# "ExposureInuAs_DICOM": FLOAT,
# "RelativeXRayExposure_DICOM": FLOAT,
}