In [64]:
import pandas as pd
import os

from vocab import Vocab
from constants import *

# <span style="text-decoration: underline">Statistics</span>
This is the file where statistics are calculated

#### <u>Stat-1: Method to calculate % of visits that have higher than certain threshold % of "[unk]" tokens </u>

In [65]:
def percentage_visits_with_high_unk(vocab: Vocab, threshold: float = 0.3, for_single_vsit: bool = False):
    visit_record_key = ["SUBJECT_ID", "HADM_ID", "ICD9_CODE", "ATC4"]
    visit = pd.read_pickle(os.path.join(GLOBAL_DATA_PATH, MULTI_VISIT_PKL))[visit_record_key[1:]]
    if for_single_vsit is True:
        visit = pd.read_pickle(os.path.join(GLOBAL_DATA_PATH, SINGLE_VISIT_PKL))[visit_record_key[1:]]
    total = visit.shape[0]
    def process(row):
        icd: list[str] = row["ICD9_CODE"]
        unk = 0
        for code in icd:
            if vocab.word2idx.get(code) == None:
                unk += 1
        if unk/len(icd) > threshold:
            print(row["HADM_ID"])
            return 1
    temp = visit.apply(process, axis=1)
    count: int = temp.loc[temp == 1].shape[0]
    return count/total

#### <u>Stat-2: Method to calculate stats of Table 2 from Original paper </u>

In [67]:
def table_2_stats():
    visit_record_key = ["SUBJECT_ID", "HADM_ID", "ICD9_CODE", "ATC4"]
    single_visit = pd.read_pickle(os.path.join(GLOBAL_DATA_PATH, SINGLE_VISIT_PKL))
    multi_visit_temporal = pd.read_pickle(os.path.join(GLOBAL_DATA_PATH, MULTI_VISIT_TEMPORAL_PKL))
    multi_visit = pd.read_pickle(os.path.join(GLOBAL_DATA_PATH, MULTI_VISIT_PKL))

    # Single visit stats
    single_visit_sbj_count = single_visit[visit_record_key[0]].nunique()
    print("[table_2_stats] # of patients (Single-Visit)", single_visit_sbj_count)
    total_visits = single_visit[visit_record_key[:2]].groupby("SUBJECT_ID").count().sum().values[0]
    print("[table_2_stats] avg. # of visits (Single-Visit)", total_visits/single_visit_sbj_count)
    total_dx = single_visit[visit_record_key[2]].apply(len).sum()
    print("[table_2_stats] avg. # of dx per visit per patient (Single-Visit)", total_dx/total_visits)
    total_rx = single_visit[visit_record_key[3]].apply(len).sum()
    print("[table_2_stats] avg. # of rx per visit per patient (Single-Visit)", total_rx/total_visits)
    unique_dx = len(set(code for ls in single_visit[visit_record_key[2]] for code in ls))
    print("[table_2_stats] unique # of dx (Single-Visit)", unique_dx) # 7k since we do not filter patient data
    # we replace in training/eval less frequent DX with "[UNK]" && similar for RX
    unique_rx = len(set(code for ls in single_visit[visit_record_key[3]] for code in ls))
    print("[table_2_stats] unique # of drx (Single-Visit)", unique_rx)

    # Multiple visit stats
    multi_visit_sbj_count = multi_visit_temporal[visit_record_key[0]].nunique()
    print("\n\n[table_2_stats] # of patients (Multi-Visit)", multi_visit_sbj_count)
    total_multi_visits = (multi_visit_temporal.groupby(by=["SUBJECT_ID"])["T_1"].max() + 1).sum()
    print("[table_2_stats] avg. # of visits (Multi-Visit)", total_multi_visits/multi_visit_sbj_count)
    def process_icd(row: pd.Series):
        row["ICD9_LEN"] = sum([len(ls) for ls in row["ICD9_CODE"]])
        return row
    # this can be improved but not P1    
    total_multi_dx = multi_visit_temporal.apply(process_icd, axis=1).groupby("SUBJECT_ID")["ICD9_LEN"].max().sum() 
    print("[table_2_stats] avg. # of dx per visit per patient (Multi-Visit)(This includes the first visit dx as well)"\
        , total_multi_dx/total_multi_visits) # though it is fall as seen in paper but it is because we are including both 
    def process_atc(row: pd.Series):
        row["ATC4_LEN"] = sum([len(ls) for ls in row["ATC4"]])
        return row
    # this can be improved but not P1    
    total_multi_rx = multi_visit_temporal.apply(process_atc, axis=1).groupby("SUBJECT_ID")["ATC4_LEN"].max().sum() 
    print("[table_2_stats] avg. # of rx per visit per patient (Multi-Visit)(This includes the first visit rx as well)"\
        , total_multi_rx/total_multi_visits) # though it is fall as seen in paper but it is because we are including both 
    unique_dx_multi = set()
    def store_unique_icd(row: pd.Series):
        unique_dx_multi.update(code for ls in row["ICD9_CODE"] for code in ls)
        return row
    multi_visit_temporal.apply(store_unique_icd, axis=1) 
    print("[table_2_stats] unique # of dx (Multi-Visit)", len(unique_dx_multi))
    unique_rx_multi = set()
    def store_unique_atc(row: pd.Series):
        unique_rx_multi.update(code for ls in row["ATC4"] for code in ls)
        return row
    multi_visit_temporal.apply(store_unique_atc, axis=1) 
    print("[table_2_stats] unique # of rx (Multi-Visit)", len(unique_rx_multi))

table_2_stats()

[table_2_stats] # of patients (Single-Visit) 29189
[table_2_stats] avg. # of visits (Single-Visit) 1.0
[table_2_stats] avg. # of dx per visit per patient (Single-Visit) 11.444276953646922
[table_2_stats] avg. # of rx per visit per patient (Single-Visit) 21.20703004556511
[table_2_stats] unique # of dx (Single-Visit) 6191
[table_2_stats] unique # of drx (Single-Visit) 398


[table_2_stats] # of patients (Multi-Visit) 5917
[table_2_stats] avg. # of visits (Multi-Visit) 2.677539293560926
[table_2_stats] avg. # of dx per visit per patient (Multi-Visit)(This includes the first visit dx as well) 14.217572429464116
[table_2_stats] avg. # of rx per visit per patient (Multi-Visit)(This includes the first visit rx as well) 22.555071640472132
[table_2_stats] unique # of dx (Multi-Visit) 4551
[table_2_stats] unique # of rx (Multi-Visit) 385
