In [1]:
from IPython.display import display

import pandas as pd
import warnings
from model_drift import settings
from model_drift.data.utils import nested2series
import matplotlib.pylab as plt
import numpy as np
import seaborn as sns
from model_drift.drift.sampler import Sampler
from model_drift.drift.numeric import KSDriftCalculator, BasicDriftCalculator
from model_drift.drift.categorical import ChiSqDriftCalculator
from model_drift.drift.collection import DriftCollectionCalculator

from model_drift.data.padchest import PadChest
import plotly.graph_objects as go

warnings.filterwarnings("ignore")

# Real Valued Number Drift Detection

In [2]:
# Load padchest CSV
pc = PadChest(settings.PADCHEST_FILENAME)
train, val, test = pc.split(settings.PADCHEST_SPLIT_DATES, studydate_index=True)

In [3]:
pd.concat(
    {
        "all": pc.df["StudyDate"].describe(datetime_is_numeric=True),
        "train": train.df["StudyDate"].describe(datetime_is_numeric=True),
        "val": val.df["StudyDate"].describe(datetime_is_numeric=True),
        "test": test.df["StudyDate"].describe(datetime_is_numeric=True),
    },
    axis=1,
)

Unnamed: 0,all,train,val,test
count,160819,91726,22176,46917
mean,2012-09-14 20:54:45.910246912,2011-01-06 03:16:23.616423168,2013-06-16 00:14:59.999999744,2015-08-29 00:15:02.359485696
min,2007-05-03 00:00:00,2007-05-03 00:00:00,2013-01-01 00:00:00,2014-01-01 00:00:00
25%,2010-10-27 00:00:00,2010-01-19 00:00:00,2013-03-10 00:00:00,2014-08-08 00:00:00
50%,2012-06-18 00:00:00,2011-01-18 00:00:00,2013-06-04 00:00:00,2015-06-09 00:00:00
75%,2014-05-28 00:00:00,2012-01-11 00:00:00,2013-09-25 00:00:00,2016-09-13 00:00:00
max,2017-11-17 00:00:00,2012-12-28 00:00:00,2013-12-31 00:00:00,2017-11-17 00:00:00


In [4]:
day = "2014-05-05"
window = "30D"

day_dt = pd.to_datetime(day)
delta = pd.tseries.frequencies.to_offset(window)
sample = test.df.loc[str(day_dt-delta):str(day_dt)]
sample['StudyDate'].describe()

count                    2255
unique                     26
top       2014-05-05 00:00:00
freq                      219
first     2014-04-07 00:00:00
last      2014-05-05 00:00:00
Name: StudyDate, dtype: object

In [5]:
ks_test = KSDriftCalculator(val.df['age'].values)
stats_test = BasicDriftCalculator(val.df['age'].values)


rv_test = DriftCollectionCalculator([ks_test, stats_test])
stats = nested2series(rv_test.predict(sample['age'].values))
stats

ks     distance          7.797949e-02
       pval              2.839598e-11
       critical_value    2.705099e-02
       critical_diff     5.092850e-02
stats  mean              6.290306e+01
       std               1.795997e+01
       median            6.628473e+01
dtype: float64

In [6]:
len(sample)

2255

In [7]:
sampler = Sampler(1000, replacement=True)

rv_test.predict(sample['age'].values, sampler=sampler, n_samples=10, agg=None).stack()

ks     distance  0       0.075132
                 1       0.067455
                 2       0.063605
                 3       0.064011
                 4       0.059747
                          ...    
stats  median    6      66.309370
                 7      65.307296
                 8      65.308665
                 9      65.802857
                 obs    66.284729
Length: 77, dtype: float64

In [8]:
rv_test.predict(sample['age'].values, sampler=sampler, n_samples=10)

ks     distance        mean    8.974066e-02
                       std     1.774263e-02
                       obs     7.797949e-02
       pval            mean    1.808650e-03
                       std     5.684280e-03
                       obs     2.839598e-11
       critical_value  mean    3.956444e-02
                       std     7.314236e-18
                       obs     2.705099e-02
       critical_diff   mean    5.017621e-02
                       std     1.774263e-02
                       obs     5.092850e-02
stats  mean            mean    6.317733e+01
                       std     6.544876e-01
                       obs     6.290306e+01
       std             mean    1.795816e+01
                       std     2.918517e-01
                       obs     1.795997e+01
       median          mean    6.626173e+01
                       std     7.608907e-01
                       obs     6.628473e+01
dtype: float64

In [9]:

display(nested2series(stats, name=day).to_frame())
fig = go.Figure()

i = 0
x = day
y = "age"
ref = val.df.assign(src="ref")
fig.add_trace(
    go.Violin(
        # x=['ref'] * len(ref),
        x=ref[y],
        # legendgroup=day,
        # scalegroup=day,
        name="Ref",
        # side="positive",
        # line_color="blue",
    )
)

sample = sample.assign(src="sample")
fig.add_trace(
    go.Violin(
        # x=[x] * len(sample),
        x=sample[y],
        # legendgroup=x,
        # scalegroup=x,
        name=day+'~',
        hovertemplate = 'Price: %{y:$.2f}<extra></extra>',
        # side="positive",
        # line_color='blue'
    )
)
fig.update_layout(hovermode="x unified", 
    hoverlabel=dict(
        bgcolor="white",
        font_size=10,
    ))
# fig.update_layout(violinmode='group')
fig.update_traces(meanline_visible=True)
fig.update_traces(orientation='h', side='positive', width=3, points=False)
fig.update_layout(xaxis_showgrid=False, xaxis_zeroline=False)
fig.show()

Unnamed: 0,Unnamed: 1,2014-05-05
ks,distance,0.07797949
ks,pval,2.839598e-11
ks,critical_value,0.02705099
ks,critical_diff,0.0509285
stats,mean,62.90306
stats,std,17.95997
stats,median,66.28473


In [10]:
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10*2, 8))

# sns.histplot(data, x='age', stat='density', kde=True, hue='src', common_norm=False, ax=ax1)
# ax1.set_title(f"Distance: {stats['ks']['distance']:.2e}, p-Value: {stats['ks']['pval']:.2e}")
# sns.violinplot(data=data, y='age', hue='src', x='dist', ax=ax2, inner="quartile", split=True)


In [11]:
col = "Projection"
chi2_test = ChiSqDriftCalculator(val.df[col].values)
stats = chi2_test.predict(sample[col].values)
stats

{'distance': 6.751263401336422,
 'pval': 0.2398060810665436,
 'dof': 5,
 'critical_value': 9.236356899781123,
 'critical_diff': -2.4850934984447015}

In [12]:
fig = go.Figure()


fig.add_trace(go.Histogram(
    x=ref[col],
    histnorm='probability',
    name='ref', # name used in legend and hover labels
    # xbins=dict( # bins used for histogram
    #     start=-4.0,
    #     end=3.0,
    #     size=0.5
    # ),
    opacity=0.75
))
fig.add_trace(go.Histogram(
    x=sample[col],
    histnorm='probability',
    name=day,
    # xbins=dict(
    #     start=-3.0,
    #     end=4,
    #     size=0.5
    # ),
    opacity=0.75
))

fig.update_layout(hovermode="x unified", 
    hoverlabel=dict(
        bgcolor="white",
        font_size=10,
    ))

fig.show()

In [13]:
# display(nested2series(stats, name=day).to_frame())

# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10*2, 8))
# ax1.set_title(f"Distance: {stats['distance']:.2e}, p-Value: {stats['pval']:.2e}")
# sns.histplot(data=data, x=col, hue="src", multiple="dodge", shrink=.8, ax=ax1, stat='density', common_norm=False)
# data2 = data.groupby('src')[col].value_counts(normalize=True).unstack().T.sort_values('ref', ascending=False).T
# data2.plot(kind='bar', stacked=True, ax=ax2)

In [14]:
FLOAT = KSDriftCalculator
CAT = ChiSqDriftCalculator

cols = {
    'age': FLOAT,
    'image_size': FLOAT,
    'Projection': CAT,
    "PatientSex_DICOM": CAT,
    "ViewPosition_DICOM": CAT,
    "Modality_DICOM": CAT,
    "Manufacturer_DICOM": CAT,
    "PhotometricInterpretation_DICOM": CAT,
    "PixelRepresentation_DICOM": CAT,
    "PixelAspectRatio_DICOM": CAT,
    "SpatialResolution_DICOM": CAT,
    "BitsStored_DICOM": CAT,
    "WindowCenter_DICOM": FLOAT,
    "WindowWidth_DICOM": FLOAT,
    "Rows_DICOM": FLOAT,
    "Columns_DICOM": FLOAT,
    "XRayTubeCurrent_DICOM": CAT,
    "Exposure_DICOM": CAT,
    "ExposureInuAs_DICOM": FLOAT,
    "RelativeXRayExposure_DICOM": FLOAT,
}


In [15]:
from model_drift.drift.tabular import TabularDriftCalculator

dwc = TabularDriftCalculator(val.df)

dwc.add_drift_stat('age', KSDriftCalculator)
dwc.add_drift_stat('RelativeXRayExposure_DICOM', KSDriftCalculator)
dwc.add_drift_stat('WindowCenter_DICOM', KSDriftCalculator)
dwc.add_drift_stat('WindowWidth_DICOM', KSDriftCalculator)
# dwc.add_drift_stat('WindowWidth_DICOM', BasicDriftCalculator)
# dwc.add_drift_stat('WindowCenter_DICOM', BasicDriftCalculator)

dwc.add_drift_stat('Projection', ChiSqDriftCalculator, )
dwc.add_drift_stat('PatientSex_DICOM', ChiSqDriftCalculator)
dwc.add_drift_stat('Modality_DICOM', ChiSqDriftCalculator)

dwc.prepare()

dwc._metric_collections
results = dwc.predict(sample)
results

{'age': {'ks': {'distance': 0.07797949002217297,
   'pval': 2.8395980627922525e-11,
   'critical_value': 0.02705099092457315,
   'critical_diff': 0.05092849909759982}},
 'RelativeXRayExposure_DICOM': {'ks': {'distance': 0.06455451729841977,
   'pval': 7.358755986518096e-08,
   'critical_value': 0.02705099092457315,
   'critical_diff': 0.037503526373846616}},
 'WindowCenter_DICOM': {'ks': {'distance': 0.08747404357160449,
   'pval': 4.458090643470519e-14,
   'critical_value': 0.02705099092457315,
   'critical_diff': 0.060423052647031336}},
 'WindowWidth_DICOM': {'ks': {'distance': 0.19082946186604716,
   'pval': 9.341556502153338e-66,
   'critical_value': 0.02705099092457315,
   'critical_diff': 0.16377847094147402}},
 'Projection': {'chi2': {'distance': 6.751263401336422,
   'pval': 0.2398060810665436,
   'dof': 5,
   'critical_value': 9.236356899781123,
   'critical_diff': -2.4850934984447015}},
 'PatientSex_DICOM': {'chi2': {'distance': 2.0806253108547628,
   'pval': 0.35334418970803

In [16]:
results = dwc.predict(sample, sampler=sampler, n_samples=10)

results

age             ks    distance  mean       0.085853
                                std        0.009774
                                obs        0.077979
                      pval      mean       0.000022
                                std        0.000052
                                           ...     
Modality_DICOM  chi2  dof       std        0.000000
                                obs        0.000000
count                           mean    1000.000000
                                std        0.000000
                                obs     2255.000000
Length: 90, dtype: float64

In [39]:
output = dwc.rolling_window_predict(pc.df.set_index('StudyDate'),
                                    sampler=sampler, n_samples=5,
                                    stride='D', window='30D', min_periods=50)


2007-05-03 - 2017-11-17 window: 30D, stride: D: 100%|██████████| 3852/3852 [28:08<00:00,  2.28it/s, 2017-11-17]


In [59]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.05, row_heights=[.8, .2])

which = 'obs'
_output = output.swaplevel(0, -1, axis=1)[[which]].swaplevel(0, -1, axis=1).droplevel(-1, axis=1).copy()
counts = _output['count']
_output["Modality_DICOM"] = _output["Modality_DICOM"].fillna(0)

stats = _output.loc[val.df.index.unique()].agg(["mean", "std"])
output_standard = ((_output-stats.loc['mean'])/(stats.loc["std"]))
graph_view = output_standard.swaplevel(0, 2, axis=1).swaplevel(1, 2, axis=1).sort_index(axis=1)['distance']

weights = {
    # 'Modality_DICOM': 1,
    'PatientSex_DICOM': 1,
    'Projection': 1,
    'RelativeXRayExposure_DICOM': 1,
    'WindowCenter_DICOM': 1,
    'WindowWidth_DICOM': 1,
    'age': 1
}
graph_view = graph_view[weights]

span = 90


def smooth(y: pd.DataFrame):
    ys = y.ewm(span=span, ignore_na=False).mean()
    ys[y.isna()] = None
    return ys.clip(-1, 1)

single_disp = dict(line=dict(dash="dot", width=1))
g = graph_view

y = graph_view.clip(-1, 1).mean(axis=1)


fig.add_trace(go.Line(x=g.index, y=smooth(y), showlegend=True, name="Combined", hovertemplate="%{y: .5f}"), row=1, col=1)
fig.add_trace(go.Bar(x=counts.index, y=(counts), showlegend=True, legendgroup="Count", marker={"color": "green"}, name="num_samples"
                     #   name=name_, hovertemplate="%{y: .5f}", line={"color": colors[row], "width": 1}, connectgaps=False
                     ), row=2, col=1)

for c in list(g.columns.to_flat_index()):
    name = "{} ({})".format(*c)
    y = smooth(g[c])
    fig.add_trace(go.Line(x=g.index, y=y,
                          showlegend=True,
                          hovertemplate=f"%{{y: .5f}}, weight={weights[c[0]]:.2f}",
                          legendgroup=str(name),
                          name=str(name),
                          **single_disp
                          ),
                  row=1, col=1)


fig.update_layout(title=f"Meta Data Drift ({which}), Statisical Distance (standardized)")
fig.update_layout(hovermode="x unified")

fig.add_shape(type='line',
              x0=settings.PADCHEST_SPLIT_DATES[0],
              y0=0,
              x1=settings.PADCHEST_SPLIT_DATES[0],
              y1=1,
              line=dict(color='black', dash='dot'),
              xref='x',
              yref='paper'
              )
fig.add_annotation(textangle=0,
                   xref="x",
                   yref="paper", x=settings.PADCHEST_SPLIT_DATES[0], y=1.08,
                   text=f"Val Start<br />({settings.PADCHEST_SPLIT_DATES[0]})", showarrow=False,)
fig.add_shape(type='line', x0=settings.PADCHEST_SPLIT_DATES[1], y0=0, x1=settings.PADCHEST_SPLIT_DATES[1], y1=1,
              line=dict(color='black', dash='dot'), xref='x', yref='paper'
              )
fig.add_annotation(textangle=0,
                   xref="x",
                   yref="paper", x=settings.PADCHEST_SPLIT_DATES[1], y=1.08,
                   text=f"Test Start<br />({settings.PADCHEST_SPLIT_DATES[1].strip()})", showarrow=False,)
fig.update_layout(height=700)
# fig.update_yaxes(range=[-0.1, 1.1])
fig.show()


fig.write_html(f"html/metadata_{which}.html")


In [None]:
stats.swaplevel(0,2, axis=1)['distance'].loc["std"]

In [None]:

stats = output.loc[val.df.index.unique()].agg(["mean", "std"])
output_standard = (output-stats.loc['mean'])/(stats.loc["std"])

output_standard.swaplevel(0,2, axis=1)["distance"].describe()

In [None]:
act = np.tanh
# act = none
output.swaplevel(0,-1, axis=1)["distance"].apply(act).describe()

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots (rows=2, cols=1, shared_xaxes=True, vertical_spacing = 0.05)
graph_view = output.swaplevel(0,-1, axis=1)
g = graph_view['critical_diff']['ks']
for c in list(g.columns.to_flat_index()):
    y = g[c].rolling(3).mean()
    # fig.add_trace(go.Line(x=a.index, y=a[c], 
    # customdata=a.values,
    # hovertemplate="<br />".join([f"{c}: %{{customdata[{i}]:d}}" for i, c in enumerate(a)]),
    # name=c), row=1, col=1)

    fig.add_trace(go.Line(x=g.index, y=y, 
    # customdata=output_swap[c][hcols].values,
    # hovertemplate='distance: %{customdata[0]:.3f}, p_val: %{customdata[3]:.3f}, is_drift:%{customdata[1]:d}, nobs: %{customdata[2]:d}',
    name=str(c)), 
    row=1, col=1)


g = graph_view['critical_diff']['chi2']
for c in list(g.columns.to_flat_index()):
    y = g[c].rolling(3).mean()
    # fig.add_trace(go.Line(x=a.index, y=a[c], 
    # customdata=a.values,
    # hovertemplate="<br />".join([f"{c}: %{{customdata[{i}]:d}}" for i, c in enumerate(a)]),
    # name=c), row=1, col=1)

    fig.add_trace(go.Line(x=g.index, y=y, 
    # customdata=output_swap[c][hcols].values,
    # hovertemplate='distance: %{customdata[0]:.3f}, p_val: %{customdata[3]:.3f}, is_drift:%{customdata[1]:d}, nobs: %{customdata[2]:d}',
    name=str(c)), 
    row=2, col=1)

    

fig.update_layout(title = "Input Data Drift, P Values")
fig.update_layout(hovermode="x unified")
fig.update_layout(height=600)
fig.show()

In [None]:
compare = ["2014-11-30", "2015-01-08"]

compare = pd.to_datetime(compare)
compare = pd.date_range(
    compare[0], compare[1], freq='3D'
)

cols=["WindowWidth_DICOM", "WindowCenter_DICOM"]
stats, data = dwc.drilldown(test.df, compare, cols=cols)
for c in cols:
    data[c] = pd.to_numeric(data[c], errors='coerce')
display(stats.T)


In [None]:

# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10*2, 8))
fig = make_subplots (rows=2, cols=1, shared_xaxes=True, vertical_spacing = 0.05)

ref = data[data['src'] == "_ref"]

ref
rem = data[data['src'] != "_ref"]


cycle = ['red', 'green', 'purple', 'orange']

colors = n_colors('rgb(5, 200, 200)', 'rgb(200, 10, 10)', len(compare), colortype='rgb')

for r, y in enumerate(cols):
    fig.add_trace(
        go.Violin(
            x=ref[y],
            name="ref",
            showlegend=False,
            ),
            row=r+1, col=1
        )
    for i, (name, sample) in enumerate(rem.groupby('src')):
        fig.add_trace(
            go.Violin(
                x=sample[y],
                name=str(name.date())+'.',
                line_color=colors[i],
                scalemode="width",
                showlegend=False
            ), row=r+1, col=1
        )

fig.update_layout(
    width=1600,
    height=400*2
)
fig.update_layout(hovermode=False, 
#     hoverlabel=dict(
#         bgcolor="white",
#         font_size=10,
    # )
    )
fig.update_traces(orientation='h', side='positive', width=3, points=False)
# fig.update_layout(xaxis_showgrid=False, xaxis_zeroline=False)
fig.show()
# sns.histplot(data, x=cols[0], stat='density', kde=True, hue='src', common_norm=False, ax=ax1)
# # ax1.set_title(f"Distance: {stats['ks']['distance']:.2e}, p-Value: {stats['ks']['pval']:.2e}")
# sns.violinplot(data=data, y=cols[0], x='src', ax=ax2, inner="quartile", split=True)

In [None]:
compare = ["2015-12-23", "2016-02-15"]

cols=["Modality_DICOM"]
stats, data = dwc.compare_dates(test.df, compare, cols=cols)
display(stats)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10*2, 8), sharey=True)
col="Modality_DICOM"
sns.histplot(data=data, x=col, hue="src", multiple="dodge", shrink=.8, ax=ax1, stat='density', common_norm=False)
data2 = data.groupby('src')[col].value_counts(normalize=True).unstack().T.sort_values('_ref', ascending=False).T
data2.plot(kind='bar', stacked=True, ax=ax2)

In [None]:
compare = ["2015-12-23", "2016-02-15"]

cols=["Modality_DICOM"]
stats, data = dwc.drilldown(test.df, compare, cols=cols)
display(stats)

# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10*2, 8))
fig = make_subplots (rows=len(["Modality_DICOM"]), cols=1, shared_xaxes=True, vertical_spacing = 0.05)

ref = data[data['src'] == "_ref"]

ref
rem = data[data['src'] != "_ref"]




cycle = ['red', 'green', 'purple', 'orange']
y = cols[0]
for r, y in enumerate(cols):
    for i, (name, sample) in enumerate(rem.groupby('src')):
        fig.add_trace(
go.Histogram(
    x=ref[y],
    histnorm='percent',
    name='ref', # name used in legend and hover labels
    # xbins=dict( # bins used for histogram
    #     start=-4.0,
    #     end=3.0,
    #     size=0.5
    # ),
    opacity=0.75
),
            row=r+1, col=1
        )

        fig.add_trace(
            go.Histogram(
    x=sample[y],
    histnorm='percent',
    name=day,
    # xbins=dict(
    #     start=-3.0,
    #     end=4,
    #     size=0.5
    # ),
    opacity=0.75
), row=r+1, col=1
        )

fig.update_layout(
    width=600*2,
    height=400
)
fig.update_layout(hovermode="x unified", 
#     hoverlabel=dict(
#         bgcolor="white",
#         font_size=10,
    # )
    )
# fig.update_layout(violinmode='overlay')
# fig.update_traces(box_visible=False, meanline_visible=True)
fig.show()

# sns.histplot(data, x=cols[0], stat='density', kde=True, hue='src', common_norm=False, ax=ax1)
# # ax1.set_title(f"Distance: {stats['ks']['distance']:.2e}, p-Value: {stats['ks']['pval']:.2e}")
# sns.violinplot(data=data, y=cols[0], x='src', ax=ax2, inner="quartile", split=True)


In [None]:
FLOAT = "f"
CAT = 'c'

cols = {
'age': FLOAT,
'image_size': FLOAT,
'Projection': CAT,
"PatientSex_DICOM": CAT,
# "ViewPosition_DICOM": CAT,
"Modality_DICOM": CAT,
"Manufacturer_DICOM": CAT,
# "PhotometricInterpretation_DICOM": CAT,
# "PixelRepresentation_DICOM": CAT,
# "PixelAspectRatio_DICOM": CAT,
# "SpatialResolution_DICOM": CAT,
# "BitsStored_DICOM": CAT,
"WindowCenter_DICOM": FLOAT,
"WindowWidth_DICOM": FLOAT,
"Rows_DICOM": FLOAT,
"Columns_DICOM": FLOAT,
"XRayTubeCurrent_DICOM": CAT,
# "Exposure_DICOM": CAT,
# "ExposureInuAs_DICOM": FLOAT,
# "RelativeXRayExposure_DICOM": FLOAT,
}