In [None]:
%run default-imports.ipynb
from sklearn.preprocessing import KBinsDiscretizer
from sklearn.metrics import normalized_mutual_info_score

In [None]:
cohorts = {
    'MIMIC' : { 'title' : "MIMIC-III", 'filename' : "./experiments/experiments_rfe_mimic.d" },
    'SINAI' : { 'title' : "Mt. Sinai", 'filename' : "./experiments/experiments_rfe_mimic.d" },
    'DHZB'  : { 'title' : "German Heart Center", 'filename': "./experiments/experiments_rfe_sinai.d" },
}

In [None]:
''' mimic becomes the reference column '''
reference_columns = Load().execute(filename=cohorts['MIMIC']['filename']).columns

''' according to the order in the mimic file '''
numeric_columns = [1, 3] + list(range(34, 104))

''' for storing the data necessary to plot '''
nmi_matrices = unpickle('./experiments/nmi_matrices.d') or {}
nmi_differences = unpickle('./experiments/nmi_differences.d') or defaultdict(lambda: {})

In [None]:
''' computes the normalized mutual information (NMI) matrix '''
def compute_nmi(data):    
    discretizer = KBinsDiscretizer(n_bins=20, encode='ordinal', strategy='quantile')
    ''' necessary in case a given column is not available '''
    for column in [column for column in reference_columns if not column in data.columns]:
        data[column] = np.zeros(data.shape[0])
    data_discretized = data[reference_columns]
    data_discretized.iloc[numeric_columns] = discretizer.fit_transform(data.iloc[numeric_columns])
    nmi_matrix = np.zeros((len(data_discretized.columns), len(data_discretized.columns)))

    for i in range(len(data_discretized.columns)):
        for j in range(len(data_discretized.columns)):
            col1 = data_discretized.columns[i]
            col2 = data_discretized.columns[j]
            nmi_matrix[i, j] = normalized_mutual_info_score(data_discretized[col1], data_discretized[col2], 'max')
    return nmi_matrix

''' computes the mean nmi different between two given matrices '''    
def compute_mean_nmi_difference(nmi_matrix1, nmi_matrix2):    
    diff = nmi_matrix1 - nmi_matrix2
    return np.mean(np.absolute(diff))

In [None]:
''' obtain the NMI matrices '''
for cohort in cohorts:
    try:
        data,_ = Impute().execute(Load().execute(filename=cohorts[cohort]['filename']))    
        nmi_matrix = compute_nmi(data)
        nmi_matrices[cohort] = nmi_matrix
    except:
        print(f"Cohort {cohort} not available. File '{cohorts[cohort]['filename']}' not found.")        

In [None]:
''' now obtain the NMI differences '''
for cohort1, cohort2 in product([cohort for cohort in cohorts], [cohort for cohort in cohorts]):        
    if (cohort1 == cohort2): nmi_differences[cohort1][cohort2] = 0.0            
    elif not nmi_differences.get(cohort2):
        ''' calculate differences '''   
        nmi_differences[cohort1][cohort2] = compute_mean_nmi_difference(nmi_matrices[cohort1], nmi_matrices[cohort2])
        print(f"Calculated {cohort1} x {cohort2}")
print('Finished calculating mean differences')        

In [None]:
''' now store everything for later use '''
if pickle(nmi_matrices, 'nmi_matrices.d'):
    print('Successfully saved nmi_matrices.')

nmi_differences = {k:v for k,v in nmi_differences.items()} #remove lambda for pickling
if pickle(nmi_differences, './experiments/nmi_differences.d'):
    print('Successfully saved nmi_differences.')

In [None]:
''' plot the NMI matrices '''
fig, axs = plt.subplots(1, len(cohorts), figsize=(15,5))
fig_index = 0

BIGGER_SIZE = 14
plt.rc('font', size=BIGGER_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=BIGGER_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels

for cohort in cohorts:
    ax = axs[fig_index]
    ax.xaxis.tick_top()
    ax.xaxis.set_label_position('top')
    ax.set_xlabel('Features')
    ax.set_ylabel('Features')
    ax.set_title(cohorts[cohort]['title'], y=-0.12) #set title in the bottom
    _ = ax.imshow(nmi_matrices[cohort], cmap='Blues')     
    fig_index += 1

plt.tight_layout()
plt.savefig('nmi_matrices.pdf')