In [80]:
from IPython.display import display

import pandas as pd
import warnings
from model_drift import settings
from model_drift.data.utils import nested2series
import matplotlib.pylab as plt
import numpy as np
import seaborn as sns
from model_drift.drift.numeric import KSDriftCalculator, BasicDriftCalculator
from model_drift.drift.categorical import ChiSqDriftCalculator
from model_drift.drift.collection import DriftCollectionCalculator

from model_drift.data.padchest import PadChest
import plotly.graph_objects as go

warnings.filterwarnings("ignore")

# Real Valued Number Drift Detection

In [81]:
# Load padchest CSV
train, val, test = PadChest.splits(studydate_index=True)

In [82]:
pd.concat(
    {
        "train": train.df["StudyDate"].describe(datetime_is_numeric=True),
        "val": val.df["StudyDate"].describe(datetime_is_numeric=True),
        "test": test.df["StudyDate"].describe(datetime_is_numeric=True),
    },
    axis=1,
)

Unnamed: 0,train,val,test
count,91768,22176,46917
mean,2011-01-06 01:16:19.478685696,2013-06-16 00:14:59.999999744,2015-08-29 00:15:02.359485696
min,2007-05-03 00:00:00,2013-01-01 00:00:00,2014-01-01 00:00:00
25%,2010-01-19 00:00:00,2013-03-10 00:00:00,2014-08-08 00:00:00
50%,2011-01-18 00:00:00,2013-06-04 00:00:00,2015-06-09 00:00:00
75%,2012-01-11 00:00:00,2013-09-25 00:00:00,2016-09-13 00:00:00
max,2012-12-28 00:00:00,2013-12-31 00:00:00,2017-11-17 00:00:00


In [83]:
day = "2014-05-05"
window = "30D"

day_dt = pd.to_datetime(day)
delta = pd.tseries.frequencies.to_offset(window)
sample = test.df.loc[str(day_dt-delta):str(day_dt)]
sample['StudyDate'].describe()

count                    2255
unique                     26
top       2014-05-05 00:00:00
freq                      219
first     2014-04-07 00:00:00
last      2014-05-05 00:00:00
Name: StudyDate, dtype: object

In [84]:
data = pd.concat([val.df.assign(src="ref"), sample.assign(src="sample")]).reset_index(drop=True)
data['terb'] = np.random.uniform(size=len(data))+ (data['src']=='sample')*1.3
data['dist'] = ''

In [85]:
ks_test = KSDriftCalculator(val.df['age'].values)
stats_test = BasicDriftCalculator(val.df['age'].values)
rv_test = DriftCollectionCalculator([ks_test, stats_test])
stats = rv_test(sample['age'].values)

In [101]:

display(nested2series(stats, name=day).to_frame())
fig = go.Figure()

i = 0
x = day
y = "age"
ref = val.df.assign(src="ref")
fig.add_trace(
    go.Violin(
        x=[x] * len(ref),
        y=ref[y],
        legendgroup=day,
        scalegroup=day,
        name="Ref",
        # side="negative",
        line_color="blue",
    )
)

sample = sample.assign(src="sample")
fig.add_trace(
    go.Violin(
        x=[x] * len(sample),
        y=sample[y],
        legendgroup=x,
        scalegroup=x,
        name=day,
        hovertemplate = 'Price: %{y:$.2f}<extra></extra>',
        # side="positive",
        # line_color='blue'
    )
)
fig.update_layout(
    width=600,
)
fig.update_layout(hovermode="y unified", 
    hoverlabel=dict(
        bgcolor="white",
        font_size=10,
    ))
# fig.update_layout(violinmode='group')
# fig.update_traces(box_visible=True, meanline_visible=True)

ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().

In [8]:
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10*2, 8))

# sns.histplot(data, x='age', stat='density', kde=True, hue='src', common_norm=False, ax=ax1)
# ax1.set_title(f"Distance: {stats['ks']['distance']:.2e}, p-Value: {stats['ks']['pval']:.2e}")
# sns.violinplot(data=data, y='age', hue='src', x='dist', ax=ax2, inner="quartile", split=True)


In [87]:
col = "Projection"
chi2_test = ChiSqDriftCalculator(val.df[col].values)
stats = chi2_test(sample[col].values)
stats

{'distance': 6.751263401336422,
 'pval': 0.2398060810665436,
 'dof': 5,
 'critical_value': 9.236356899781123,
 'critical_diff': -2.4850934984447015}

In [96]:
fig = go.Figure()


fig.add_trace(go.Histogram(
    x=ref[col],
    histnorm='probability',
    name='ref', # name used in legend and hover labels
    # xbins=dict( # bins used for histogram
    #     start=-4.0,
    #     end=3.0,
    #     size=0.5
    # ),
    opacity=0.75
))
fig.add_trace(go.Histogram(
    x=sample[col],
    histnorm='probability',
    name=day,
    # xbins=dict(
    #     start=-3.0,
    #     end=4,
    #     size=0.5
    # ),
    opacity=0.75
))

fig.update_layout(hovermode="x unified", 
    hoverlabel=dict(
        bgcolor="white",
        font_size=10,
    ))

fig.show()

In [11]:
# display(nested2series(stats, name=day).to_frame())

# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10*2, 8))
# ax1.set_title(f"Distance: {stats['distance']:.2e}, p-Value: {stats['pval']:.2e}")
# sns.histplot(data=data, x=col, hue="src", multiple="dodge", shrink=.8, ax=ax1, stat='density', common_norm=False)
# data2 = data.groupby('src')[col].value_counts(normalize=True).unstack().T.sort_values('ref', ascending=False).T
# data2.plot(kind='bar', stacked=True, ax=ax2)

In [12]:
from model_drift.drift.tabular import TabularDriftCalculator

dwc = TabularDriftCalculator(val.df)

dwc.add_drift_stat('age', KSDriftCalculator)
dwc.add_drift_stat('RelativeXRayExposure_DICOM', KSDriftCalculator)
dwc.add_drift_stat('WindowCenter_DICOM', KSDriftCalculator)
dwc.add_drift_stat('WindowWidth_DICOM', KSDriftCalculator)

dwc.add_drift_stat('Projection', ChiSqDriftCalculator)
dwc.add_drift_stat('PatientSex_DICOM', ChiSqDriftCalculator)
dwc.add_drift_stat('Modality_DICOM', ChiSqDriftCalculator)

dwc.prepare()

dwc._metric_collections
results = dwc.predict(sample)

In [13]:
output = dwc.rolling_window_predict(test.df, stride='D')
output

100%|██████████| 1417/1417 [01:58<00:00, 12.00it/s]


Unnamed: 0_level_0,Modality_DICOM,Modality_DICOM,Modality_DICOM,Modality_DICOM,Modality_DICOM,PatientSex_DICOM,PatientSex_DICOM,PatientSex_DICOM,PatientSex_DICOM,PatientSex_DICOM,...,WindowCenter_DICOM,WindowWidth_DICOM,WindowWidth_DICOM,WindowWidth_DICOM,WindowWidth_DICOM,age,age,age,age,count
Unnamed: 0_level_1,chi2,chi2,chi2,chi2,chi2,chi2,chi2,chi2,chi2,chi2,...,ks,ks,ks,ks,ks,ks,ks,ks,ks,Unnamed: 21_level_1
Unnamed: 0_level_2,critical_diff,critical_value,distance,dof,pval,critical_diff,critical_value,distance,dof,pval,...,pval,critical_diff,critical_value,distance,pval,critical_diff,critical_value,distance,pval,Unnamed: 21_level_2
2014-01-01,,,0.000000,0.0,1.0,-4.236511,4.60517,0.368660,2.0,0.831661,...,0.107469,-0.019317,0.706637,0.687320,0.061141,-0.173854,0.706637,0.532783,0.261279,3.0
2014-01-02,,,0.000000,0.0,1.0,-4.600601,4.60517,0.004569,2.0,0.997718,...,0.565718,0.053807,0.353390,0.407197,0.026004,0.023774,0.353390,0.377165,0.048320,12.0
2014-01-03,,,0.000000,0.0,1.0,-4.022248,4.60517,0.582922,2.0,0.747171,...,0.334055,-0.035279,0.280890,0.245612,0.170414,0.043643,0.280890,0.324533,0.027848,19.0
2014-01-04,,,0.000000,0.0,1.0,-4.022248,4.60517,0.582922,2.0,0.747171,...,0.334055,-0.035279,0.280890,0.245612,0.170414,0.043643,0.280890,0.324533,0.027848,19.0
2014-01-05,,,0.000000,0.0,1.0,-4.022248,4.60517,0.582922,2.0,0.747171,...,0.334055,-0.035279,0.280890,0.245612,0.170414,0.043643,0.280890,0.324533,0.027848,19.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2017-11-13,16887.583783,2.705543,16890.289326,1.0,0.0,5.084967,4.60517,9.690137,2.0,0.007867,...,0.000000,0.648238,0.039082,0.687320,0.000000,0.032866,0.039082,0.071948,0.000073,1026.0
2017-11-14,17026.223308,2.705543,17028.928851,1.0,0.0,4.884923,4.60517,9.490093,2.0,0.008695,...,0.000000,0.648807,0.038513,0.687320,0.000000,0.032426,0.038513,0.070939,0.000073,1058.0
2017-11-15,17073.011322,2.705543,17075.716866,1.0,0.0,3.432346,4.60517,8.037517,2.0,0.017975,...,0.000000,0.649531,0.037788,0.687320,0.000000,0.033562,0.037788,0.071350,0.000043,1101.0
2017-11-16,17446.791812,2.705543,17449.497355,1.0,0.0,4.182499,4.60517,8.787669,2.0,0.012353,...,0.000000,0.649300,0.038019,0.687320,0.000000,0.035870,0.038019,0.073889,0.000023,1087.0


In [69]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots (rows=2, cols=1,shared_xaxes=True, vertical_spacing = 0.05)
graph_view = output.swaplevel(0,-1, axis=1, )
g = graph_view['distance']['ks']
for c in list(g.columns.to_flat_index()):
    y = g[c].rolling(3).mean()
    # fig.add_trace(go.Line(x=a.index, y=a[c], 
    # customdata=a.values,
    # hovertemplate="<br />".join([f"{c}: %{{customdata[{i}]:d}}" for i, c in enumerate(a)]),
    # name=c), row=1, col=1)

    fig.add_trace(go.Line(x=g.index, y=y, 
    # customdata=output_swap[c][hcols].values,
    # hovertemplate='distance: %{customdata[0]:.3f}, p_val: %{customdata[3]:.3f}, is_drift:%{customdata[1]:d}, nobs: %{customdata[2]:d}',
    name=str(c)), 
    row=1, col=1)


g = graph_view['distance']['chi2']
for c in list(g.columns.to_flat_index()):
    y = g[c].rolling(3).mean()
    # fig.add_trace(go.Line(x=a.index, y=a[c], 
    # customdata=a.values,
    # hovertemplate="<br />".join([f"{c}: %{{customdata[{i}]:d}}" for i, c in enumerate(a)]),
    # name=c), row=1, col=1)

    fig.add_trace(go.Line(x=g.index, y=y, 
    # customdata=output_swap[c][hcols].values,
    # hovertemplate='distance: %{customdata[0]:.3f}, p_val: %{customdata[3]:.3f}, is_drift:%{customdata[1]:d}, nobs: %{customdata[2]:d}',
    name=str(c)), 
    row=2, col=1)

    

fig.update_layout(title = "Input Data Drift, Statisical Distance")
fig.update_layout(hovermode="x unified")
fig.update_layout(height=600)
fig.show()


plotly.graph_objs.Line is deprecated.
Please replace it with one of the following more specific types
  - plotly.graph_objs.scatter.Line
  - plotly.graph_objs.layout.shape.Line
  - etc.




In [18]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots (rows=2, cols=1, shared_xaxes=True, vertical_spacing = 0.05)
graph_view = output.swaplevel(0,-1, axis=1)
g = graph_view['pval']['ks']
for c in list(g.columns.to_flat_index()):
    y = g[c].rolling(3).mean()
    # fig.add_trace(go.Line(x=a.index, y=a[c], 
    # customdata=a.values,
    # hovertemplate="<br />".join([f"{c}: %{{customdata[{i}]:d}}" for i, c in enumerate(a)]),
    # name=c), row=1, col=1)

    fig.add_trace(go.Line(x=g.index, y=y, 
    # customdata=output_swap[c][hcols].values,
    # hovertemplate='distance: %{customdata[0]:.3f}, p_val: %{customdata[3]:.3f}, is_drift:%{customdata[1]:d}, nobs: %{customdata[2]:d}',
    name=str(c)), 
    row=1, col=1)


g = graph_view['pval']['chi2']
for c in list(g.columns.to_flat_index()):
    y = g[c].rolling(3).mean()
    # fig.add_trace(go.Line(x=a.index, y=a[c], 
    # customdata=a.values,
    # hovertemplate="<br />".join([f"{c}: %{{customdata[{i}]:d}}" for i, c in enumerate(a)]),
    # name=c), row=1, col=1)

    fig.add_trace(go.Line(x=g.index, y=y, 
    # customdata=output_swap[c][hcols].values,
    # hovertemplate='distance: %{customdata[0]:.3f}, p_val: %{customdata[3]:.3f}, is_drift:%{customdata[1]:d}, nobs: %{customdata[2]:d}',
    name=str(c)), 
    row=2, col=1)

    

fig.update_layout(title = "Input Data Drift, P Values")
fig.update_layout(hovermode="x unified")
fig.update_layout(height=600)
fig.show()

In [115]:
compare = ["2014-11-30", "2015-01-08"]

cols=["WindowWidth_DICOM", "WindowCenter_DICOM"]
stats, data = dwc.drilldown(test.df, compare, cols=cols)
for c in cols:
    data[c] = pd.to_numeric(data[c], errors='coerce')
display(stats.T)

# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10*2, 8))
fig = make_subplots (rows=2, cols=1, shared_xaxes=True, vertical_spacing = 0.05)

ref = data[data['src'] == "_ref"]

ref
rem = data[data['src'] != "_ref"]


cycle = ['red', 'green', 'purple', 'orange']
y = cols[0]
for r, y in enumerate(cols):
    for i, (name, sample) in enumerate(rem.groupby('src')):
        fig.add_trace(
        go.Violin(
            x=[name] * len(ref),
            y=ref[y],
            legendgroup=name,
            # scalegroup=name,
            name="ref",
            side="negative",
            line_color="blue",
            scalemode="width",
            showlegend=r==0,
            ),
            row=r+1, col=1
        )

        fig.add_trace(
            go.Violin(
                x=[name] * len(sample),
                y=sample[y],
                legendgroup=name,
                # scalegroup=name,
                name=name,
                side="positive",
                line_color=cycle[i],
                scalemode="width",
                showlegend=r==0
            ), row=r+1, col=1
        )

fig.update_layout(
    width=800,
    height=400*2
)
fig.update_layout(hovermode="x unified", 
#     hoverlabel=dict(
#         bgcolor="white",
#         font_size=10,
    # )
    )
fig.update_layout(violinmode='overlay')
fig.update_traces(box_visible=False, meanline_visible=True)
fig.show()

# sns.histplot(data, x=cols[0], stat='density', kde=True, hue='src', common_norm=False, ax=ax1)
# # ax1.set_title(f"Distance: {stats['ks']['distance']:.2e}, p-Value: {stats['ks']['pval']:.2e}")
# sns.violinplot(data=data, y=cols[0], x='src', ax=ax2, inner="quartile", split=True)

Unnamed: 0_level_0,WindowCenter_DICOM,WindowCenter_DICOM,WindowCenter_DICOM,WindowCenter_DICOM,WindowWidth_DICOM,WindowWidth_DICOM,WindowWidth_DICOM,WindowWidth_DICOM,count
Unnamed: 0_level_1,ks,ks,ks,ks,ks,ks,ks,ks,Unnamed: 9_level_1
Unnamed: 0_level_2,distance,pval,critical_value,critical_diff,distance,pval,critical_value,critical_diff,Unnamed: 9_level_2
2014-11-30,0.134623,2.521222e-21,0.03368,0.100943,0.138602,1.392474e-22,0.03368,0.104922,1404.0
2015-01-08,0.626396,8.049977999999999e-281,0.045034,0.581362,0.684702,0.0,0.045034,0.639668,764.0


In [None]:
compare = ["2015-12-23", "2016-02-15"]

cols=["Modality_DICOM"]
stats, data = dwc.compare_dates(test.df, compare, cols=cols)
display(stats)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10*2, 8), sharey=True)
col="Modality_DICOM"
sns.histplot(data=data, x=col, hue="src", multiple="dodge", shrink=.8, ax=ax1, stat='density', common_norm=False)
data2 = data.groupby('src')[col].value_counts(normalize=True).unstack().T.sort_values('_ref', ascending=False).T
data2.plot(kind='bar', stacked=True, ax=ax2)

In [108]:
compare = ["2015-12-23", "2016-02-15"]

cols=["Modality_DICOM"]
stats, data = dwc.drilldown(test.df, compare, cols=cols)
display(stats)

# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10*2, 8))
fig = make_subplots (rows=len(["Modality_DICOM"]), cols=1, shared_xaxes=True, vertical_spacing = 0.05)

ref = data[data['src'] == "_ref"]

ref
rem = data[data['src'] != "_ref"]




cycle = ['red', 'green', 'purple', 'orange']
y = cols[0]
for r, y in enumerate(cols):
    for i, (name, sample) in enumerate(rem.groupby('src')):
        fig.add_trace(
go.Histogram(
    x=ref[y],
    histnorm='percent',
    name='ref', # name used in legend and hover labels
    # xbins=dict( # bins used for histogram
    #     start=-4.0,
    #     end=3.0,
    #     size=0.5
    # ),
    opacity=0.75
),
            row=r+1, col=1
        )

        fig.add_trace(
            go.Histogram(
    x=sample[y],
    histnorm='percent',
    name=day,
    # xbins=dict(
    #     start=-3.0,
    #     end=4,
    #     size=0.5
    # ),
    opacity=0.75
), row=r+1, col=1
        )

fig.update_layout(
    width=600*2,
    height=400
)
fig.update_layout(hovermode="x unified", 
#     hoverlabel=dict(
#         bgcolor="white",
#         font_size=10,
    # )
    )
# fig.update_layout(violinmode='overlay')
# fig.update_traces(box_visible=False, meanline_visible=True)
fig.show()

# sns.histplot(data, x=cols[0], stat='density', kde=True, hue='src', common_norm=False, ax=ax1)
# # ax1.set_title(f"Distance: {stats['ks']['distance']:.2e}, p-Value: {stats['ks']['pval']:.2e}")
# sns.violinplot(data=data, y=cols[0], x='src', ax=ax2, inner="quartile", split=True)


Unnamed: 0,Unnamed: 1,Unnamed: 2,2015-12-23,2016-02-15
Modality_DICOM,chi2,distance,0.0,14204.568447
Modality_DICOM,chi2,pval,1.0,0.0
Modality_DICOM,chi2,dof,0.0,1.0
Modality_DICOM,chi2,critical_value,,2.705543
Modality_DICOM,chi2,critical_diff,,14201.862904
count,,,858.0,1203.0


In [None]:
FLOAT = "f"
CAT = 'c'

cols = {
'age': FLOAT,
'image_size': FLOAT,
'Projection': CAT,
"PatientSex_DICOM": CAT,
# "ViewPosition_DICOM": CAT,
"Modality_DICOM": CAT,
"Manufacturer_DICOM": CAT,
# "PhotometricInterpretation_DICOM": CAT,
# "PixelRepresentation_DICOM": CAT,
# "PixelAspectRatio_DICOM": CAT,
# "SpatialResolution_DICOM": CAT,
# "BitsStored_DICOM": CAT,
"WindowCenter_DICOM": FLOAT,
"WindowWidth_DICOM": FLOAT,
"Rows_DICOM": FLOAT,
"Columns_DICOM": FLOAT,
"XRayTubeCurrent_DICOM": CAT,
# "Exposure_DICOM": CAT,
# "ExposureInuAs_DICOM": FLOAT,
# "RelativeXRayExposure_DICOM": FLOAT,
}