In [5]:
from IPython.display import display

import pandas as pd
import warnings
from model_drift import settings
from model_drift.data.utils import nested2series
import matplotlib.pylab as plt
import numpy as np
import seaborn as sns
from model_drift.drift.tabular import TabularDriftCalculator
from model_drift.drift.numeric import KSDriftCalculator, BasicDriftCalculator
from model_drift.drift.categorical import ChiSqDriftCalculator
from model_drift.drift.collection import DriftCollectionCalculator

from model_drift.data.padchest import PadChest
import plotly.graph_objects as go

warnings.filterwarnings("ignore")


In [6]:
# Load padchest CSV
pc = PadChest(settings.PADCHEST_FILENAME)
pc.prepare()
train, val, test = pc.split(settings.PADCHEST_SPLIT_DATES, studydate_index=True)


In [7]:
pd.concat(
    {
        "all": pc.df["StudyDate"].describe(datetime_is_numeric=True),
        "train": train.df["StudyDate"].describe(datetime_is_numeric=True),
        "val": val.df["StudyDate"].describe(datetime_is_numeric=True),
        "test": test.df["StudyDate"].describe(datetime_is_numeric=True),
    },
    axis=1,
)


Unnamed: 0,all,train,val,test
count,160819,91726,22176,46917
mean,2012-09-14 20:54:45.910246912,2011-01-06 03:16:23.616423168,2013-06-16 00:14:59.999999744,2015-08-29 00:15:02.359485696
min,2007-05-03 00:00:00,2007-05-03 00:00:00,2013-01-01 00:00:00,2014-01-01 00:00:00
25%,2010-10-27 00:00:00,2010-01-19 00:00:00,2013-03-10 00:00:00,2014-08-08 00:00:00
50%,2012-06-18 00:00:00,2011-01-18 00:00:00,2013-06-04 00:00:00,2015-06-09 00:00:00
75%,2014-05-28 00:00:00,2012-01-11 00:00:00,2013-09-25 00:00:00,2016-09-13 00:00:00
max,2017-11-17 00:00:00,2012-12-28 00:00:00,2013-12-31 00:00:00,2017-11-17 00:00:00


In [8]:
FLOAT = KSDriftCalculator
CAT = ChiSqDriftCalculator

cols = {
    'age': FLOAT,
    'Projection': CAT,
    "PatientSex_DICOM": CAT,
    "ViewPosition_DICOM": CAT,
    "Modality_DICOM": CAT,
    "Manufacturer_DICOM": CAT,
    "PhotometricInterpretation_DICOM": CAT,
    "PixelRepresentation_DICOM": CAT,
    "PixelAspectRatio_DICOM": CAT,
    "SpatialResolution_DICOM": CAT,
    "BitsStored_DICOM": CAT,
    "WindowCenter_DICOM": FLOAT,
    "WindowWidth_DICOM": FLOAT,
    "Rows_DICOM": FLOAT,
    "Columns_DICOM": FLOAT,
    "XRayTubeCurrent_DICOM": CAT,
    "Exposure_DICOM": CAT,
    "ExposureInuAs_DICOM": FLOAT,
    "RelativeXRayExposure_DICOM": FLOAT,
}


In [9]:
window = "7D"
stride = "D"
ref_frontal_only = True
target_frontal_only = True

In [12]:
refdf = val.df.copy()
if ref_frontal_only:
    refdf = refdf.query("Frontal")


print(len(refdf), len(val.df))

dwc = TabularDriftCalculator(refdf)

for c, kls in cols.items():
    dwc.add_drift_stat(c, kls)

dwc.prepare()

target_df = pc.df.set_index('StudyDate')

if target_frontal_only:
    target_df = target_df.query("Frontal")

output = dwc.rolling_window_predict(target_df, stride=stride, window=window)
fname = settings.TOP_DIR.joinpath(
    "results", "drift_csvs", f"metadata_s{stride}-w{window}_frontalonly-ref{ref_frontal_only}-target{target_frontal_only}.csv")
print(fname)
output.to_csv(fname)


Unnamed: 0,0
PixelAspectRatio_DICOM.chi2.critical_diff,-2.329873
Manufacturer_DICOM.chi2.critical_diff,-1.906270
PatientSex_DICOM.chi2.critical_diff,-0.986193
PhotometricInterpretation_DICOM.chi2.dof,0.000000
SpatialResolution_DICOM.chi2.pval,0.000000
...,...
Modality_DICOM.chi2.critical_diff,
PhotometricInterpretation_DICOM.chi2.critical_value,
PhotometricInterpretation_DICOM.chi2.critical_diff,
PixelRepresentation_DICOM.chi2.critical_value,


In [14]:
fname = settings.TOP_DIR.joinpath("results", "drift_csvs", f"metadata_s{stride}-w{window}_frontalonly-ref{ref_frontal_only}-target{target_frontal_only}.csv")
print(fname)
output.to_csv(fname)


D:\Code\MLOpsDay2\MedImaging-ModelDriftMonitoring\results\drift_csvs\metadata_sD-w7D_frontalonly-refTrue-targetTrue.csv
