In [1]:
import pandas as pd
import numpy as np
import json
import glob
import tqdm
import seaborn as sns
import matplotlib.pylab as plt

import warnings
warnings.filterwarnings("ignore")

from model_drift import settings, helpers



In [106]:
from model_drift.data.padchest import LABEL_MAP
label_cols = list(LABEL_MAP)

df = []

jsonl_dir = str(settings.TOP_DIR.joinpath("results", 'classifier', 'finetuned'))
jsonl_files = glob.glob(f"{jsonl_dir}/preds.jsonl")

df = helpers.jsonl_files2dataframe(jsonl_files)


df = pd.concat(
    [
        df,
        pd.DataFrame(df['activation'].values.tolist(), columns=[f"activation.{c}" for c in label_cols])
    ],
    axis=1)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/160819 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:04<00:00,  4.11s/it]


In [16]:
from model_drift.data.padchest import PadChest

pc = PadChest.from_csv()
pc.prepare()

In [17]:
pc.merge(df, left_on="ImageID", right_on="index", how='inner')
pc.df.head()

Unnamed: 0,ImageID,ImageDir,StudyDate_DICOM,StudyID,PatientID,PatientBirth,PatientSex_DICOM,ViewPosition_DICOM,Projection,MethodProjection,...,activation.Atelectasis,activation.Cardiomegaly,activation.Consolidation,activation.Edema,activation.Lesion,activation.No Finding,activation.Opacity,activation.Pleural Abnormalities,activation.Pleural Effusion,activation.Pneumonia
0,20536686640136348236148679891455886468_k6ga29.png,0,20140915,20536686640136348236148679891455886468,839860488694292331637988235681460987,1930-01-01,F,POSTEROANTERIOR,PA,Manual review of DICOM fields,...,0.047886,0.505307,0.003428,0.000492,0.031599,0.040083,0.046906,0.223817,0.010098,0.01091
1,135803415504923515076821959678074435083_fzis7d...,0,20150914,135803415504923515076821959678074435083,313572750430997347502932654319389875966,1929-01-01,M,LATERAL,L,Manual review of DICOM fields,...,0.038255,0.611089,0.002242,0.001264,0.030819,0.01864,0.082885,0.146705,0.012119,0.005538
2,135803415504923515076821959678074435083_fzis7b...,0,20150914,135803415504923515076821959678074435083,313572750430997347502932654319389875966,1929-01-01,M,POSTEROANTERIOR,PA,Manual review of DICOM fields,...,0.105066,0.129081,0.007552,0.000427,0.025125,0.081793,0.154357,0.045849,0.005696,0.02303
3,113855343774216031107737439268243531979_3k951l...,0,20150717,113855343774216031107737439268243531979,50783093527901818115346441867348318648,1925-01-01,F,POSTEROANTERIOR,PA,Manual review of DICOM fields,...,0.040446,0.07421,0.001145,0.000125,0.024128,0.34073,0.028292,0.006541,0.000678,0.002352
4,113855343774216031107737439268243531979_3k951n...,0,20150717,113855343774216031107737439268243531979,50783093527901818115346441867348318648,1925-01-01,F,LATERAL,L,Manual review of DICOM fields,...,0.013773,0.037902,0.000529,9.7e-05,0.022773,0.269458,0.025679,0.003461,0.000689,0.003102


In [19]:
cols = [c for c in list(pc_df_vae) if c.startswith("activation.") and 'all' not in c]
cols

['activation.Atelectasis',
 'activation.Cardiomegaly',
 'activation.Consolidation',
 'activation.Edema',
 'activation.Lesion',
 'activation.No Finding',
 'activation.Opacity',
 'activation.Pleural Abnormalities',
 'activation.Pleural Effusion',
 'activation.Pneumonia']

In [20]:
train, val, test = pc.split(settings.PADCHEST_SPLIT_DATES, studydate_index=True)


In [21]:
from model_drift.drift.tabular import TabularDriftCalculator
from model_drift.drift.numeric import KSDriftCalculator

dwc = TabularDriftCalculator(val.df.reset_index(drop=True))

for c in cols:
    dwc.add_drift_stat(c, KSDriftCalculator)
dwc.prepare()

dwc._metric_collections
results, data = dwc.drilldown(test.df, ["2017-01-05"])

results

Unnamed: 0,Unnamed: 1,Unnamed: 2,2017-01-05
activation.Atelectasis,ks,distance,0.08089889
activation.Atelectasis,ks,pval,7.359369e-05
activation.Atelectasis,ks,critical_value,0.04396364
activation.Atelectasis,ks,critical_diff,0.03693525
activation.Cardiomegaly,ks,distance,0.1224469
activation.Cardiomegaly,ks,pval,1.385567e-10
activation.Cardiomegaly,ks,critical_value,0.04396364
activation.Cardiomegaly,ks,critical_diff,0.07848331
activation.Consolidation,ks,distance,0.1903761
activation.Consolidation,ks,pval,4.516324e-25


In [74]:
output = dwc.rolling_window_predict(pc.df.set_index("StudyDate"), stride='D')
output

2007-05-03 - 2017-11-17: 100%|██████████| 3852/3852 [02:25<00:00, 26.42it/s, 2017-11-17]


Unnamed: 0_level_0,activation.Atelectasis,activation.Atelectasis,activation.Atelectasis,activation.Atelectasis,activation.Cardiomegaly,activation.Cardiomegaly,activation.Cardiomegaly,activation.Cardiomegaly,activation.Consolidation,activation.Consolidation,...,activation.Pleural Abnormalities,activation.Pleural Effusion,activation.Pleural Effusion,activation.Pleural Effusion,activation.Pleural Effusion,activation.Pneumonia,activation.Pneumonia,activation.Pneumonia,activation.Pneumonia,count
Unnamed: 0_level_1,ks,ks,ks,ks,ks,ks,ks,ks,ks,ks,...,ks,ks,ks,ks,ks,ks,ks,ks,ks,Unnamed: 21_level_1
Unnamed: 0_level_2,critical_diff,critical_value,distance,pval,critical_diff,critical_value,distance,pval,critical_diff,critical_value,...,pval,critical_diff,critical_value,distance,pval,critical_diff,critical_value,distance,pval,Unnamed: 21_level_2
2007-05-03,-0.031556,0.865430,0.833874,5.519539e-02,-0.189519,0.865430,0.675911,2.100675e-01,-0.170309,0.865430,...,3.834441e-01,-0.310776,0.865430,0.554654,3.966667e-01,-0.037869,0.865430,0.827561,5.947019e-02,2.0
2007-05-04,-0.031556,0.865430,0.833874,5.519539e-02,-0.189519,0.865430,0.675911,2.100675e-01,-0.170309,0.865430,...,3.834441e-01,-0.310776,0.865430,0.554654,3.966667e-01,-0.037869,0.865430,0.827561,5.947019e-02,2.0
2007-05-05,-0.031556,0.865430,0.833874,5.519539e-02,-0.189519,0.865430,0.675911,2.100675e-01,-0.170309,0.865430,...,3.834441e-01,-0.310776,0.865430,0.554654,3.966667e-01,-0.037869,0.865430,0.827561,5.947019e-02,2.0
2007-05-06,-0.031556,0.865430,0.833874,5.519539e-02,-0.189519,0.865430,0.675911,2.100675e-01,-0.170309,0.865430,...,3.834441e-01,-0.310776,0.865430,0.554654,3.966667e-01,-0.037869,0.865430,0.827561,5.947019e-02,2.0
2007-05-07,-0.031556,0.865430,0.833874,5.519539e-02,-0.189519,0.865430,0.675911,2.100675e-01,-0.170309,0.865430,...,3.834441e-01,-0.310776,0.865430,0.554654,3.966667e-01,-0.037869,0.865430,0.827561,5.947019e-02,2.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2017-11-13,0.050964,0.039082,0.090046,2.264725e-07,0.132086,0.039082,0.171167,1.339706e-25,0.113152,0.039082,...,4.644012e-12,0.087395,0.039082,0.126477,3.873818e-14,0.105611,0.039082,0.144693,2.186802e-18,1026.0
2017-11-14,0.053613,0.038513,0.092126,6.549408e-08,0.133480,0.038513,0.171993,1.355575e-26,0.113448,0.038513,...,1.333563e-12,0.088165,0.038513,0.126678,1.376534e-14,0.105399,0.038513,0.143912,1.022474e-18,1058.0
2017-11-15,0.053784,0.037788,0.091572,4.181493e-08,0.132059,0.037788,0.169847,6.323144e-27,0.111544,0.037788,...,5.964311e-13,0.088370,0.037788,0.126158,5.175621e-15,0.105709,0.037788,0.143497,2.597229e-19,1101.0
2017-11-16,0.064038,0.038019,0.102057,7.568099e-10,0.141085,0.038019,0.179105,1.512015e-29,0.106924,0.038019,...,5.567089e-14,0.093622,0.038019,0.131641,4.087149e-16,0.100247,0.038019,0.138266,9.719320e-18,1087.0


In [117]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go


def y_func(y):
    return y.apply(np.tanh).ewm(span=2).mean()


fig = make_subplots(rows=1, cols=1, shared_xaxes=True, vertical_spacing=0.01)
graph_view = output.swaplevel(0, 2, axis=1).swaplevel(1, 2, axis=1).sort_index(axis=1)['critical_diff']

def sigmoid(x):
    return np.where(x < 0, np.exp(x)/(1 + np.exp(x)), 1/(1 + np.exp(-x)))

def tanh_clip(x):
    return np.tanh(x).clip(0)


def clip(x):
    return np.array(x).clip(0, 1)

def none(x):
    return x

act = clip

def y_func(y):
    # return y.apply(lambda x: np.clip(x, -1, 1))
    return y.apply(act)

def smooth(y: pd.DataFrame):
    return y.ewm(span=30).mean()


y = smooth(y_func(graph_view).mean(axis=1))

fig.add_trace(go.Line(x=g.index, y=y, showlegend=True, name="Combined", hovertemplate="%{y: .5f}"), row=1, col=1)

single_disp = dict(line=dict(dash="dot", width=.8))
g = graph_view
for c in list(g.columns.to_flat_index()):
    name = "{} ({})".format(*c)
    y = smooth(y_func(g[c]))
    fig.add_trace(go.Line(x=g.index, y=y,
                          showlegend=True,
                          hovertemplate=f"%{{y: .5f}}",
                          legendgroup=str(name),
                          name=str(name),
                          **single_disp
                          ),
                  row=1, col=1)


fig.add_shape(type='line',
              x0=settings.PADCHEST_SPLIT_DATES[0],
              y0=0,
              x1=settings.PADCHEST_SPLIT_DATES[0],
              y1=1,
              line=dict(color='black', dash='dot'),
              xref='x',
              yref='paper'
              )
fig.add_annotation(textangle=0,
                   xref="x",
                   yref="paper", x=settings.PADCHEST_SPLIT_DATES[0], y=1.08,
                   text=f"Val Start<br />({settings.PADCHEST_SPLIT_DATES[0]})", showarrow=False,)
fig.add_shape(type='line', x0=settings.PADCHEST_SPLIT_DATES[1], y0=0, x1=settings.PADCHEST_SPLIT_DATES[1], y1=1,
              line=dict(color='black', dash='dot'), xref='x', yref='paper'
              )
fig.add_annotation(textangle=0,
                   xref="x",
                   yref="paper", x=settings.PADCHEST_SPLIT_DATES[1], y=1.08,
                   text=f"Test Start<br />({settings.PADCHEST_SPLIT_DATES[1].strip()})", showarrow=False,)
fig.update_layout(title=f"Output Scores, Statisical Distance ({act.__name__})")
fig.update_layout(hovermode="x unified")
fig.update_layout(height=600)
fig.update_yaxes(range=[-0.1, 1.1])
fig.show()



plotly.graph_objs.Line is deprecated.
Please replace it with one of the following more specific types
  - plotly.graph_objs.scatter.Line
  - plotly.graph_objs.layout.shape.Line
  - etc.




In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

N = 10
fig = make_subplots (rows=1, cols=1,shared_xaxes=True, vertical_spacing = 0.05)
graph_view = output.swaplevel(0,-1, axis=1, )
g = graph_view['critical_diff']['ks']
y = (g >0).mean(axis=1)
y = y.rolling(N).mean()
# fig.add_trace(go.Line(x=a.index, y=a[c], 
# customdata=a.values,
# hovertemplate="<br />".join([f"{c}: %{{customdata[{i}]:d}}" for i, c in enumerate(a)]),
# name=c), row=1, col=1)

fig.add_trace(go.Line(x=g.index, y=y, 
# customdata=output_swap[c][hcols].values,
# hovertemplate='distance: %{customdata[0]:.3f}, p_val: %{customdata[3]:.3f}, is_drift:%{customdata[1]:d}, nobs: %{customdata[2]:d}',
name="Merged"), 
row=1, col=1)

# g = graph_view['pval']['ks']
# for c in list(g.columns.to_flat_index()):
#     y = g[c].rolling(N).mean()
#     # fig.add_trace(go.Line(x=a.index, y=a[c], 
#     # customdata=a.values,
#     # hovertemplate="<br />".join([f"{c}: %{{customdata[{i}]:d}}" for i, c in enumerate(a)]),
#     # name=c), row=1, col=1)

#     fig.add_trace(go.Line(x=g.index, y=y, showlegend=False,
#     # customdata=output_swap[c][hcols].values,
#     # hovertemplate='distance: %{customdata[0]:.3f}, p_val: %{customdata[3]:.3f}, is_drift:%{customdata[1]:d}, nobs: %{customdata[2]:d}',
#     name=str(c)), 
#     row=2, col=1)

fig.update_layout(title = "Average Drift Across Scores")
fig.update_layout(hovermode="x unified")
fig.update_layout(height=600)
fig.show()

In [None]:
from model_drift.stats import calc_p_real


ref = pc_df_vae.query("View == 'Frontal'")
target = pc_df_vae.query("View != 'Frontal'")
pvals = []
for i_val in range(128):
    xcol = f"mu.{i_val}"a
    samp1 = ref[xcol].values
    samp2 = target[xcol].values
    ks, p = ks_2samp(samp1, samp2)
    if ks>0.5:
        p = 1-p
    pvals.append(ks)
    # axes[i].legend(bbox_to_anchor=(-0.01, -0.2), loc='upper right', ncol=5)