In [12]:
import matplotlib.pyplot as plt
import os
import pandas as pd

In [13]:
def compute_diff_counts(df, metric_name):
    """ 
    Computes the metric_name score difference between transported target \
        score and target score for df if the target score does not exceed the source score
    """
    
    source_scores = df["source_"+metric_name]
    target_scores = df["target_"+metric_name]
    trans_target_scores = df["trans_target_"+metric_name]
    diff_scores = []
    for source_score, target_score, trans_target_score in zip(source_scores, target_scores, trans_target_scores):
        if source_score > target_score:
            diff_scores.append(trans_target_score-target_score)
    return diff_scores


In [14]:
# For analyzing EHR-OT performance alone
score_dir = "/home/wanxinli/EHR-OT/outputs/mimic/"

f1_diff_counts = []
recall_diff_counts = []
precision_diff_counts = []

f1_degrade_code = []
precision_degrade_code = []
recall_degrade_code = []
code_count = 0

for file in os.listdir(score_dir):
    if file.endswith("_score.csv") and 'exp3_' in file and 'TCA' not in file and "ind" not in file:
        code_count += 1
        score_df = pd.read_csv(os.path.join(score_dir, file), index_col=None, header=0)

        f1_diffs = compute_diff_counts(score_df, 'f1')
        f1_improve_count = [x>0 for x in f1_diffs].count(True)

        if f1_improve_count <= len(f1_diffs)/2:
            f1_degrade_code.append(file.split("_")[1])

        precision_diffs = compute_diff_counts(score_df, 'precision')
        precision_improve_count = [x>0 for x in precision_diffs].count(True)

        if precision_improve_count <= len(precision_diffs)/2:
            precision_degrade_code.append(file.split("_")[1])
        
        recall_diffs = compute_diff_counts(score_df, 'recall')
        recall_improve_count = [x>0 for x in recall_diffs].count(True)

        if recall_improve_count <= len(recall_diffs)/2:
            recall_degrade_code.append(file.split("_")[1])


In [15]:
print("code count is:", code_count)
print("f1 improve code percent is:", 1-len(f1_degrade_code)/code_count)
print("precision improve code percent is:", 1-len(precision_degrade_code)/code_count)
print("recall improve code percent is:", 1-len(recall_degrade_code)/code_count)

code count is: 244
f1 improve code percent is: 0.7049180327868853
precision improve code percent is: 0.5122950819672132
recall improve code percent is: 0.6598360655737705


In [19]:
# For comparing TCA versus EHR-OT
score_dir = "/home/wanxinli/EHR-OT/outputs/mimic/"

f1_diff_counts = []
TCA_diff_counts = []

f1_degrade_code = []
recall_degrade_code = []
TCA_f1_degrade_code = []
TCA_recall_degrade_code = []

code_count = 0

for file in os.listdir(score_dir):
    # ind stands for independantly applying TCA without PCA step in EHR-OT
    if file.endswith("TCA_ind_score.csv") and 'exp3_' in file:
        code_count += 1
        label_code = file.split('_')[1]

        score_path = os.path.join(score_dir, "exp3_"+label_code+"_score.csv")
        score_df = pd.read_csv(score_path, index_col=None, header=0)

        # f1
        f1_diffs = compute_diff_counts(score_df, 'f1')
        f1_improve_count = [x>0 for x in f1_diffs].count(True)

        if f1_improve_count <= len(f1_diffs)/2:
            f1_degrade_code.append(label_code)

        # recall
        recall_diffs = compute_diff_counts(score_df, 'recall')
        recall_improve_count = [x>0 for x in recall_diffs].count(True)

        if recall_improve_count <= len(recall_diffs)/2:
            recall_degrade_code.append(label_code)

        TCA_score_path = os.path.join(score_dir, file)
        TCA_score_df = pd.read_csv(TCA_score_path, index_col=None, header=0)

        # f1
        TCA_f1_diffs = compute_diff_counts(TCA_score_df, 'f1')
        TCA_f1_improve_count = [x>0 for x in TCA_f1_diffs].count(True)

        if TCA_f1_improve_count <= len(TCA_f1_diffs)/2:
            TCA_f1_degrade_code.append(label_code)
        
        # recall
        TCA_recall_diffs = compute_diff_counts(TCA_score_df, 'recall')
        TCA_recall_improve_count = [x>0 for x in TCA_recall_diffs].count(True)

        if TCA_recall_improve_count <= len(TCA_recall_diffs)/2:
            TCA_recall_degrade_code.append(label_code)
        


In [20]:
print(f"total count is: {code_count}")
print(f"TCA f1 degrade percent: {len(TCA_f1_degrade_code)}, TCA f1 degrade count: {len(TCA_f1_degrade_code)/code_count}")
print(f"EHR-OT f1 degrade percent: {len(f1_degrade_code)}, EHR-OT f1 degrade count: {len(f1_degrade_code)/code_count}")

total count is: 243
TCA f1 degrade percent: 44, TCA f1 degrade count: 0.18106995884773663
EHR-OT f1 degrade percent: 71, EHR-OT f1 degrade count: 0.29218106995884774


In [21]:
print(f"TCA recall degrade percent: {len(TCA_recall_degrade_code)}, TCA recall degrade count: {len(TCA_recall_degrade_code)/code_count}")
print(f"EHR-OT recall degrade percent: {len(recall_degrade_code)}, EHR-OT recall degrade count: {len(recall_degrade_code)/code_count}")

TCA recall degrade percent: 148, TCA recall degrade count: 0.6090534979423868
EHR-OT recall degrade percent: 82, EHR-OT recall degrade count: 0.3374485596707819
