In [1]:
import matplotlib.pyplot as plt
import os
import pandas as pd

In [2]:
def compute_diff_counts(df, metric_name):
    """ 
    Computes the metric_name score difference between transported target \
        score and target score for df if the target score does not exceed the source score
    """
    
    source_scores = df["source_"+metric_name]
    target_scores = df["target_"+metric_name]
    trans_target_scores = df["trans_target_"+metric_name]
    diff_scores = []
    for source_score, target_score, trans_target_score in zip(source_scores, target_scores, trans_target_scores):
        if source_score > target_score:
            diff_scores.append(trans_target_score-target_score)
    return diff_scores


In [3]:
# Analyzing EHR-OT performance individually
score_dir = "/home/wanxinli/EHR-OT/outputs/mimic/"

f1_diff_counts = []
recall_diff_counts = []
precision_diff_counts = []

f1_degrade_code = []
precision_degrade_code = []
recall_degrade_code = []
code_count = 0

for file in os.listdir(score_dir):
    if file.endswith("_score.csv") and 'exp3_' in file and 'TCA' not in file and "ind" not in file and 'MMD' not in file:
        code_count += 1
        score_df = pd.read_csv(os.path.join(score_dir, file), index_col=None, header=0)

        f1_diffs = compute_diff_counts(score_df, 'f1')
        f1_improve_count = [x>0 for x in f1_diffs].count(True)

        if f1_improve_count <= len(f1_diffs)/2:
            f1_degrade_code.append(file.split("_")[1])

        precision_diffs = compute_diff_counts(score_df, 'precision')
        precision_improve_count = [x>0 for x in precision_diffs].count(True)

        if precision_improve_count <= len(precision_diffs)/2:
            precision_degrade_code.append(file.split("_")[1])
        
        recall_diffs = compute_diff_counts(score_df, 'recall')
        recall_improve_count = [x>0 for x in recall_diffs].count(True)

        if recall_improve_count <= len(recall_diffs)/2:
            recall_degrade_code.append(file.split("_")[1])


In [4]:
print("code count is:", code_count)
print("f1 improve code percent is:", 1-len(f1_degrade_code)/code_count)
print("precision improve code percent is:", 1-len(precision_degrade_code)/code_count)
print("recall improve code percent is:", 1-len(recall_degrade_code)/code_count)
# print(sorted(f1_degrade_code))

code count is: 243
f1 improve code percent is: 0.7078189300411523
precision improve code percent is: 0.51440329218107
recall improve code percent is: 0.6625514403292181


In [5]:
def determine_degrade(label_code, trans_metric, eval_metric):
    """ 
    Determine whether a label code (label_code) has degraded performance using
    tranporint metric (trans_metric) and in terms of evaluation metric (eval_metric)
    """

    score_path = os.path.join(score_dir, f"exp3_{label_code}_{trans_metric}_score.csv")
    score_df = pd.read_csv(score_path, index_col=None, header=0)

    diffs = compute_diff_counts(score_df, eval_metric)
    improve_count = [x>0 for x in diffs].count(True)

    return improve_count <= len(diffs)/2

In [6]:
# Comparing TCA versus EHR-OT
score_dir = "/home/wanxinli/EHR-OT/outputs/mimic/"

f1_diff_counts = []
TCA_diff_counts = []

OT_f1_degrade_codes = []
OT_recall_degrade_codes = []
TCA_f1_degrade_codes = []
TCA_recall_degrade_codes = []
MMD_f1_degrade_codes = []
MMD_recall_degrade_codes = []

code_count = 0

for file in os.listdir(score_dir):
    # ind stands for independantly applying TCA without PCA step in EHR-OT
    if file.endswith("MMD_score.csv") and 'exp3_' in file:
        code_count += 1
        label_code = file.split('_')[1]

        if determine_degrade(label_code, 'OT', 'f1'):
            OT_f1_degrade_codes.append(label_code)
        if determine_degrade(label_code, 'OT', 'recall'):
            OT_recall_degrade_codes.append(label_code)

        if determine_degrade(label_code, 'TCA', 'f1'):
            TCA_f1_degrade_codes.append(label_code)
        if determine_degrade(label_code, 'TCA', 'recall'):
            TCA_recall_degrade_codes.append(label_code)
        
        if determine_degrade(label_code, 'MMD', 'f1'):
            MMD_f1_degrade_codes.append(label_code)
        if determine_degrade(label_code, 'MMD', 'recall'):
            MMD_recall_degrade_codes.append(label_code)


In [7]:
print(f"total count is: {code_count}")
print(f"OT f1 degrade count: {len(OT_f1_degrade_codes)}, EHR-OT f1 degrade percent: {len(OT_f1_degrade_codes)/code_count}, EHR-OT recall degrade percent: {len(OT_recall_degrade_codes)/code_count}")
print(f"TCA f1 degrade count: {len(TCA_f1_degrade_codes)}, TCA f1 degrade percent: {len(TCA_f1_degrade_codes)/code_count}, TCA recall degrade percent: {len(TCA_recall_degrade_codes)/code_count}")
print(f"MMD f1 degrade count: {len(MMD_f1_degrade_codes)}, TCA f1 degrade percent: {len(MMD_f1_degrade_codes)/code_count}, MMD recall degrade percent: {len(MMD_recall_degrade_codes)/code_count}")



total count is: 243
OT f1 degrade count: 71, EHR-OT f1 degrade percent: 0.29218106995884774, EHR-OT recall degrade percent: 0.3374485596707819
TCA f1 degrade count: 38, TCA f1 degrade percent: 0.15637860082304528, TCA recall degrade percent: 0.6419753086419753
MMD f1 degrade count: 135, TCA f1 degrade percent: 0.5555555555555556, MMD recall degrade percent: 0.5102880658436214
