In [None]:
import math
import pickle
import sklearn.metrics
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from metrics_tools import *

# load results pickles
results_list = []
for name in ["10x_results.pickle", "20x_results.pickle", "40x_results.pickle"]:
    with open(name, 'rb') as file:
        results_list.append(pickle.load(file))

# load results fov pickles
results_fov_list = []
for name in ["10x_results_fov.pickle", "20x_results_fov.pickle", "40x_results_fov.pickle"]:
    with open(name, 'rb') as file:
        results_fov_list.append(pickle.load(file))

In [None]:
# set global style params

# resolution
res_list = ['LYNA 10x', 'LYNA 20x', 'LYNA 40x']

# line style
line_list = [None, 'dash', 'dot']

# bar style
bar_list = [None, "x", "/"]

# color style
color_list = ['#636EFA', '#EF553B', '#00CC96']

# margin size
margin = 10

# subclass lists
short_names = ['BrCA', 'Fat', 'Lymphocytes', 'Capsule', 'Blood', 'Histiocytes', 'Vein', 'Artery', 'GC', 'Sinus', 'Mantle', 'Nerve']
immune_cells = ["Histiocytes", "GC", "Mantle", "Lymphocytes"]
conn_tissue = ["Sinus", "Capsule", "Nerve", "Artery", "Vein", "Blood", "Fat"]
groups = [["BrCA"], immune_cells, conn_tissue]

# set short_name colors and line types for plotting    
color_dict = { 'overall': '#2CA02C', 
              'BrCA': '#00CC96', 
              'Fat': 'rgb(220,171,2)', 
              'Lymphocytes': '#AB63FA', 
              'Capsule': '#FF97FF', 
              'Blood': '#D62728', 
              'Histiocytes': '#FF7F0E', 
              'Vein': '#AF0038', 
              'Artery': 'rgb(231,138,195)', 
              'GC': 'rgb(22,03,228)', 
              'Sinus': 'rgb(242, 183, 1)', 
              'Mantle': '#620042', 
              'Nerve': 'rgb(102,102,102)', 
              'Marginal': None, 
              'Light Zone': None, 
              'Dark Zone': None}
line_dict = {'overall': None, 
             'BrCA': None, 
             'Fat': None, 
             'Lymphocytes': None, 
             'Capsule': 'dash', 
             'Blood': None, 
             'Histiocytes': None, 
             'Vein': 'dash', 
             'Artery': None, 
             'GC': None, 
             'Sinus': 'dash', 
             'Mantle': None, 
             'Nerve': None, 
             'Marginal': None, 
             'Light Zone': None, 
             'Dark Zone': None}

In [None]:
# plot ROC curves across magnifications
fig = go.Figure()
fig.update_layout(height=600,width=600, margin = {'l':margin,'r':margin,'t':40,'b':margin})
fig.update_xaxes(range=[0,1], scaleanchor = "y", scaleratio = 1)
fig.update_layout(title="ROI level ROC curves per magnification", xaxis_title="False Positive Rate", yaxis_title="True Positive Rate", font=dict(size=18))
fig.update_layout(legend=dict(yanchor="bottom",y=0.05,xanchor="right",x=0.95,font=dict(size=16)))

fig.add_trace(go.Scatter(x=[0,1], y=[0,1], showlegend=False, line=dict(color='black', dash='dash')))

for i, r in enumerate(results_list):
    fig.add_trace(go.Scatter(x=r['overall']['fpr'], y=r['overall']['tpr'], name=res_list[i], line=dict(color=color_list[i], dash=line_list[i])))

fig.write_image("graphs/ROC_ROI.pdf")
fig.show()

In [None]:
# plot fov level ROC curves across magnifications
fig = go.Figure()
fig.update_layout(height=600,width=600, margin = {'l':margin,'r':margin,'t':40,'b':margin})
fig.update_xaxes(range=[0,1], scaleanchor = "y", scaleratio = 1)
fig.update_layout(title="FOV level ROC curves per magnification", xaxis_title="False Positive Rate", yaxis_title="True Positive Rate",font=dict(size=18))
fig.update_layout(legend=dict(yanchor="bottom",y=0.05,xanchor="right",x=0.95,font=dict(size=16)))

fig.add_trace(go.Scatter(x=[0,1], y=[0,1], showlegend=False, line=dict(color='black', dash='dash')))

for i, r in enumerate(results_fov_list):
    fig.add_trace(go.Scatter(x=r['fpr'], y=r['tpr'], name=res_list[i], line=dict(color=color_list[i], dash=line_list[i])))

fig.write_image("graphs/ROC_FOV.pdf")
fig.show()

In [None]:
# plot ROC curves for each magnification, comparing overall and nofat
fig = make_subplots(rows=1, cols=3, 
                    shared_yaxes=True, 
                    x_title="False Positive Rate", 
                    y_title="True Positive Rate",
                    subplot_titles=res_list)
fig.update_layout(height=370, width=900, margin = {'l':margin,'r':margin,'t':margin,'b':margin})
fig.update_layout(legend=dict(
    orientation="h",
    yanchor="bottom",
    y=-.27,
    xanchor="right",
    x=1
))
col = 1

for i, r in enumerate(results_list):
    legend = (col == 1)
    fig.add_trace(go.Scatter(x=[0,1], y=[0,1], line=dict(dash='dash', color='black'), showlegend=False),row=1,col=col)
    fig.add_trace(go.Scatter(x=r['overall']['fpr'], y=r['overall']['tpr'], name='All subclasses', line=dict(color=color_list[0], dash=line_list[0]), showlegend=legend),row=1,col=col)
    fig.add_trace(go.Scatter(x=r['nofat']['fpr'], y=r['nofat']['tpr'], name='Without Fat', line=dict(color=color_list[2], dash=line_list[2]), showlegend=legend),row=1,col=col)
    col+=1

fig.write_image("graphs/ROC_nofat.png")
fig.show()

In [None]:
# plot prevalance bar chart across magnifications
fig = make_subplots(rows=1, cols=12, 
                    shared_yaxes=True, 
                    x_title="Subclass", 
                    y_title="Percent of total testing set (log scale)",
                    specs=[[{}, {"colspan":4}, None, None, None, {"colspan":7}, None, None, None, None, None, None]])
fig.update_yaxes(type="log")
fig.update_layout(title="Subclass prevalance per magnification")
fig.update_layout(height=400, width=900, margin = {'l':65,'r':margin,'t':40,'b':margin})
fig.update_layout(barmode='group')
fig.update_layout(legend=dict(
    orientation="h",
    yanchor="bottom",
    y=-.15,
    xanchor="right",
    x=1
))

for i, r in enumerate(results_list):
    col = 1
    for g in groups:
        legend = col == 1
        per = [r[key]['total'] *100 / r['overall']['total'] for key in g]
        fig.add_trace(go.Bar(x=g, y=per, text=per, showlegend=legend, name=res_list[i], marker_color=color_list[i], marker_pattern_shape=bar_list[i], textposition = 'auto', textfont=dict(size=10,color='white')), row=1,col=col)
        col += len(g)
               
fig.update_traces(texttemplate='%{text:.2f}')
fig.write_image("graphs/subclass_prevalance.pdf")
fig.show()

In [None]:
# plot accuracy bar chart across magnifications
fig = make_subplots(rows=1, cols=12, 
                    shared_yaxes=True, 
                    x_title="Subclass", 
                    y_title="Accuracy (percent)",
                    specs=[[{}, {"colspan":4}, None, None, None, {"colspan":7}, None, None, None, None, None, None]])
fig.update_layout(title="Model accuracy per subclass and magnification")
fig.update_layout(height=400, width=900, margin = {'l':65,'r':margin,'t':40,'b':margin})
fig.update_layout(barmode='group')
fig.update_layout(legend=dict(
    orientation="h",
    yanchor="bottom",
    y=-.15,
    xanchor="right",
    x=1
))

for i, r in enumerate(results_list):
    col = 1
    for g in groups:
        legend = col == 1
        accuracy = [r[key]['accuracy'] for key in g]
        fig.add_trace(go.Bar(x=g, y=accuracy, text=accuracy, showlegend=legend, name=res_list[i], marker_color=color_list[i], marker_pattern_shape=bar_list[i], textposition = 'auto', textfont=dict(size=10,color='white')), row=1, col=col)
        col += len(g)
        
fig.update_traces(texttemplate='%{text:.3f}')
fig.write_image("graphs/subclass_accuracy.pdf")
fig.show()

In [None]:
# plot cancer threshold vs FNR curve for BrCA across magnificaions
fig = go.Figure()
fig.update_xaxes(range=[0, 1])
fig.update_layout(height=600,width=600)
fig.update_layout(title = "BrCA", xaxis_title = 'Cancer threshold', yaxis_title = 'False negative rate')
fig.update_layout(legend=dict(yanchor="top",y=0.99,xanchor="left",x=0.01,font=dict(size=16)))

for i, r in enumerate(results_list):
    fig.add_trace(go.Scatter(x=r['BrCA']['thresh'], y=1-r['BrCA']['tpr'], name=res_list[i], line=dict(color=color_list[i], dash=line_list[i])))

fig.show()

In [None]:
# plot cancer threshold vs FPR for each negative subclass across magnifcations 
# BrCA subplot
fig = make_subplots(rows=1, cols=1, subplot_titles=["BrCA"],x_title="Cancer threshold", y_title="False negative rate")
fig.update_layout(title="Cancer")
fig.update_layout(height=340, width=420)
fig.update_layout(legend=dict(yanchor="bottom", y=0.2, xanchor="left", x=1.25, font=dict(size=18)))
row = 1
col = 1
                                                    
for i, r in enumerate(results_list):
    fig.add_trace(go.Scatter(x=r['BrCA']['thresh'], y=1-r['BrCA']['tpr'], name=res_list[i], line=dict(color=color_list[i], dash=line_list[i])), row=row, col=col)
    fig.update_xaxes(range=[0,1], row=row, col=col)
    fig.update_yaxes(range=[0,1], row=row, col=col)                                                

fig.show()

# immune_cells subplots
fig = make_subplots(rows=1, cols=4, subplot_titles=immune_cells,x_title="Cancer threshold", y_title="False positive rate")
fig.update_layout(title="Immune Cells")
fig.update_layout(showlegend=False)
fig.update_layout(height=340, width=900)
row = 1
col = 1

for key in immune_cells:
    for i, r in enumerate(results_list):
        fig.add_trace(go.Scatter(x=r[key]['thresh'], y=r[key]['fpr'], line=dict(color=color_list[i], dash=line_list[i])), row=row, col=col)
        fig.update_xaxes(range=[0,1], row=row, col=col)
        fig.update_yaxes(range=[0,1], row=row, col=col)
    col += 1
    if col % 5 == 0:
        col = 1
        row += 1
fig.show()
    
# conn_tissue subplots
fig = make_subplots(rows=2, cols=4, 
                    subplot_titles=conn_tissue,x_title="Cancer threshold", 
                    y_title="False positive rate",
                    vertical_spacing=0.125)
fig.update_layout(title="Connective Tissue")
fig.update_layout(height=540, width=900)
row = 1
col = 1

for key in conn_tissue:
    for i, r in enumerate(results_list):
        fig.add_trace(go.Scatter(x=r[key]['thresh'], y=r[key]['fpr'], showlegend=False, name=res_list[i], line=dict(color=color_list[i], dash=line_list[i])), row=row, col=col)
        fig.update_xaxes(range=[0,1], row=row, col=col)
        fig.update_yaxes(range=[0,1], row=row, col=col)
    col += 1
    if col % 5 == 0:
        col = 1
        row += 1       
fig.show()

In [None]:
# plot cancer threshold vs FPR curve for benign short_names for each magnification
for i, r in enumerate(results_list):
    fig = go.Figure()
    fig.update_xaxes(range=[0, 1])
    fig.update_layout(height=600,width=600, margin = {'l':margin,'r':margin,'t':margin,'b':margin})
    fig.update_title(res_list[i])
    fig.update_layout(xaxis_title = 'Cancer threshold', yaxis_title = 'False positive rate')
    for key in [x for x in short_names if x != 'BrCA']:
        fig.add_trace(go.Scatter(x=r[key]['thresh'], y=r[key]['fpr'], name=key, line=dict(color=color_dict[key], dash=line_dict[key])))
    fig.show()