In [None]:
%matplotlib inline
import os
import matplotlib.pyplot as plt
import numpy as np
import collections
import re
import pandas as pd
from typing import Dict, List, Any
from sklearn.metrics import roc_curve, auc, accuracy_score, roc_auc_score
from sklearn.preprocessing import label_binarize

def parse_log_line(line: str) -> Dict[str, Any]:
    """
    Extract pid, target label, and predicted label from log line
    """
    key_mapping = {"pid: ": "pid", 
                   "target: ": "target", 
                   "predicted: ": "predicted"}
    info = {}
    for key in key_mapping:
        try:
            info[key_mapping[key]] = float(re.split(r'(\s|\,)', line.split(key)[1])[0])
        except:
            info[key_mapping[key]] = re.split(r'(\s|\,)', line.split(key)[1])[0]
        if key_mapping[key] == "target" or key_mapping[key] == "predicted":
            info[key_mapping[key]] = int(info[key_mapping[key]])
    return info

def calculate_patient_label(data: List[str], include_patient_ids: List[str] = None, num_classes: int = 2):
    """
    Calculates patient-level prediction for multiclass classification
    on tiles of an image.

    Returns dict of tile predictions for each patient (n), patient predictions (as probabilities) (n, c),
    patient_targets (n), and list of patients ids (n)
    """
    # Consolidate all predictions on tiles from a patient
    # into a list, and store that list in preprocessed_data
    consolidated_data = collections.defaultdict(list)
    for line in data:
        pid = line["pid"]
        try: # if patient id in long format
            patient_id = "-".join(line["pid"].split("-")[:4])
        except: # if patient id in short format
            patient_id = line["pid"]
            
        if include_patient_ids is not None and patient_id not in include_patient_ids:
            continue
            
        consolidated_data[patient_id].append(line)
    
    patient_targets = []
    patient_ids = []
    include_patient_ids = include_patient_ids or consolidated_data.keys()
    patient_prediction_scores = np.zeros((len(include_patient_ids), num_classes))
    
    # Tabulate predictions on each tile for a patient
    for i, patient_id in enumerate(include_patient_ids):
        # Iterate over each patient present in the logs
        if patient_id in consolidated_data:
            patient_dict = consolidated_data[patient_id]
            # Iterate over each tile in each patient
            for image_dict in patient_dict:
                patient_prediction_scores[i, image_dict["predicted"]] += 1
            # Calculate probability of each class
            patient_prediction_scores[i] = patient_prediction_scores[i] / len(patient_dict)

        patient_targets.append(image_dict["target"])
        patient_ids.append(patient_id)
    return consolidated_data, patient_prediction_scores, patient_targets, patient_ids


def get_predictions_from_logs(path: str, include_patient_ids: List = None, num_classes:int = 2):
    """
    Extracts tile-level predictions from logs, then returns
    patient-level prediction after consolidating predictions from tiles
    """
    data_filter = lambda x: "mode: predict" in x and "target:" in x

    data = []
    if os.path.isdir(path):
        for each in os.listdir(path):
            data.extend(open(os.path.join(path, each)).read().split("\n"))
    else:
        data = open(path).read().split("\n")

    data = filter(data_filter, data)
    data = [parse_log_line(d) for d in data]
    data, patient_prediction_scores, patient_targets, patient_ids = calculate_patient_label(data, include_patient_ids, num_classes)
    return data, patient_prediction_scores, patient_targets, patient_ids

def plot_roc(path: str, plot_label: str, include_patient_ids: List[str] = None, num_classes: int = 2, linestyle: str = None, class_labels: List = ["msi-h", "msi-l", "mss"]):
    """
    Plots ROC for patient-level prediction. Extracts predictions from logs,
    assigns patient-level predictions, then calculates ROC, AUROC, and plots ROC
    for each class in multiclass case, or just one ROC for binary case
    """
    labels, patient_prediction_scores, patient_targets, patient_ids = get_predictions_from_logs(path, include_patient_ids, num_classes)
    
    patient_predictions = np.argmax(patient_prediction_scores, axis=1)
    acc = accuracy_score(patient_targets, patient_predictions)

    # If not binary classification, binarize patient_targets
    if num_classes > 2:
        class_options = list(range(num_classes))
        binary_patient_targets = label_binarize(patient_targets, class_options)
        fprs = []
        tprs = []
        thresholds = []
        aurocs = np.zeros(num_classes)
        # Calculate fpr and tpr for each class
        for i in range(num_classes):
            # roc curve will calculate one vs all, so we supply probability of being in target class as a score
            fpr, tpr, threshold = roc_curve(binary_patient_targets[:, i], patient_prediction_scores[:, i])
            fprs.append(fpr)
            tprs.append(tpr)
            thresholds.append(threshold)
            aurocs[i] = auc(fpr, tpr)
            plt.plot(fpr, tpr,
                     label=plot_label + f" | AUC={round(auc(fpr, tpr),3)}" + f" | class={class_labels[i]}",
                     linewidth=3, linestyle=linestyle)
        
        # ovr: one-vs-rest.
        # macro calculates metrics for each label, then finds unweighted mean (doesn't weight class imbalance)
        # wegihted is macro with class balancing
        aurocs_macro_score = roc_auc_score(patient_targets, patient_prediction_scores, average='macro', multi_class='ovr')
        aurocs_weighted_score = roc_auc_score(patient_targets, patient_prediction_scores, average='weighted', multi_class='ovr')
        print(plot_label + f" | AUROC={round(aurocs_macro_score, 3)} | weighted AUROC={round(aurocs_weighted_score, 3)} | ACC={round(acc, 4)*100}")

    else:
        fpr, tpr, thresholds = roc_curve(patient_targets, patient_predictions)
        patient_predictions = np.array(patient_predictions)
        patient_predictions[patient_predictions > 0.5] = 1
        patient_predictions[patient_predictions <= 0.5] = 0
        
        plt.plot(fpr, tpr, label=plot_label + f" | AUC={round(auc(fpr, tpr),3)} | ACC={round(acc, 4)}", linewidth=3)
    
    return labels, patient_prediction_scores, patient_targets, patient_ids

## Example Usage

In [None]:
# Plotting setup
fig = plt.figure()
fig.set_size_inches(12, 12)

# Set path to logs from model prediction
prediction_logs_path = "path_to_predictions"
_, _, _, _ = plot_roc(prediction_logs_path, "prediction_title")

# Plot settings
plt.xlabel("FPR", fontsize=16)
plt.ylabel("TPR", fontsize=16)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.title("ROC 512 x 512: Class Balancing Runs | No dropout | LR=1e-5 | WD=1e-4", fontsize=18)
plt.legend(prop={'size': 16})