In [2]:
%load_ext autoreload
%autoreload 2

In [1]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import jupyter_black

jupyter_black.load(
    line_length=100,
)

from model.infere import inference
from model.metrics import Metrics
from model.main import CNN
from utils.plotting.sparse_bagnet_visualization import (
    plot_heatmap_on_image,
    plot_image_and_patches,
)
from utils.plotting.get_images import GetPlotImages
from utils.plotting.util import add_subplot_labels
from utils.helpers import get_config

os.environ["WANDB__SERVICE_WAIT"] = "300"

In [5]:
CONFIG = "configs/sparsebagnet_cox.yml"
c = get_config(CONFIG)

In [1]:
device = "cuda"
c.cnn.gpu = 0
c.cnn.load_best_model = True
c.cnn.resume_training = False

estimator = Metrics(c)
cnn = CNN(c=c, estimator=estimator)

In [7]:
filter_for_images = [
    "2146_04_RE_F2_LS.jpg",
    "55684_06_F2_RE_LS.jpg",
    "55684_14_F2_RE_LS.jpg",
    "58684_08_F2_RE_RS.jpg",
    "58684_10_F2_RE_RS.jpg",
    "58684_12_F2_RE_RS.jpg",
    "3591_04_RE_F2_RS.jpg",
    "61691_06_F2_RE_RS.jpg",
    "3591_12_RE_F2_RS.jpg",
    "52832_08_F2_LE_RS.jpg"
]

In [8]:
df = GetPlotImages(c=c, converters_only=False, len=None).get_image_df()
df = df[df["image_file"].isin(filter_for_images)]

Len data (Number of images in test split): 10589


In [13]:
# Set target folder for plots
today = pd.Timestamp("today").strftime("%Y-%m-%d")
p = f"figures/{today}/{c.cnn.run_id}/"
os.makedirs(p, exist_ok=True)

# Set target file for joint logits, predictions, and ground truth
all_logits_csv = os.path.join(p, "all_logits.csv")

In [103]:
# Run inference to create logits df
d = df.copy()
d.reset_index(drop=True, inplace=True)
## inference -> save logit to file alongside with image path
(
    logits,
    activations,
    survival_curves,
    images_cropped,
    images_cropped_hires,
    images_normalized,
    img_size,
    img_size_hires,
) = inference(d, c, cnn)

## Save logits to file
# Create or load joint logit file
if os.path.isfile(all_logits_csv):
    all_logits_df = pd.read_csv(all_logits_csv, index_col=0)
else:
    all_logits_df = pd.DataFrame(
        columns=["patient_id", "image_file", "logit", "survival_curve", "years_to_event", "grade_zero_based"]
    )

# Save logit and survival predictions to file alongside with image path
logits = logits.cpu().numpy().tolist()
for i in range(len(d)):
    logit_row = pd.DataFrame(
        {
            "patient_id": d.loc[i, "patient_id"],
            "image_file": d.loc[i, "image_file"],
            "logit": logits[i],
            "survival_curve": str(survival_curves[i]),
            "years_to_event": d.loc[i, "duration"] / 2 if d.loc[i, "event"] else np.nan,
            "grade_zero_based": int(d.loc[i, "diagnosis_amd_grade_12c"]),
        }
    )
    all_logits_df = pd.concat([all_logits_df, logit_row], ignore_index=True)

all_logits_df.sort_values(by="logit", ascending=False).reset_index(drop=True).to_csv(all_logits_csv)

all_testset_logits = pd.read_csv(all_logits_csv, index_col=0)

Create Plots

In [15]:
## Settings for plotting aside from figures/style.txt
pwidth = 4.8
fsize = 6

In [4]:
# Fig 2

height = 2 / 3 * pwidth - 0.3

with_arrow = True


def plot_one_curve(i, row, ax, arrow=True):
    print("row", row, "i", i)
    curve = survival_curves[i]
    start = 1 + plot_df.loc[i, "visit_number_rel"] / 2
    end = start + len(curve) - 1
    print(start, end)
    years = np.arange(start, end + 1)
    ax.plot(years, curve, color=colors[row], clip_on=False, label=f"Example {row + 1}", linewidth=1)

    if arrow:
        # Get time of conversion
        m = 7  # caretdown
        conversion_year = start + plot_df.loc[i, "duration"] / 2
        print("conversion year", conversion_year)
        curve += [0] * 10
        y = curve[int(conversion_year) - 1] + 0.008
        ax.scatter(
            conversion_year,
            y,
            color=colors[row],
            marker=m,
            clip_on=False,
            label=f"Conversion\nExample {row + 1}",
        )


plot_3_alpha = 0.75
pred_to_plot = 1  # 0, 1, 2

if pred_to_plot == 0:
    tr = "min"
elif pred_to_plot == 1:
    tr = "median"
elif pred_to_plot == 2:
    tr = "max"


plot_imgs = [
    "2146_04_RE_F2_LS.jpg",
    "55684_06_F2_RE_LS.jpg",
    "55684_14_F2_RE_LS.jpg",
    "58684_08_F2_RE_RS.jpg",
    "58684_10_F2_RE_RS.jpg",
    "58684_12_F2_RE_RS.jpg",
    "3591_04_RE_F2_RS.jpg",
    "61691_06_F2_RE_RS.jpg",
    "3591_12_RE_F2_RS.jpg",
]
plot_df = df[df["image_file"].isin(plot_imgs)]
plot_df = plot_df.sort_values(["patient_id", "visit_number"], ascending=True).reset_index(drop=True)
plot_df = plot_df.iloc[[0, 1, 2, 6, 7, 8, 3, 4, 5]].reset_index(drop=True)
plot_df["adj_visit_number"] = plot_df.groupby("patient_id")["visit_number"].transform(tr)
plot_df["visit_number_rel"] = plot_df["visit_number"] - plot_df["adj_visit_number"]

display(plot_df[["patient_id", "visit_number", "visit_number_rel", "duration"]])

(
    logits,
    activations,
    survival_curves,
    images_cropped,
    images_cropped_hires,
    images_normalized,
    img_size,
    img_size_hires,
) = inference(plot_df, c, cnn)

with plt.style.context("figures/style.txt"):
    # Set layout for Plot 3
    plt.rcParams.update(
        {
            "figure.constrained_layout.h_pad": 0.001,  # padding around axes. default is 0.04167
            "figure.constrained_layout.w_pad": 0.01,
            "figure.constrained_layout.hspace": 0.0,  # padding between axes. default is 0.02
            "figure.constrained_layout.wspace": 0.0,
        }
    )
    mosaic = """wABC#
                wDEF#
                wGHI#"""
    letter_to_img_index = {l: i for i, l in enumerate("ABCDEFGHI#w")}
    figsize = (pwidth, height)
    gridspec_kw = {
        "wspace": 0,
        "hspace": 0,
        "width_ratios": [0.15, 1, 1, 1, 1.5],
        "height_ratios": [1, 1, 1],
    }
    axd = plt.figure(figsize=figsize).subplot_mosaic(mosaic, gridspec_kw=gridspec_kw)

    colors = [plt.get_cmap("tab10")(i) for i in [0, 1, 3]]

    max_year = 7

    # Squeeze height of curves plot
    aspect = height * pwidth * 0.81
    axd["#"].set_aspect(aspect)

    print(axd)
    for i, (l, ax) in enumerate(axd.items()):
        i = letter_to_img_index[l]

        with_legend = l in ["C", "F", "I"]

        if l == "w":
            ax.axis("off")
            continue

        if l == "#":
            # Plot individual survival curves
            i, i_row = pred_to_plot, 0
            j, j_row = i + 3, 1
            k, k_row = j + 3, 2

            plot_one_curve(i, i_row, ax, arrow=False)
            plot_one_curve(j, j_row, ax)
            plot_one_curve(k, k_row, ax)

            # Refine plot
            ylim = (0, 1)
            yticks = [0, 1]
            xlim = (0, max_year)
            ax.set_ylim(*ylim)
            ax.set_xlim(*xlim)
            ax.set_xticks(list(range(0, max_year + 1)))
            ax.set_yticks(yticks)

            ax.set_xlabel("Year")
            ax.text(
                -0.08,
                0.5,
                "S(t)",
                horizontalalignment="left",
                verticalalignment="center",
                transform=ax.transAxes,
                fontsize=fsize,
                rotation=90,
            )

            # Create mock line in black and mock marker in black
            ax.plot([0], [0], color="gray", label="Example\nprediction")
            if with_arrow:
                ax.scatter(
                    [-2],
                    [-2],
                    color="gray",
                    marker=7,
                    label="Conversion",
                    linewidth=1,
                    clip_on=False,
                )

            handles, labels = ax.get_legend_handles_labels()
            ax.legend(
                handles=handles[-2:],
                labels=labels[-2:],
                frameon=False,
                fontsize=fsize,
                loc="upper right",
                bbox_to_anchor=(1.01, 0.9),
            )

        elif l in ["A", "B", "C", "D", "E", "F", "G", "H", "I"]:
            if l in ["A", "B", "C"]:
                color = colors[0]
            elif l in ["D", "E", "F"]:
                color = colors[1]
            else:
                color = colors[2]

            plot_heatmap_on_image(
                images_cropped[i],
                activations[i],
                with_legend=with_legend,
                ax=ax,
                shrink_colorbar=0.8,
                alpha=plot_3_alpha,
                border_color=color,
                border_width=8,
            )

            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_xticks([])
            ax.axis("off")

        else:
            raise ValueError("Not implemented")

    # Add panel labels a and b
    add_subplot_labels([axd["A"]], to_left=-0.15, to_top=-0.06)
    add_subplot_labels([axd["#"]], to_left=-0.13, to_top=0.0675, i_start=1)

    for suff in ["pdf", "png"]:
        plt.savefig(p + f"Fig2.{suff}")
    plt.close("all")

In [5]:
# Fig 3

plot_imgs = [
    "52832_08_F2_LE_RS.jpg",
    "2146_04_RE_F2_LS.jpg",
]
plot_df = df[df["image_file"].isin(plot_imgs)]
plot_df = plot_df.set_index("image_file").loc[plot_imgs].reset_index()
display(plot_df)

(
    logits,
    activations,
    survival_curves,
    images_cropped,
    images_cropped_hires,
    images_normalized,
    img_size,
    img_size_hires,
) = inference(plot_df, c, cnn)

with plt.style.context("figures/style.txt"):
    plt.rcParams.update(
        {
            "figure.constrained_layout.h_pad": 0.01,  # padding around axes. default is 0.04167
            "figure.constrained_layout.w_pad": 0.01,
            "figure.constrained_layout.hspace": 0.01,  # padding between axes. default is 0.02
            "figure.constrained_layout.wspace": 0.01,
        }
    )
    ncols = len(plot_imgs)
    figsize = (pwidth, 1.76)
    fig = plt.figure(figsize=figsize)
    figs = fig.subfigures(1, 2 * ncols, width_ratios=[2, 0.85] * ncols)
    figs = figs.flatten()

    for i_img, i_ax in zip(list(range(0, len(plot_imgs))), list(range(0, ncols * 2, 2))):
        # print(i_img, i_ax)
        ax0 = figs[i_ax].subplots(1, 1, gridspec_kw={"hspace": 0.025, "wspace": 0.025})
        axs1 = figs[i_ax + 1].subplots(4, 2, gridspec_kw={"hspace": 0.025, "wspace": 0.025})
        plot_image_and_patches(
            images_cropped_hires[i_img],
            activations[i_img],
            k=6,
            bb_size=66,
            labels=[0, 1],
            axs=[ax0, axs1],
            return_axs=False,
            labels_fontsize=fsize - 1,
            labels_above_image=False,
            add_colormarker=True,
        )
    for suff in ["pdf", "png"]:
        plt.savefig(p + f"Fig3" + "." + suff)
    plt.close("all")