In [None]:
from IPython.display import display

import pandas as pd
import warnings
from model_drift import settings, helpers
from model_drift.data.utils import nested2series
import matplotlib.pylab as plt
import numpy as np
import seaborn as sns
from model_drift.drift.tabular import TabularDriftCalculator
from model_drift.drift.numeric import KSDriftCalculator, BasicDriftCalculator
from model_drift.drift.categorical import ChiSqDriftCalculator
from model_drift.drift.collection import DriftCollectionCalculator
from model_drift.drift.performance import AUROCCalculator

from model_drift.drift.sampler import Sampler
from model_drift.data.padchest import PadChest
import plotly.graph_objects as go

warnings.filterwarnings("ignore")


In [None]:
from model_drift.data.padchest import LABEL_MAP
label_cols = list(LABEL_MAP)
jsonl_file = str(settings.TOP_DIR.joinpath("results", 'classifier', 'padchest-trained', 'frontal_only', "preds.jsonl"))
scores_df = helpers.jsonl_files2dataframe(jsonl_file)
scores_df = pd.concat(
    [
        scores_df,
        pd.DataFrame(scores_df['activation'].values.tolist(), columns=[f"activation.{c}" for c in label_cols])
    ],
    axis=1)
scores_df.head()


In [None]:

data = 'all-data'
jsonl_file = str(settings.TOP_DIR.joinpath("results", 'vae', 'padchest-trained', data, 'preds.jsonl'))
vae_df = helpers.jsonl_files2dataframe(jsonl_file)
vae_df = pd.concat(
    [
        vae_df,
        pd.DataFrame(vae_df['mu'].values.tolist(), columns=[f"mu.{c:0>3}" for c in range(128)])
    ],
    axis=1
)
vae_df.head()


In [None]:
# Load padchest CSV
pc = PadChest(settings.PADCHEST_FILENAME)
pc.prepare()


pc.merge(vae_df, left_on="ImageID", right_on="index", how='inner')
# pc.merge(scores_df, left_on="ImageID", right_on="index", how='inner')


train, val, test = pc.split(settings.PADCHEST_SPLIT_DATES, studydate_index=True)


In [None]:
list(pc.df.columns)

In [None]:
FLOAT = lambda x: pd.to_numeric(x, errors='coerce').astype(float)
CAT = lambda x: x.astype('category')

vae_cols = {c: FLOAT for c in list(pc.df) if c.startswith("mu.") and 'all' not in c}
score_cols = {c: FLOAT for c in list(pc.df) if c.startswith("activation.") and 'all' not in c}
metadata_cols = {
    'age': FLOAT,
    'Projection': CAT,
    "PatientSex_DICOM": CAT,
    "ViewPosition_DICOM": CAT,
    "Modality_DICOM": CAT,
    "Manufacturer_DICOM": CAT,
    "PhotometricInterpretation_DICOM": CAT,
    "PixelRepresentation_DICOM": CAT,
    "PixelAspectRatio_DICOM": CAT,
    "SpatialResolution_DICOM": FLOAT,
    "BitsStored_DICOM": CAT,
    "WindowCenter_DICOM": FLOAT,
    "WindowWidth_DICOM": FLOAT,
    "Rows_DICOM": FLOAT,
    "Columns_DICOM": FLOAT,
    "XRayTubeCurrent_DICOM": CAT,
    "Exposure_DICOM": FLOAT,
    "ExposureInuAs_DICOM": FLOAT,
    "RelativeXRayExposure_DICOM": FLOAT,
    "Frontal": lambda x: x.astype(str),
    "Pediatric": CAT
}


for c, f in metadata_cols.items():
    pc.df[c] = f(pc.df[c])

In [None]:
pc.df.loc[pc.df['age'] < 12, "Peds"] = 'Yes'

pc.df['Peds'] = pc.df['Peds'].fillna('No')

pc.df['Peds'].value_counts()

In [None]:
import matplotlib.pylab as plt
from tqdm import tqdm

from plotly.subplots import make_subplots
import plotly.graph_objects as go




def make_plotly_graph(graphs):
    fig = make_subplots(1, 1, vertical_spacing=0.05, horizontal_spacing=0.05)
    for k, g in graphs.items():
        x = g['fpr']
        y = g['tpr']
        fig.add_trace(go.Line(x=x, y=y, showlegend=True, name=k))

    # fig.update_yaxes(scaleanchor="x", scaleratio=1)
    fig.update_yaxes(range=[0, 1.01], constrain="domain")
    fig.update_xaxes(range=[0, 1.01], constrain="domain")
    fig.update_layout(height=400, width=600, margin=go.layout.Margin(
        l=10,  # left margin
        r=10,  # right margin
        b=20,  # bottom margin
        t=20,  # top margin
    ))
    return fig

def chunkify(vals, n):
    i = 0
    while i < len(vals):
        yield vals[i:i+n]
        i += n

bins = 3
qrt = .95
col = "Peds"


df = {}
series = pc.df[col]
if "float" in series.dtype.name:
    series = pd.cut(series, bins=bins, retbins=False)

cats = series.unique()
try:
    cats = sorted(cats, key=lambda x: x.left if hasattr(x, "left") else x)
except:
    pass
print(cats)

for cat in tqdm(cats):
    s = series==cat
    m = pc.df[vae_cols].corrwith(s)
    df[cat] = m

df = pd.concat(df, axis=1)
top_corr = (df.max(axis=1)-df.min(axis=1)).sort_values().rename("corr")
bp = top_corr.quantile(qrt)
keep = top_corr[top_corr >= bp]
top_corr[keep.index]


fig = make_subplots(3, 1, vertical_spacing=.15)
fig.add_trace(
    go.Line(x=top_corr.index, y=top_corr, name=str(col))
)
b = df.loc[keep.index][cats].dropna(axis=1)
for c in list(b):
    pass
    fig.add_trace(
        go.Bar(x=b.index, y=b[c], name=str(c)), row=2, col=1
    )
try:
    hist = pd.cut(pc.df[col], bins=25).value_counts().sort_index()
except:
    hist = pc.df[col].value_counts().sort_index()
fig.add_trace(go.Bar(x=hist.index.map(str), y=hist, name='Hist'), row=3, col=1)
fig.update_layout(height=900,
                  margin=go.layout.Margin(
    l=10,  # left margin
    r=10,  # right margin
    b=20,  # bottom margin
    t=20,  # top margin
))
fig.show()

In [None]:
from IPython.display import display_html, display_markdown, HTML, Markdown, display
import json
import os

output_dir = str(settings.TOP_DIR.joinpath('html'))
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

N = 8

sas = "sv=2020-08-04&st=2021-12-21T17%3A19%3A21Z&se=2022-12-22T17%3A19%3A00Z&sr=c&sp=rl&sig=gLCWA8i%2B5U%2Fjc2UVQoLyTwW2tK3G9%2BRc055uR%2BVJviw%3D"
container_url = "https://mlopsday2datasets.blob.core.windows.net/padchest/png/"

fname=f'vae[{data}]-metadata[{col}]'

# <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/4.0.0-alpha.4/css/bootstrap.min.css">
html = """
<style>
.container {
    width: 80%;
    margin: 0 auto;
  }

.gallery {
    display: block; 
    line-height:0;
   -webkit-column-count:5; /* split it into 5 columns */
   -webkit-column-gap:5px; /* give it a 5px gap between columns */
   -moz-column-count:5;
   -moz-column-gap:5px;
   column-count:4;
   column-gap:5px;
   background-color: darkslategray;
}

.gallery img {
   width: 100% !important;
   height: auto !important;
   margin-bottom:5px; /* to match column gap */
}
    </style>
    
    
<div class="container">
"""

html += f"""<div class="row"><h1>{col}</h1></div>
"""

html += """<div class="row">""" + fig.to_html() + """</div>"""
N = 16
for c in keep.index[::-1]:
    sample_top = pc.df.sort_values(c, ascending=False).iloc[:8]
    html += f"""<h2>{c} (Highest)<h2><div class="row gallery">\n"""
    for i, (ix, row) in enumerate(sample_top.sort_index().iterrows()):
        tooltip = {k: str(v) for k, v in row.to_dict().items()}
        tooltip = json.dumps(tooltip, indent=2).replace('"', '').strip("{").strip()
        html += """
        <img data="{ImageDir}/{ImageID}" 
        class="padchest" id="{i}" >
      """.format(
            i=i, **row, tooltip=tooltip)
    html += f"""</div> <!-- row -->"""
    sample_bottom = pc.df.sort_values(c, ascending=False).iloc[-8:]
    html += f"""<h2>{c} (Lowest)<h2><div class="row gallery">\n"""
    for i, (ix, row) in enumerate(sample_bottom.sort_index().iterrows()):
        tooltip = {k: str(v) for k, v in row.to_dict().items()}
        tooltip = json.dumps(tooltip, indent=2).replace('"', '').strip("{").strip()
        html += """
        <img data="{ImageDir}/{ImageID}" 
        class="padchest" id="{i}" >
      """.format(
            i=i, **row, tooltip=tooltip)
    html += f"""</div> <!-- row -->
    <hr size="8" width="100%" color="green">
    """

html += f"""
</div> <!-- container -->
<script>
const url = "{container_url}"; 
const sas = "{sas}"; 
var x = document.getElementsByClassName("padchest");
var i;
for (i = 0; i < x.length; i++) {{
  let data = x[i].getAttribute("data");
  x[i].src = url + data + "?" + sas
}}
</script>
"""

print(fname)
with open(f"{output_dir}/{fname}.html", 'w') as f:
    print(html, file=f)


In [None]:
from IPython.display import display_html, display_markdown, HTML, Markdown, display
import json
import os
from tqdm import tqdm

# sample_top = pc.df.sort_values('age', ascending=False).iloc[:N]

In [None]:
from IPython.display import display_html, display_markdown, HTML, Markdown, display
import json
import os

output_dir = str(settings.TOP_DIR.joinpath(f"html/vae[{data}]"))
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

N = 8

sas = "sv=2020-08-04&st=2021-12-21T17%3A37%3A50Z&se=2022-12-22T17%3A37%3A00Z&sr=c&sp=rl&sig=I30VhloVfnAzyzFNjeTD6q8Qk472r9xYtW59s3SZZ5g%3D"
container_url = "https://mlopsday2datasets.blob.core.windows.net/padchest/png/"
for c in tqdm(vae_cols):
  fname=f'{c}'

  # <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/4.0.0-alpha.4/css/bootstrap.min.css">
  html = """
  <style>
  .container {
      width: 80%;
      margin: 0 auto;
    }

  .gallery {
      display: block; 
      line-height:0;
    -webkit-column-count:5; /* split it into 5 columns */
    -webkit-column-gap:5px; /* give it a 5px gap between columns */
    -moz-column-count:5;
    -moz-column-gap:5px;
    column-count:4;
    column-gap:5px;
    background-color: darkslategray;
  }

  .gallery img {
    width: 100% !important;
    height: auto !important;
    margin-bottom:5px; /* to match column gap */
  }
      </style>
      
      
  <div class="container">
  """

  vals = pc.df[c].describe().drop(['count', 'std']).sort_values()
  html += f"""<h1>{c}<h1>"""
  for t, val in vals.items():

    loc = (val-pc.df[c]).abs().sort_values(ascending=True).iloc[:N].index
    html += f"""<h2>Closest to {t} of {c}<h2><div class="row gallery">\n"""
    for i, (ix, row) in enumerate(pc.df.loc[loc].iterrows()):
        tooltip = {k: str(v) for k, v in row.to_dict().items()}
        tooltip = json.dumps(tooltip, indent=2).replace('"', '').strip("{").strip()
        html += """
        <img data="{ImageDir}/{ImageID}" 
        class="padchest" id="{i}" >
      """.format(
            i=i, **row, tooltip=tooltip)
    html += f"""</div> <!-- row -->"""
  
  
  # loc = (pc.df[c].mean()-pc.df[c]).abs().sort_values(ascending=True).iloc[:N].index
  
  # html += f"""<h2>{c} (Near Mean)<h2><div class="row gallery">\n"""
  # for i, (ix, row) in enumerate(pc.df.loc[loc].iterrows()):
  #     tooltip = {k: str(v) for k, v in row.to_dict().items()}
  #     tooltip = json.dumps(tooltip, indent=2).replace('"', '').strip("{").strip()
  #     html += """
  #     <img data="{ImageDir}/{ImageID}" 
  #     class="padchest" id="{i}" >
  #   """.format(
  #         i=i, **row, tooltip=tooltip)
  # html += f"""</div> <!-- row -->"""
  
  
  
  # sample_bottom = pc.df.sort_values(c, ascending=True).iloc[:N]
  # html += f"""<h2>{c} (Lowest)<h2><div class="row gallery">\n"""
  # for i, (ix, row) in enumerate(sample_bottom.sort_index().iterrows()):
  #     tooltip = {k: str(v) for k, v in row.to_dict().items()}
  #     tooltip = json.dumps(tooltip, indent=2).replace('"', '').strip("{").strip()
  #     html += """
  #     <img data="{ImageDir}/{ImageID}" 
  #     class="padchest" id="{i}" >
  #   """.format(
  #         i=i, **row, tooltip=tooltip)
  # html += f"""</div> <!-- row -->
  
  html += """<hr size="8" width="100%" color="green">"""

  html += f"""
  </div> <!-- container -->
  <script>
  const url = "{container_url}"; 
  const sas = "{sas}"; 
  var x = document.getElementsByClassName("padchest");
  var i;
  for (i = 0; i < x.length; i++) {{
    let data = x[i].getAttribute("data");
    x[i].src = url + data + "?" + sas
  }}
  </script>
  """
  with open(f"{output_dir}/{fname}.html", 'w') as f:
      print(html, file=f)


In [None]:
def do_coor(col, bins=3, qrt=0.95):
    df = {}
    series = pc.df[col]
    if "float" in series.dtype.name:
        series = pd.cut(series, bins=bins, retbins=False)

    cats = series.unique()
    try:
        cats = sorted(cats, key=lambda x: x.left if hasattr(x, "left") else x)
    except:
        pass
    for cat in cats:
        s = series == cat
        m = pc.df[vae_cols].corrwith(s)
        df[cat] = m

    df = pd.concat(df, axis=1)
    top_corr = (df.max(axis=1)-df.min(axis=1)).sort_values().rename("corr")
    bp = top_corr.quantile(qrt)
    keep = top_corr[top_corr >= bp]
    return top_corr[keep.index].rename(col).to_frame().T


html = ["<h1>Max Difference Correlations"]
for cc in metadata_cols:
    tp = do_coor(cc)
    html.append(tp.to_html())

output_dir = str(settings.TOP_DIR.joinpath(f"html"))
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

with open(f"{output_dir}/vae[{data}]-correlations.html", 'w') as f:
    print("<br />".join(html), file=f)


In [None]:
# frontal only interesting
col = "activation.No Finding"
"Manufacturer_DICOM"


# all-data
"Frontal"
"activation.No Finding"
